# SPDX-License-Identifier: MIT import torch import triton import triton.language as tl @triton.jit def exp(x): return tl.exp(x) @triton.jit def safe_exp(x): return exp(tl.where(x <= 0, x, float("-inf"))) def prepare_chunk_indices(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor: indices = torch.cat([ torch.arange(n, device=cu_seqlens.device) for n in triton.cdiv(cu_seqlens[1:] - cu_seqlens[:-1], chunk_size).tolist() ]) return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens) def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor: lens = cu_seqlens[1:] - cu_seqlens[:-1] return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(lens, chunk_size)]).cumsum(-1) def _is_nvidia_hopper() -> bool: if not torch.cuda.is_available(): return False major, _ = torch.cuda.get_device_capability(0) name = torch.cuda.get_device_name(0) return ("NVIDIA H" in name) or (major >= 9) is_nvidia_hopper = _is_nvidia_hopper() def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool: # Keep behavior simple and local for tests/bench. if not torch.cuda.is_available(): return False try: max_shared_mem = triton.runtime.driver.active.utils.get_device_properties(tensor_idx)["max_shared_mem"] except Exception: return False # Same thresholds used by sglang utils. thresholds = { "ADA": 101376, "AMPERE": 166912, "HOPPER": 232448, "NONE": 102400, } return max_shared_mem >= thresholds.get(arch.upper(), thresholds["NONE"])