Unverified Commit 7d8975de authored by Bram Wasti's avatar Bram Wasti Committed by GitHub
Browse files

Deepseek-v3 Batch Invariant on 8xH100 (#26609)


Signed-off-by: default avatarBram Wasti <bwasti@meta.com>
Co-authored-by: default avatarWentao Ye <44945378+yewentao256@users.noreply.github.com>
parent 785d8b64
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test batch-invariant RMS normalization against standard implementations.
This test compares the Triton-based batch-invariant RMS norm implementation
with the standard CUDA-based implementation to ensure numerical accuracy.
"""
import pytest
import torch
from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_norm
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.platforms import current_platform
@pytest.mark.skipif(
not current_platform.has_device_capability(90),
reason="Batch invariance tests only supported on Hopper (SM90)",
)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
)
@pytest.mark.parametrize("batch_size", [1, 4, 16, 64])
@pytest.mark.parametrize("hidden_size", [512, 2048, 4096, 8192])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("eps", [1e-6, 1e-5])
def test_rms_norm_batch_invariant_vs_standard(
batch_size: int, hidden_size: int, dtype: torch.dtype, eps: float
):
"""
Compare batch-invariant Triton RMS norm against standard CUDA implementation.
Tests that the Triton-based batch-invariant RMS norm produces numerically
equivalent results to the standard CUDA implementation across various
configurations.
"""
device = torch.device("cuda")
# Create test input and weight
torch.manual_seed(42)
input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
weight = torch.randn(hidden_size, dtype=dtype, device=device)
# Standard implementation (CUDA ops)
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
rms_norm_layer.weight.data = weight.clone()
standard_output = rms_norm_layer.forward_cuda(input_tensor)
# Batch-invariant implementation (Triton)
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
# Compare outputs
# Use looser tolerance for bfloat16 due to its lower precision
if dtype == torch.bfloat16:
rtol, atol = 1e-1, 1e-1 # 10% relative tolerance for bfloat16
else:
rtol, atol = 1e-2, 1e-2 # 1% for float16/float32
torch.testing.assert_close(
triton_output,
standard_output,
rtol=rtol,
atol=atol,
msg=f"RMS norm mismatch for batch_size={batch_size}, "
f"hidden_size={hidden_size}, "
f"dtype={dtype}, eps={eps}",
)
@pytest.mark.skipif(
not current_platform.has_device_capability(90),
reason="Batch invariance tests only supported on Hopper (SM90)",
)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
)
@pytest.mark.parametrize("batch_size", [1, 16, 128])
@pytest.mark.parametrize("seq_len", [1, 32, 512])
@pytest.mark.parametrize("hidden_size", [2048, 4096])
def test_rms_norm_3d_input(batch_size: int, seq_len: int, hidden_size: int):
"""
Test RMS norm with 3D input tensors (batch, seq_len, hidden_size).
Ensures that the batch-invariant RMS norm correctly handles multi-dimensional
inputs that are common in transformer models.
"""
device = torch.device("cuda")
dtype = torch.bfloat16
eps = 1e-6
torch.manual_seed(42)
input_tensor = torch.randn(
batch_size, seq_len, hidden_size, dtype=dtype, device=device
)
weight = torch.randn(hidden_size, dtype=dtype, device=device)
# Standard implementation
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
rms_norm_layer.weight.data = weight.clone()
standard_output = rms_norm_layer.forward_cuda(input_tensor)
# Batch-invariant implementation
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
# Use looser tolerance for bfloat16
rtol, atol = 1e-1, 1e-1 # 10% tolerance for bfloat16
torch.testing.assert_close(
triton_output,
standard_output,
rtol=rtol,
atol=atol,
msg=f"RMS norm mismatch for 3D input with batch_size={batch_size}, "
f"seq_len={seq_len}, hidden_size={hidden_size}",
)
@pytest.mark.skipif(
not current_platform.has_device_capability(90),
reason="Batch invariance tests only supported on Hopper (SM90)",
)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
)
def test_rms_norm_numerical_stability():
"""
Test RMS norm numerical stability with extreme values.
Ensures that both implementations handle edge cases like very small or large
values without producing NaN or Inf.
"""
device = torch.device("cuda")
dtype = torch.float16
eps = 1e-6
hidden_size = 2048
# Test cases with extreme values
test_cases = [
# Very small values
torch.ones(4, hidden_size, dtype=dtype, device=device) * 1e-5,
# Very large values
torch.ones(4, hidden_size, dtype=dtype, device=device) * 1e4,
# Mixed small and large
torch.randn(4, hidden_size, dtype=dtype, device=device) * 100,
# Values near zero
torch.randn(4, hidden_size, dtype=dtype, device=device) * 1e-6,
]
weight = torch.ones(hidden_size, dtype=dtype, device=device)
for idx, input_tensor in enumerate(test_cases):
# Standard implementation
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
rms_norm_layer.weight.data = weight.clone()
standard_output = rms_norm_layer.forward_cuda(input_tensor)
# Batch-invariant implementation
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
# Check for NaN or Inf
assert not torch.isnan(standard_output).any(), (
f"Standard RMS norm produced NaN for test case {idx}"
)
assert not torch.isinf(standard_output).any(), (
f"Standard RMS norm produced Inf for test case {idx}"
)
assert not torch.isnan(triton_output).any(), (
f"Triton RMS norm produced NaN for test case {idx}"
)
assert not torch.isinf(triton_output).any(), (
f"Triton RMS norm produced Inf for test case {idx}"
)
# Compare outputs - very lenient for extreme values with float16
torch.testing.assert_close(
triton_output,
standard_output,
rtol=2e-1, # 20% tolerance for extreme values
atol=2e-1,
msg=f"RMS norm mismatch for extreme value test case {idx}",
)
@pytest.mark.skipif(
not current_platform.has_device_capability(90),
reason="Batch invariance tests only supported on Hopper (SM90)",
)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
)
def test_rms_norm_formula():
"""
Test that RMS norm follows the correct mathematical formula.
Verifies: output = input / sqrt(mean(input^2) + eps) * weight
"""
device = torch.device("cuda")
dtype = torch.float32 # Use float32 for higher precision in formula check
eps = 1e-6
hidden_size = 1024
torch.manual_seed(42)
input_tensor = torch.randn(8, hidden_size, dtype=dtype, device=device)
weight = torch.randn(hidden_size, dtype=dtype, device=device)
# Compute expected output using the formula
variance = (input_tensor.pow(2).mean(dim=-1, keepdim=True)).to(dtype)
expected_output = input_tensor * torch.rsqrt(variance + eps) * weight
# Batch-invariant implementation
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
# Compare against formula
torch.testing.assert_close(
triton_output,
expected_output,
rtol=1e-4,
atol=1e-4,
msg="Triton RMS norm doesn't match expected formula",
)
@pytest.mark.skipif(
not current_platform.has_device_capability(90),
reason="Batch invariance tests only supported on Hopper (SM90)",
)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
)
@pytest.mark.parametrize("hidden_size", [128, 1024, 4096, 16384])
def test_rms_norm_different_hidden_sizes(hidden_size: int):
"""
Test RMS norm with various hidden sizes to ensure block size handling.
The Triton kernel uses a fixed BLOCK_SIZE=1024, so this tests that it
correctly handles hidden sizes both smaller and larger than the block size.
"""
device = torch.device("cuda")
dtype = torch.bfloat16
eps = 1e-6
batch_size = 16
torch.manual_seed(42)
input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
weight = torch.randn(hidden_size, dtype=dtype, device=device)
# Standard implementation
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
rms_norm_layer.weight.data = weight.clone()
standard_output = rms_norm_layer.forward_cuda(input_tensor)
# Batch-invariant implementation
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
# Use looser tolerance for bfloat16
rtol, atol = 1e-1, 1e-1 # 10% tolerance for bfloat16
torch.testing.assert_close(
triton_output,
standard_output,
rtol=rtol,
atol=atol,
msg=f"RMS norm mismatch for hidden_size={hidden_size}",
)
@pytest.mark.skipif(
not current_platform.has_device_capability(90),
reason="Batch invariance tests only supported on Hopper (SM90)",
)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
)
def test_rms_norm_determinism():
"""
Test that batch-invariant RMS norm produces deterministic results.
Runs the same input through the kernel multiple times and verifies
identical outputs.
"""
device = torch.device("cuda")
dtype = torch.bfloat16
eps = 1e-6
hidden_size = 4096
batch_size = 32
torch.manual_seed(42)
input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
weight = torch.randn(hidden_size, dtype=dtype, device=device)
# Run multiple times
outputs = []
for _ in range(5):
output = triton_rms_norm(input_tensor.clone(), weight, eps=eps)
outputs.append(output)
# All outputs should be identical
reference = outputs[0]
for idx, output in enumerate(outputs[1:], start=1):
torch.testing.assert_close(
output,
reference,
rtol=0.0,
atol=0.0,
msg=f"RMS norm not deterministic: run {idx} differs from reference",
)
if __name__ == "__main__":
# Run a quick smoke test
print("Running quick smoke test of RMS norm implementations...")
device = torch.device("cuda")
batch_size = 8
hidden_size = 4096
dtype = torch.bfloat16
eps = 1e-6
torch.manual_seed(42)
input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
weight = torch.randn(hidden_size, dtype=dtype, device=device)
# Standard implementation
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
rms_norm_layer.weight.data = weight.clone()
standard_output = rms_norm_layer.forward_cuda(input_tensor)
# Batch-invariant implementation
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
# Compare
max_diff = (triton_output - standard_output).abs().max().item()
mean_diff = (triton_output - standard_output).abs().mean().item()
print(f"Max difference: {max_diff:.6e}")
print(f"Mean difference: {mean_diff:.6e}")
print(f"Standard output sample: {standard_output[0, :5].tolist()}")
print(f"Triton output sample: {triton_output[0, :5].tolist()}")
if max_diff < 1e-3:
print("✓ Smoke test passed!")
else:
print("✗ Smoke test failed - differences too large")
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import hashlib import hashlib
import inspect import inspect
import os
import pickle import pickle
from unittest.mock import patch from unittest.mock import patch
...@@ -168,7 +169,8 @@ def _compute_code_hash(files: set[str]) -> str: ...@@ -168,7 +169,8 @@ def _compute_code_hash(files: set[str]) -> str:
) )
file_contents = {} file_contents = {}
for filepath in files: for filepath in files:
if filepath == "<string>": # Skip files that don't exist (e.g., <string>, <frozen modules>, etc.)
if not os.path.isfile(filepath):
file_contents[filepath] = "" file_contents[filepath] = ""
else: else:
with open(filepath) as f: with open(filepath) as f:
......
...@@ -20,6 +20,9 @@ from vllm.config.pooler import PoolerConfig ...@@ -20,6 +20,9 @@ from vllm.config.pooler import PoolerConfig
from vllm.config.scheduler import RunnerType from vllm.config.scheduler import RunnerType
from vllm.config.utils import assert_hashable, config, getattr_iter from vllm.config.utils import assert_hashable, config, getattr_iter
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
ConfigFormat, ConfigFormat,
...@@ -419,6 +422,10 @@ class ModelConfig: ...@@ -419,6 +422,10 @@ class ModelConfig:
skip_mm_profiling: bool | None, skip_mm_profiling: bool | None,
video_pruning_rate: float | None, video_pruning_rate: float | None,
) -> None: ) -> None:
# Enable batch invariance settings if requested
if vllm_kernel_override_batch_invariant():
self.enforce_eager = True
# Set the default seed to 0 in V1. # Set the default seed to 0 in V1.
# NOTE(woosuk): In V0, we set the default seed to None because the # NOTE(woosuk): In V0, we set the default seed to None because the
# driver worker shares the same process as the user process, and thus # driver worker shares the same process as the user process, and thus
......
...@@ -14,6 +14,9 @@ from typing_extensions import Self ...@@ -14,6 +14,9 @@ from typing_extensions import Self
import vllm.envs as envs import vllm.envs as envs
from vllm.config.utils import config from vllm.config.utils import config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cuda_device_count_stateless, get_open_ports_list from vllm.utils import cuda_device_count_stateless, get_open_ports_list
...@@ -560,7 +563,10 @@ class ParallelConfig: ...@@ -560,7 +563,10 @@ class ParallelConfig:
def _verify_args(self) -> Self: def _verify_args(self) -> Self:
# Lazy import to avoid circular import # Lazy import to avoid circular import
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.platforms import current_platform
# Enable batch invariance settings if requested
if vllm_kernel_override_batch_invariant():
self.disable_custom_all_reduce = True
if ( if (
self.distributed_executor_backend is not None self.distributed_executor_backend is not None
......
...@@ -19,6 +19,9 @@ import torch.multiprocessing as mp ...@@ -19,6 +19,9 @@ import torch.multiprocessing as mp
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
)
from vllm.utils import cuda_device_count_stateless, update_environment_variables from vllm.utils import cuda_device_count_stateless, update_environment_variables
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -71,6 +74,9 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor) ...@@ -71,6 +74,9 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor)
is_symmetric_memory_enabled, is_symmetric_memory_enabled,
) )
if vllm_kernel_override_batch_invariant():
return False
if not is_symmetric_memory_enabled(): if not is_symmetric_memory_enabled():
return False return False
if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]: if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]:
......
...@@ -9,6 +9,9 @@ from vllm.distributed.device_communicators.all_reduce_utils import ( ...@@ -9,6 +9,9 @@ from vllm.distributed.device_communicators.all_reduce_utils import (
SYMM_MEM_ALL_REDUCE_MAX_SIZES, SYMM_MEM_ALL_REDUCE_MAX_SIZES,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
try: try:
...@@ -100,6 +103,8 @@ class SymmMemCommunicator: ...@@ -100,6 +103,8 @@ class SymmMemCommunicator:
return return
self.force_multimem = force_multimem self.force_multimem = force_multimem
self.disabled = False self.disabled = False
if vllm_kernel_override_batch_invariant():
self.disabled = True
def should_use_symm_mem(self, inp: torch.Tensor): def should_use_symm_mem(self, inp: torch.Tensor):
if self.disabled: if self.disabled:
......
...@@ -1694,7 +1694,7 @@ class EngineArgs: ...@@ -1694,7 +1694,7 @@ class EngineArgs:
) -> None: ) -> None:
"""Set Default Arguments for V1 Engine.""" """Set Default Arguments for V1 Engine."""
# V1 always uses chunked prefills and prefix caching # V1 uses chunked prefills and prefix caching by default
# for non-pooling tasks. # for non-pooling tasks.
# For pooling tasks the default is False # For pooling tasks the default is False
if model_config.runner_type != "pooling": if model_config.runner_type != "pooling":
......
...@@ -395,7 +395,6 @@ def mean_dim( ...@@ -395,7 +395,6 @@ def mean_dim(
Tensor with mean values along specified dimension Tensor with mean values along specified dimension
""" """
# Validate inputs # Validate inputs
assert input.is_cuda, "Input must be a CUDA tensor"
assert -input.ndim <= dim < input.ndim, ( assert -input.ndim <= dim < input.ndim, (
f"Invalid dimension {dim} for tensor with {input.ndim} dimensions" f"Invalid dimension {dim} for tensor with {input.ndim} dimensions"
) )
...@@ -470,6 +469,45 @@ def mm_batch_invariant(a, b): ...@@ -470,6 +469,45 @@ def mm_batch_invariant(a, b):
return matmul_persistent(a, b) return matmul_persistent(a, b)
def matmul_batch_invariant(a, b, *, out=None):
# torch.matmul can handle various dimensions
# For 2D x 2D, it's the same as mm
if a.ndim == 2 and b.ndim == 2:
result = matmul_persistent(a, b)
if out is not None:
out.copy_(result)
return out
return result
elif a.ndim == 3 and b.ndim == 3:
# Handle batched case like bmm
return bmm_batch_invariant(a, b, out=out)
else:
raise ValueError(
f"matmul_batch_invariant currently only supports 2D x 2D and 3D x 3D, "
f"got shapes {a.shape} and {b.shape}"
)
def bmm_batch_invariant(a, b, *, out=None):
# Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N)
# Process each batch separately with our persistent kernel
if a.ndim == 3 and b.ndim == 3:
results = []
for i in range(a.shape[0]):
results.append(matmul_persistent(a[i], b[i]))
result = torch.stack(results, dim=0)
if out is not None:
out.copy_(result)
return out
return result
else:
raise ValueError(
f"bmm_batch_invariant expects 3D tensors, "
f"got shapes {a.shape} and {b.shape}"
)
def addmm_batch_invariant(bias, a, b): def addmm_batch_invariant(bias, a, b):
return matmul_persistent(a, b, bias=bias) return matmul_persistent(a, b, bias=bias)
...@@ -479,11 +517,24 @@ def _log_softmax_batch_invariant(input, dim, _half_to_float): ...@@ -479,11 +517,24 @@ def _log_softmax_batch_invariant(input, dim, _half_to_float):
return log_softmax(input, dim=dim) return log_softmax(input, dim=dim)
def softmax_batch_invariant(input, dim, dtype=None):
# Compute softmax in a deterministic way
# First subtract max for numerical stability (standard practice)
input_max = torch.amax(input, dim=dim, keepdim=True)
input = input - input_max
exp_x = torch.exp(input)
sum_exp_x = torch.sum(exp_x, dim=dim, keepdim=True)
return exp_x / sum_exp_x
def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None = None): def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None = None):
assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}" assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}"
result = input.to(torch.float32) result = input.to(torch.float32)
if len(dim) == 0:
dim = [i for i in range(len(input.shape))]
# Sort dimensions to reduce from largest to smallest to handle shifting dims # Sort dimensions to reduce from largest to smallest to handle shifting dims
# during iterative reduction. # during iterative reduction.
sorted_dims = sorted([d % input.ndim for d in dim], reverse=True) sorted_dims = sorted([d % input.ndim for d in dim], reverse=True)
...@@ -500,8 +551,134 @@ def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None = ...@@ -500,8 +551,134 @@ def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None =
return result return result
@triton.jit
def _rms_norm_kernel(
input_ptr,
weight_ptr,
output_ptr,
input_row_stride,
output_row_stride,
n_cols,
eps,
BLOCK_SIZE: tl.constexpr,
):
"""
Compute RMS normalization along the last dimension of a 2D tensor.
RMS Norm: y = x / sqrt(mean(x^2) + eps) * weight
Each block handles one row of the input tensor.
"""
row_idx = tl.program_id(0).to(tl.int64)
row_start_ptr = input_ptr + row_idx * input_row_stride
output_row_start_ptr = output_ptr + row_idx * output_row_stride
# Step 1: Compute sum of squares in float32 to avoid overflow
sum_sq = tl.zeros([1], dtype=tl.float32)
for col_offset in range(0, n_cols, BLOCK_SIZE):
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
mask = col_idx < n_cols
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
# Convert to float32 for accumulation to prevent overflow
vals_f32 = vals.to(tl.float32)
sq_vals = vals_f32 * vals_f32
sum_sq += tl.sum(tl.where(mask, sq_vals, 0.0))
# Step 2: Compute RMS (root mean square) in float32
mean_sq = sum_sq / n_cols
rms = tl.sqrt(mean_sq + eps)
inv_rms = 1.0 / rms
# Step 3: Normalize and apply weight
for col_offset in range(0, n_cols, BLOCK_SIZE):
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
mask = col_idx < n_cols
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
weight = tl.load(weight_ptr + col_idx, mask=mask, other=1.0)
# Compute in float32 then convert back to input dtype
vals_f32 = vals.to(tl.float32)
weight_f32 = weight.to(tl.float32)
output_f32 = vals_f32 * inv_rms * weight_f32
output = output_f32.to(vals.dtype)
tl.store(output_row_start_ptr + col_idx, output, mask=mask)
def rms_norm(
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> torch.Tensor:
"""
Compute RMS normalization using Triton kernel.
RMS Norm normalizes the input by the root mean square and scales by weight:
output = input / sqrt(mean(input^2) + eps) * weight
Args:
input: Input tensor of shape (..., hidden_size)
weight: Weight tensor of shape (hidden_size,)
eps: Small constant for numerical stability
Returns:
Tensor with RMS normalization applied along the last dimension
"""
assert weight.dim() == 1, "Weight must be 1-dimensional"
assert input.shape[-1] == weight.shape[0], (
f"Input last dimension ({input.shape[-1]}) must match "
f"weight dimension ({weight.shape[0]})"
)
# Flatten all dimensions except the last one
original_shape = input.shape
input_2d = input.reshape(-1, input.shape[-1])
input_2d = input_2d.contiguous()
weight = weight.contiguous()
n_rows, n_cols = input_2d.shape
output = torch.empty_like(input_2d)
BLOCK_SIZE = 1024
grid = (n_rows,)
_rms_norm_kernel[grid](
input_2d,
weight,
output,
input_2d.stride(0),
output.stride(0),
n_cols,
eps,
BLOCK_SIZE=BLOCK_SIZE,
)
return output.reshape(original_shape)
def rms_norm_batch_invariant(
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> torch.Tensor:
"""
Batch-invariant wrapper for RMS normalization.
This function provides a deterministic, batch-invariant implementation
of RMS normalization for use with the batch_invariant mode.
Args:
input: Input tensor of shape (..., hidden_size)
weight: Weight tensor of shape (hidden_size,)
eps: Small constant for numerical stability
Returns:
RMS normalized tensor
"""
return rms_norm(input, weight, eps=eps)
def linear_batch_invariant(input, weight, bias=None):
output = mm_batch_invariant(input, weight.t())
if bias is not None:
output = output + bias
return output
_batch_invariant_MODE = False _batch_invariant_MODE = False
_batch_invariant_LIB = None _batch_invariant_LIB = None
_original_torch_bmm = None
def is_batch_invariant_mode_enabled(): def is_batch_invariant_mode_enabled():
...@@ -509,7 +686,7 @@ def is_batch_invariant_mode_enabled(): ...@@ -509,7 +686,7 @@ def is_batch_invariant_mode_enabled():
def enable_batch_invariant_mode(): def enable_batch_invariant_mode():
global _batch_invariant_MODE, _batch_invariant_LIB global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
if _batch_invariant_MODE: if _batch_invariant_MODE:
return return
...@@ -517,16 +694,28 @@ def enable_batch_invariant_mode(): ...@@ -517,16 +694,28 @@ def enable_batch_invariant_mode():
_batch_invariant_LIB = torch.library.Library("aten", "IMPL") _batch_invariant_LIB = torch.library.Library("aten", "IMPL")
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA") _batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA") _batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA")
_batch_invariant_LIB.impl( _batch_invariant_LIB.impl(
"aten::_log_softmax", _log_softmax_batch_invariant, "CUDA" "aten::_log_softmax", _log_softmax_batch_invariant, "CUDA"
) )
_batch_invariant_LIB.impl("aten::softmax", softmax_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::_softmax", softmax_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA") _batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
# Also monkeypatch torch.bmm directly as a fallback
_original_torch_bmm = torch.bmm
torch.bmm = bmm_batch_invariant
def disable_batch_invariant_mode(): def disable_batch_invariant_mode():
global _batch_invariant_MODE, _batch_invariant_LIB global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
if _batch_invariant_LIB is not None: if _batch_invariant_LIB is not None:
_batch_invariant_LIB._destroy() _batch_invariant_LIB._destroy()
if _original_torch_bmm is not None:
torch.bmm = _original_torch_bmm
_original_torch_bmm = None
_batch_invariant_MODE = False _batch_invariant_MODE = False
_batch_invariant_LIB = None _batch_invariant_LIB = None
...@@ -563,17 +752,55 @@ def vllm_kernel_override_batch_invariant(): ...@@ -563,17 +752,55 @@ def vllm_kernel_override_batch_invariant():
return is_overridden return is_overridden
def override_envs_for_invariance():
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
supported_backends = [
"FLASH_ATTN", # best supported backend
"FLEX_ATTENTION",
"FLASHINFER",
"FLASH_ATTN_MLA",
"TRITON_MLA",
# Not yet supported MLA backends
# "FLASHMLA",
# "FLASHINFER_MLA",
]
if curr_attn_backend not in supported_backends:
warning = (
"Forcibly updating attention backend to"
f" {supported_backends[0]} for batch_invariant. "
f" Supported backends: {supported_backends}."
)
logger.warning_once(warning)
os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0]
if os.environ["VLLM_ATTENTION_BACKEND"] != supported_backends[0]:
warning = (
"You are using a decode-invariant form of batch invariance. "
"This will not be invariant between prefill and decode."
)
logger.warning_once(warning)
os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
# NCCL determinism settings
os.environ["NCCL_LAUNCH_MODE"] = "GROUP"
os.environ["NCCL_COLLNET_ENABLE"] = "0"
os.environ["NCCL_NVLS_ENABLE"] = "0"
os.environ["NCCL_P2P_NET_DISABLE"] = "1"
os.environ["NCCL_MIN_NCHANNELS"] = "1"
os.environ["NCCL_MAX_NCHANNELS"] = "1"
os.environ["NCCL_PROTO"] = "Simple"
os.environ["NCCL_ALGO"] = "allreduce:tree"
os.environ["NCCL_NTHREADS"] = "1"
os.environ["NCCL_SOCKET_NTHREADS"] = "1"
def init_batch_invariance(): def init_batch_invariance():
# this will hit all the csrc overrides as well # this will hit all the csrc overrides as well
if vllm_kernel_override_batch_invariant(): if vllm_kernel_override_batch_invariant():
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND override_envs_for_invariance()
supported_backends = ["FLEX_ATTENTION", "FLASHINFER"]
if curr_attn_backend not in supported_backends:
warning = (
"Forcibly updating attention backend to"
f" {supported_backends[0]} for batch_invariant. "
f" Supported backends: {supported_backends}."
)
logger.warning_once(warning)
os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0]
enable_batch_invariant_mode() enable_batch_invariant_mode()
# Disable TF32 for batch invariance - it causes non-deterministic rounding
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
...@@ -15,6 +15,9 @@ import vllm.envs as envs ...@@ -15,6 +15,9 @@ import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG, FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -837,6 +840,10 @@ def get_moe_configs( ...@@ -837,6 +840,10 @@ def get_moe_configs(
be picked and the associated configuration chosen to invoke the kernel. be picked and the associated configuration chosen to invoke the kernel.
""" """
# Avoid optimizing for the batch invariant case. Use default config
if vllm_kernel_override_batch_invariant():
return None
# First look up if an optimized configuration is available in the configs # First look up if an optimized configuration is available in the configs
# directory # directory
block_shape = [block_n, block_k] if block_n and block_k else None block_shape = [block_n, block_k] if block_n and block_k else None
...@@ -969,6 +976,15 @@ def get_default_config( ...@@ -969,6 +976,15 @@ def get_default_config(
dtype: str | None, dtype: str | None,
block_shape: list[int] | None = None, block_shape: list[int] | None = None,
) -> dict[str, int]: ) -> dict[str, int]:
if vllm_kernel_override_batch_invariant():
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
}
return config
if dtype == "fp8_w8a8" and block_shape is not None: if dtype == "fp8_w8a8" and block_shape is not None:
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
# BLOCK_SIZE_K must be divisible by block_shape[1] # BLOCK_SIZE_K must be divisible by block_shape[1]
...@@ -1118,7 +1134,10 @@ def fused_topk_bias( ...@@ -1118,7 +1134,10 @@ def fused_topk_bias(
scores_for_choice = scores.view( scores_for_choice = scores.view(
-1, n_routed_experts -1, n_routed_experts
) + e_score_correction_bias.unsqueeze(0) ) + e_score_correction_bias.unsqueeze(0)
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=False)[1]
# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = vllm_kernel_override_batch_invariant()
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
topk_weights = scores.gather(1, topk_indices) topk_weights = scores.gather(1, topk_indices)
if renormalize: if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
...@@ -1179,7 +1198,10 @@ def grouped_topk( ...@@ -1179,7 +1198,10 @@ def grouped_topk(
group_scores = ( group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group] ) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = vllm_kernel_override_batch_invariant()
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
1 1
] # [n, top_k_group] ] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group] group_mask = torch.zeros_like(group_scores) # [n, n_group]
...@@ -1192,11 +1214,13 @@ def grouped_topk( ...@@ -1192,11 +1214,13 @@ def grouped_topk(
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
if e_score_correction_bias is not None: if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1]
# Use original unbiased scores for the routing weights # Use original unbiased scores for the routing weights
topk_weights = original_scores.gather(1, topk_ids) topk_weights = original_scores.gather(1, topk_ids)
else: else:
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) topk_weights, topk_ids = torch.topk(
tmp_scores, k=topk, dim=-1, sorted=use_sorted
)
if renormalize: if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
......
...@@ -8,6 +8,10 @@ import torch.nn.functional as F ...@@ -8,6 +8,10 @@ import torch.nn.functional as F
import vllm.envs as envs import vllm.envs as envs
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.batch_invariant import (
rms_norm_batch_invariant,
vllm_kernel_override_batch_invariant,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
...@@ -21,6 +25,8 @@ def rms_norm( ...@@ -21,6 +25,8 @@ def rms_norm(
) -> torch.Tensor: ) -> torch.Tensor:
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
if vllm_kernel_override_batch_invariant():
return rms_norm_batch_invariant(x, weight, variance_epsilon)
out = torch.empty_like(x) out = torch.empty_like(x)
ops.rms_norm( ops.rms_norm(
out, out,
...@@ -39,6 +45,10 @@ def fused_add_rms_norm( ...@@ -39,6 +45,10 @@ def fused_add_rms_norm(
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
if vllm_kernel_override_batch_invariant():
return rms_norm_batch_invariant(
x + residual, weight, variance_epsilon
), x + residual
ops.fused_add_rms_norm( ops.fused_add_rms_norm(
x, x,
residual, residual,
......
...@@ -160,6 +160,7 @@ class MultiHeadLatentAttentionWrapper(CustomOp): ...@@ -160,6 +160,7 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
k_pe, k_pe,
output_shape=(hidden_states.shape[0], self.num_heads * self.v_head_dim), output_shape=(hidden_states.shape[0], self.num_heads * self.v_head_dim),
) )
return self.o_proj(attn_out)[0] return self.o_proj(attn_out)[0]
def forward_cuda(self, *args, **kwargs): def forward_cuda(self, *args, **kwargs):
......
...@@ -14,6 +14,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk ...@@ -14,6 +14,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
)
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoE,
FusedMoEActivationFormat, FusedMoEActivationFormat,
...@@ -353,6 +356,8 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -353,6 +356,8 @@ class Fp8LinearMethod(LinearMethodBase):
# Disable marlin for rocm # Disable marlin for rocm
if current_platform.is_rocm(): if current_platform.is_rocm():
self.use_marlin = False self.use_marlin = False
if vllm_kernel_override_batch_invariant():
self.use_marlin = False
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
...@@ -534,6 +539,66 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -534,6 +539,66 @@ class Fp8LinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
# If batch invariant mode is enabled, dequantize and use BF16 compute
if vllm_kernel_override_batch_invariant():
# Dequantize FP8 weights to BF16
weight_fp8 = layer.weight.to(torch.bfloat16)
weight_scale = layer.weight_scale.to(torch.bfloat16)
# Handle different quantization granularities
if self.block_quant:
# Block-wise quantization:
# - Weight is NOT transposed, shape is [N, K] (output_size, input_size)
# - Scale has shape [num_blocks_k, num_blocks_n] (TRANSPOSED!)
assert self.weight_block_size is not None
block_n, block_k = self.weight_block_size # Note: order is [N, K]
N, K = weight_fp8.shape
# Scale is stored transposed: [num_blocks_k, num_blocks_n]
# We need to transpose it to [num_blocks_n, num_blocks_k] first
weight_scale = weight_scale.t()
# Expand scale to match weight dimensions
# scale_expanded should have shape [N, K]
scale_expanded = weight_scale.repeat_interleave(
block_n, dim=0
).repeat_interleave(block_k, dim=1)
# Trim to exact weight size (in case of padding)
scale_expanded = scale_expanded[:N, :K]
weight_bf16 = weight_fp8 * scale_expanded
else:
# Per-tensor quantization: weight IS transposed to [K, N]
# scale should be scalar or [1] or per-output-channel [N]
if weight_scale.numel() == 1:
# Per-tensor: simple scalar multiplication
weight_bf16 = weight_fp8 * weight_scale
else:
# Multiple scales (fused modules like QKV)
# Try to infer correct broadcasting
# weight is [K, N], scale could be [num_logical_weights]
# Need to figure out how to broadcast - for now just try
# direct multiplication
if (
weight_scale.dim() == 1
and weight_scale.shape[0] == weight_fp8.shape[0]
):
# Per-row scaling
weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
else:
# Fallback
weight_bf16 = weight_fp8 * weight_scale
# For block quant, weight is [N, K], for per-tensor it's [K, N]
# F.linear expects weight to be [N, K], so:
if self.block_quant:
# Already in correct shape [N, K]
output = torch.nn.functional.linear(x, weight_bf16, bias)
else:
# Need to transpose back: [K, N] -> [N, K]
output = torch.nn.functional.linear(x, weight_bf16.t(), bias)
return output
if self.use_marlin: if self.use_marlin:
return apply_fp8_marlin_linear( return apply_fp8_marlin_linear(
input=x, input=x,
......
...@@ -216,6 +216,7 @@ class TransformerBlock(torch.nn.Module): ...@@ -216,6 +216,7 @@ class TransformerBlock(torch.nn.Module):
else: else:
hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.attn(hidden_states, positions) hidden_states = self.attn(hidden_states, positions)
# Fully Connected # Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
output = self.mlp(hidden_states) output = self.mlp(hidden_states)
......
...@@ -31,10 +31,12 @@ if is_flash_attn_varlen_func_available(): ...@@ -31,10 +31,12 @@ if is_flash_attn_varlen_func_available():
get_scheduler_metadata, get_scheduler_metadata,
reshape_and_cache_flash, reshape_and_cache_flash,
) )
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.parallel_state import get_dcp_group from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
)
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionCGSupport,
...@@ -306,6 +308,9 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad ...@@ -306,6 +308,9 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
# we only set num_splits when using cuda graphs. # we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits max_num_splits = self.max_num_splits
if vllm_kernel_override_batch_invariant():
max_num_splits = 1
def schedule( def schedule(
batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal
): ):
...@@ -478,6 +483,9 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -478,6 +483,9 @@ class FlashAttentionImpl(AttentionImpl):
self.attn_type = attn_type self.attn_type = attn_type
self.vllm_flash_attn_version = get_flash_attn_version() self.vllm_flash_attn_version = get_flash_attn_version()
# Cache the batch invariant result for use in forward passes
self.batch_invariant_enabled = vllm_kernel_override_batch_invariant()
if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8(): if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8():
raise NotImplementedError( raise NotImplementedError(
"FlashAttention does not support fp8 kv-cache on this device." "FlashAttention does not support fp8 kv-cache on this device."
...@@ -810,6 +818,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -810,6 +818,7 @@ class FlashAttentionImpl(AttentionImpl):
q_descale=layer._q_scale.expand(descale_shape), q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape), k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape),
num_splits=1 if self.batch_invariant_enabled else 0,
) )
return output return output
...@@ -954,6 +963,7 @@ def cascade_attention( ...@@ -954,6 +963,7 @@ def cascade_attention(
# s_aux is incorporated into prefix_lse inside the GPU kernel, # s_aux is incorporated into prefix_lse inside the GPU kernel,
# enabling its effect during the final attention merge. # enabling its effect during the final attention merge.
s_aux=s_aux, s_aux=s_aux,
num_splits=1 if vllm_kernel_override_batch_invariant() else 0,
) )
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2]) descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
...@@ -978,6 +988,7 @@ def cascade_attention( ...@@ -978,6 +988,7 @@ def cascade_attention(
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
num_splits=1 if vllm_kernel_override_batch_invariant() else 0,
) )
# Merge prefix and suffix outputs, and store the result in output. # Merge prefix and suffix outputs, and store the result in output.
......
...@@ -211,6 +211,9 @@ from vllm.attention.utils.fa_utils import get_flash_attn_version ...@@ -211,6 +211,9 @@ from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
)
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
LinearBase, LinearBase,
...@@ -1187,6 +1190,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): ...@@ -1187,6 +1190,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
# Convert from (B, N, L) to (N, B, L) # Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
if is_rocm_aiter_fp8bmm_enabled(): if is_rocm_aiter_fp8bmm_enabled():
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
x = aiter_triton_fp8_bmm( x = aiter_triton_fp8_bmm(
...@@ -1279,6 +1283,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1279,6 +1283,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
# ROCm leverages the upstream flash_attn, which takes a parameter # ROCm leverages the upstream flash_attn, which takes a parameter
# called "return_attn_probs" instead of return_softmax_lse # called "return_attn_probs" instead of return_softmax_lse
kwargs["return_attn_probs"] = return_softmax_lse kwargs["return_attn_probs"] = return_softmax_lse
if vllm_kernel_override_batch_invariant():
kwargs["num_splits"] = 1
attn_out = self.flash_attn_varlen_func( attn_out = self.flash_attn_varlen_func(
q=q, q=q,
...@@ -1841,9 +1847,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1841,9 +1847,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
if has_decode: if has_decode:
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
decode_q_nope, decode_q_pe = decode_q.split( decode_q_nope, decode_q_pe = decode_q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
) )
# Convert from (B, N, P) to (N, B, P) # Convert from (B, N, P) to (N, B, P)
decode_q_nope = decode_q_nope.transpose(0, 1) decode_q_nope = decode_q_nope.transpose(0, 1)
...@@ -1868,17 +1876,18 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1868,17 +1876,18 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
# Pads the head_dim if necessary (for the underlying kernel) # Pads the head_dim if necessary (for the underlying kernel)
N, B, P = decode_q_nope.shape N, B, P = decode_q_nope.shape
_, _, L = self.W_UK_T.shape _, _, L = self.W_UK_T.shape
if self.q_pad_num_heads is not None: if self.q_pad_num_heads is not None:
decode_ql_nope = decode_q_nope.new_empty( decode_ql_nope = decode_q_nope.new_empty(
(self.q_pad_num_heads, B, L) (self.q_pad_num_heads, B, L)
) )
decode_ql_nope.resize_((N, B, L)) decode_ql_nope.resize_((N, B, L))
else: else:
decode_ql_nope = decode_q_nope.new_empty((N, B, L)) decode_ql_nope = decode_q_nope.new_empty((N, B, L))
# Multiply (N, B, P) x (N, P, L) -> (N, B, L) # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope) torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope)
# Convert from (N, B, L) to (B, N, L) # Convert from (N, B, L) to (B, N, L)
decode_ql_nope = decode_ql_nope.transpose(0, 1) decode_ql_nope = decode_ql_nope.transpose(0, 1)
......
...@@ -18,6 +18,9 @@ from vllm.attention.utils.fa_utils import ( ...@@ -18,6 +18,9 @@ from vllm.attention.utils.fa_utils import (
) )
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
)
from vllm.v1.attention.backends.mla.common import ( from vllm.v1.attention.backends.mla.common import (
MLACommonBackend, MLACommonBackend,
MLACommonDecodeMetadata, MLACommonDecodeMetadata,
...@@ -107,6 +110,9 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] ...@@ -107,6 +110,9 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
# pre-allocated during capture. # pre-allocated during capture.
self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
if vllm_kernel_override_batch_invariant():
self.max_num_splits = 1
def _schedule_decode( def _schedule_decode(
self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal
): ):
...@@ -175,7 +181,10 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] ...@@ -175,7 +181,10 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
# we only set num_splits when using cuda graphs. # we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits max_num_splits = self.max_num_splits
return FlashAttnMLADecodeMetadata( if vllm_kernel_override_batch_invariant():
max_num_splits = 1
metadata = FlashAttnMLADecodeMetadata(
block_table=block_table_tensor, block_table=block_table_tensor,
seq_lens=seq_lens_device, seq_lens=seq_lens_device,
query_start_loc=query_start_loc_device, query_start_loc=query_start_loc_device,
...@@ -185,6 +194,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] ...@@ -185,6 +194,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
max_num_splits=max_num_splits, max_num_splits=max_num_splits,
dcp_tot_seq_lens=dcp_tot_seq_lens_device, dcp_tot_seq_lens=dcp_tot_seq_lens_device,
) )
return metadata
class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]): class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
......
...@@ -14,6 +14,9 @@ from vllm.attention.ops.flashmla import ( ...@@ -14,6 +14,9 @@ from vllm.attention.ops.flashmla import (
) )
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
)
from vllm.v1.attention.backends.mla.common import ( from vllm.v1.attention.backends.mla.common import (
MLACommonBackend, MLACommonBackend,
MLACommonDecodeMetadata, MLACommonDecodeMetadata,
...@@ -223,19 +226,50 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -223,19 +226,50 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
if type(q) is tuple: if type(q) is tuple:
q = torch.cat(q, dim=-1) q = torch.cat(q, dim=-1)
# mypy assertion: q is now always a tensor
assert isinstance(q, torch.Tensor) assert isinstance(q, torch.Tensor)
num_decodes = attn_metadata.num_decodes num_decodes = attn_metadata.num_decodes
q = reshape_query_for_spec_decode(q, num_decodes) q = reshape_query_for_spec_decode(q, num_decodes)
tile_scheduler_metadata = attn_metadata.decode.tile_scheduler_metadata
num_splits = attn_metadata.decode.num_splits
if vllm_kernel_override_batch_invariant():
device = q.device
dtype = torch.int32
B = q.shape[0]
# block_table shape: [batch_size, max_num_blocks_per_seq]
# The number of blocks per sequence is in the second dimension
topk = attn_metadata.decode.block_table.shape[-1]
B_TOPK = 64
assert topk % B_TOPK == 0, f"topk ({topk}) must be divisible by {B_TOPK}"
end_block_idx = topk // B_TOPK
# Single partition => num_sm_parts = 1
# TileSchedulerMetaDataSize = 8, layout:
# [begin_idx, begin_block_idx, end_idx, end_block_idx,
# begin_n_split_idx, _, _, _]
tile_scheduler_metadata = torch.zeros((1, 8), dtype=dtype, device=device)
tile_scheduler_metadata[0, 0] = 0 # begin_idx
tile_scheduler_metadata[0, 1] = 0 # sched_begin_block_idx
tile_scheduler_metadata[0, 2] = B - 1 # end_idx
tile_scheduler_metadata[0, 3] = end_block_idx
tile_scheduler_metadata[0, 4] = 0 # begin_n_split_idx
# fields [5..7] stay 0
# Non-split path ignores num_splits, but the API requires it:
# zeros of length B+1
num_splits = torch.zeros((B + 1,), dtype=dtype, device=device)
o, lse = flash_mla_with_kvcache( o, lse = flash_mla_with_kvcache(
q=q, q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table, block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens, cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank, head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata.decode.tile_scheduler_metadata, tile_scheduler_metadata=tile_scheduler_metadata,
num_splits=attn_metadata.decode.num_splits, num_splits=num_splits,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
descale_q=layer._q_scale.reshape(1), descale_q=layer._q_scale.reshape(1),
......
...@@ -13,6 +13,9 @@ from vllm.attention.backends.abstract import ( ...@@ -13,6 +13,9 @@ from vllm.attention.backends.abstract import (
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from vllm.attention.ops.triton_flash_attention import triton_attention from vllm.attention.ops.triton_flash_attention import triton_attention
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
from vllm.v1.attention.backends.mla.common import ( from vllm.v1.attention.backends.mla.common import (
...@@ -158,7 +161,9 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): ...@@ -158,7 +161,9 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
B, q_num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device B, q_num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device
) )
lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device) lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device)
num_kv_splits = 4 # TODO: heuristic
# For batch invariance, use only 1 split to ensure deterministic reduction
num_kv_splits = 1 if vllm_kernel_override_batch_invariant() else 4
# TODO(lucas) Allocate ahead of time # TODO(lucas) Allocate ahead of time
attn_logits = torch.empty( attn_logits = torch.empty(
......
...@@ -231,9 +231,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -231,9 +231,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
set_cpu_offload_max_bytes(int(self.cache_config.cpu_offload_gb * 1024**3)) set_cpu_offload_max_bytes(int(self.cache_config.cpu_offload_gb * 1024**3))
from vllm.model_executor.layers.batch_invariant import init_batch_invariance
init_batch_invariance()
model_config = self.model_config model_config = self.model_config
cache_config = self.cache_config cache_config = self.cache_config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment