# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib import os from collections import namedtuple from collections.abc import Callable from typing import Any import torch import vllm.envs as envs from vllm.logger import init_logger from vllm.triton_utils import tl, triton logger = init_logger(__name__) def _matmul_launch_metadata( grid: Callable[..., Any], kernel: Any, args: dict[str, Any] ) -> dict[str, Any]: ret = {} m, n, k = args["M"], args["N"], args["K"] ret["name"] = f"{kernel.name} [M={m}, N={n}, K={k}]" if "tiles_per_update" in args: ret["name"] = ( f"{kernel.name} [M={m}, N={n}, K={k}, " f"tiles_per_update={args['tiles_per_update']:02}]" ) if "c_ptr" in args: bytes_per_elem = args["c_ptr"].element_size() else: bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2 ret[f"flops{bytes_per_elem * 8}"] = 2.0 * m * n * k ret["bytes"] = bytes_per_elem * (m * k + n * k + m * n) return ret @triton.jit def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS): group_id = tile_id // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + (tile_id % group_size_m) pid_n = (tile_id % num_pid_in_group) // group_size_m return pid_m, pid_n @triton.jit(launch_metadata=_matmul_launch_metadata) def matmul_kernel_persistent( a_ptr, b_ptr, c_ptr, # bias_ptr, M, N, K, # stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, # BLOCK_SIZE_N: tl.constexpr, # BLOCK_SIZE_K: tl.constexpr, # GROUP_SIZE_M: tl.constexpr, # NUM_SMS: tl.constexpr, # A_LARGE: tl.constexpr, B_LARGE: tl.constexpr, C_LARGE: tl.constexpr, HAS_BIAS: tl.constexpr, ): start_pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) k_tiles = tl.cdiv(K, BLOCK_SIZE_K) num_tiles = num_pid_m * num_pid_n tile_id_c = start_pid - NUM_SMS offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) num_pid_in_group = GROUP_SIZE_M * num_pid_n for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): pid_m, pid_n = _compute_pid( tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS ) start_m = pid_m * BLOCK_SIZE_M start_n = pid_n * BLOCK_SIZE_N offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) if A_LARGE: offs_am = offs_am.to(tl.int64) if B_LARGE: offs_bn = offs_bn.to(tl.int64) offs_am = tl.where(offs_am < M, offs_am, 0) offs_bn = tl.where(offs_bn < N, offs_bn, 0) offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for ki in range(k_tiles): if A_LARGE or B_LARGE: offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K).to(tl.int64) else: offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + ( offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak ) b_ptrs = b_ptr + ( offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn ) a = tl.load( a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0 ) b = tl.load( b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0 ) accumulator = tl.dot(a, b, accumulator) tile_id_c += NUM_SMS pid_m, pid_n = _compute_pid( tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS ) offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) if C_LARGE: offs_cm = offs_cm.to(tl.int64) offs_cn = offs_cn.to(tl.int64) c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) if HAS_BIAS: bias_ptrs = bias_ptr + offs_cn bias = tl.load(bias_ptrs, mask=offs_cn < N, other=0.0).to(tl.float32) accumulator += bias if c_ptr.dtype.element_ty == tl.float8e4nv: c = accumulator.to(tl.float8e4nv) else: c = accumulator.to(tl.float16) tl.store(c_ptrs, c, mask=c_mask) def matmul_persistent( a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None ): # Check constraints. assert a.shape[1] == b.shape[0], "Incompatible dimensions" assert a.dtype == b.dtype, "Incompatible dtypes" assert bias is None or bias.dim() == 1, ( "Currently assuming bias is 1D, let Horace know if you run into this" ) NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count M, K = a.shape K, N = b.shape dtype = a.dtype # Allocates output. c = torch.empty((M, N), device=a.device, dtype=dtype) # 1D launch kernel where each block gets its own program. def grid(META): return ( min( NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ), ) configs = { torch.bfloat16: { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, "num_stages": 3, "num_warps": 8, }, torch.float16: { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, "num_stages": 3, "num_warps": 8, }, torch.float32: { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8, "num_stages": 3, "num_warps": 8, }, } # print(a.device, b.device, c.device) matmul_kernel_persistent[grid]( a, b, c, # bias, M, N, K, # a.stride(0), a.stride(1), # b.stride(0), b.stride(1), # c.stride(0), c.stride(1), # NUM_SMS=NUM_SMS, # A_LARGE=a.numel() > 2**31, B_LARGE=b.numel() > 2**31, C_LARGE=c.numel() > 2**31, HAS_BIAS=bias is not None, **configs[dtype], ) return c @triton.jit def _log_softmax_kernel( input_ptr, output_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr, ): """ Compute log_softmax along the last dimension of a 2D tensor. Each block handles one row of the input tensor. """ # Get the row index for this block row_idx = tl.program_id(0).to(tl.int64) # Compute base pointers for input and output rows row_start_ptr = input_ptr + row_idx * input_row_stride output_row_start_ptr = output_ptr + row_idx * output_row_stride # Step 1: Find maximum value in the row for numerical stability max_val = -float("inf") for col_offset in range(0, n_cols, BLOCK_SIZE): col_idx = col_offset + tl.arange(0, BLOCK_SIZE) mask = col_idx < n_cols # Load values vals = tl.load(row_start_ptr + col_idx, mask=mask, other=-float("inf")) # Update maximum max_val = tl.max(tl.maximum(vals, max_val)) # Step 2: Compute sum of exp(x - max_val) sum_exp = 0.0 for col_offset in range(0, n_cols, BLOCK_SIZE): col_idx = col_offset + tl.arange(0, BLOCK_SIZE) mask = col_idx < n_cols # Load values vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0) # Compute exp(x - max_val) and accumulate exp_vals = tl.exp(vals - max_val) sum_exp += tl.sum(tl.where(mask, exp_vals, 0.0)) # Compute log(sum_exp) log_sum_exp = tl.log(sum_exp) # Step 3: Compute final log_softmax values: x - max_val - log_sum_exp for col_offset in range(0, n_cols, BLOCK_SIZE): col_idx = col_offset + tl.arange(0, BLOCK_SIZE) mask = col_idx < n_cols # Load values vals = tl.load(row_start_ptr + col_idx, mask=mask) # Compute log_softmax output = vals - max_val - log_sum_exp # Store results tl.store(output_row_start_ptr + col_idx, output, mask=mask) def log_softmax(input: torch.Tensor, dim: int = -1) -> torch.Tensor: """ Compute log_softmax using Triton kernel. Args: input: Input tensor dim: Dimension along which to compute log_softmax (only -1 or last dim supported) >> Stashed changes Returns: Tensor with log_softmax applied along the specified dimension """ if dim != -1 and dim != input.ndim - 1: raise ValueError( "This implementation only supports log_softmax along the last dimension" ) # Flatten all dimensions except the last one original_shape = input.shape input_2d = input.reshape(-1, input.shape[-1]) input_2d = input_2d.contiguous() n_rows, n_cols = input_2d.shape # Allocate output tensor output = torch.empty_like(input_2d) # Choose block size based on the number of columns BLOCK_SIZE = 1024 # Launch kernel with one block per row grid = (n_rows,) _log_softmax_kernel[grid]( input_2d, output, input_2d.stride(0), output.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, ) # Reshape output back to original shape return output.reshape(original_shape) @triton.jit def mean_kernel( input_ptr, output_ptr, input_stride0, input_stride1, input_stride2, output_stride0, output_stride1, M, # size before reduction dim N, # size of reduction dim K, # size after reduction dim BLOCK_SIZE: tl.constexpr, ): """ Kernel for computing mean along a single dimension. Input is viewed as (M, N, K) where N is the dimension being reduced. """ # Program ID gives us which output element we're computing pid = tl.program_id(0) # Compute output indices m_idx = pid // K k_idx = pid % K # Bounds check if m_idx >= M or k_idx >= K: return # Accumulate sum across reduction dimension acc = 0.0 for n_start in range(0, N, BLOCK_SIZE): n_offsets = n_start + tl.arange(0, BLOCK_SIZE) mask = n_offsets < N # Calculate input indices input_idx = ( m_idx * input_stride0 + n_offsets * input_stride1 + k_idx * input_stride2 ) # Load and accumulate vals = tl.load(input_ptr + input_idx, mask=mask, other=0.0) acc += tl.sum(vals) # Compute mean and store mean_val = acc / N output_idx = m_idx * output_stride0 + k_idx * output_stride1 tl.store(output_ptr + output_idx, mean_val) def mean_dim( input: torch.Tensor, dim: int, keepdim: bool = False, dtype: torch.dtype | None = None, ) -> torch.Tensor: """ Triton implementation of torch.mean with single dimension reduction. Args: input: Input tensor dim: Single dimension along which to compute mean keepdim: Whether to keep the reduced dimension dtype: Output dtype. If None, uses input dtype (or float32 for integer inputs) Returns: Tensor with mean values along specified dimension """ # Validate inputs assert -input.ndim <= dim < input.ndim, ( f"Invalid dimension {dim} for tensor with {input.ndim} dimensions" ) # Handle negative dim if dim < 0: dim = dim + input.ndim # Handle dtype if dtype is None: if input.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: dtype = torch.float32 else: dtype = input.dtype # Convert input to appropriate dtype if needed if input.dtype != dtype: input = input.to(dtype) # Get input shape and strides shape = list(input.shape) # Calculate dimensions for kernel M = 1 for i in range(dim): M *= shape[i] N = shape[dim] K = 1 for i in range(dim + 1, len(shape)): K *= shape[i] # Reshape input to 3D view (M, N, K) input_3d = input.reshape(M, N, K) # Create output shape if keepdim: output_shape = shape.copy() output_shape[dim] = 1 else: output_shape = shape[:dim] + shape[dim + 1 :] # Create output tensor output = torch.empty(output_shape, dtype=dtype, device=input.device) # Reshape output for kernel output_2d = output.reshape(M, 1, K).squeeze(1) if keepdim else output.reshape(M, K) # Launch kernel grid = (M * K,) BLOCK_SIZE = 1024 mean_kernel[grid]( input_3d, output_2d, input_3d.stride(0), input_3d.stride(1), input_3d.stride(2), output_2d.stride(0), output_2d.stride(1) if output_2d.ndim > 1 else 0, M, N, K, BLOCK_SIZE, ) return output def mm_batch_invariant(a, b): return matmul_persistent(a, b) def matmul_batch_invariant(a, b, *, out=None): # torch.matmul can handle various dimensions # For 2D x 2D, it's the same as mm if a.ndim == 2 and b.ndim == 2: result = matmul_persistent(a, b) if out is not None: out.copy_(result) return out return result elif a.ndim == 3 and b.ndim == 3: # Handle batched case like bmm return bmm_batch_invariant(a, b, out=out) else: raise ValueError( f"matmul_batch_invariant currently only supports 2D x 2D and 3D x 3D, " f"got shapes {a.shape} and {b.shape}" ) def bmm_batch_invariant(a, b, *, out=None): # Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N) # Process each batch separately with our persistent kernel if a.ndim == 3 and b.ndim == 3: results = [] for i in range(a.shape[0]): results.append(matmul_persistent(a[i], b[i])) result = torch.stack(results, dim=0) if out is not None: out.copy_(result) return out return result else: raise ValueError( f"bmm_batch_invariant expects 3D tensors, " f"got shapes {a.shape} and {b.shape}" ) def addmm_batch_invariant(bias, a, b): return matmul_persistent(a, b, bias=bias) def _log_softmax_batch_invariant(input, dim, _half_to_float): assert not _half_to_float, "not implemented" return log_softmax(input, dim=dim) def softmax_batch_invariant(input, dim, dtype=None): # Compute softmax in a deterministic way # First subtract max for numerical stability (standard practice) input_max = torch.amax(input, dim=dim, keepdim=True) input = input - input_max exp_x = torch.exp(input) sum_exp_x = torch.sum(exp_x, dim=dim, keepdim=True) return exp_x / sum_exp_x def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None = None): assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}" result = input.to(torch.float32) if len(dim) == 0: dim = [i for i in range(len(input.shape))] # Sort dimensions to reduce from largest to smallest to handle shifting dims # during iterative reduction. sorted_dims = sorted([d % input.ndim for d in dim], reverse=True) # Iteratively apply a deterministic mean. for d in sorted_dims: result = mean_dim(result, dim=d, keepdim=True) if not keepdim: # Squeeze the reduced dimensions. for d in sorted_dims: result = result.squeeze(d) return result @triton.jit def _rms_norm_kernel( input_ptr, weight_ptr, output_ptr, input_row_stride, output_row_stride, n_cols, eps, BLOCK_SIZE: tl.constexpr, ): """ Compute RMS normalization along the last dimension of a 2D tensor. RMS Norm: y = x / sqrt(mean(x^2) + eps) * weight Each block handles one row of the input tensor. """ row_idx = tl.program_id(0).to(tl.int64) row_start_ptr = input_ptr + row_idx * input_row_stride output_row_start_ptr = output_ptr + row_idx * output_row_stride # Step 1: Compute sum of squares in float32 to avoid overflow sum_sq = tl.zeros([1], dtype=tl.float32) for col_offset in range(0, n_cols, BLOCK_SIZE): col_idx = col_offset + tl.arange(0, BLOCK_SIZE) mask = col_idx < n_cols vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0) # Convert to float32 for accumulation to prevent overflow vals_f32 = vals.to(tl.float32) sq_vals = vals_f32 * vals_f32 sum_sq += tl.sum(tl.where(mask, sq_vals, 0.0)) # Step 2: Compute RMS (root mean square) in float32 mean_sq = sum_sq / n_cols rms = tl.sqrt(mean_sq + eps) inv_rms = 1.0 / rms # Step 3: Normalize and apply weight for col_offset in range(0, n_cols, BLOCK_SIZE): col_idx = col_offset + tl.arange(0, BLOCK_SIZE) mask = col_idx < n_cols vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0) weight = tl.load(weight_ptr + col_idx, mask=mask, other=1.0) # Compute in float32 then convert back to input dtype vals_f32 = vals.to(tl.float32) weight_f32 = weight.to(tl.float32) output_f32 = vals_f32 * inv_rms * weight_f32 output = output_f32.to(vals.dtype) tl.store(output_row_start_ptr + col_idx, output, mask=mask) def rms_norm( input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 ) -> torch.Tensor: """ Compute RMS normalization using Triton kernel. RMS Norm normalizes the input by the root mean square and scales by weight: output = input / sqrt(mean(input^2) + eps) * weight Args: input: Input tensor of shape (..., hidden_size) weight: Weight tensor of shape (hidden_size,) eps: Small constant for numerical stability Returns: Tensor with RMS normalization applied along the last dimension """ assert weight.dim() == 1, "Weight must be 1-dimensional" assert input.shape[-1] == weight.shape[0], ( f"Input last dimension ({input.shape[-1]}) must match " f"weight dimension ({weight.shape[0]})" ) # Flatten all dimensions except the last one original_shape = input.shape input_2d = input.reshape(-1, input.shape[-1]) input_2d = input_2d.contiguous() weight = weight.contiguous() n_rows, n_cols = input_2d.shape output = torch.empty_like(input_2d) BLOCK_SIZE = 1024 grid = (n_rows,) _rms_norm_kernel[grid]( input_2d, weight, output, input_2d.stride(0), output.stride(0), n_cols, eps, BLOCK_SIZE=BLOCK_SIZE, ) return output.reshape(original_shape) def rms_norm_batch_invariant( input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 ) -> torch.Tensor: """ Batch-invariant wrapper for RMS normalization. This function provides a deterministic, batch-invariant implementation of RMS normalization for use with the batch_invariant mode. Args: input: Input tensor of shape (..., hidden_size) weight: Weight tensor of shape (hidden_size,) eps: Small constant for numerical stability Returns: RMS normalized tensor """ return rms_norm(input, weight, eps=eps) def linear_batch_invariant(input, weight, bias=None): output = mm_batch_invariant(input, weight.t()) if bias is not None: output = output + bias return output _batch_invariant_MODE = False _batch_invariant_LIB = None _original_torch_bmm = None def is_batch_invariant_mode_enabled(): return _batch_invariant_MODE def enable_batch_invariant_mode(): global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm if _batch_invariant_MODE: return _batch_invariant_MODE = True _batch_invariant_LIB = torch.library.Library("aten", "IMPL") _batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA") _batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA") _batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA") _batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA") _batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA") _batch_invariant_LIB.impl( "aten::_log_softmax", _log_softmax_batch_invariant, "CUDA" ) _batch_invariant_LIB.impl("aten::softmax", softmax_batch_invariant, "CUDA") _batch_invariant_LIB.impl("aten::_softmax", softmax_batch_invariant, "CUDA") _batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA") # Also monkeypatch torch.bmm directly as a fallback _original_torch_bmm = torch.bmm torch.bmm = bmm_batch_invariant def disable_batch_invariant_mode(): global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm if _batch_invariant_LIB is not None: _batch_invariant_LIB._destroy() if _original_torch_bmm is not None: torch.bmm = _original_torch_bmm _original_torch_bmm = None _batch_invariant_MODE = False _batch_invariant_LIB = None @contextlib.contextmanager def set_batch_invariant_mode(enabled: bool = True): global _batch_invariant_MODE, _batch_invariant_LIB old_data = (_batch_invariant_MODE, _batch_invariant_LIB) if enabled: enable_batch_invariant_mode() else: disable_batch_invariant_mode() yield if _batch_invariant_LIB is not None: _batch_invariant_LIB._destroy() _batch_invariant_MODE, _batch_invariant_LIB = old_data AttentionBlockSize = namedtuple("AttentionBlockSize", ["block_m", "block_n"]) def get_batch_invariant_attention_block_size() -> AttentionBlockSize: return AttentionBlockSize(block_m=16, block_n=16) def vllm_is_batch_invariant(): env_key = "VLLM_BATCH_INVARIANT" is_overridden = False val = os.getenv(env_key, "0") try: is_overridden = int(val) != 0 except ValueError: is_overridden = False return is_overridden def override_envs_for_invariance(): curr_attn_backend = envs.VLLM_ATTENTION_BACKEND supported_backends = [ "FLASH_ATTN", # best supported backend "FLEX_ATTENTION", "FLASHINFER", "FLASH_ATTN_MLA", "TRITON_MLA", # Not yet supported MLA backends # "FLASHMLA", # "FLASHINFER_MLA", ] if curr_attn_backend not in supported_backends: warning = ( "Forcibly updating attention backend to" f" {supported_backends[0]} for batch_invariant. " f" Supported backends: {supported_backends}." ) logger.warning_once(warning) os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0] if os.environ["VLLM_ATTENTION_BACKEND"] != supported_backends[0]: warning = ( "You are using a decode-invariant form of batch invariance. " "This will not be invariant between prefill and decode." ) logger.warning_once(warning) os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0" os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # NCCL determinism settings os.environ["NCCL_LAUNCH_MODE"] = "GROUP" os.environ["NCCL_COLLNET_ENABLE"] = "0" os.environ["NCCL_NVLS_ENABLE"] = "0" os.environ["NCCL_P2P_NET_DISABLE"] = "1" os.environ["NCCL_MIN_NCHANNELS"] = "1" os.environ["NCCL_MAX_NCHANNELS"] = "1" os.environ["NCCL_PROTO"] = "Simple" os.environ["NCCL_ALGO"] = "allreduce:tree" os.environ["NCCL_NTHREADS"] = "1" os.environ["NCCL_SOCKET_NTHREADS"] = "1" def init_batch_invariance(): # this will hit all the csrc overrides as well if vllm_is_batch_invariant(): override_envs_for_invariance() enable_batch_invariant_mode() # Disable TF32 for batch invariance - it causes non-deterministic rounding torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False