"vscode:/vscode.git/clone" did not exist on "7b5ecf79bd94aab0d782c70126d0dcc37c16bc60"
Unverified Commit ecd1ea13 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Kernel] Porting the TRTLLM minimax_allreduce_rms kernels (#37045)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 8f121f78
...@@ -20,7 +20,20 @@ steps: ...@@ -20,7 +20,20 @@ steps:
- tests/kernels/core - tests/kernels/core
- tests/kernels/test_concat_mla_q.py - tests/kernels/test_concat_mla_q.py
commands: commands:
- pytest -v -s kernels/core kernels/test_concat_mla_q.py - pytest -v -s kernels/core --ignore=kernels/core/test_minimax_reduce_rms.py kernels/test_concat_mla_q.py
- label: Kernels MiniMax Reduce RMS Test (2 GPUs)
timeout_in_minutes: 15
num_devices: 2
device: h100
source_file_dependencies:
- csrc/minimax_reduce_rms_kernel.cu
- csrc/minimax_reduce_rms_kernel.h
- vllm/model_executor/layers/mamba/linear_attn.py
- vllm/model_executor/layers/mamba/lamport_workspace.py
- tests/kernels/core/test_minimax_reduce_rms.py
commands:
- pytest -v -s kernels/core/test_minimax_reduce_rms.py
- label: Kernels Attention Test %N - label: Kernels Attention Test %N
timeout_in_minutes: 35 timeout_in_minutes: 35
......
...@@ -307,6 +307,8 @@ set(VLLM_EXT_SRC ...@@ -307,6 +307,8 @@ set(VLLM_EXT_SRC
"csrc/torch_bindings.cpp") "csrc/torch_bindings.cpp")
if(VLLM_GPU_LANG STREQUAL "CUDA") if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_EXT_SRC "csrc/minimax_reduce_rms_kernel.cu")
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. # Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
......
This diff is collapsed.
/*
* Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <torch/types.h>
namespace vllm {
namespace tensorrt_llm {
template <typename DType>
struct ElemsPerAccess;
template <>
struct ElemsPerAccess<half> {
static constexpr int value = 8;
using vec_type = float4;
};
template <>
struct ElemsPerAccess<nv_bfloat16> {
static constexpr int value = 8;
using vec_type = float4;
};
template <>
struct ElemsPerAccess<float> {
static constexpr int value = 4;
using vec_type = float4;
};
template <typename DType>
static constexpr int kElemsPerAccess = ElemsPerAccess<DType>::value;
struct MiniMaxReduceRMSParams {
int nranks{};
int rank{};
at::ScalarType dtype{at::ScalarType::Undefined};
int size_q{};
int hidden_dim{};
int size_k{};
int hidden_dim_k{};
int stride_q{}; // row stride for q input (elements); when > hidden_dim,
// q is part of a wider qkv tensor
int stride_k{}; // row stride for k input (elements); when > hidden_dim_k,
// k is part of a wider qkv tensor
int stride_q_out{}; // row stride for q output (elements); 0 = contiguous
int stride_k_out{}; // row stride for k output (elements); 0 = contiguous
void** workspace{};
void* allreduce_in{};
void* rms_norm_out{};
void* rms_gamma{};
void* allreduce_in_k{};
void* rms_norm_out_k{};
void* rms_gamma_k{};
float rms_eps{};
cudaStream_t stream{};
};
void minimax_reduce_rms_op(MiniMaxReduceRMSParams const& params);
} // namespace tensorrt_llm
} // namespace vllm
...@@ -309,3 +309,15 @@ int64_t qr_max_size(); ...@@ -309,3 +309,15 @@ int64_t qr_max_size();
void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a,
torch::Tensor const& mat_b); torch::Tensor const& mat_b);
#endif #endif
#ifndef USE_ROCM
torch::Tensor minimax_allreduce_rms(torch::Tensor const& input,
torch::Tensor const& norm_weight,
torch::Tensor workspace, int64_t const rank,
int64_t const nranks, double const eps);
std::tuple<torch::Tensor, torch::Tensor> minimax_allreduce_rms_qk(
torch::Tensor qkv, torch::Tensor const& norm_weight_q,
torch::Tensor const& norm_weight_k, torch::Tensor workspace,
int64_t const q_size, int64_t const kv_size, int64_t const rank,
int64_t const nranks, double const eps);
#endif
\ No newline at end of file
...@@ -496,6 +496,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -496,6 +496,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? b_qzeros, " "Tensor? b_qzeros, "
"SymInt n, SymInt group_size, SymInt sm_count, SymInt sm_version, SymInt " "SymInt n, SymInt group_size, SymInt sm_count, SymInt sm_version, SymInt "
"CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) -> Tensor"); "CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) -> Tensor");
ops.def(
"minimax_allreduce_rms("
"Tensor input,"
"Tensor norm_weight,"
"Tensor workspace,"
"int rank,"
"int nranks,"
"float eps) -> Tensor");
ops.impl("minimax_allreduce_rms", torch::kCUDA, &minimax_allreduce_rms);
ops.def(
"minimax_allreduce_rms_qk("
"Tensor qkv,"
"Tensor norm_weight_q,"
"Tensor norm_weight_k,"
"Tensor workspace,"
"int q_size,"
"int kv_size,"
"int rank,"
"int nranks,"
"float eps) -> (Tensor, Tensor)");
ops.impl("minimax_allreduce_rms_qk", torch::kCUDA, &minimax_allreduce_rms_qk);
// conditionally compiled so impl in source file // conditionally compiled so impl in source file
#endif #endif
} }
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for MiniMax QK RMS-norm: NCCL reference vs Lamport fused kernel."""
import pytest
import torch
import torch.nn as nn
from torch.multiprocessing import spawn
from tests.kernels.utils import opcheck
from tests.utils import ensure_current_vllm_config, init_test_distributed_environment
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.model_executor.layers.mamba.linear_attn import MiniMaxText01RMSNormTP
from vllm.platforms import current_platform
from vllm.utils.network_utils import get_open_port
from vllm.utils.torch_utils import set_random_seed
@ensure_current_vllm_config()
def _worker_forward_qk(
local_rank,
world_size,
port,
num_tokens,
hidden_q_full,
hidden_k_full,
dtype,
seed,
eps,
):
"""Per-rank worker: compare NCCL allreduce path vs Lamport fused kernel."""
if not hasattr(torch.ops._C, "minimax_allreduce_rms_qk"):
cleanup_dist_env_and_memory()
return
device = torch.device(f"cuda:{local_rank}")
torch.accelerator.set_device_index(device)
init_test_distributed_environment(
world_size, 1, local_rank, port, local_rank=local_rank
)
hq = hidden_q_full // world_size
hk = hidden_k_full // world_size
q_norm = MiniMaxText01RMSNormTP(hidden_q_full, eps=eps).cuda()
k_norm = MiniMaxText01RMSNormTP(hidden_k_full, eps=eps).cuda()
set_random_seed(seed)
qw = torch.randn(hidden_q_full, dtype=dtype, device="cuda")
kw = torch.randn(hidden_k_full, dtype=dtype, device="cuda")
q_norm.weight = nn.Parameter(qw[local_rank * hq : (local_rank + 1) * hq])
k_norm.weight = nn.Parameter(kw[local_rank * hk : (local_rank + 1) * hk])
torch.manual_seed(seed + 1000 + local_rank)
qkv = torch.randn(num_tokens, hq + hk + hk, dtype=dtype, device="cuda")
q_ref, k_ref, v_ref = qkv.clone().split([hq, hk, hk], dim=-1)
ref_q, ref_k = MiniMaxText01RMSNormTP.forward_qk(q_norm, k_norm, q_ref, k_ref)
# Set up Lamport workspace.
from vllm.distributed.parallel_state import get_tp_group
from vllm.model_executor.layers.mamba.lamport_workspace import (
get_allreduce_workspace,
)
workspace = get_allreduce_workspace(
rank=local_rank,
world_size=world_size,
max_tokens=num_tokens,
process_group=get_tp_group().cpu_group,
)
opcheck(
torch.ops._C.minimax_allreduce_rms_qk,
(
qkv.clone(),
q_norm.weight,
k_norm.weight,
workspace,
hq,
hk,
local_rank,
world_size,
eps,
),
)
fused_q, fused_k = torch.ops._C.minimax_allreduce_rms_qk(
qkv.clone(),
q_norm.weight,
k_norm.weight,
workspace,
hq,
hk,
local_rank,
world_size,
eps,
)
_, _, fused_v = qkv.split([hq, hk, hk], dim=-1)
torch.accelerator.synchronize()
torch.testing.assert_close(
fused_q,
ref_q,
atol=3e-2,
rtol=3e-2,
)
torch.testing.assert_close(fused_k, ref_k, atol=3e-2, rtol=3e-2)
cleanup_dist_env_and_memory()
@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="CUDA required",
)
@pytest.mark.parametrize("world_size", [2, 4, 8])
@pytest.mark.parametrize("num_tokens", [1, 128, 333])
@pytest.mark.parametrize(
"hidden_dims",
[(6144, 1024)],
)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("eps", [1e-6])
@pytest.mark.parametrize("seed", [42])
def test_minimax_reduce_rms_qk(
world_size,
num_tokens,
hidden_dims,
dtype,
eps,
seed,
):
num_gpus = current_platform.device_count()
if num_gpus < world_size:
pytest.skip(f"Need >= {world_size} GPUs, have {num_gpus}")
hidden_q_full, hidden_k_full = hidden_dims
port = str(get_open_port())
spawn(
_worker_forward_qk,
args=(
world_size,
port,
num_tokens,
hidden_q_full,
hidden_k_full,
dtype,
seed,
eps,
),
nprocs=world_size,
join=True,
)
...@@ -3491,3 +3491,38 @@ if hasattr(torch.ops._C, "hadacore_transform"): ...@@ -3491,3 +3491,38 @@ if hasattr(torch.ops._C, "hadacore_transform"):
@register_fake("_C::hadacore_transform") @register_fake("_C::hadacore_transform")
def _hadacore_transform_fake(x: torch.Tensor, inplace: bool) -> torch.Tensor: def _hadacore_transform_fake(x: torch.Tensor, inplace: bool) -> torch.Tensor:
return torch.empty_like(x) if not inplace else x return torch.empty_like(x) if not inplace else x
if hasattr(torch.ops._C, "minimax_allreduce_rms"):
@register_fake("_C::minimax_allreduce_rms")
def _minimax_allreduce_rms_fake(
input: torch.Tensor,
norm_weight: torch.Tensor,
workspace: torch.Tensor,
rank: int,
nranks: int,
eps: float,
) -> torch.Tensor:
return torch.empty_like(input)
if hasattr(torch.ops._C, "minimax_allreduce_rms_qk"):
@register_fake("_C::minimax_allreduce_rms_qk")
def _minimax_allreduce_rms_qk_fake(
qkv: torch.Tensor,
norm_weight_q: torch.Tensor,
norm_weight_k: torch.Tensor,
workspace: torch.Tensor,
q_size: int,
kv_size: int,
rank: int,
nranks: int,
eps: float,
) -> tuple[torch.Tensor, torch.Tensor]:
token_num = qkv.shape[0]
return (
torch.empty([token_num, q_size], dtype=qkv.dtype, device=qkv.device),
torch.empty([token_num, kv_size], dtype=qkv.dtype, device=qkv.device),
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Fusion pass: replace MiniMax QK allreduce + RMS norm with the Lamport
fused kernel (minimax_allreduce_rms_qk) for decode-size batches.
Pattern (inlined forward_qk in compiled graph):
q, k, v = qkv.split([q_size, kv_size, kv_size], -1)
q_fp32 = q.to(float32); k_fp32 = k.to(float32)
q_var = q_fp32.pow(2).mean(-1, keepdim=True)
k_var = k_fp32.pow(2).mean(-1, keepdim=True)
qk_var = cat([q_var, k_var], -1)
qk_var = allreduce(qk_var) / tp_world
q_var, k_var = qk_var.chunk(2, -1)
q_out = (q_fp32 * rsqrt(q_var + eps) * q_weight).to(orig_dtype)
k_out = (k_fp32 * rsqrt(k_var + eps) * k_weight).to(orig_dtype)
return q_out, k_out, v
Replacement (pure, no in-place on qkv/q/k):
q_out, k_out = minimax_qk_norm_fused(qkv, q_weight, k_weight, workspace, ...)
v = qkv.split([q_size, kv_size, kv_size], -1)[2]
return q_out, k_out, v
is_applicable_for_range: only fires for compile_range.end <= max_decode_tokens
so that large prefill batches fall through to the original forward_qk (= main).
"""
import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.logger import init_logger
from vllm.utils.torch_utils import direct_register_custom_op
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
MAX_TOKEN_NUM = 2048
_MINIMAX_QK_NORM_FUSED_OP = None
if hasattr(torch.ops._C, "minimax_allreduce_rms_qk"):
def _minimax_qk_norm_fused(
qkv: torch.Tensor,
norm_weight_q: torch.Tensor,
norm_weight_k: torch.Tensor,
q_size: int,
kv_size: int,
rank: int,
nranks: int,
eps: float,
max_tokens: int,
) -> tuple[torch.Tensor, torch.Tensor]:
from vllm.distributed.parallel_state import get_tp_group
from vllm.model_executor.layers.mamba.lamport_workspace import (
get_allreduce_workspace,
)
workspace = get_allreduce_workspace(
rank=rank,
world_size=nranks,
max_tokens=max_tokens,
process_group=get_tp_group().cpu_group,
)
return torch.ops._C.minimax_allreduce_rms_qk(
qkv,
norm_weight_q,
norm_weight_k,
workspace,
q_size,
kv_size,
rank,
nranks,
eps,
)
def _minimax_qk_norm_fused_fake(
qkv: torch.Tensor,
norm_weight_q: torch.Tensor,
norm_weight_k: torch.Tensor,
q_size: int,
kv_size: int,
rank: int,
nranks: int,
eps: float,
max_tokens: int,
) -> tuple[torch.Tensor, torch.Tensor]:
T = qkv.shape[0]
return (
torch.empty([T, q_size], dtype=qkv.dtype, device=qkv.device),
torch.empty([T, kv_size], dtype=qkv.dtype, device=qkv.device),
)
direct_register_custom_op(
op_name="minimax_qk_norm_fused",
op_func=_minimax_qk_norm_fused,
fake_impl=_minimax_qk_norm_fused_fake,
mutates_args=[],
)
_MINIMAX_QK_NORM_FUSED_OP = torch.ops.vllm.minimax_qk_norm_fused.default
class MiniMaxQKNormPattern:
"""
Match the forward_qk allreduce+rms pattern and replace with Lamport kernel.
"""
def __init__(
self,
q_size: int,
kv_size: int,
eps: float,
tp_world: int,
tp_rank: int,
max_tokens: int,
dtype: torch.dtype,
device: str | None,
) -> None:
self.q_size = q_size
self.kv_size = kv_size
self.eps = eps
self.tp_world = tp_world
self.tp_rank = tp_rank
self.max_tokens = max_tokens
self.dtype = dtype
self.device = device
def get_inputs(self) -> list[torch.Tensor]:
T = 4
qkv = torch.empty(
[T, self.q_size + 2 * self.kv_size],
device=self.device,
dtype=self.dtype,
)
q_weight = torch.empty([self.q_size], device=self.device, dtype=self.dtype)
k_weight = torch.empty([self.kv_size], device=self.device, dtype=self.dtype)
return [qkv, q_weight, k_weight]
def register(self, pm_pass: PatternMatcherPass) -> None:
q_size = self.q_size
kv_size = self.kv_size
eps = self.eps
tp_world = self.tp_world
max_tokens = self.max_tokens
tp_rank = self.tp_rank
dtype = self.dtype
def pattern(
qkv: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
q_fp32 = q.to(torch.float32)
k_fp32 = k.to(torch.float32)
q_var = q_fp32.pow(2).mean(dim=-1, keepdim=True)
k_var = k_fp32.pow(2).mean(dim=-1, keepdim=True)
qk_var = torch.cat([q_var, k_var], dim=-1)
qk_var = tensor_model_parallel_all_reduce(qk_var) / tp_world
q_var, k_var = qk_var.chunk(2, dim=-1)
q_out = (q_fp32 * torch.rsqrt(q_var + eps) * q_weight).to(dtype)
k_out = (k_fp32 * torch.rsqrt(k_var + eps) * k_weight).to(dtype)
return q_out, k_out, v
def replacement(
qkv: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert _MINIMAX_QK_NORM_FUSED_OP is not None
q_out, k_out = torch.ops.vllm.minimax_qk_norm_fused(
qkv,
q_weight,
k_weight,
q_size,
kv_size,
tp_rank,
tp_world,
eps,
max_tokens,
)
_, _, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
return q_out, k_out, v
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
# Second pattern: three separate split_with_sizes nodes (one per output),
# each with _users=1. This occurs when the QKV projection uses a
# functional GEMM kernel (e.g. cutlass_scaled_mm via auto_functionalized),
# which causes inductor to generate one split per consumer.
def pattern_split3(
qkv: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q = qkv.split([q_size, kv_size, kv_size], dim=-1)[0]
k = qkv.split([q_size, kv_size, kv_size], dim=-1)[1]
v = qkv.split([q_size, kv_size, kv_size], dim=-1)[2]
q_fp32 = q.to(torch.float32)
k_fp32 = k.to(torch.float32)
q_var = q_fp32.pow(2).mean(dim=-1, keepdim=True)
k_var = k_fp32.pow(2).mean(dim=-1, keepdim=True)
qk_var = torch.cat([q_var, k_var], dim=-1)
qk_var = tensor_model_parallel_all_reduce(qk_var) / tp_world
q_var, k_var = qk_var.chunk(2, dim=-1)
q_out = (q_fp32 * torch.rsqrt(q_var + eps) * q_weight).to(dtype)
k_out = (k_fp32 * torch.rsqrt(k_var + eps) * k_weight).to(dtype)
return q_out, k_out, v
pm.register_replacement(
pattern_split3, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class MiniMaxQKNormPass(VllmPatternMatcherPass):
"""
Replace forward_qk allreduce+norm with the Lamport fused kernel.
Only applied for decode-size compile ranges (small token counts).
"""
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.disabled = True
if _MINIMAX_QK_NORM_FUSED_OP is None:
logger.warning_once(
"minimax_allreduce_rms_qk op not found, MiniMaxQKNormPass disabled."
)
return
tp_world = get_tensor_model_parallel_world_size()
if tp_world <= 1:
logger.warning_once("MiniMaxQKNormPass disabled: tp_size <= 1.")
return
if config.model_config is None:
logger.warning_once("MiniMaxQKNormPass disabled: no model_config.")
return
hf_cfg = config.model_config.hf_config
model_name = getattr(hf_cfg, "architectures", "")[0]
if model_name != "MiniMaxM2ForCausalLM":
return
num_attention_heads = getattr(hf_cfg, "num_attention_heads", 0)
num_key_value_heads = getattr(hf_cfg, "num_key_value_heads", 0)
hidden_size = getattr(hf_cfg, "hidden_size", 0)
head_dim = getattr(hf_cfg, "head_dim", 0)
eps: float = getattr(hf_cfg, "rms_norm_eps", 1e-6)
if (
num_attention_heads != 48
or num_key_value_heads != 8
or hidden_size != 3072
or head_dim != 128
):
logger.warning_once(
"MiniMaxQKNormPass disabled: cannot infer model info from hf_config."
)
return
num_heads_per_rank = num_attention_heads // tp_world
num_kv_heads_per_rank = max(1, num_key_value_heads // tp_world)
q_size = num_heads_per_rank * head_dim
kv_size = num_kv_heads_per_rank * head_dim
self.max_token_num = min(
MAX_TOKEN_NUM, config.scheduler_config.max_num_batched_tokens
)
tp_rank = get_tensor_model_parallel_rank()
# Allocate Lamport workspace first.
from vllm.distributed.parallel_state import get_tp_group
from vllm.model_executor.layers.mamba.lamport_workspace import (
get_allreduce_workspace,
)
get_allreduce_workspace(
rank=tp_rank,
world_size=tp_world,
max_tokens=self.max_token_num,
process_group=get_tp_group().cpu_group,
)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="minimax_qk_norm_pass"
)
self._register_patterns(q_size, kv_size, eps, tp_world, tp_rank)
self.dump_patterns(config, self.patterns)
self.disabled = False
@enable_fake_mode
def _register_patterns(
self,
q_size: int,
kv_size: int,
eps: float,
tp_world: int,
tp_rank: int,
) -> None:
MiniMaxQKNormPattern(
q_size=q_size,
kv_size=kv_size,
eps=eps,
tp_world=tp_world,
tp_rank=tp_rank,
max_tokens=self.max_token_num,
dtype=self.model_dtype,
device=self.device,
).register(self.patterns)
def is_applicable_for_range(self, compile_range: Range) -> bool:
if self.disabled:
return False
return bool(compile_range.end <= self.max_token_num)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
if self.disabled:
return
self.matched_count = self.patterns.apply(graph)
logger.debug("MiniMaxQKNormPass replaced %s patterns", self.matched_count)
def uuid(self) -> str:
return VllmInductorPass.hash_source(self, MiniMaxQKNormPattern)
...@@ -38,6 +38,7 @@ if current_platform.is_cuda_alike(): ...@@ -38,6 +38,7 @@ if current_platform.is_cuda_alike():
if current_platform.is_cuda(): if current_platform.is_cuda():
from .fusion.allreduce_rms_fusion import AllReduceFusionPass from .fusion.allreduce_rms_fusion import AllReduceFusionPass
from .fusion.collective_fusion import AsyncTPPass from .fusion.collective_fusion import AsyncTPPass
from .fusion.minimax_qk_norm_fusion import MiniMaxQKNormPass
from .inductor_pass import ( from .inductor_pass import (
CustomGraphPass, CustomGraphPass,
...@@ -137,6 +138,9 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc] ...@@ -137,6 +138,9 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
if self.pass_config.fuse_allreduce_rms: if self.pass_config.fuse_allreduce_rms:
self.passes += [AllReduceFusionPass(config)] self.passes += [AllReduceFusionPass(config)]
if self.pass_config.fuse_minimax_qk_norm:
self.passes += [MiniMaxQKNormPass(config)]
if self.pass_config.fuse_norm_quant: if self.pass_config.fuse_norm_quant:
self.passes += [RMSNormQuantFusionPass(config)] self.passes += [RMSNormQuantFusionPass(config)]
if rocm_aiter_ops.is_enabled(): if rocm_aiter_ops.is_enabled():
......
...@@ -134,6 +134,8 @@ class PassConfig: ...@@ -134,6 +134,8 @@ class PassConfig:
"""Enable async TP.""" """Enable async TP."""
fuse_allreduce_rms: bool = None # type: ignore[assignment] fuse_allreduce_rms: bool = None # type: ignore[assignment]
"""Enable flashinfer allreduce fusion.""" """Enable flashinfer allreduce fusion."""
fuse_minimax_qk_norm: bool = None # type: ignore[assignment]
"""Enable fused allreduce+RMSNorm for MiniMax QK norm."""
enable_qk_norm_rope_fusion: bool = False enable_qk_norm_rope_fusion: bool = False
"""Enable fused Q/K RMSNorm + RoPE pass.""" """Enable fused Q/K RMSNorm + RoPE pass."""
......
...@@ -1627,6 +1627,22 @@ class VllmConfig: ...@@ -1627,6 +1627,22 @@ class VllmConfig:
compile_range_end, compile_range_end,
) )
if compilation_config.pass_config.fuse_minimax_qk_norm:
from vllm.compilation.passes.fusion.minimax_qk_norm_fusion import (
MAX_TOKEN_NUM,
)
max_token_num = min(
MAX_TOKEN_NUM, self.scheduler_config.max_num_batched_tokens
)
if compile_range_end is not None and max_token_num < compile_range_end:
computed_compile_ranges_endpoints.append(max_token_num)
else:
logger.debug(
"Max num batched tokens below MiniMax QK norm fusion threshold, "
"MiniMax QK norm fusion enabled for all num_tokens."
)
if compilation_config.compile_ranges_endpoints is not None: if compilation_config.compile_ranges_endpoints is not None:
for x in compilation_config.compile_ranges_endpoints: for x in compilation_config.compile_ranges_endpoints:
assert isinstance(x, int) assert isinstance(x, int)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import array
import contextlib
import struct
import sys
import threading
import torch
try:
from cuda.bindings import runtime as cudart
except ImportError:
from cuda import cudart
_ALIGN = 1 << 21 # 2 MiB — CUDA IPC allocation alignment
# ---------------------------------------------------------------------------
# CUDA helpers
# ---------------------------------------------------------------------------
def _check(error):
"""Raise on CUDA runtime error."""
success = getattr(cudart.cudaError_t, "cudaSuccess", None) or cudart.cudaError_t(0)
if error != success:
raise RuntimeError(f"CUDA runtime error: {error}")
def _cuda_malloc(size: int):
aligned = ((size + _ALIGN - 1) >> 21) << 21
err, ptr = cudart.cudaMalloc(aligned)
_check(err)
return ptr, aligned
def _cuda_free(ptr: int):
if ptr:
_check(cudart.cudaFree(ptr)[0])
def _cuda_memset_zero(ptr: int, size: int):
_check(cudart.cudaMemset(ptr, 0, size)[0])
def _cuda_memcpy_d2d(dst: int, src: int, size: int):
_check(
cudart.cudaMemcpy(
dst, src, size, cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice
)[0]
)
# ---------------------------------------------------------------------------
# IPC buffer
# ---------------------------------------------------------------------------
class IpcBuffer:
"""
Allocates CUDA device memory and exchanges IPC handles with all ranks
so that every rank holds a valid device pointer to every other rank's buffer.
"""
def __init__(self, rank: int, world_size: int, size: int, process_group=None):
self.rank = rank
self.world_size = world_size
self.peer_ptrs: list[int] = [0] * world_size
self.local_ptr: int = 0
self._alive = False
if size <= 0:
return
self.local_ptr, _ = _cuda_malloc(size)
_cuda_memset_zero(self.local_ptr, size)
self._alive = True
# --- exchange IPC handles via torch.distributed ---
err, local_handle = cudart.cudaIpcGetMemHandle(self.local_ptr)
_check(err)
all_handles: list[bytes | None] = [None] * world_size
torch.distributed.all_gather_object(
all_handles, bytes(local_handle.reserved), group=process_group
)
for r in range(world_size):
if r == rank:
self.peer_ptrs[r] = self.local_ptr
else:
handle = cudart.cudaIpcMemHandle_t()
handle.reserved = all_handles[r]
err, ptr = cudart.cudaIpcOpenMemHandle(
handle, cudart.cudaIpcMemLazyEnablePeerAccess
)
_check(err)
self.peer_ptrs[r] = ptr
def serialize(self) -> list[int]:
"""Return peer pointers as a list of int64 values (one per rank)."""
raw = b""
for ptr in self.peer_ptrs:
raw += struct.pack("P", ptr)
return array.array("Q", raw).tolist()
def cleanup(self):
if not self._alive:
return
self._alive = False
for r in range(self.world_size):
if self.peer_ptrs[r] == 0:
continue
if r == self.rank:
_cuda_free(self.peer_ptrs[r])
else:
with contextlib.suppress(RuntimeError):
_check(cudart.cudaIpcCloseMemHandle(self.peer_ptrs[r])[0])
self.peer_ptrs[r] = 0
self.local_ptr = 0
def __del__(self):
if not sys.is_finalizing():
self.cleanup()
# ---------------------------------------------------------------------------
# Lamport negative-zero initialization
# ---------------------------------------------------------------------------
def _lamport_fill_neg_zero(device_ptr: int, size_bytes: int):
"""
Fill device memory with IEEE-754 negative zero (-0.0f = 0x80000000).
This is the "slot empty" sentinel for the Lamport protocol: the kernel
spin-waits until a value is *not* negative zero.
"""
if size_bytes == 0 or device_ptr == 0:
return
n_floats = size_bytes // 4
# torch preserves -0.0 in IEEE-754
fill = torch.full((n_floats,), -0.0, dtype=torch.float32, device="cuda")
_cuda_memcpy_d2d(device_ptr, fill.data_ptr(), size_bytes)
del fill
# ---------------------------------------------------------------------------
# LamportWorkspace — the main class
# ---------------------------------------------------------------------------
class LamportWorkspace:
"""
Self-contained workspace for Lamport-based cross-GPU AllReduce.
Parameters
----------
rank : int
Local rank (0-based).
world_size : int
Total number of ranks in the TP group.
comm_size : int
Size in bytes of *one* Lamport buffer slot. The total IPC allocation
per rank is ``3 * comm_size`` (triple-buffering). Must be large enough
to hold the per-slot data written by the kernel. Use
``compute_comm_size_for_minimax()`` for a safe default.
process_group : optional
``torch.distributed`` process group for IPC handle exchange.
``None`` uses the default group.
"""
def __init__(self, rank: int, world_size: int, comm_size: int, process_group=None):
assert world_size >= 2, "Lamport workspace requires at least 2 ranks"
assert comm_size > 0, "comm_size must be positive"
self.rank = rank
self.world_size = world_size
self.comm_size = comm_size
# 1) Lamport triple-buffer (the only IPC memory the kernel reads/writes)
lamport_total = 3 * comm_size
self._lamport = IpcBuffer(rank, world_size, lamport_total, process_group)
_lamport_fill_neg_zero(self._lamport.local_ptr, lamport_total)
# 2) flag_buffer on device: int32[3] = {counter, unused, lamport_flag}
# counter — used for block-level sync inside the kernel
# unused — reserved (index 1)
# lamport_flag — triple-buffer rotation index (0 → 1 → 2 → 0 …)
self._flag_buf = torch.zeros(3, dtype=torch.int32, device="cuda")
# 3) layout_buffer on device: int64[2] = {clear_size, comm_size}
# clear_size — bytes to clear from *previous* slot (set by kernel)
# comm_size — size of one triple-buffer slot
self._layout_buf = torch.tensor(
[0, comm_size], dtype=torch.int64, device="cuda"
)
# 4) Assemble device-side void* pointer array
N = world_size
ptrs: list[int] = []
ptrs += [0] * N # [0 .. N-1] ipc_buffers (placeholder)
ptrs += [0] * N # [N .. 2N-1] ipc_barriers (placeholder)
ptrs += self._lamport.serialize() # [2N .. 3N-1] lamport peer ptrs
ptrs.append(self._flag_buf.data_ptr()) # [3N] flag_buffer
ptrs.append(self._layout_buf.data_ptr()) # [3N+1] layout_buffer
self._workspace = torch.tensor(ptrs, dtype=torch.int64, device="cuda")
@property
def workspace(self) -> torch.Tensor:
"""Device tensor (int64) that can be passed to the kernel
as ``void** workspace``."""
return self._workspace
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
@staticmethod
def compute_comm_size_for_minimax(
max_tokens: int,
world_size: int,
fused_qk: bool = True,
) -> int:
"""
Return a safe ``comm_size`` (in bytes) for MiniMaxReduceRMSKernel.
The kernel stores per-token variance scalars in the Lamport buffer:
- single-matrix path: ``world_size × max_tokens × 4`` bytes per slot
- fused Q+K path: ``world_size × 2 × ceil(max_tokens/4) × 16`` bytes per slot
The returned value is rounded up to 2 MiB alignment.
"""
if fused_qk:
groups = (max_tokens + 3) // 4
slot_bytes = world_size * 2 * groups * 16 # 16 = sizeof(float4)
else:
slot_bytes = world_size * max_tokens * 4 # 4 = sizeof(float)
return ((slot_bytes + _ALIGN - 1) >> 21) << 21
def cleanup(self):
if hasattr(self, "_lamport"):
self._lamport.cleanup()
def __del__(self):
if not sys.is_finalizing():
self.cleanup()
def __repr__(self):
return (
f"LamportWorkspace(rank={self.rank}, world_size={self.world_size}, "
f"comm_size={self.comm_size})"
)
# ---------------------------------------------------------------------------
# Cached convenience function (mirrors TRT-LLM's get_allreduce_workspace)
# ---------------------------------------------------------------------------
_cache_lock = threading.Lock()
_workspace_cache: dict = {}
def get_allreduce_workspace(
rank: int,
world_size: int,
comm_size: int | None = None,
max_tokens: int = 16384,
process_group=None,
) -> torch.Tensor:
"""
Return a cached workspace tensor for the given (rank, world_size) pair.
On first call the workspace is allocated and IPC handles are exchanged;
subsequent calls with the same arguments return the cached tensor.
Parameters
----------
rank, world_size : int
TP rank and TP size.
comm_size : int, optional
Explicit slot size in bytes. If ``None``, computed automatically
from ``max_tokens`` and ``world_size`` (fused Q+K path).
max_tokens : int
Maximum number of tokens per batch (used when ``comm_size is None``).
process_group : optional
``torch.distributed`` process group.
"""
if comm_size is None:
comm_size = LamportWorkspace.compute_comm_size_for_minimax(
max_tokens, world_size, fused_qk=True
)
pg_id = id(process_group) if process_group is not None else 0
key = (rank, world_size, comm_size, pg_id)
with _cache_lock:
if key not in _workspace_cache:
ws = LamportWorkspace(rank, world_size, comm_size, process_group)
_workspace_cache[key] = ws
return _workspace_cache[key].workspace
...@@ -233,9 +233,7 @@ class MiniMaxM2Attention(nn.Module): ...@@ -233,9 +233,7 @@ class MiniMaxM2Attention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = MiniMaxText01RMSNormTP.forward_qk( q, k = MiniMaxText01RMSNormTP.forward_qk(self.q_norm, self.k_norm, q, k)
self.q_norm, self.k_norm, q.contiguous(), k.contiguous()
)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment