"vllm/entrypoints/openai/engine/protocol.py" did not exist on "53a56e658b9ab1eabb7339d3001dc3fd9178dd21"
Commit c5460385 authored by Jee Jee Li's avatar Jee Jee Li Committed by khluu
Browse files

[Kernel] Porting the TRTLLM minimax_allreduce_rms kernels (#37045)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
(cherry picked from commit ecd1ea13

)
Signed-off-by: default avatarkhluu <khluu000@gmail.com>
parent 4bbb8faa
......@@ -10,7 +10,20 @@ steps:
- tests/kernels/test_top_k_per_row.py
- tests/kernels/test_concat_mla_q.py
commands:
- pytest -v -s kernels/core kernels/test_top_k_per_row.py kernels/test_concat_mla_q.py
- pytest -v -s kernels/core --ignore=kernels/core/test_minimax_reduce_rms.py kernels/test_top_k_per_row.py kernels/test_concat_mla_q.py
- label: Kernels MiniMax Reduce RMS Test (2 GPUs)
timeout_in_minutes: 15
num_devices: 2
device: h100
source_file_dependencies:
- csrc/minimax_reduce_rms_kernel.cu
- csrc/minimax_reduce_rms_kernel.h
- vllm/model_executor/layers/mamba/linear_attn.py
- vllm/model_executor/layers/mamba/lamport_workspace.py
- tests/kernels/core/test_minimax_reduce_rms.py
commands:
- pytest -v -s kernels/core/test_minimax_reduce_rms.py
- label: Kernels Attention Test %N
timeout_in_minutes: 35
......
......@@ -306,6 +306,8 @@ set(VLLM_EXT_SRC
"csrc/torch_bindings.cpp")
if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_EXT_SRC "csrc/minimax_reduce_rms_kernel.cu")
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
......
/*
* Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cooperative_groups.h>
#include <cuda_runtime.h>
#include <torch/cuda.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_compat.h"
#include "cuda_utils.h"
#include "core/registration.h"
#include "minimax_reduce_rms_kernel.h"
#include <algorithm>
#define FINAL_MASK 0xffffffff
#define MINIMAX_REDUCE_RMS_WARP_SIZE 32
namespace vllm {
namespace tensorrt_llm {
template <int NRanks>
struct LamportComm {
__device__ __forceinline__ LamportComm(void** workspace, int rank) {
counter_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[0];
flag_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[2];
clear_ptr = &reinterpret_cast<int64_t*>(workspace[NRanks * 3 + 1])[0];
flag_value = *flag_ptr;
auto comm_size = reinterpret_cast<int64_t*>(workspace[NRanks * 3 + 1])[1];
clear_size = *clear_ptr;
int data_offset = flag_value % 3;
int clear_offset = (flag_value + 2) % 3;
for (int r = 0; r < NRanks; ++r) {
data_bufs[r] = reinterpret_cast<uint8_t*>(workspace[2 * NRanks + r]) +
data_offset * comm_size;
}
clear_buf = reinterpret_cast<uint8_t*>(workspace[2 * NRanks + rank]) +
clear_offset * comm_size;
__syncthreads();
if (threadIdx.x == 0) {
atomicAdd(counter_ptr, 1);
}
}
__device__ __forceinline__ void update(int64_t new_clear_size) {
if (blockIdx.x == 0 && threadIdx.x == 0) {
while (*reinterpret_cast<int volatile*>(counter_ptr) != gridDim.x) {
}
*flag_ptr = (flag_value + 1) % 3;
*clear_ptr = new_clear_size;
*counter_ptr = 0;
}
}
int* counter_ptr;
int* flag_ptr;
int64_t* clear_ptr;
uint8_t* data_bufs[NRanks];
uint8_t* clear_buf;
int64_t clear_size;
int flag_value;
};
__device__ __forceinline__ bool is_neg_zero(float v) {
return *reinterpret_cast<uint32_t*>(&v) == 0x80000000;
}
__device__ __forceinline__ bool is_neg_zero(float4 v) {
return is_neg_zero(v.x) || is_neg_zero(v.y) || is_neg_zero(v.z) ||
is_neg_zero(v.w);
}
__device__ __forceinline__ float4 get_neg_zero() {
float4 vec;
#pragma unroll
for (int i = 0; i < 4; ++i) {
reinterpret_cast<uint32_t*>(&vec)[i] = 0x80000000;
}
return vec;
}
template <int Dim>
__device__ __forceinline__ float rms_rsqrt(float& v, float eps) {
constexpr float kInvDim = 1.0F / static_cast<float>(Dim);
v = rsqrtf((v * kInvDim) + eps);
return v;
}
template <int Dim>
__device__ __forceinline__ float4 rms_rsqrt(float4& v, float eps) {
constexpr float kInvDim = 1.0F / static_cast<float>(Dim);
v.x = rsqrtf((v.x * kInvDim) + eps);
v.y = rsqrtf((v.y * kInvDim) + eps);
v.z = rsqrtf((v.z * kInvDim) + eps);
v.w = rsqrtf((v.w * kInvDim) + eps);
return v;
}
__device__ __forceinline__ float4 ld_global_volatile(float4* addr) {
float4 val;
asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];"
: "=f"(val.x), "=f"(val.y), "=f"(val.z), "=f"(val.w)
: "l"(addr));
return val;
}
__device__ __forceinline__ float ld_global_volatile(float* addr) {
float val;
asm volatile("ld.volatile.global.f32 %0, [%1];" : "=f"(val) : "l"(addr));
return val;
}
// Used by the scalar (non-float4) kernel only
template <typename T, int NUM>
__inline__ __device__ T warpReduceSumV2(T* val) {
#pragma unroll
for (int i = 0; i < NUM; i++) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32);
}
return (T)(0.0f);
}
template <typename T, int NUM>
__inline__ __device__ T blockReduceSumV2(T* val) {
static __shared__ T shared[NUM][33];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
warpReduceSumV2<T, NUM>(val);
if (lane == 0) {
#pragma unroll
for (int i = 0; i < NUM; i++) {
shared[i][wid] = val[i];
}
}
__syncthreads();
bool is_mask = threadIdx.x < (blockDim.x / 32.f);
#pragma unroll
for (int i = 0; i < NUM; i++) {
val[i] = is_mask ? shared[i][lane] : (T)(0.0f);
}
warpReduceSumV2<T, NUM>(val);
return (T)0.0f;
}
// for float4 version
template <uint32_t kNumThreads, typename T, int ArraySize = 4>
__device__ __forceinline__ void local_warp_reduce_sum_array(
T* value_ptr, uint32_t active_mask = 0xffffffffu) {
static_assert(kNumThreads >= 1 &&
kNumThreads <= MINIMAX_REDUCE_RMS_WARP_SIZE);
#pragma unroll
for (int i = 0; i < ArraySize; ++i) {
#pragma unroll
for (int mask = kNumThreads / 2; mask > 0; mask >>= 1) {
value_ptr[i] += __shfl_xor_sync(active_mask, value_ptr[i], mask,
MINIMAX_REDUCE_RMS_WARP_SIZE);
}
}
}
constexpr int next_pow2(int val) {
int result = 1;
while (result < val) {
result <<= 1;
}
return result;
}
// ---------------------------------------------------------------------------
template <typename DType>
class IndexHelper {
public:
__device__ __forceinline__ IndexHelper(MiniMaxReduceRMSParams const& params) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
namespace cg = cooperative_groups;
cg::cluster_group cluster = cg::this_cluster();
cg::grid_group grid = cg::this_grid();
token_id = grid.cluster_rank();
access_id_in_token = cluster.thread_rank();
token_stride = grid.num_clusters();
#else
token_id = blockIdx.x;
access_id_in_token = threadIdx.x;
token_stride = gridDim.x;
#endif
access_id = token_id * params.hidden_dim / kElemsPerAccess<DType> +
access_id_in_token;
access_stride = token_stride * params.hidden_dim / kElemsPerAccess<DType>;
tot_access = params.size_q / kElemsPerAccess<DType>;
}
int token_id;
int access_id_in_token;
int token_stride;
int access_id;
int access_stride;
int tot_access;
};
/**
* this kernel is used to for minimax attention module
* input tensor [total_tokens, hidden_dim / tp_size], fp32
* rms weight [hidden_dim / tp_size], bf16
step 1: reduce from single rank to get the variance sum (reduce(input^2,
dim=-1)) step 2: reduce from all ranks to get the variance sum
(all_reduce(variance_sum)) step 3: calculate the rms norm (input *
rsqrt(variance + eps)) in this case, max hidden_dim is 6144 (float data), for
each token, we only need 6144 / 4 / tp_size = (1536 / tp_size) threads so we can
assume cluster size is 1 (tp_size >= 2)
*/
template <typename DType, int NRanks>
__global__ void __launch_bounds__(1024)
minimax_reduce_rms_kernel_lamport(MiniMaxReduceRMSParams params) {
IndexHelper<DType> index_helper(params);
int token_id = index_helper.token_id;
int access_id_in_token = index_helper.access_id_in_token;
int token_stride = index_helper.token_stride;
int access_id = index_helper.access_id;
int access_stride = index_helper.access_stride;
int tot_access = index_helper.tot_access;
int tot_tokens = params.size_q / params.hidden_dim;
float4 clear_vec = get_neg_zero();
LamportComm<NRanks> comm(params.workspace, params.rank);
int clear_access = comm.clear_size / kElemsPerAccess<DType>;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
for (int idx = access_id; idx < tot_access;
idx += access_stride, token_id += token_stride) {
alignas(16) DType vals[kElemsPerAccess<DType>];
float sum_variance = 0.F;
*reinterpret_cast<float4*>(vals) =
reinterpret_cast<float4*>(params.allreduce_in)[idx];
#pragma unroll
for (int i = 0; i < kElemsPerAccess<DType>; ++i) {
sum_variance += static_cast<float>(vals[i]) * static_cast<float>(vals[i]);
}
blockReduceSumV2<float, 1>(&sum_variance);
if (is_neg_zero(sum_variance)) {
sum_variance = 0.F;
}
if (threadIdx.x == 0) {
for (int r = 0; r < NRanks; ++r) {
reinterpret_cast<float*>(
comm.data_bufs[r])[(params.rank * tot_tokens) + token_id] =
(sum_variance);
}
}
bool done = false;
float vars_all_ranks[NRanks];
while (!done) {
done = true;
#pragma unroll
for (int r = 0; r < NRanks; ++r) {
vars_all_ranks[r] = ld_global_volatile(&reinterpret_cast<float*>(
comm.data_bufs[params.rank])[(r * tot_tokens) + token_id]);
done &= !is_neg_zero(vars_all_ranks[r]);
}
}
sum_variance = 0.F;
#pragma unroll
for (int r = 0; r < NRanks; ++r) {
sum_variance += vars_all_ranks[r];
}
DType norm_weight[kElemsPerAccess<DType>];
*reinterpret_cast<typename ElemsPerAccess<DType>::vec_type*>(norm_weight) =
reinterpret_cast<typename ElemsPerAccess<DType>::vec_type*>(
params.rms_gamma)[access_id_in_token];
#pragma unroll
for (int i = 0; i < kElemsPerAccess<DType>; ++i) {
vals[i] = static_cast<DType>(
static_cast<float>(vals[i]) *
rsqrtf(
(sum_variance / static_cast<float>(params.hidden_dim) / NRanks) +
params.rms_eps) *
static_cast<float>(norm_weight[i]));
}
reinterpret_cast<float4*>(params.rms_norm_out)[idx] =
*reinterpret_cast<float4*>(vals);
}
for (int idx = access_id; idx < clear_access; idx += access_stride) {
reinterpret_cast<float4*>(comm.clear_buf)[idx] = clear_vec;
}
comm.update(params.size_q * NRanks);
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}
/**
* Float4 variant: process 4 rows at once, allreduce variance sums as float4 for
* better memory coalescing. sum_variance is always float; applies to all DTypes
* (half, bf16, float). When tot_tokens % 4 != 0, the last group pads rows with
* zeros; padded rows are not written to rms_norm_out. IsQK: when true, process
* Q+K in one loop with doubled comm buffer; when false, single-matrix (Q only).
*/
template <typename DType, int NRanks, int OriginQDim, int OriginKDim>
__global__ void __launch_bounds__(1024)
minimax_reduce_qk_rms_kernel_lamport_float4(MiniMaxReduceRMSParams params) {
// Compile-time per-rank dimensions
constexpr int RankQDim = OriginQDim / NRanks;
constexpr int RankKDim = OriginKDim / NRanks;
// Threads needed to cover one row of Q / K with float4 accesses
constexpr int ThreadsPerRowQ = RankQDim / kElemsPerAccess<DType>;
constexpr int ThreadsPerRowK = RankKDim / kElemsPerAccess<DType>;
// Number of warps dedicated to Q / K
constexpr int NumWarpQ = (ThreadsPerRowQ + MINIMAX_REDUCE_RMS_WARP_SIZE - 1) /
MINIMAX_REDUCE_RMS_WARP_SIZE;
constexpr int NumWarpK = (ThreadsPerRowK + MINIMAX_REDUCE_RMS_WARP_SIZE - 1) /
MINIMAX_REDUCE_RMS_WARP_SIZE;
int tot_tokens = params.size_q / RankQDim;
int tot_groups = (tot_tokens + 3) / 4; // ceiling; last group may be partial
// Memory strides for strided qkv tensors (elements -> float4-access units)
int access_stride_q = (params.stride_q > 0 ? params.stride_q : RankQDim) /
kElemsPerAccess<DType>;
int access_stride_k = (params.stride_k > 0 ? params.stride_k : RankKDim) /
kElemsPerAccess<DType>;
// Output strides: default to contiguous (hidden_dim / hidden_dim_k)
int access_stride_q_out =
(params.stride_q_out > 0 ? params.stride_q_out : params.hidden_dim) /
kElemsPerAccess<DType>;
int access_stride_k_out =
(params.stride_k_out > 0 ? params.stride_k_out : params.hidden_dim_k) /
kElemsPerAccess<DType>;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
namespace cg = cooperative_groups;
cg::cluster_group cluster = cg::this_cluster();
cg::grid_group grid = cg::this_grid();
int group_id = grid.cluster_rank();
int access_id_in_token = cluster.thread_rank();
int group_stride = grid.num_clusters();
#else
int group_id = blockIdx.x;
int access_id_in_token = threadIdx.x;
int group_stride = gridDim.x;
#endif
bool is_q = (access_id_in_token < NumWarpQ * MINIMAX_REDUCE_RMS_WARP_SIZE);
int k_thread_idx =
access_id_in_token - (NumWarpQ * MINIMAX_REDUCE_RMS_WARP_SIZE);
bool is_valid_q = (access_id_in_token < ThreadsPerRowQ);
bool is_valid_k = (k_thread_idx >= 0 && k_thread_idx < ThreadsPerRowK);
float4 clear_vec = get_neg_zero();
// Shared memory for two-level block reduction and scale broadcast
__shared__ float block_reduce_sum[4][MINIMAX_REDUCE_RMS_WARP_SIZE + 1];
__shared__ float global_scale_q[4];
__shared__ float global_scale_k[4];
LamportComm<NRanks> comm(params.workspace, params.rank);
DType norm_weight[kElemsPerAccess<DType>]{};
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
if (is_q) {
if (is_valid_q) {
*reinterpret_cast<typename ElemsPerAccess<DType>::vec_type*>(
norm_weight) =
reinterpret_cast<typename ElemsPerAccess<DType>::vec_type const*>(
params.rms_gamma)[access_id_in_token];
}
} else {
if (is_valid_k) {
*reinterpret_cast<typename ElemsPerAccess<DType>::vec_type*>(
norm_weight) =
reinterpret_cast<typename ElemsPerAccess<DType>::vec_type const*>(
params.rms_gamma_k)[k_thread_idx];
}
}
// Main loop: process one group of 4 tokens per iteration.
for (int g = group_id; g < tot_groups; g += group_stride) {
alignas(16) DType vals[4][kElemsPerAccess<DType>]{};
float warp_sum_variance[4]{0.F, 0.F, 0.F, 0.F};
if (is_q) {
#pragma unroll
for (int row = 0; row < 4; ++row) {
int token_r = g * 4 + row;
if (token_r >= tot_tokens || !is_valid_q) {
continue;
}
int idx_r = token_r * access_stride_q + access_id_in_token;
*reinterpret_cast<float4*>(&vals[row][0]) =
reinterpret_cast<float4 const*>(params.allreduce_in)[idx_r];
#pragma unroll
for (int i = 0; i < kElemsPerAccess<DType>; ++i) {
float x = static_cast<float>(vals[row][i]);
warp_sum_variance[row] += x * x;
}
}
} else {
#pragma unroll
for (int row = 0; row < 4; ++row) {
int token_r = g * 4 + row;
if (token_r >= tot_tokens || !is_valid_k) {
continue;
}
int idx_r = token_r * access_stride_k + k_thread_idx;
*reinterpret_cast<float4*>(&vals[row][0]) =
reinterpret_cast<float4 const*>(params.allreduce_in_k)[idx_r];
#pragma unroll
for (int i = 0; i < kElemsPerAccess<DType>; ++i) {
float x = static_cast<float>(vals[row][i]);
warp_sum_variance[row] += x * x;
}
}
}
local_warp_reduce_sum_array<MINIMAX_REDUCE_RMS_WARP_SIZE, float, 4>(
warp_sum_variance);
// Warp lane 0 writes its warp's partial sum to shared memory
int lane = threadIdx.x & (MINIMAX_REDUCE_RMS_WARP_SIZE - 1);
if (lane == 0) {
#pragma unroll
for (int t = 0; t < 4; ++t) {
block_reduce_sum[t][threadIdx.x / MINIMAX_REDUCE_RMS_WARP_SIZE] =
warp_sum_variance[t];
}
}
__syncthreads();
int tid = threadIdx.x;
if (tid < MINIMAX_REDUCE_RMS_WARP_SIZE) {
constexpr int kNumWarpQPow2 =
(next_pow2(NumWarpQ) > NRanks) ? next_pow2(NumWarpQ) : NRanks;
float local_sum[4];
#pragma unroll
for (int t = 0; t < 4; ++t) {
local_sum[t] = (tid < NumWarpQ) ? block_reduce_sum[t][tid] : 0.F;
}
// After this, all kNumWarpQPow2 lanes (including tid 0..NRanks-1) have
// the total Q sum-of-squares for all 4 tokens.
local_warp_reduce_sum_array<kNumWarpQPow2, float, 4>(local_sum);
if (tid < NRanks) {
#pragma unroll
for (int t = 0; t < 4; ++t) {
if (is_neg_zero(local_sum[t])) {
local_sum[t] = 0.F;
}
}
// Parallel push: thread tid writes this rank's Q sum to rank tid's buf
reinterpret_cast<float4*>(
comm.data_bufs[tid])[(params.rank * tot_groups * 2) + (2 * g)] =
*reinterpret_cast<float4*>(local_sum);
// Parallel pull: thread tid reads rank tid's contribution from
// this rank's (params.rank's) buffer
bool done = false;
float4 var_all_ranks;
while (!done) {
done = true;
var_all_ranks = ld_global_volatile(&reinterpret_cast<float4*>(
comm.data_bufs[params.rank])[(tid * tot_groups * 2) + (2 * g)]);
done &= !is_neg_zero(var_all_ranks);
}
// Warp-level allreduce: each of the NRanks threads holds one rank's
// partial sum; after this all NRanks threads have the global total.
constexpr uint32_t kQActiveMask = (1u << NRanks) - 1u;
local_warp_reduce_sum_array<NRanks, float, 4>(
reinterpret_cast<float*>(&var_all_ranks), kQActiveMask);
// Thread 0 computes rsqrt with compile-time Dim and writes to smem
if (tid == 0) {
*reinterpret_cast<float4*>(global_scale_q) =
rms_rsqrt<OriginQDim>(var_all_ranks, params.rms_eps);
}
}
} else if (tid >= MINIMAX_REDUCE_RMS_WARP_SIZE * NumWarpQ &&
tid < MINIMAX_REDUCE_RMS_WARP_SIZE * (NumWarpQ + 1)) {
// --- K leader warp ---
constexpr int kNumWarpKPow2 =
(next_pow2(NumWarpK) > NRanks) ? next_pow2(NumWarpK) : NRanks;
float local_sum[4];
#pragma unroll
for (int t = 0; t < 4; ++t) {
local_sum[t] = (k_thread_idx < NumWarpK)
? block_reduce_sum[t][NumWarpQ + k_thread_idx]
: 0.F;
}
local_warp_reduce_sum_array<kNumWarpKPow2, float, 4>(local_sum);
if (k_thread_idx < NRanks) {
#pragma unroll
for (int t = 0; t < 4; ++t) {
if (is_neg_zero(local_sum[t])) {
local_sum[t] = 0.F;
}
}
reinterpret_cast<float4*>(
comm.data_bufs[k_thread_idx])[(params.rank * tot_groups * 2) +
(2 * g + 1)] =
*reinterpret_cast<float4*>(local_sum);
bool done = false;
float4 var_all_ranks;
while (!done) {
done = true;
var_all_ranks = ld_global_volatile(&reinterpret_cast<float4*>(
comm.data_bufs[params.rank])[(k_thread_idx * tot_groups * 2) +
(2 * g + 1)]);
done &= !is_neg_zero(var_all_ranks);
}
constexpr uint32_t kKActiveMask = (1u << NRanks) - 1u;
local_warp_reduce_sum_array<NRanks, float, 4>(
reinterpret_cast<float*>(&var_all_ranks), kKActiveMask);
if (k_thread_idx == 0) {
*reinterpret_cast<float4*>(global_scale_k) =
rms_rsqrt<OriginKDim>(var_all_ranks, params.rms_eps);
}
}
}
__syncthreads();
if (is_q) {
#pragma unroll
for (int t = 0; t < 4; ++t) {
warp_sum_variance[t] = global_scale_q[t];
}
#pragma unroll
for (int r = 0; r < 4; ++r) {
#pragma unroll
for (int i = 0; i < kElemsPerAccess<DType>; ++i) {
vals[r][i] = static_cast<DType>(static_cast<float>(vals[r][i]) *
warp_sum_variance[r] *
static_cast<float>(norm_weight[i]));
}
int token_r = g * 4 + r;
if (token_r >= tot_tokens || !is_valid_q) {
continue;
}
int idx_out = token_r * access_stride_q_out + access_id_in_token;
reinterpret_cast<float4*>(params.rms_norm_out)[idx_out] =
*reinterpret_cast<float4*>(&vals[r][0]);
}
} else {
#pragma unroll
for (int t = 0; t < 4; ++t) {
warp_sum_variance[t] = global_scale_k[t];
}
#pragma unroll
for (int r = 0; r < 4; ++r) {
#pragma unroll
for (int i = 0; i < kElemsPerAccess<DType>; ++i) {
vals[r][i] = static_cast<DType>(static_cast<float>(vals[r][i]) *
warp_sum_variance[r] *
static_cast<float>(norm_weight[i]));
}
int token_r = g * 4 + r;
if (token_r >= tot_tokens || !is_valid_k) {
continue;
}
int idx_out = token_r * access_stride_k_out + k_thread_idx;
reinterpret_cast<float4*>(params.rms_norm_out_k)[idx_out] =
*reinterpret_cast<float4*>(&vals[r][0]);
}
}
} // end group loop
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
int clear_access = static_cast<int>(comm.clear_size / kElemsPerAccess<DType>);
int clear_stride = group_stride * blockDim.x;
for (int idx = group_id * blockDim.x + threadIdx.x; idx < clear_access;
idx += clear_stride) {
reinterpret_cast<float4*>(comm.clear_buf)[idx] = clear_vec;
}
comm.update(static_cast<int64_t>(2) * tot_groups * kElemsPerAccess<DType> *
NRanks);
}
int get_sm_count() {
static int sm_count = 0;
if (sm_count == 0) {
int device_id;
CUDA_CHECK(cudaGetDevice(&device_id));
cudaDeviceProp device_prop;
cudaGetDeviceProperties(&device_prop, device_id);
sm_count = device_prop.multiProcessorCount;
}
return sm_count;
}
inline int getSMVersion(bool queryRealSmArch = false) {
int device{-1};
CUDA_CHECK(cudaGetDevice(&device));
int sm_major = 0;
int sm_minor = 0;
CUDA_CHECK(cudaDeviceGetAttribute(&sm_major,
cudaDevAttrComputeCapabilityMajor, device));
CUDA_CHECK(cudaDeviceGetAttribute(&sm_minor,
cudaDevAttrComputeCapabilityMinor, device));
int sm = sm_major * 10 + sm_minor;
if (sm == 121 && !queryRealSmArch) {
return 120;
}
return sm;
}
template <typename KernelFunc>
int get_max_active_blocks(KernelFunc kernel, int block_size,
int dynamic_smem = 0) {
int max_active = 0;
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active, kernel, block_size, dynamic_smem));
return std::max(max_active, 1);
}
template <typename DType, int NRanks>
void minimax_reduce_rms_kernel_launcher(MiniMaxReduceRMSParams const& params) {
static int SM = getSMVersion();
int token_num = params.size_q / params.hidden_dim;
int sm_count = get_sm_count();
int cluster_size = 1;
int cluster_num = token_num;
int threads_per_token = params.hidden_dim / kElemsPerAccess<DType>;
int block_size = threads_per_token;
int max_blocks_per_sm = get_max_active_blocks(
minimax_reduce_rms_kernel_lamport<DType, NRanks>, block_size);
int max_grid = max_blocks_per_sm * sm_count;
int grid_size =
(std::min(max_grid, cluster_num * cluster_size) / cluster_size) *
cluster_size;
cudaLaunchConfig_t cfg;
cfg.gridDim = grid_size;
cfg.blockDim = block_size;
cfg.dynamicSmemBytes = 0;
cfg.stream = params.stream;
cudaLaunchAttribute attribute[2];
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[0].val.programmaticStreamSerializationAllowed = 1;
attribute[1].id = cudaLaunchAttributeClusterDimension;
attribute[1].val.clusterDim.x = cluster_size;
attribute[1].val.clusterDim.y = 1;
attribute[1].val.clusterDim.z = 1;
cfg.attrs = attribute;
cfg.numAttrs = SM >= 90 ? 2 : 0;
CUDA_CHECK(cudaLaunchKernelEx(
&cfg, minimax_reduce_rms_kernel_lamport<DType, NRanks>, params));
}
template <typename DType, int NRanks, int OriginQDim, int OriginKDim>
void minimax_reduce_rms_kernel_launcher_float4(
MiniMaxReduceRMSParams const& params) {
TORCH_CHECK(params.size_q % params.hidden_dim == 0);
TORCH_CHECK(params.hidden_dim % kElemsPerAccess<DType> == 0);
if (params.stride_q > 0) {
TORCH_CHECK(params.stride_q % kElemsPerAccess<DType> == 0);
}
TORCH_CHECK(params.allreduce_in_k != nullptr,
"float4 QK kernel requires K input");
TORCH_CHECK(params.hidden_dim >= params.hidden_dim_k);
TORCH_CHECK(params.size_k % params.hidden_dim_k == 0);
TORCH_CHECK(params.hidden_dim_k % kElemsPerAccess<DType> == 0);
TORCH_CHECK(params.size_q / params.hidden_dim ==
params.size_k / params.hidden_dim_k);
if (params.stride_k > 0) {
TORCH_CHECK(params.stride_k % kElemsPerAccess<DType> == 0);
}
int token_num = params.size_q / params.hidden_dim;
int tot_groups = (token_num + 3) / 4;
if (tot_groups == 0) {
return;
}
static int SM = getSMVersion();
int sm_count = get_sm_count();
int cluster_size = 1;
int cluster_num = tot_groups;
int access_per_row_q = params.hidden_dim / kElemsPerAccess<DType>;
int access_per_row_k = params.hidden_dim_k / kElemsPerAccess<DType>;
// Round each section up to a warp boundary
auto divUp = [](int a, int b) { return (a + b - 1) / b * b; };
int block_size = divUp(access_per_row_q, MINIMAX_REDUCE_RMS_WARP_SIZE) +
divUp(access_per_row_k, MINIMAX_REDUCE_RMS_WARP_SIZE);
auto kfn =
minimax_reduce_qk_rms_kernel_lamport_float4<DType, NRanks, OriginQDim,
OriginKDim>;
int max_blocks_per_sm = get_max_active_blocks(kfn, block_size);
int max_grid = max_blocks_per_sm * sm_count;
int grid_size =
(std::min(max_grid, cluster_num * cluster_size) / cluster_size) *
cluster_size;
cudaLaunchConfig_t cfg;
cfg.gridDim = grid_size;
cfg.blockDim = block_size;
cfg.dynamicSmemBytes = 0;
cfg.stream = params.stream;
cudaLaunchAttribute attribute[2];
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[0].val.programmaticStreamSerializationAllowed = 1;
attribute[1].id = cudaLaunchAttributeClusterDimension;
attribute[1].val.clusterDim.x = cluster_size;
attribute[1].val.clusterDim.y = 1;
attribute[1].val.clusterDim.z = 1;
cfg.attrs = attribute;
cfg.numAttrs = SM >= 90 ? 2 : 0;
CUDA_CHECK(cudaLaunchKernelEx(&cfg, kfn, params));
}
template <int NRanks>
void dispatch_dtype(MiniMaxReduceRMSParams const& params) {
// Use the optimized QK float4 kernel when:
// - K input is present, AND
// - the full (NRanks * per-rank) dimensions match the MiniMax M2 shape.
// Otherwise fall back to the scalar kernel.
bool use_float4 = (params.allreduce_in_k != nullptr) &&
(params.hidden_dim * params.nranks == 6144) &&
(params.hidden_dim_k * params.nranks == 1024);
if (params.dtype == at::ScalarType::Half) {
if (use_float4) {
minimax_reduce_rms_kernel_launcher_float4<half, NRanks, 6144, 1024>(
params);
} else {
minimax_reduce_rms_kernel_launcher<half, NRanks>(params);
}
} else if (params.dtype == at::ScalarType::BFloat16) {
if (use_float4) {
minimax_reduce_rms_kernel_launcher_float4<__nv_bfloat16, NRanks, 6144,
1024>(params);
} else {
minimax_reduce_rms_kernel_launcher<__nv_bfloat16, NRanks>(params);
}
} else if (params.dtype == at::ScalarType::Float) {
if (use_float4) {
minimax_reduce_rms_kernel_launcher_float4<float, NRanks, 6144, 1024>(
params);
} else {
minimax_reduce_rms_kernel_launcher<float, NRanks>(params);
}
} else {
TORCH_CHECK(false, "Unsupported data type for minimax_reduce_rms_op");
}
}
void minimax_reduce_rms_op(MiniMaxReduceRMSParams const& params) {
if (params.nranks == 2) {
dispatch_dtype<2>(params);
} else if (params.nranks == 4) {
dispatch_dtype<4>(params);
} else if (params.nranks == 8) {
dispatch_dtype<8>(params);
} else if (params.nranks == 16) {
dispatch_dtype<16>(params);
} else {
TORCH_CHECK(false, "minimax_reduce_rms_op: unsupported ranks number!");
}
}
} // namespace tensorrt_llm
} // namespace vllm
torch::Tensor minimax_allreduce_rms(torch::Tensor const& input,
torch::Tensor const& norm_weight,
torch::Tensor workspace, int64_t const rank,
int64_t const nranks, double const eps) {
auto allreduce_params = vllm::tensorrt_llm::MiniMaxReduceRMSParams();
allreduce_params.nranks = static_cast<int>(nranks);
allreduce_params.rank = static_cast<int>(rank);
allreduce_params.dtype = input.scalar_type();
allreduce_params.size_q = static_cast<int>(input.numel());
allreduce_params.hidden_dim = static_cast<int>(input.size(-1));
allreduce_params.stride_q = allreduce_params.hidden_dim;
allreduce_params.workspace =
reinterpret_cast<void**>(workspace.mutable_data_ptr());
allreduce_params.allreduce_in = input.data_ptr();
allreduce_params.rms_gamma = norm_weight.data_ptr();
allreduce_params.rms_eps = static_cast<float>(eps);
allreduce_params.stream = at::cuda::getCurrentCUDAStream(input.get_device());
torch::Tensor rms_norm_out = torch::empty_like(input);
allreduce_params.rms_norm_out = rms_norm_out.mutable_data_ptr();
vllm::tensorrt_llm::minimax_reduce_rms_op(allreduce_params);
return rms_norm_out;
}
std::tuple<torch::Tensor, torch::Tensor> minimax_allreduce_rms_qk(
torch::Tensor qkv, torch::Tensor const& norm_weight_q,
torch::Tensor const& norm_weight_k, torch::Tensor workspace,
int64_t const q_size, int64_t const kv_size, int64_t const rank,
int64_t const nranks, double const eps) {
TORCH_CHECK(qkv.dim() == 2, "minimax_allreduce_rms_qk: qkv must be 2D");
TORCH_CHECK(qkv.is_contiguous(),
"minimax_allreduce_rms_qk: qkv must be contiguous");
int64_t qkv_dim = qkv.size(-1);
TORCH_CHECK(qkv_dim == q_size + 2 * kv_size,
"minimax_allreduce_rms_qk: qkv last dim must equal "
"q_size + 2 * kv_size");
TORCH_CHECK(rank < nranks,
"minimax_allreduce_rms_qk: rank must be less than nranks");
int64_t num_tokens = qkv.size(0);
int elem_bytes = qkv.element_size();
torch::Tensor q_out = torch::empty({num_tokens, q_size}, qkv.options());
torch::Tensor k_out = torch::empty({num_tokens, kv_size}, qkv.options());
auto params = vllm::tensorrt_llm::MiniMaxReduceRMSParams();
params.nranks = static_cast<int>(nranks);
params.rank = static_cast<int>(rank);
params.dtype = qkv.scalar_type();
params.size_q = static_cast<int>(num_tokens * q_size);
params.hidden_dim = static_cast<int>(q_size);
params.size_k = static_cast<int>(num_tokens * kv_size);
params.hidden_dim_k = static_cast<int>(kv_size);
params.stride_q = static_cast<int>(qkv_dim);
params.stride_k = static_cast<int>(qkv_dim);
params.stride_q_out = 0; // q_out is contiguous; kernel uses hidden_dim
params.stride_k_out = 0; // k_out is contiguous; kernel uses hidden_dim_k
params.workspace = reinterpret_cast<void**>(workspace.mutable_data_ptr());
uint8_t* base = static_cast<uint8_t*>(qkv.data_ptr());
params.allreduce_in = base;
params.allreduce_in_k = base + q_size * elem_bytes;
params.rms_gamma = norm_weight_q.data_ptr();
params.rms_gamma_k = norm_weight_k.data_ptr();
params.rms_eps = static_cast<float>(eps);
params.stream = at::cuda::getCurrentCUDAStream(qkv.get_device());
params.rms_norm_out = q_out.mutable_data_ptr();
params.rms_norm_out_k = k_out.mutable_data_ptr();
vllm::tensorrt_llm::minimax_reduce_rms_op(params);
return {q_out, k_out};
}
/*
* Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <torch/types.h>
namespace vllm {
namespace tensorrt_llm {
template <typename DType>
struct ElemsPerAccess;
template <>
struct ElemsPerAccess<half> {
static constexpr int value = 8;
using vec_type = float4;
};
template <>
struct ElemsPerAccess<nv_bfloat16> {
static constexpr int value = 8;
using vec_type = float4;
};
template <>
struct ElemsPerAccess<float> {
static constexpr int value = 4;
using vec_type = float4;
};
template <typename DType>
static constexpr int kElemsPerAccess = ElemsPerAccess<DType>::value;
struct MiniMaxReduceRMSParams {
int nranks{};
int rank{};
at::ScalarType dtype{at::ScalarType::Undefined};
int size_q{};
int hidden_dim{};
int size_k{};
int hidden_dim_k{};
int stride_q{}; // row stride for q input (elements); when > hidden_dim,
// q is part of a wider qkv tensor
int stride_k{}; // row stride for k input (elements); when > hidden_dim_k,
// k is part of a wider qkv tensor
int stride_q_out{}; // row stride for q output (elements); 0 = contiguous
int stride_k_out{}; // row stride for k output (elements); 0 = contiguous
void** workspace{};
void* allreduce_in{};
void* rms_norm_out{};
void* rms_gamma{};
void* allreduce_in_k{};
void* rms_norm_out_k{};
void* rms_gamma_k{};
float rms_eps{};
cudaStream_t stream{};
};
void minimax_reduce_rms_op(MiniMaxReduceRMSParams const& params);
} // namespace tensorrt_llm
} // namespace vllm
......@@ -392,3 +392,15 @@ int64_t qr_max_size();
void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a,
torch::Tensor const& mat_b);
#endif
#ifndef USE_ROCM
torch::Tensor minimax_allreduce_rms(torch::Tensor const& input,
torch::Tensor const& norm_weight,
torch::Tensor workspace, int64_t const rank,
int64_t const nranks, double const eps);
std::tuple<torch::Tensor, torch::Tensor> minimax_allreduce_rms_qk(
torch::Tensor qkv, torch::Tensor const& norm_weight_q,
torch::Tensor const& norm_weight_k, torch::Tensor workspace,
int64_t const q_size, int64_t const kv_size, int64_t const rank,
int64_t const nranks, double const eps);
#endif
\ No newline at end of file
......@@ -668,6 +668,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? b_qzeros, "
"SymInt n, SymInt group_size, SymInt sm_count, SymInt sm_version, SymInt "
"CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) -> Tensor");
ops.def(
"minimax_allreduce_rms("
"Tensor input,"
"Tensor norm_weight,"
"Tensor workspace,"
"int rank,"
"int nranks,"
"float eps) -> Tensor");
ops.impl("minimax_allreduce_rms", torch::kCUDA, &minimax_allreduce_rms);
ops.def(
"minimax_allreduce_rms_qk("
"Tensor qkv,"
"Tensor norm_weight_q,"
"Tensor norm_weight_k,"
"Tensor workspace,"
"int q_size,"
"int kv_size,"
"int rank,"
"int nranks,"
"float eps) -> (Tensor, Tensor)");
ops.impl("minimax_allreduce_rms_qk", torch::kCUDA, &minimax_allreduce_rms_qk);
// conditionally compiled so impl in source file
#endif
}
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for MiniMax QK RMS-norm: NCCL reference vs Lamport fused kernel."""
import pytest
import torch
import torch.nn as nn
from torch.multiprocessing import spawn
from tests.kernels.utils import opcheck
from tests.utils import ensure_current_vllm_config, init_test_distributed_environment
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.model_executor.layers.mamba.linear_attn import MiniMaxText01RMSNormTP
from vllm.platforms import current_platform
from vllm.utils.network_utils import get_open_port
from vllm.utils.torch_utils import set_random_seed
@ensure_current_vllm_config()
def _worker_forward_qk(
local_rank,
world_size,
port,
num_tokens,
hidden_q_full,
hidden_k_full,
dtype,
seed,
eps,
):
"""Per-rank worker: compare NCCL allreduce path vs Lamport fused kernel."""
if not hasattr(torch.ops._C, "minimax_allreduce_rms_qk"):
cleanup_dist_env_and_memory()
return
device = torch.device(f"cuda:{local_rank}")
torch.accelerator.set_device_index(device)
init_test_distributed_environment(
world_size, 1, local_rank, port, local_rank=local_rank
)
hq = hidden_q_full // world_size
hk = hidden_k_full // world_size
q_norm = MiniMaxText01RMSNormTP(hidden_q_full, eps=eps).cuda()
k_norm = MiniMaxText01RMSNormTP(hidden_k_full, eps=eps).cuda()
set_random_seed(seed)
qw = torch.randn(hidden_q_full, dtype=dtype, device="cuda")
kw = torch.randn(hidden_k_full, dtype=dtype, device="cuda")
q_norm.weight = nn.Parameter(qw[local_rank * hq : (local_rank + 1) * hq])
k_norm.weight = nn.Parameter(kw[local_rank * hk : (local_rank + 1) * hk])
torch.manual_seed(seed + 1000 + local_rank)
qkv = torch.randn(num_tokens, hq + hk + hk, dtype=dtype, device="cuda")
q_ref, k_ref, v_ref = qkv.clone().split([hq, hk, hk], dim=-1)
ref_q, ref_k = MiniMaxText01RMSNormTP.forward_qk(q_norm, k_norm, q_ref, k_ref)
# Set up Lamport workspace.
from vllm.distributed.parallel_state import get_tp_group
from vllm.model_executor.layers.mamba.lamport_workspace import (
get_allreduce_workspace,
)
workspace = get_allreduce_workspace(
rank=local_rank,
world_size=world_size,
max_tokens=num_tokens,
process_group=get_tp_group().cpu_group,
)
opcheck(
torch.ops._C.minimax_allreduce_rms_qk,
(
qkv.clone(),
q_norm.weight,
k_norm.weight,
workspace,
hq,
hk,
local_rank,
world_size,
eps,
),
)
fused_q, fused_k = torch.ops._C.minimax_allreduce_rms_qk(
qkv.clone(),
q_norm.weight,
k_norm.weight,
workspace,
hq,
hk,
local_rank,
world_size,
eps,
)
_, _, fused_v = qkv.split([hq, hk, hk], dim=-1)
torch.accelerator.synchronize()
torch.testing.assert_close(
fused_q,
ref_q,
atol=3e-2,
rtol=3e-2,
)
torch.testing.assert_close(fused_k, ref_k, atol=3e-2, rtol=3e-2)
cleanup_dist_env_and_memory()
@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="CUDA required",
)
@pytest.mark.parametrize("world_size", [2, 4, 8])
@pytest.mark.parametrize("num_tokens", [1, 128, 333])
@pytest.mark.parametrize(
"hidden_dims",
[(6144, 1024)],
)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("eps", [1e-6])
@pytest.mark.parametrize("seed", [42])
def test_minimax_reduce_rms_qk(
world_size,
num_tokens,
hidden_dims,
dtype,
eps,
seed,
):
num_gpus = current_platform.device_count()
if num_gpus < world_size:
pytest.skip(f"Need >= {world_size} GPUs, have {num_gpus}")
hidden_q_full, hidden_k_full = hidden_dims
port = str(get_open_port())
spawn(
_worker_forward_qk,
args=(
world_size,
port,
num_tokens,
hidden_q_full,
hidden_k_full,
dtype,
seed,
eps,
),
nprocs=world_size,
join=True,
)
......@@ -3397,3 +3397,38 @@ if hasattr(torch.ops._C, "hadacore_transform"):
@register_fake("_C::hadacore_transform")
def _hadacore_transform_fake(x: torch.Tensor, inplace: bool) -> torch.Tensor:
return torch.empty_like(x) if not inplace else x
if hasattr(torch.ops._C, "minimax_allreduce_rms"):
@register_fake("_C::minimax_allreduce_rms")
def _minimax_allreduce_rms_fake(
input: torch.Tensor,
norm_weight: torch.Tensor,
workspace: torch.Tensor,
rank: int,
nranks: int,
eps: float,
) -> torch.Tensor:
return torch.empty_like(input)
if hasattr(torch.ops._C, "minimax_allreduce_rms_qk"):
@register_fake("_C::minimax_allreduce_rms_qk")
def _minimax_allreduce_rms_qk_fake(
qkv: torch.Tensor,
norm_weight_q: torch.Tensor,
norm_weight_k: torch.Tensor,
workspace: torch.Tensor,
q_size: int,
kv_size: int,
rank: int,
nranks: int,
eps: float,
) -> tuple[torch.Tensor, torch.Tensor]:
token_num = qkv.shape[0]
return (
torch.empty([token_num, q_size], dtype=qkv.dtype, device=qkv.device),
torch.empty([token_num, kv_size], dtype=qkv.dtype, device=qkv.device),
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Fusion pass: replace MiniMax QK allreduce + RMS norm with the Lamport
fused kernel (minimax_allreduce_rms_qk) for decode-size batches.
Pattern (inlined forward_qk in compiled graph):
q, k, v = qkv.split([q_size, kv_size, kv_size], -1)
q_fp32 = q.to(float32); k_fp32 = k.to(float32)
q_var = q_fp32.pow(2).mean(-1, keepdim=True)
k_var = k_fp32.pow(2).mean(-1, keepdim=True)
qk_var = cat([q_var, k_var], -1)
qk_var = allreduce(qk_var) / tp_world
q_var, k_var = qk_var.chunk(2, -1)
q_out = (q_fp32 * rsqrt(q_var + eps) * q_weight).to(orig_dtype)
k_out = (k_fp32 * rsqrt(k_var + eps) * k_weight).to(orig_dtype)
return q_out, k_out, v
Replacement (pure, no in-place on qkv/q/k):
q_out, k_out = minimax_qk_norm_fused(qkv, q_weight, k_weight, workspace, ...)
v = qkv.split([q_size, kv_size, kv_size], -1)[2]
return q_out, k_out, v
is_applicable_for_range: only fires for compile_range.end <= max_decode_tokens
so that large prefill batches fall through to the original forward_qk (= main).
"""
import torch
import torch._inductor.pattern_matcher as pm
import torch.fx as fx
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.logger import init_logger
from vllm.utils.torch_utils import direct_register_custom_op
from ..inductor_pass import enable_fake_mode
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
MAX_TOKEN_NUM = 2048
_MINIMAX_QK_NORM_FUSED_OP = None
if hasattr(torch.ops._C, "minimax_allreduce_rms_qk"):
def _minimax_qk_norm_fused(
qkv: torch.Tensor,
norm_weight_q: torch.Tensor,
norm_weight_k: torch.Tensor,
q_size: int,
kv_size: int,
rank: int,
nranks: int,
eps: float,
max_tokens: int,
) -> tuple[torch.Tensor, torch.Tensor]:
from vllm.distributed.parallel_state import get_tp_group
from vllm.model_executor.layers.mamba.lamport_workspace import (
get_allreduce_workspace,
)
workspace = get_allreduce_workspace(
rank=rank,
world_size=nranks,
max_tokens=max_tokens,
process_group=get_tp_group().cpu_group,
)
return torch.ops._C.minimax_allreduce_rms_qk(
qkv,
norm_weight_q,
norm_weight_k,
workspace,
q_size,
kv_size,
rank,
nranks,
eps,
)
def _minimax_qk_norm_fused_fake(
qkv: torch.Tensor,
norm_weight_q: torch.Tensor,
norm_weight_k: torch.Tensor,
q_size: int,
kv_size: int,
rank: int,
nranks: int,
eps: float,
max_tokens: int,
) -> tuple[torch.Tensor, torch.Tensor]:
T = qkv.shape[0]
return (
torch.empty([T, q_size], dtype=qkv.dtype, device=qkv.device),
torch.empty([T, kv_size], dtype=qkv.dtype, device=qkv.device),
)
direct_register_custom_op(
op_name="minimax_qk_norm_fused",
op_func=_minimax_qk_norm_fused,
fake_impl=_minimax_qk_norm_fused_fake,
mutates_args=[],
)
_MINIMAX_QK_NORM_FUSED_OP = torch.ops.vllm.minimax_qk_norm_fused.default
class MiniMaxQKNormPattern:
"""
Match the forward_qk allreduce+rms pattern and replace with Lamport kernel.
"""
def __init__(
self,
q_size: int,
kv_size: int,
eps: float,
tp_world: int,
tp_rank: int,
max_tokens: int,
dtype: torch.dtype,
device: str | None,
) -> None:
self.q_size = q_size
self.kv_size = kv_size
self.eps = eps
self.tp_world = tp_world
self.tp_rank = tp_rank
self.max_tokens = max_tokens
self.dtype = dtype
self.device = device
def get_inputs(self) -> list[torch.Tensor]:
T = 4
qkv = torch.empty(
[T, self.q_size + 2 * self.kv_size],
device=self.device,
dtype=self.dtype,
)
q_weight = torch.empty([self.q_size], device=self.device, dtype=self.dtype)
k_weight = torch.empty([self.kv_size], device=self.device, dtype=self.dtype)
return [qkv, q_weight, k_weight]
def register(self, pm_pass: PatternMatcherPass) -> None:
q_size = self.q_size
kv_size = self.kv_size
eps = self.eps
tp_world = self.tp_world
max_tokens = self.max_tokens
tp_rank = self.tp_rank
dtype = self.dtype
def pattern(
qkv: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
q_fp32 = q.to(torch.float32)
k_fp32 = k.to(torch.float32)
q_var = q_fp32.pow(2).mean(dim=-1, keepdim=True)
k_var = k_fp32.pow(2).mean(dim=-1, keepdim=True)
qk_var = torch.cat([q_var, k_var], dim=-1)
qk_var = tensor_model_parallel_all_reduce(qk_var) / tp_world
q_var, k_var = qk_var.chunk(2, dim=-1)
q_out = (q_fp32 * torch.rsqrt(q_var + eps) * q_weight).to(dtype)
k_out = (k_fp32 * torch.rsqrt(k_var + eps) * k_weight).to(dtype)
return q_out, k_out, v
def replacement(
qkv: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert _MINIMAX_QK_NORM_FUSED_OP is not None
q_out, k_out = torch.ops.vllm.minimax_qk_norm_fused(
qkv,
q_weight,
k_weight,
q_size,
kv_size,
tp_rank,
tp_world,
eps,
max_tokens,
)
_, _, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
return q_out, k_out, v
pm.register_replacement(
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
# Second pattern: three separate split_with_sizes nodes (one per output),
# each with _users=1. This occurs when the QKV projection uses a
# functional GEMM kernel (e.g. cutlass_scaled_mm via auto_functionalized),
# which causes inductor to generate one split per consumer.
def pattern_split3(
qkv: torch.Tensor,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q = qkv.split([q_size, kv_size, kv_size], dim=-1)[0]
k = qkv.split([q_size, kv_size, kv_size], dim=-1)[1]
v = qkv.split([q_size, kv_size, kv_size], dim=-1)[2]
q_fp32 = q.to(torch.float32)
k_fp32 = k.to(torch.float32)
q_var = q_fp32.pow(2).mean(dim=-1, keepdim=True)
k_var = k_fp32.pow(2).mean(dim=-1, keepdim=True)
qk_var = torch.cat([q_var, k_var], dim=-1)
qk_var = tensor_model_parallel_all_reduce(qk_var) / tp_world
q_var, k_var = qk_var.chunk(2, dim=-1)
q_out = (q_fp32 * torch.rsqrt(q_var + eps) * q_weight).to(dtype)
k_out = (k_fp32 * torch.rsqrt(k_var + eps) * k_weight).to(dtype)
return q_out, k_out, v
pm.register_replacement(
pattern_split3, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
class MiniMaxQKNormPass(VllmPatternMatcherPass):
"""
Replace forward_qk allreduce+norm with the Lamport fused kernel.
Only applied for decode-size compile ranges (small token counts).
"""
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
self.disabled = True
if _MINIMAX_QK_NORM_FUSED_OP is None:
logger.warning_once(
"minimax_allreduce_rms_qk op not found, MiniMaxQKNormPass disabled."
)
return
tp_world = get_tensor_model_parallel_world_size()
if tp_world <= 1:
logger.warning_once("MiniMaxQKNormPass disabled: tp_size <= 1.")
return
if config.model_config is None:
logger.warning_once("MiniMaxQKNormPass disabled: no model_config.")
return
hf_cfg = config.model_config.hf_config
model_name = getattr(hf_cfg, "architectures", "")[0]
if model_name != "MiniMaxM2ForCausalLM":
return
num_attention_heads = getattr(hf_cfg, "num_attention_heads", 0)
num_key_value_heads = getattr(hf_cfg, "num_key_value_heads", 0)
hidden_size = getattr(hf_cfg, "hidden_size", 0)
head_dim = getattr(hf_cfg, "head_dim", 0)
eps: float = getattr(hf_cfg, "rms_norm_eps", 1e-6)
if (
num_attention_heads != 48
or num_key_value_heads != 8
or hidden_size != 3072
or head_dim != 128
):
logger.warning_once(
"MiniMaxQKNormPass disabled: cannot infer model info from hf_config."
)
return
num_heads_per_rank = num_attention_heads // tp_world
num_kv_heads_per_rank = max(1, num_key_value_heads // tp_world)
q_size = num_heads_per_rank * head_dim
kv_size = num_kv_heads_per_rank * head_dim
self.max_token_num = min(
MAX_TOKEN_NUM, config.scheduler_config.max_num_batched_tokens
)
tp_rank = get_tensor_model_parallel_rank()
# Allocate Lamport workspace first.
from vllm.distributed.parallel_state import get_tp_group
from vllm.model_executor.layers.mamba.lamport_workspace import (
get_allreduce_workspace,
)
get_allreduce_workspace(
rank=tp_rank,
world_size=tp_world,
max_tokens=self.max_token_num,
process_group=get_tp_group().cpu_group,
)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="minimax_qk_norm_pass"
)
self._register_patterns(q_size, kv_size, eps, tp_world, tp_rank)
self.dump_patterns(config, self.patterns)
self.disabled = False
@enable_fake_mode
def _register_patterns(
self,
q_size: int,
kv_size: int,
eps: float,
tp_world: int,
tp_rank: int,
) -> None:
MiniMaxQKNormPattern(
q_size=q_size,
kv_size=kv_size,
eps=eps,
tp_world=tp_world,
tp_rank=tp_rank,
max_tokens=self.max_token_num,
dtype=self.model_dtype,
device=self.device,
).register(self.patterns)
def is_applicable_for_range(self, compile_range: Range) -> bool:
if self.disabled:
return False
return bool(compile_range.end <= self.max_token_num)
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
if self.disabled:
return
self.matched_count = self.patterns.apply(graph)
logger.debug("MiniMaxQKNormPass replaced %s patterns", self.matched_count)
def uuid(self) -> str:
return VllmInductorPass.hash_source(self, MiniMaxQKNormPattern)
......@@ -36,6 +36,7 @@ if current_platform.is_cuda_alike():
if current_platform.is_cuda():
from .fusion.allreduce_rms_fusion import AllReduceFusionPass
from .fusion.collective_fusion import AsyncTPPass
from .fusion.minimax_qk_norm_fusion import MiniMaxQKNormPass
from .inductor_pass import (
CustomGraphPass,
......@@ -124,6 +125,9 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
if self.pass_config.fuse_allreduce_rms:
self.passes += [AllReduceFusionPass(config)]
if self.pass_config.fuse_minimax_qk_norm:
self.passes += [MiniMaxQKNormPass(config)]
if self.pass_config.fuse_norm_quant:
self.passes += [RMSNormQuantFusionPass(config)]
if rocm_aiter_ops.is_enabled():
......
......@@ -132,6 +132,8 @@ class PassConfig:
"""Enable async TP."""
fuse_allreduce_rms: bool = None # type: ignore[assignment]
"""Enable flashinfer allreduce fusion."""
fuse_minimax_qk_norm: bool = None # type: ignore[assignment]
"""Enable fused allreduce+RMSNorm for MiniMax QK norm."""
enable_qk_norm_rope_fusion: bool = False
"""Enable fused Q/K RMSNorm + RoPE pass."""
......@@ -282,7 +284,7 @@ class PassConfig:
"""
enabled_fusions = [
f.name[len("fuse_") :]
for f in fields(self)
for f in fields(self) # type: ignore[arg-type]
if getattr(self, f.name) and f.name.startswith("fuse_")
]
......
......@@ -1577,6 +1577,22 @@ class VllmConfig:
compile_range_end,
)
if compilation_config.pass_config.fuse_minimax_qk_norm:
from vllm.compilation.passes.fusion.minimax_qk_norm_fusion import (
MAX_TOKEN_NUM,
)
max_token_num = min(
MAX_TOKEN_NUM, self.scheduler_config.max_num_batched_tokens
)
if compile_range_end is not None and max_token_num < compile_range_end:
computed_compile_ranges_endpoints.append(max_token_num)
else:
logger.debug(
"Max num batched tokens below MiniMax QK norm fusion threshold, "
"MiniMax QK norm fusion enabled for all num_tokens."
)
if compilation_config.compile_ranges_endpoints is not None:
for x in compilation_config.compile_ranges_endpoints:
assert isinstance(x, int)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import array
import contextlib
import struct
import sys
import threading
import torch
try:
from cuda.bindings import runtime as cudart
except ImportError:
from cuda import cudart
_ALIGN = 1 << 21 # 2 MiB — CUDA IPC allocation alignment
# ---------------------------------------------------------------------------
# CUDA helpers
# ---------------------------------------------------------------------------
def _check(error):
"""Raise on CUDA runtime error."""
success = getattr(cudart.cudaError_t, "cudaSuccess", None) or cudart.cudaError_t(0)
if error != success:
raise RuntimeError(f"CUDA runtime error: {error}")
def _cuda_malloc(size: int):
aligned = ((size + _ALIGN - 1) >> 21) << 21
err, ptr = cudart.cudaMalloc(aligned)
_check(err)
return ptr, aligned
def _cuda_free(ptr: int):
if ptr:
_check(cudart.cudaFree(ptr)[0])
def _cuda_memset_zero(ptr: int, size: int):
_check(cudart.cudaMemset(ptr, 0, size)[0])
def _cuda_memcpy_d2d(dst: int, src: int, size: int):
_check(
cudart.cudaMemcpy(
dst, src, size, cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice
)[0]
)
# ---------------------------------------------------------------------------
# IPC buffer
# ---------------------------------------------------------------------------
class IpcBuffer:
"""
Allocates CUDA device memory and exchanges IPC handles with all ranks
so that every rank holds a valid device pointer to every other rank's buffer.
"""
def __init__(self, rank: int, world_size: int, size: int, process_group=None):
self.rank = rank
self.world_size = world_size
self.peer_ptrs: list[int] = [0] * world_size
self.local_ptr: int = 0
self._alive = False
if size <= 0:
return
self.local_ptr, _ = _cuda_malloc(size)
_cuda_memset_zero(self.local_ptr, size)
self._alive = True
# --- exchange IPC handles via torch.distributed ---
err, local_handle = cudart.cudaIpcGetMemHandle(self.local_ptr)
_check(err)
all_handles: list[bytes | None] = [None] * world_size
torch.distributed.all_gather_object(
all_handles, bytes(local_handle.reserved), group=process_group
)
for r in range(world_size):
if r == rank:
self.peer_ptrs[r] = self.local_ptr
else:
handle = cudart.cudaIpcMemHandle_t()
handle.reserved = all_handles[r]
err, ptr = cudart.cudaIpcOpenMemHandle(
handle, cudart.cudaIpcMemLazyEnablePeerAccess
)
_check(err)
self.peer_ptrs[r] = ptr
def serialize(self) -> list[int]:
"""Return peer pointers as a list of int64 values (one per rank)."""
raw = b""
for ptr in self.peer_ptrs:
raw += struct.pack("P", ptr)
return array.array("Q", raw).tolist()
def cleanup(self):
if not self._alive:
return
self._alive = False
for r in range(self.world_size):
if self.peer_ptrs[r] == 0:
continue
if r == self.rank:
_cuda_free(self.peer_ptrs[r])
else:
with contextlib.suppress(RuntimeError):
_check(cudart.cudaIpcCloseMemHandle(self.peer_ptrs[r])[0])
self.peer_ptrs[r] = 0
self.local_ptr = 0
def __del__(self):
if not sys.is_finalizing():
self.cleanup()
# ---------------------------------------------------------------------------
# Lamport negative-zero initialization
# ---------------------------------------------------------------------------
def _lamport_fill_neg_zero(device_ptr: int, size_bytes: int):
"""
Fill device memory with IEEE-754 negative zero (-0.0f = 0x80000000).
This is the "slot empty" sentinel for the Lamport protocol: the kernel
spin-waits until a value is *not* negative zero.
"""
if size_bytes == 0 or device_ptr == 0:
return
n_floats = size_bytes // 4
# torch preserves -0.0 in IEEE-754
fill = torch.full((n_floats,), -0.0, dtype=torch.float32, device="cuda")
_cuda_memcpy_d2d(device_ptr, fill.data_ptr(), size_bytes)
del fill
# ---------------------------------------------------------------------------
# LamportWorkspace — the main class
# ---------------------------------------------------------------------------
class LamportWorkspace:
"""
Self-contained workspace for Lamport-based cross-GPU AllReduce.
Parameters
----------
rank : int
Local rank (0-based).
world_size : int
Total number of ranks in the TP group.
comm_size : int
Size in bytes of *one* Lamport buffer slot. The total IPC allocation
per rank is ``3 * comm_size`` (triple-buffering). Must be large enough
to hold the per-slot data written by the kernel. Use
``compute_comm_size_for_minimax()`` for a safe default.
process_group : optional
``torch.distributed`` process group for IPC handle exchange.
``None`` uses the default group.
"""
def __init__(self, rank: int, world_size: int, comm_size: int, process_group=None):
assert world_size >= 2, "Lamport workspace requires at least 2 ranks"
assert comm_size > 0, "comm_size must be positive"
self.rank = rank
self.world_size = world_size
self.comm_size = comm_size
# 1) Lamport triple-buffer (the only IPC memory the kernel reads/writes)
lamport_total = 3 * comm_size
self._lamport = IpcBuffer(rank, world_size, lamport_total, process_group)
_lamport_fill_neg_zero(self._lamport.local_ptr, lamport_total)
# 2) flag_buffer on device: int32[3] = {counter, unused, lamport_flag}
# counter — used for block-level sync inside the kernel
# unused — reserved (index 1)
# lamport_flag — triple-buffer rotation index (0 → 1 → 2 → 0 …)
self._flag_buf = torch.zeros(3, dtype=torch.int32, device="cuda")
# 3) layout_buffer on device: int64[2] = {clear_size, comm_size}
# clear_size — bytes to clear from *previous* slot (set by kernel)
# comm_size — size of one triple-buffer slot
self._layout_buf = torch.tensor(
[0, comm_size], dtype=torch.int64, device="cuda"
)
# 4) Assemble device-side void* pointer array
N = world_size
ptrs: list[int] = []
ptrs += [0] * N # [0 .. N-1] ipc_buffers (placeholder)
ptrs += [0] * N # [N .. 2N-1] ipc_barriers (placeholder)
ptrs += self._lamport.serialize() # [2N .. 3N-1] lamport peer ptrs
ptrs.append(self._flag_buf.data_ptr()) # [3N] flag_buffer
ptrs.append(self._layout_buf.data_ptr()) # [3N+1] layout_buffer
self._workspace = torch.tensor(ptrs, dtype=torch.int64, device="cuda")
@property
def workspace(self) -> torch.Tensor:
"""Device tensor (int64) that can be passed to the kernel
as ``void** workspace``."""
return self._workspace
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
@staticmethod
def compute_comm_size_for_minimax(
max_tokens: int,
world_size: int,
fused_qk: bool = True,
) -> int:
"""
Return a safe ``comm_size`` (in bytes) for MiniMaxReduceRMSKernel.
The kernel stores per-token variance scalars in the Lamport buffer:
- single-matrix path: ``world_size × max_tokens × 4`` bytes per slot
- fused Q+K path: ``world_size × 2 × ceil(max_tokens/4) × 16`` bytes per slot
The returned value is rounded up to 2 MiB alignment.
"""
if fused_qk:
groups = (max_tokens + 3) // 4
slot_bytes = world_size * 2 * groups * 16 # 16 = sizeof(float4)
else:
slot_bytes = world_size * max_tokens * 4 # 4 = sizeof(float)
return ((slot_bytes + _ALIGN - 1) >> 21) << 21
def cleanup(self):
if hasattr(self, "_lamport"):
self._lamport.cleanup()
def __del__(self):
if not sys.is_finalizing():
self.cleanup()
def __repr__(self):
return (
f"LamportWorkspace(rank={self.rank}, world_size={self.world_size}, "
f"comm_size={self.comm_size})"
)
# ---------------------------------------------------------------------------
# Cached convenience function (mirrors TRT-LLM's get_allreduce_workspace)
# ---------------------------------------------------------------------------
_cache_lock = threading.Lock()
_workspace_cache: dict = {}
def get_allreduce_workspace(
rank: int,
world_size: int,
comm_size: int | None = None,
max_tokens: int = 16384,
process_group=None,
) -> torch.Tensor:
"""
Return a cached workspace tensor for the given (rank, world_size) pair.
On first call the workspace is allocated and IPC handles are exchanged;
subsequent calls with the same arguments return the cached tensor.
Parameters
----------
rank, world_size : int
TP rank and TP size.
comm_size : int, optional
Explicit slot size in bytes. If ``None``, computed automatically
from ``max_tokens`` and ``world_size`` (fused Q+K path).
max_tokens : int
Maximum number of tokens per batch (used when ``comm_size is None``).
process_group : optional
``torch.distributed`` process group.
"""
if comm_size is None:
comm_size = LamportWorkspace.compute_comm_size_for_minimax(
max_tokens, world_size, fused_qk=True
)
pg_id = id(process_group) if process_group is not None else 0
key = (rank, world_size, comm_size, pg_id)
with _cache_lock:
if key not in _workspace_cache:
ws = LamportWorkspace(rank, world_size, comm_size, process_group)
_workspace_cache[key] = ws
return _workspace_cache[key].workspace
......@@ -232,9 +232,7 @@ class MiniMaxM2Attention(nn.Module):
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = MiniMaxText01RMSNormTP.forward_qk(
self.q_norm, self.k_norm, q.contiguous(), k.contiguous()
)
q, k = MiniMaxText01RMSNormTP.forward_qk(self.q_norm, self.k_norm, q, k)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment