# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib import os from collections import namedtuple from collections.abc import Callable from typing import Any, Union import torch from vllm.triton_utils import tl, triton def _matmul_launch_metadata( grid: Callable[..., Any], kernel: Any, args: dict[str, Any] ) -> dict[str, Any]: ret = {} m, n, k = args["M"], args["N"], args["K"] ret["name"] = f"{kernel.name} [M={m}, N={n}, K={k}]" if "tiles_per_update" in args: ret["name"] = ( f"{kernel.name} [M={m}, N={n}, K={k}, " f"tiles_per_update={args['tiles_per_update']:02}]" ) if "c_ptr" in args: bytes_per_elem = args["c_ptr"].element_size() else: bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2 ret[f"flops{bytes_per_elem * 8}"] = 2.0 * m * n * k ret["bytes"] = bytes_per_elem * (m * k + n * k + m * n) return ret @triton.jit def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS): group_id = tile_id // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + (tile_id % group_size_m) pid_n = (tile_id % num_pid_in_group) // group_size_m return pid_m, pid_n @triton.jit(launch_metadata=_matmul_launch_metadata) def matmul_kernel_persistent( a_ptr, b_ptr, c_ptr, # bias_ptr, M, N, K, # stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, # BLOCK_SIZE_N: tl.constexpr, # BLOCK_SIZE_K: tl.constexpr, # GROUP_SIZE_M: tl.constexpr, # NUM_SMS: tl.constexpr, # A_LARGE: tl.constexpr, B_LARGE: tl.constexpr, C_LARGE: tl.constexpr, HAS_BIAS: tl.constexpr, ): start_pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) k_tiles = tl.cdiv(K, BLOCK_SIZE_K) num_tiles = num_pid_m * num_pid_n tile_id_c = start_pid - NUM_SMS offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) num_pid_in_group = GROUP_SIZE_M * num_pid_n for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): pid_m, pid_n = _compute_pid( tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS ) start_m = pid_m * BLOCK_SIZE_M start_n = pid_n * BLOCK_SIZE_N offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) if A_LARGE: offs_am = offs_am.to(tl.int64) if B_LARGE: offs_bn = offs_bn.to(tl.int64) offs_am = tl.where(offs_am < M, offs_am, 0) offs_bn = tl.where(offs_bn < N, offs_bn, 0) offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for ki in range(k_tiles): if A_LARGE or B_LARGE: offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K).to(tl.int64) else: offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + ( offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak ) b_ptrs = b_ptr + ( offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn ) a = tl.load( a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0 ) b = tl.load( b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0 ) accumulator = tl.dot(a, b, accumulator) tile_id_c += NUM_SMS pid_m, pid_n = _compute_pid( tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS ) offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) if C_LARGE: offs_cm = offs_cm.to(tl.int64) offs_cn = offs_cn.to(tl.int64) c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) if HAS_BIAS: bias_ptrs = bias_ptr + offs_cn bias = tl.load(bias_ptrs, mask=offs_cn < N, other=0.0).to(tl.float32) accumulator += bias if c_ptr.dtype.element_ty == tl.float8e4nv: c = accumulator.to(tl.float8e4nv) else: c = accumulator.to(tl.float16) tl.store(c_ptrs, c, mask=c_mask) def matmul_persistent( a: torch.Tensor, b: torch.Tensor, bias: Union[torch.Tensor, None] = None ): # Check constraints. assert a.shape[1] == b.shape[0], "Incompatible dimensions" assert a.dtype == b.dtype, "Incompatible dtypes" assert bias is None or bias.dim() == 1, ( "Currently assuming bias is 1D, let Horace know if you run into this" ) NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count M, K = a.shape K, N = b.shape dtype = a.dtype # Allocates output. c = torch.empty((M, N), device=a.device, dtype=dtype) # 1D launch kernel where each block gets its own program. def grid(META): return ( min( NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ), ) configs = { torch.bfloat16: { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, "num_stages": 3, "num_warps": 8, }, torch.float16: { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, "num_stages": 3, "num_warps": 8, }, torch.float32: { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8, "num_stages": 3, "num_warps": 8, }, } # print(a.device, b.device, c.device) matmul_kernel_persistent[grid]( a, b, c, # bias, M, N, K, # a.stride(0), a.stride(1), # b.stride(0), b.stride(1), # c.stride(0), c.stride(1), # NUM_SMS=NUM_SMS, # A_LARGE=a.numel() > 2**31, B_LARGE=b.numel() > 2**31, C_LARGE=c.numel() > 2**31, HAS_BIAS=bias is not None, **configs[dtype], ) return c @triton.jit def _log_softmax_kernel( input_ptr, output_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr, ): """ Compute log_softmax along the last dimension of a 2D tensor. Each block handles one row of the input tensor. """ # Get the row index for this block row_idx = tl.program_id(0).to(tl.int64) # Compute base pointers for input and output rows row_start_ptr = input_ptr + row_idx * input_row_stride output_row_start_ptr = output_ptr + row_idx * output_row_stride # Step 1: Find maximum value in the row for numerical stability max_val = -float("inf") for col_offset in range(0, n_cols, BLOCK_SIZE): col_idx = col_offset + tl.arange(0, BLOCK_SIZE) mask = col_idx < n_cols # Load values vals = tl.load(row_start_ptr + col_idx, mask=mask, other=-float("inf")) # Update maximum max_val = tl.max(tl.maximum(vals, max_val)) # Step 2: Compute sum of exp(x - max_val) sum_exp = 0.0 for col_offset in range(0, n_cols, BLOCK_SIZE): col_idx = col_offset + tl.arange(0, BLOCK_SIZE) mask = col_idx < n_cols # Load values vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0) # Compute exp(x - max_val) and accumulate exp_vals = tl.exp(vals - max_val) sum_exp += tl.sum(tl.where(mask, exp_vals, 0.0)) # Compute log(sum_exp) log_sum_exp = tl.log(sum_exp) # Step 3: Compute final log_softmax values: x - max_val - log_sum_exp for col_offset in range(0, n_cols, BLOCK_SIZE): col_idx = col_offset + tl.arange(0, BLOCK_SIZE) mask = col_idx < n_cols # Load values vals = tl.load(row_start_ptr + col_idx, mask=mask) # Compute log_softmax output = vals - max_val - log_sum_exp # Store results tl.store(output_row_start_ptr + col_idx, output, mask=mask) def log_softmax(input: torch.Tensor, dim: int = -1) -> torch.Tensor: """ Compute log_softmax using Triton kernel. Args: input: Input tensor dim: Dimension along which to compute log_softmax (only -1 or last dim supported) >> Stashed changes Returns: Tensor with log_softmax applied along the specified dimension """ if dim != -1 and dim != input.ndim - 1: raise ValueError( "This implementation only supports log_softmax along the last dimension" ) # Flatten all dimensions except the last one original_shape = input.shape input_2d = input.reshape(-1, input.shape[-1]) input_2d = input_2d.contiguous() n_rows, n_cols = input_2d.shape # Allocate output tensor output = torch.empty_like(input_2d) # Choose block size based on the number of columns BLOCK_SIZE = 1024 # Launch kernel with one block per row grid = (n_rows,) _log_softmax_kernel[grid]( input_2d, output, input_2d.stride(0), output.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, ) # Reshape output back to original shape return output.reshape(original_shape) @triton.jit def mean_kernel( input_ptr, output_ptr, input_stride0, input_stride1, input_stride2, output_stride0, output_stride1, M, # size before reduction dim N, # size of reduction dim K, # size after reduction dim BLOCK_SIZE: tl.constexpr, ): """ Kernel for computing mean along a single dimension. Input is viewed as (M, N, K) where N is the dimension being reduced. """ # Program ID gives us which output element we're computing pid = tl.program_id(0) # Compute output indices m_idx = pid // K k_idx = pid % K # Bounds check if m_idx >= M or k_idx >= K: return # Accumulate sum across reduction dimension acc = 0.0 for n_start in range(0, N, BLOCK_SIZE): n_offsets = n_start + tl.arange(0, BLOCK_SIZE) mask = n_offsets < N # Calculate input indices input_idx = ( m_idx * input_stride0 + n_offsets * input_stride1 + k_idx * input_stride2 ) # Load and accumulate vals = tl.load(input_ptr + input_idx, mask=mask, other=0.0) acc += tl.sum(vals) # Compute mean and store mean_val = acc / N output_idx = m_idx * output_stride0 + k_idx * output_stride1 tl.store(output_ptr + output_idx, mean_val) def mean_dim( input: torch.Tensor, dim: int, keepdim: bool = False, dtype: Union[torch.dtype, None] = None, ) -> torch.Tensor: """ Triton implementation of torch.mean with single dimension reduction. Args: input: Input tensor dim: Single dimension along which to compute mean keepdim: Whether to keep the reduced dimension dtype: Output dtype. If None, uses input dtype (or float32 for integer inputs) Returns: Tensor with mean values along specified dimension """ # Validate inputs assert input.is_cuda, "Input must be a CUDA tensor" assert -input.ndim <= dim < input.ndim, ( f"Invalid dimension {dim} for tensor with {input.ndim} dimensions" ) # Handle negative dim if dim < 0: dim = dim + input.ndim # Handle dtype if dtype is None: if input.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: dtype = torch.float32 else: dtype = input.dtype # Convert input to appropriate dtype if needed if input.dtype != dtype: input = input.to(dtype) # Get input shape and strides shape = list(input.shape) # Calculate dimensions for kernel M = 1 for i in range(dim): M *= shape[i] N = shape[dim] K = 1 for i in range(dim + 1, len(shape)): K *= shape[i] # Reshape input to 3D view (M, N, K) input_3d = input.reshape(M, N, K) # Create output shape if keepdim: output_shape = shape.copy() output_shape[dim] = 1 else: output_shape = shape[:dim] + shape[dim + 1 :] # Create output tensor output = torch.empty(output_shape, dtype=dtype, device=input.device) # Reshape output for kernel output_2d = output.reshape(M, 1, K).squeeze(1) if keepdim else output.reshape(M, K) # Launch kernel grid = (M * K,) BLOCK_SIZE = 1024 mean_kernel[grid]( input_3d, output_2d, input_3d.stride(0), input_3d.stride(1), input_3d.stride(2), output_2d.stride(0), output_2d.stride(1) if output_2d.ndim > 1 else 0, M, N, K, BLOCK_SIZE, ) return output def mm_batch_invariant(a, b): return matmul_persistent(a, b) def addmm_batch_invariant(bias, a, b): return matmul_persistent(a, b, bias=bias) def _log_softmax_batch_invariant(input, dim, _half_to_float): assert not _half_to_float, "not implemented" return log_softmax(input, dim=dim) def mean_batch_invariant( input, dim, keepdim=False, dtype: Union[torch.dtype, None] = None ): assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}" result = input.to(torch.float32) # Sort dimensions to reduce from largest to smallest to handle shifting dims # during iterative reduction. sorted_dims = sorted([d % input.ndim for d in dim], reverse=True) # Iteratively apply a deterministic mean. for d in sorted_dims: result = mean_dim(result, dim=d, keepdim=True) if not keepdim: # Squeeze the reduced dimensions. for d in sorted_dims: result = result.squeeze(d) return result _batch_invariant_MODE = False _batch_invariant_LIB = None def is_batch_invariant_mode_enabled(): return _batch_invariant_MODE def enable_batch_invariant_mode(): global _batch_invariant_MODE, _batch_invariant_LIB if _batch_invariant_MODE: return _batch_invariant_MODE = True _batch_invariant_LIB = torch.library.Library("aten", "IMPL") _batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA") _batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA") _batch_invariant_LIB.impl( "aten::_log_softmax", _log_softmax_batch_invariant, "CUDA" ) _batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA") def disable_batch_invariant_mode(): global _batch_invariant_MODE, _batch_invariant_LIB if _batch_invariant_LIB is not None: _batch_invariant_LIB._destroy() _batch_invariant_MODE = False _batch_invariant_LIB = None @contextlib.contextmanager def set_batch_invariant_mode(enabled: bool = True): global _batch_invariant_MODE, _batch_invariant_LIB old_data = (_batch_invariant_MODE, _batch_invariant_LIB) if enabled: enable_batch_invariant_mode() else: disable_batch_invariant_mode() yield if _batch_invariant_LIB is not None: _batch_invariant_LIB._destroy() _batch_invariant_MODE, _batch_invariant_LIB = old_data AttentionBlockSize = namedtuple("AttentionBlockSize", ["block_m", "block_n"]) def get_batch_invariant_attention_block_size() -> AttentionBlockSize: return AttentionBlockSize(block_m=16, block_n=16) def vllm_kernel_override_batch_invariant(): env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT" is_overridden = False val = os.getenv(env_key, "0") try: is_overridden = int(val) != 0 except ValueError: is_overridden = False return is_overridden def init_batch_invariance(): # this will hit all the csrc overrides as well if vllm_kernel_override_batch_invariant(): os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION" enable_batch_invariant_mode()