Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ecd1ea13
Unverified
Commit
ecd1ea13
authored
Apr 11, 2026
by
Jee Jee Li
Committed by
GitHub
Apr 11, 2026
Browse files
[Kernel] Porting the TRTLLM minimax_allreduce_rms kernels (#37045)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
8f121f78
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
1861 additions
and
4 deletions
+1861
-4
.buildkite/test_areas/kernels.yaml
.buildkite/test_areas/kernels.yaml
+14
-1
CMakeLists.txt
CMakeLists.txt
+2
-0
csrc/minimax_reduce_rms_kernel.cu
csrc/minimax_reduce_rms_kernel.cu
+879
-0
csrc/minimax_reduce_rms_kernel.h
csrc/minimax_reduce_rms_kernel.h
+79
-0
csrc/ops.h
csrc/ops.h
+12
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+23
-0
tests/kernels/core/test_minimax_reduce_rms.py
tests/kernels/core/test_minimax_reduce_rms.py
+152
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+35
-0
vllm/compilation/passes/fusion/minimax_qk_norm_fusion.py
vllm/compilation/passes/fusion/minimax_qk_norm_fusion.py
+340
-0
vllm/compilation/passes/pass_manager.py
vllm/compilation/passes/pass_manager.py
+4
-0
vllm/config/compilation.py
vllm/config/compilation.py
+2
-0
vllm/config/vllm.py
vllm/config/vllm.py
+16
-0
vllm/model_executor/layers/mamba/lamport_workspace.py
vllm/model_executor/layers/mamba/lamport_workspace.py
+302
-0
vllm/model_executor/models/minimax_m2.py
vllm/model_executor/models/minimax_m2.py
+1
-3
No files found.
.buildkite/test_areas/kernels.yaml
View file @
ecd1ea13
...
@@ -20,7 +20,20 @@ steps:
...
@@ -20,7 +20,20 @@ steps:
-
tests/kernels/core
-
tests/kernels/core
-
tests/kernels/test_concat_mla_q.py
-
tests/kernels/test_concat_mla_q.py
commands
:
commands
:
-
pytest -v -s kernels/core kernels/test_concat_mla_q.py
-
pytest -v -s kernels/core --ignore=kernels/core/test_minimax_reduce_rms.py kernels/test_concat_mla_q.py
-
label
:
Kernels MiniMax Reduce RMS Test (2 GPUs)
timeout_in_minutes
:
15
num_devices
:
2
device
:
h100
source_file_dependencies
:
-
csrc/minimax_reduce_rms_kernel.cu
-
csrc/minimax_reduce_rms_kernel.h
-
vllm/model_executor/layers/mamba/linear_attn.py
-
vllm/model_executor/layers/mamba/lamport_workspace.py
-
tests/kernels/core/test_minimax_reduce_rms.py
commands
:
-
pytest -v -s kernels/core/test_minimax_reduce_rms.py
-
label
:
Kernels Attention Test %N
-
label
:
Kernels Attention Test %N
timeout_in_minutes
:
35
timeout_in_minutes
:
35
...
...
CMakeLists.txt
View file @
ecd1ea13
...
@@ -307,6 +307,8 @@ set(VLLM_EXT_SRC
...
@@ -307,6 +307,8 @@ set(VLLM_EXT_SRC
"csrc/torch_bindings.cpp"
)
"csrc/torch_bindings.cpp"
)
if
(
VLLM_GPU_LANG STREQUAL
"CUDA"
)
if
(
VLLM_GPU_LANG STREQUAL
"CUDA"
)
list
(
APPEND VLLM_EXT_SRC
"csrc/minimax_reduce_rms_kernel.cu"
)
SET
(
CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL
"Enable only the header library"
)
SET
(
CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL
"Enable only the header library"
)
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
...
...
csrc/minimax_reduce_rms_kernel.cu
0 → 100644
View file @
ecd1ea13
/*
* Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cooperative_groups.h>
#include <cuda_runtime.h>
#include <torch/cuda.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_compat.h"
#include "cuda_utils.h"
#include "core/registration.h"
#include "minimax_reduce_rms_kernel.h"
#include <algorithm>
#define FINAL_MASK 0xffffffff
#define MINIMAX_REDUCE_RMS_WARP_SIZE 32
namespace
vllm
{
namespace
tensorrt_llm
{
template
<
int
NRanks
>
struct
LamportComm
{
__device__
__forceinline__
LamportComm
(
void
**
workspace
,
int
rank
)
{
counter_ptr
=
&
reinterpret_cast
<
int
*>
(
workspace
[
NRanks
*
3
])[
0
];
flag_ptr
=
&
reinterpret_cast
<
int
*>
(
workspace
[
NRanks
*
3
])[
2
];
clear_ptr
=
&
reinterpret_cast
<
int64_t
*>
(
workspace
[
NRanks
*
3
+
1
])[
0
];
flag_value
=
*
flag_ptr
;
auto
comm_size
=
reinterpret_cast
<
int64_t
*>
(
workspace
[
NRanks
*
3
+
1
])[
1
];
clear_size
=
*
clear_ptr
;
int
data_offset
=
flag_value
%
3
;
int
clear_offset
=
(
flag_value
+
2
)
%
3
;
for
(
int
r
=
0
;
r
<
NRanks
;
++
r
)
{
data_bufs
[
r
]
=
reinterpret_cast
<
uint8_t
*>
(
workspace
[
2
*
NRanks
+
r
])
+
data_offset
*
comm_size
;
}
clear_buf
=
reinterpret_cast
<
uint8_t
*>
(
workspace
[
2
*
NRanks
+
rank
])
+
clear_offset
*
comm_size
;
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
atomicAdd
(
counter_ptr
,
1
);
}
}
__device__
__forceinline__
void
update
(
int64_t
new_clear_size
)
{
if
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
{
while
(
*
reinterpret_cast
<
int
volatile
*>
(
counter_ptr
)
!=
gridDim
.
x
)
{
}
*
flag_ptr
=
(
flag_value
+
1
)
%
3
;
*
clear_ptr
=
new_clear_size
;
*
counter_ptr
=
0
;
}
}
int
*
counter_ptr
;
int
*
flag_ptr
;
int64_t
*
clear_ptr
;
uint8_t
*
data_bufs
[
NRanks
];
uint8_t
*
clear_buf
;
int64_t
clear_size
;
int
flag_value
;
};
__device__
__forceinline__
bool
is_neg_zero
(
float
v
)
{
return
*
reinterpret_cast
<
uint32_t
*>
(
&
v
)
==
0x80000000
;
}
__device__
__forceinline__
bool
is_neg_zero
(
float4
v
)
{
return
is_neg_zero
(
v
.
x
)
||
is_neg_zero
(
v
.
y
)
||
is_neg_zero
(
v
.
z
)
||
is_neg_zero
(
v
.
w
);
}
__device__
__forceinline__
float4
get_neg_zero
()
{
float4
vec
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
reinterpret_cast
<
uint32_t
*>
(
&
vec
)[
i
]
=
0x80000000
;
}
return
vec
;
}
template
<
int
Dim
>
__device__
__forceinline__
float
rms_rsqrt
(
float
&
v
,
float
eps
)
{
constexpr
float
kInvDim
=
1.0
F
/
static_cast
<
float
>
(
Dim
);
v
=
rsqrtf
((
v
*
kInvDim
)
+
eps
);
return
v
;
}
template
<
int
Dim
>
__device__
__forceinline__
float4
rms_rsqrt
(
float4
&
v
,
float
eps
)
{
constexpr
float
kInvDim
=
1.0
F
/
static_cast
<
float
>
(
Dim
);
v
.
x
=
rsqrtf
((
v
.
x
*
kInvDim
)
+
eps
);
v
.
y
=
rsqrtf
((
v
.
y
*
kInvDim
)
+
eps
);
v
.
z
=
rsqrtf
((
v
.
z
*
kInvDim
)
+
eps
);
v
.
w
=
rsqrtf
((
v
.
w
*
kInvDim
)
+
eps
);
return
v
;
}
__device__
__forceinline__
float4
ld_global_volatile
(
float4
*
addr
)
{
float4
val
;
asm
volatile
(
"ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];"
:
"=f"
(
val
.
x
),
"=f"
(
val
.
y
),
"=f"
(
val
.
z
),
"=f"
(
val
.
w
)
:
"l"
(
addr
));
return
val
;
}
__device__
__forceinline__
float
ld_global_volatile
(
float
*
addr
)
{
float
val
;
asm
volatile
(
"ld.volatile.global.f32 %0, [%1];"
:
"=f"
(
val
)
:
"l"
(
addr
));
return
val
;
}
// Used by the scalar (non-float4) kernel only
template
<
typename
T
,
int
NUM
>
__inline__
__device__
T
warpReduceSumV2
(
T
*
val
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
#pragma unroll
for
(
int
mask
=
16
;
mask
>
0
;
mask
>>=
1
)
val
[
i
]
+=
__shfl_xor_sync
(
FINAL_MASK
,
val
[
i
],
mask
,
32
);
}
return
(
T
)(
0.0
f
);
}
template
<
typename
T
,
int
NUM
>
__inline__
__device__
T
blockReduceSumV2
(
T
*
val
)
{
static
__shared__
T
shared
[
NUM
][
33
];
int
lane
=
threadIdx
.
x
&
0x1f
;
int
wid
=
threadIdx
.
x
>>
5
;
warpReduceSumV2
<
T
,
NUM
>
(
val
);
if
(
lane
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
shared
[
i
][
wid
]
=
val
[
i
];
}
}
__syncthreads
();
bool
is_mask
=
threadIdx
.
x
<
(
blockDim
.
x
/
32.
f
);
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
val
[
i
]
=
is_mask
?
shared
[
i
][
lane
]
:
(
T
)(
0.0
f
);
}
warpReduceSumV2
<
T
,
NUM
>
(
val
);
return
(
T
)
0.0
f
;
}
// for float4 version
template
<
uint32_t
kNumThreads
,
typename
T
,
int
ArraySize
=
4
>
__device__
__forceinline__
void
local_warp_reduce_sum_array
(
T
*
value_ptr
,
uint32_t
active_mask
=
0xffffffffu
)
{
static_assert
(
kNumThreads
>=
1
&&
kNumThreads
<=
MINIMAX_REDUCE_RMS_WARP_SIZE
);
#pragma unroll
for
(
int
i
=
0
;
i
<
ArraySize
;
++
i
)
{
#pragma unroll
for
(
int
mask
=
kNumThreads
/
2
;
mask
>
0
;
mask
>>=
1
)
{
value_ptr
[
i
]
+=
__shfl_xor_sync
(
active_mask
,
value_ptr
[
i
],
mask
,
MINIMAX_REDUCE_RMS_WARP_SIZE
);
}
}
}
constexpr
int
next_pow2
(
int
val
)
{
int
result
=
1
;
while
(
result
<
val
)
{
result
<<=
1
;
}
return
result
;
}
// ---------------------------------------------------------------------------
template
<
typename
DType
>
class
IndexHelper
{
public:
__device__
__forceinline__
IndexHelper
(
MiniMaxReduceRMSParams
const
&
params
)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
namespace
cg
=
cooperative_groups
;
cg
::
cluster_group
cluster
=
cg
::
this_cluster
();
cg
::
grid_group
grid
=
cg
::
this_grid
();
token_id
=
grid
.
cluster_rank
();
access_id_in_token
=
cluster
.
thread_rank
();
token_stride
=
grid
.
num_clusters
();
#else
token_id
=
blockIdx
.
x
;
access_id_in_token
=
threadIdx
.
x
;
token_stride
=
gridDim
.
x
;
#endif
access_id
=
token_id
*
params
.
hidden_dim
/
kElemsPerAccess
<
DType
>
+
access_id_in_token
;
access_stride
=
token_stride
*
params
.
hidden_dim
/
kElemsPerAccess
<
DType
>
;
tot_access
=
params
.
size_q
/
kElemsPerAccess
<
DType
>
;
}
int
token_id
;
int
access_id_in_token
;
int
token_stride
;
int
access_id
;
int
access_stride
;
int
tot_access
;
};
/**
* this kernel is used to for minimax attention module
* input tensor [total_tokens, hidden_dim / tp_size], fp32
* rms weight [hidden_dim / tp_size], bf16
step 1: reduce from single rank to get the variance sum (reduce(input^2,
dim=-1)) step 2: reduce from all ranks to get the variance sum
(all_reduce(variance_sum)) step 3: calculate the rms norm (input *
rsqrt(variance + eps)) in this case, max hidden_dim is 6144 (float data), for
each token, we only need 6144 / 4 / tp_size = (1536 / tp_size) threads so we can
assume cluster size is 1 (tp_size >= 2)
*/
template
<
typename
DType
,
int
NRanks
>
__global__
void
__launch_bounds__
(
1024
)
minimax_reduce_rms_kernel_lamport
(
MiniMaxReduceRMSParams
params
)
{
IndexHelper
<
DType
>
index_helper
(
params
);
int
token_id
=
index_helper
.
token_id
;
int
access_id_in_token
=
index_helper
.
access_id_in_token
;
int
token_stride
=
index_helper
.
token_stride
;
int
access_id
=
index_helper
.
access_id
;
int
access_stride
=
index_helper
.
access_stride
;
int
tot_access
=
index_helper
.
tot_access
;
int
tot_tokens
=
params
.
size_q
/
params
.
hidden_dim
;
float4
clear_vec
=
get_neg_zero
();
LamportComm
<
NRanks
>
comm
(
params
.
workspace
,
params
.
rank
);
int
clear_access
=
comm
.
clear_size
/
kElemsPerAccess
<
DType
>
;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm
volatile
(
"griddepcontrol.wait;"
);
#endif
for
(
int
idx
=
access_id
;
idx
<
tot_access
;
idx
+=
access_stride
,
token_id
+=
token_stride
)
{
alignas
(
16
)
DType
vals
[
kElemsPerAccess
<
DType
>
];
float
sum_variance
=
0.
F
;
*
reinterpret_cast
<
float4
*>
(
vals
)
=
reinterpret_cast
<
float4
*>
(
params
.
allreduce_in
)[
idx
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kElemsPerAccess
<
DType
>
;
++
i
)
{
sum_variance
+=
static_cast
<
float
>
(
vals
[
i
])
*
static_cast
<
float
>
(
vals
[
i
]);
}
blockReduceSumV2
<
float
,
1
>
(
&
sum_variance
);
if
(
is_neg_zero
(
sum_variance
))
{
sum_variance
=
0.
F
;
}
if
(
threadIdx
.
x
==
0
)
{
for
(
int
r
=
0
;
r
<
NRanks
;
++
r
)
{
reinterpret_cast
<
float
*>
(
comm
.
data_bufs
[
r
])[(
params
.
rank
*
tot_tokens
)
+
token_id
]
=
(
sum_variance
);
}
}
bool
done
=
false
;
float
vars_all_ranks
[
NRanks
];
while
(
!
done
)
{
done
=
true
;
#pragma unroll
for
(
int
r
=
0
;
r
<
NRanks
;
++
r
)
{
vars_all_ranks
[
r
]
=
ld_global_volatile
(
&
reinterpret_cast
<
float
*>
(
comm
.
data_bufs
[
params
.
rank
])[(
r
*
tot_tokens
)
+
token_id
]);
done
&=
!
is_neg_zero
(
vars_all_ranks
[
r
]);
}
}
sum_variance
=
0.
F
;
#pragma unroll
for
(
int
r
=
0
;
r
<
NRanks
;
++
r
)
{
sum_variance
+=
vars_all_ranks
[
r
];
}
DType
norm_weight
[
kElemsPerAccess
<
DType
>
];
*
reinterpret_cast
<
typename
ElemsPerAccess
<
DType
>::
vec_type
*>
(
norm_weight
)
=
reinterpret_cast
<
typename
ElemsPerAccess
<
DType
>::
vec_type
*>
(
params
.
rms_gamma
)[
access_id_in_token
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kElemsPerAccess
<
DType
>
;
++
i
)
{
vals
[
i
]
=
static_cast
<
DType
>
(
static_cast
<
float
>
(
vals
[
i
])
*
rsqrtf
(
(
sum_variance
/
static_cast
<
float
>
(
params
.
hidden_dim
)
/
NRanks
)
+
params
.
rms_eps
)
*
static_cast
<
float
>
(
norm_weight
[
i
]));
}
reinterpret_cast
<
float4
*>
(
params
.
rms_norm_out
)[
idx
]
=
*
reinterpret_cast
<
float4
*>
(
vals
);
}
for
(
int
idx
=
access_id
;
idx
<
clear_access
;
idx
+=
access_stride
)
{
reinterpret_cast
<
float4
*>
(
comm
.
clear_buf
)[
idx
]
=
clear_vec
;
}
comm
.
update
(
params
.
size_q
*
NRanks
);
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm
volatile
(
"griddepcontrol.launch_dependents;"
);
#endif
}
/**
* Float4 variant: process 4 rows at once, allreduce variance sums as float4 for
* better memory coalescing. sum_variance is always float; applies to all DTypes
* (half, bf16, float). When tot_tokens % 4 != 0, the last group pads rows with
* zeros; padded rows are not written to rms_norm_out. IsQK: when true, process
* Q+K in one loop with doubled comm buffer; when false, single-matrix (Q only).
*/
template
<
typename
DType
,
int
NRanks
,
int
OriginQDim
,
int
OriginKDim
>
__global__
void
__launch_bounds__
(
1024
)
minimax_reduce_qk_rms_kernel_lamport_float4
(
MiniMaxReduceRMSParams
params
)
{
// Compile-time per-rank dimensions
constexpr
int
RankQDim
=
OriginQDim
/
NRanks
;
constexpr
int
RankKDim
=
OriginKDim
/
NRanks
;
// Threads needed to cover one row of Q / K with float4 accesses
constexpr
int
ThreadsPerRowQ
=
RankQDim
/
kElemsPerAccess
<
DType
>
;
constexpr
int
ThreadsPerRowK
=
RankKDim
/
kElemsPerAccess
<
DType
>
;
// Number of warps dedicated to Q / K
constexpr
int
NumWarpQ
=
(
ThreadsPerRowQ
+
MINIMAX_REDUCE_RMS_WARP_SIZE
-
1
)
/
MINIMAX_REDUCE_RMS_WARP_SIZE
;
constexpr
int
NumWarpK
=
(
ThreadsPerRowK
+
MINIMAX_REDUCE_RMS_WARP_SIZE
-
1
)
/
MINIMAX_REDUCE_RMS_WARP_SIZE
;
int
tot_tokens
=
params
.
size_q
/
RankQDim
;
int
tot_groups
=
(
tot_tokens
+
3
)
/
4
;
// ceiling; last group may be partial
// Memory strides for strided qkv tensors (elements -> float4-access units)
int
access_stride_q
=
(
params
.
stride_q
>
0
?
params
.
stride_q
:
RankQDim
)
/
kElemsPerAccess
<
DType
>
;
int
access_stride_k
=
(
params
.
stride_k
>
0
?
params
.
stride_k
:
RankKDim
)
/
kElemsPerAccess
<
DType
>
;
// Output strides: default to contiguous (hidden_dim / hidden_dim_k)
int
access_stride_q_out
=
(
params
.
stride_q_out
>
0
?
params
.
stride_q_out
:
params
.
hidden_dim
)
/
kElemsPerAccess
<
DType
>
;
int
access_stride_k_out
=
(
params
.
stride_k_out
>
0
?
params
.
stride_k_out
:
params
.
hidden_dim_k
)
/
kElemsPerAccess
<
DType
>
;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
namespace
cg
=
cooperative_groups
;
cg
::
cluster_group
cluster
=
cg
::
this_cluster
();
cg
::
grid_group
grid
=
cg
::
this_grid
();
int
group_id
=
grid
.
cluster_rank
();
int
access_id_in_token
=
cluster
.
thread_rank
();
int
group_stride
=
grid
.
num_clusters
();
#else
int
group_id
=
blockIdx
.
x
;
int
access_id_in_token
=
threadIdx
.
x
;
int
group_stride
=
gridDim
.
x
;
#endif
bool
is_q
=
(
access_id_in_token
<
NumWarpQ
*
MINIMAX_REDUCE_RMS_WARP_SIZE
);
int
k_thread_idx
=
access_id_in_token
-
(
NumWarpQ
*
MINIMAX_REDUCE_RMS_WARP_SIZE
);
bool
is_valid_q
=
(
access_id_in_token
<
ThreadsPerRowQ
);
bool
is_valid_k
=
(
k_thread_idx
>=
0
&&
k_thread_idx
<
ThreadsPerRowK
);
float4
clear_vec
=
get_neg_zero
();
// Shared memory for two-level block reduction and scale broadcast
__shared__
float
block_reduce_sum
[
4
][
MINIMAX_REDUCE_RMS_WARP_SIZE
+
1
];
__shared__
float
global_scale_q
[
4
];
__shared__
float
global_scale_k
[
4
];
LamportComm
<
NRanks
>
comm
(
params
.
workspace
,
params
.
rank
);
DType
norm_weight
[
kElemsPerAccess
<
DType
>
]{};
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm
volatile
(
"griddepcontrol.wait;"
);
#endif
if
(
is_q
)
{
if
(
is_valid_q
)
{
*
reinterpret_cast
<
typename
ElemsPerAccess
<
DType
>::
vec_type
*>
(
norm_weight
)
=
reinterpret_cast
<
typename
ElemsPerAccess
<
DType
>::
vec_type
const
*>
(
params
.
rms_gamma
)[
access_id_in_token
];
}
}
else
{
if
(
is_valid_k
)
{
*
reinterpret_cast
<
typename
ElemsPerAccess
<
DType
>::
vec_type
*>
(
norm_weight
)
=
reinterpret_cast
<
typename
ElemsPerAccess
<
DType
>::
vec_type
const
*>
(
params
.
rms_gamma_k
)[
k_thread_idx
];
}
}
// Main loop: process one group of 4 tokens per iteration.
for
(
int
g
=
group_id
;
g
<
tot_groups
;
g
+=
group_stride
)
{
alignas
(
16
)
DType
vals
[
4
][
kElemsPerAccess
<
DType
>
]{};
float
warp_sum_variance
[
4
]{
0.
F
,
0.
F
,
0.
F
,
0.
F
};
if
(
is_q
)
{
#pragma unroll
for
(
int
row
=
0
;
row
<
4
;
++
row
)
{
int
token_r
=
g
*
4
+
row
;
if
(
token_r
>=
tot_tokens
||
!
is_valid_q
)
{
continue
;
}
int
idx_r
=
token_r
*
access_stride_q
+
access_id_in_token
;
*
reinterpret_cast
<
float4
*>
(
&
vals
[
row
][
0
])
=
reinterpret_cast
<
float4
const
*>
(
params
.
allreduce_in
)[
idx_r
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kElemsPerAccess
<
DType
>
;
++
i
)
{
float
x
=
static_cast
<
float
>
(
vals
[
row
][
i
]);
warp_sum_variance
[
row
]
+=
x
*
x
;
}
}
}
else
{
#pragma unroll
for
(
int
row
=
0
;
row
<
4
;
++
row
)
{
int
token_r
=
g
*
4
+
row
;
if
(
token_r
>=
tot_tokens
||
!
is_valid_k
)
{
continue
;
}
int
idx_r
=
token_r
*
access_stride_k
+
k_thread_idx
;
*
reinterpret_cast
<
float4
*>
(
&
vals
[
row
][
0
])
=
reinterpret_cast
<
float4
const
*>
(
params
.
allreduce_in_k
)[
idx_r
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kElemsPerAccess
<
DType
>
;
++
i
)
{
float
x
=
static_cast
<
float
>
(
vals
[
row
][
i
]);
warp_sum_variance
[
row
]
+=
x
*
x
;
}
}
}
local_warp_reduce_sum_array
<
MINIMAX_REDUCE_RMS_WARP_SIZE
,
float
,
4
>
(
warp_sum_variance
);
// Warp lane 0 writes its warp's partial sum to shared memory
int
lane
=
threadIdx
.
x
&
(
MINIMAX_REDUCE_RMS_WARP_SIZE
-
1
);
if
(
lane
==
0
)
{
#pragma unroll
for
(
int
t
=
0
;
t
<
4
;
++
t
)
{
block_reduce_sum
[
t
][
threadIdx
.
x
/
MINIMAX_REDUCE_RMS_WARP_SIZE
]
=
warp_sum_variance
[
t
];
}
}
__syncthreads
();
int
tid
=
threadIdx
.
x
;
if
(
tid
<
MINIMAX_REDUCE_RMS_WARP_SIZE
)
{
constexpr
int
kNumWarpQPow2
=
(
next_pow2
(
NumWarpQ
)
>
NRanks
)
?
next_pow2
(
NumWarpQ
)
:
NRanks
;
float
local_sum
[
4
];
#pragma unroll
for
(
int
t
=
0
;
t
<
4
;
++
t
)
{
local_sum
[
t
]
=
(
tid
<
NumWarpQ
)
?
block_reduce_sum
[
t
][
tid
]
:
0.
F
;
}
// After this, all kNumWarpQPow2 lanes (including tid 0..NRanks-1) have
// the total Q sum-of-squares for all 4 tokens.
local_warp_reduce_sum_array
<
kNumWarpQPow2
,
float
,
4
>
(
local_sum
);
if
(
tid
<
NRanks
)
{
#pragma unroll
for
(
int
t
=
0
;
t
<
4
;
++
t
)
{
if
(
is_neg_zero
(
local_sum
[
t
]))
{
local_sum
[
t
]
=
0.
F
;
}
}
// Parallel push: thread tid writes this rank's Q sum to rank tid's buf
reinterpret_cast
<
float4
*>
(
comm
.
data_bufs
[
tid
])[(
params
.
rank
*
tot_groups
*
2
)
+
(
2
*
g
)]
=
*
reinterpret_cast
<
float4
*>
(
local_sum
);
// Parallel pull: thread tid reads rank tid's contribution from
// this rank's (params.rank's) buffer
bool
done
=
false
;
float4
var_all_ranks
;
while
(
!
done
)
{
done
=
true
;
var_all_ranks
=
ld_global_volatile
(
&
reinterpret_cast
<
float4
*>
(
comm
.
data_bufs
[
params
.
rank
])[(
tid
*
tot_groups
*
2
)
+
(
2
*
g
)]);
done
&=
!
is_neg_zero
(
var_all_ranks
);
}
// Warp-level allreduce: each of the NRanks threads holds one rank's
// partial sum; after this all NRanks threads have the global total.
constexpr
uint32_t
kQActiveMask
=
(
1u
<<
NRanks
)
-
1u
;
local_warp_reduce_sum_array
<
NRanks
,
float
,
4
>
(
reinterpret_cast
<
float
*>
(
&
var_all_ranks
),
kQActiveMask
);
// Thread 0 computes rsqrt with compile-time Dim and writes to smem
if
(
tid
==
0
)
{
*
reinterpret_cast
<
float4
*>
(
global_scale_q
)
=
rms_rsqrt
<
OriginQDim
>
(
var_all_ranks
,
params
.
rms_eps
);
}
}
}
else
if
(
tid
>=
MINIMAX_REDUCE_RMS_WARP_SIZE
*
NumWarpQ
&&
tid
<
MINIMAX_REDUCE_RMS_WARP_SIZE
*
(
NumWarpQ
+
1
))
{
// --- K leader warp ---
constexpr
int
kNumWarpKPow2
=
(
next_pow2
(
NumWarpK
)
>
NRanks
)
?
next_pow2
(
NumWarpK
)
:
NRanks
;
float
local_sum
[
4
];
#pragma unroll
for
(
int
t
=
0
;
t
<
4
;
++
t
)
{
local_sum
[
t
]
=
(
k_thread_idx
<
NumWarpK
)
?
block_reduce_sum
[
t
][
NumWarpQ
+
k_thread_idx
]
:
0.
F
;
}
local_warp_reduce_sum_array
<
kNumWarpKPow2
,
float
,
4
>
(
local_sum
);
if
(
k_thread_idx
<
NRanks
)
{
#pragma unroll
for
(
int
t
=
0
;
t
<
4
;
++
t
)
{
if
(
is_neg_zero
(
local_sum
[
t
]))
{
local_sum
[
t
]
=
0.
F
;
}
}
reinterpret_cast
<
float4
*>
(
comm
.
data_bufs
[
k_thread_idx
])[(
params
.
rank
*
tot_groups
*
2
)
+
(
2
*
g
+
1
)]
=
*
reinterpret_cast
<
float4
*>
(
local_sum
);
bool
done
=
false
;
float4
var_all_ranks
;
while
(
!
done
)
{
done
=
true
;
var_all_ranks
=
ld_global_volatile
(
&
reinterpret_cast
<
float4
*>
(
comm
.
data_bufs
[
params
.
rank
])[(
k_thread_idx
*
tot_groups
*
2
)
+
(
2
*
g
+
1
)]);
done
&=
!
is_neg_zero
(
var_all_ranks
);
}
constexpr
uint32_t
kKActiveMask
=
(
1u
<<
NRanks
)
-
1u
;
local_warp_reduce_sum_array
<
NRanks
,
float
,
4
>
(
reinterpret_cast
<
float
*>
(
&
var_all_ranks
),
kKActiveMask
);
if
(
k_thread_idx
==
0
)
{
*
reinterpret_cast
<
float4
*>
(
global_scale_k
)
=
rms_rsqrt
<
OriginKDim
>
(
var_all_ranks
,
params
.
rms_eps
);
}
}
}
__syncthreads
();
if
(
is_q
)
{
#pragma unroll
for
(
int
t
=
0
;
t
<
4
;
++
t
)
{
warp_sum_variance
[
t
]
=
global_scale_q
[
t
];
}
#pragma unroll
for
(
int
r
=
0
;
r
<
4
;
++
r
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
kElemsPerAccess
<
DType
>
;
++
i
)
{
vals
[
r
][
i
]
=
static_cast
<
DType
>
(
static_cast
<
float
>
(
vals
[
r
][
i
])
*
warp_sum_variance
[
r
]
*
static_cast
<
float
>
(
norm_weight
[
i
]));
}
int
token_r
=
g
*
4
+
r
;
if
(
token_r
>=
tot_tokens
||
!
is_valid_q
)
{
continue
;
}
int
idx_out
=
token_r
*
access_stride_q_out
+
access_id_in_token
;
reinterpret_cast
<
float4
*>
(
params
.
rms_norm_out
)[
idx_out
]
=
*
reinterpret_cast
<
float4
*>
(
&
vals
[
r
][
0
]);
}
}
else
{
#pragma unroll
for
(
int
t
=
0
;
t
<
4
;
++
t
)
{
warp_sum_variance
[
t
]
=
global_scale_k
[
t
];
}
#pragma unroll
for
(
int
r
=
0
;
r
<
4
;
++
r
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
kElemsPerAccess
<
DType
>
;
++
i
)
{
vals
[
r
][
i
]
=
static_cast
<
DType
>
(
static_cast
<
float
>
(
vals
[
r
][
i
])
*
warp_sum_variance
[
r
]
*
static_cast
<
float
>
(
norm_weight
[
i
]));
}
int
token_r
=
g
*
4
+
r
;
if
(
token_r
>=
tot_tokens
||
!
is_valid_k
)
{
continue
;
}
int
idx_out
=
token_r
*
access_stride_k_out
+
k_thread_idx
;
reinterpret_cast
<
float4
*>
(
params
.
rms_norm_out_k
)[
idx_out
]
=
*
reinterpret_cast
<
float4
*>
(
&
vals
[
r
][
0
]);
}
}
}
// end group loop
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm
volatile
(
"griddepcontrol.launch_dependents;"
);
#endif
int
clear_access
=
static_cast
<
int
>
(
comm
.
clear_size
/
kElemsPerAccess
<
DType
>
);
int
clear_stride
=
group_stride
*
blockDim
.
x
;
for
(
int
idx
=
group_id
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
clear_access
;
idx
+=
clear_stride
)
{
reinterpret_cast
<
float4
*>
(
comm
.
clear_buf
)[
idx
]
=
clear_vec
;
}
comm
.
update
(
static_cast
<
int64_t
>
(
2
)
*
tot_groups
*
kElemsPerAccess
<
DType
>
*
NRanks
);
}
int
get_sm_count
()
{
static
int
sm_count
=
0
;
if
(
sm_count
==
0
)
{
int
device_id
;
CUDA_CHECK
(
cudaGetDevice
(
&
device_id
));
cudaDeviceProp
device_prop
;
cudaGetDeviceProperties
(
&
device_prop
,
device_id
);
sm_count
=
device_prop
.
multiProcessorCount
;
}
return
sm_count
;
}
inline
int
getSMVersion
(
bool
queryRealSmArch
=
false
)
{
int
device
{
-
1
};
CUDA_CHECK
(
cudaGetDevice
(
&
device
));
int
sm_major
=
0
;
int
sm_minor
=
0
;
CUDA_CHECK
(
cudaDeviceGetAttribute
(
&
sm_major
,
cudaDevAttrComputeCapabilityMajor
,
device
));
CUDA_CHECK
(
cudaDeviceGetAttribute
(
&
sm_minor
,
cudaDevAttrComputeCapabilityMinor
,
device
));
int
sm
=
sm_major
*
10
+
sm_minor
;
if
(
sm
==
121
&&
!
queryRealSmArch
)
{
return
120
;
}
return
sm
;
}
template
<
typename
KernelFunc
>
int
get_max_active_blocks
(
KernelFunc
kernel
,
int
block_size
,
int
dynamic_smem
=
0
)
{
int
max_active
=
0
;
CUDA_CHECK
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
max_active
,
kernel
,
block_size
,
dynamic_smem
));
return
std
::
max
(
max_active
,
1
);
}
template
<
typename
DType
,
int
NRanks
>
void
minimax_reduce_rms_kernel_launcher
(
MiniMaxReduceRMSParams
const
&
params
)
{
static
int
SM
=
getSMVersion
();
int
token_num
=
params
.
size_q
/
params
.
hidden_dim
;
int
sm_count
=
get_sm_count
();
int
cluster_size
=
1
;
int
cluster_num
=
token_num
;
int
threads_per_token
=
params
.
hidden_dim
/
kElemsPerAccess
<
DType
>
;
int
block_size
=
threads_per_token
;
int
max_blocks_per_sm
=
get_max_active_blocks
(
minimax_reduce_rms_kernel_lamport
<
DType
,
NRanks
>
,
block_size
);
int
max_grid
=
max_blocks_per_sm
*
sm_count
;
int
grid_size
=
(
std
::
min
(
max_grid
,
cluster_num
*
cluster_size
)
/
cluster_size
)
*
cluster_size
;
cudaLaunchConfig_t
cfg
;
cfg
.
gridDim
=
grid_size
;
cfg
.
blockDim
=
block_size
;
cfg
.
dynamicSmemBytes
=
0
;
cfg
.
stream
=
params
.
stream
;
cudaLaunchAttribute
attribute
[
2
];
attribute
[
0
].
id
=
cudaLaunchAttributeProgrammaticStreamSerialization
;
attribute
[
0
].
val
.
programmaticStreamSerializationAllowed
=
1
;
attribute
[
1
].
id
=
cudaLaunchAttributeClusterDimension
;
attribute
[
1
].
val
.
clusterDim
.
x
=
cluster_size
;
attribute
[
1
].
val
.
clusterDim
.
y
=
1
;
attribute
[
1
].
val
.
clusterDim
.
z
=
1
;
cfg
.
attrs
=
attribute
;
cfg
.
numAttrs
=
SM
>=
90
?
2
:
0
;
CUDA_CHECK
(
cudaLaunchKernelEx
(
&
cfg
,
minimax_reduce_rms_kernel_lamport
<
DType
,
NRanks
>
,
params
));
}
template
<
typename
DType
,
int
NRanks
,
int
OriginQDim
,
int
OriginKDim
>
void
minimax_reduce_rms_kernel_launcher_float4
(
MiniMaxReduceRMSParams
const
&
params
)
{
TORCH_CHECK
(
params
.
size_q
%
params
.
hidden_dim
==
0
);
TORCH_CHECK
(
params
.
hidden_dim
%
kElemsPerAccess
<
DType
>
==
0
);
if
(
params
.
stride_q
>
0
)
{
TORCH_CHECK
(
params
.
stride_q
%
kElemsPerAccess
<
DType
>
==
0
);
}
TORCH_CHECK
(
params
.
allreduce_in_k
!=
nullptr
,
"float4 QK kernel requires K input"
);
TORCH_CHECK
(
params
.
hidden_dim
>=
params
.
hidden_dim_k
);
TORCH_CHECK
(
params
.
size_k
%
params
.
hidden_dim_k
==
0
);
TORCH_CHECK
(
params
.
hidden_dim_k
%
kElemsPerAccess
<
DType
>
==
0
);
TORCH_CHECK
(
params
.
size_q
/
params
.
hidden_dim
==
params
.
size_k
/
params
.
hidden_dim_k
);
if
(
params
.
stride_k
>
0
)
{
TORCH_CHECK
(
params
.
stride_k
%
kElemsPerAccess
<
DType
>
==
0
);
}
int
token_num
=
params
.
size_q
/
params
.
hidden_dim
;
int
tot_groups
=
(
token_num
+
3
)
/
4
;
if
(
tot_groups
==
0
)
{
return
;
}
static
int
SM
=
getSMVersion
();
int
sm_count
=
get_sm_count
();
int
cluster_size
=
1
;
int
cluster_num
=
tot_groups
;
int
access_per_row_q
=
params
.
hidden_dim
/
kElemsPerAccess
<
DType
>
;
int
access_per_row_k
=
params
.
hidden_dim_k
/
kElemsPerAccess
<
DType
>
;
// Round each section up to a warp boundary
auto
divUp
=
[](
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
*
b
;
};
int
block_size
=
divUp
(
access_per_row_q
,
MINIMAX_REDUCE_RMS_WARP_SIZE
)
+
divUp
(
access_per_row_k
,
MINIMAX_REDUCE_RMS_WARP_SIZE
);
auto
kfn
=
minimax_reduce_qk_rms_kernel_lamport_float4
<
DType
,
NRanks
,
OriginQDim
,
OriginKDim
>
;
int
max_blocks_per_sm
=
get_max_active_blocks
(
kfn
,
block_size
);
int
max_grid
=
max_blocks_per_sm
*
sm_count
;
int
grid_size
=
(
std
::
min
(
max_grid
,
cluster_num
*
cluster_size
)
/
cluster_size
)
*
cluster_size
;
cudaLaunchConfig_t
cfg
;
cfg
.
gridDim
=
grid_size
;
cfg
.
blockDim
=
block_size
;
cfg
.
dynamicSmemBytes
=
0
;
cfg
.
stream
=
params
.
stream
;
cudaLaunchAttribute
attribute
[
2
];
attribute
[
0
].
id
=
cudaLaunchAttributeProgrammaticStreamSerialization
;
attribute
[
0
].
val
.
programmaticStreamSerializationAllowed
=
1
;
attribute
[
1
].
id
=
cudaLaunchAttributeClusterDimension
;
attribute
[
1
].
val
.
clusterDim
.
x
=
cluster_size
;
attribute
[
1
].
val
.
clusterDim
.
y
=
1
;
attribute
[
1
].
val
.
clusterDim
.
z
=
1
;
cfg
.
attrs
=
attribute
;
cfg
.
numAttrs
=
SM
>=
90
?
2
:
0
;
CUDA_CHECK
(
cudaLaunchKernelEx
(
&
cfg
,
kfn
,
params
));
}
template
<
int
NRanks
>
void
dispatch_dtype
(
MiniMaxReduceRMSParams
const
&
params
)
{
// Use the optimized QK float4 kernel when:
// - K input is present, AND
// - the full (NRanks * per-rank) dimensions match the MiniMax M2 shape.
// Otherwise fall back to the scalar kernel.
bool
use_float4
=
(
params
.
allreduce_in_k
!=
nullptr
)
&&
(
params
.
hidden_dim
*
params
.
nranks
==
6144
)
&&
(
params
.
hidden_dim_k
*
params
.
nranks
==
1024
);
if
(
params
.
dtype
==
at
::
ScalarType
::
Half
)
{
if
(
use_float4
)
{
minimax_reduce_rms_kernel_launcher_float4
<
half
,
NRanks
,
6144
,
1024
>
(
params
);
}
else
{
minimax_reduce_rms_kernel_launcher
<
half
,
NRanks
>
(
params
);
}
}
else
if
(
params
.
dtype
==
at
::
ScalarType
::
BFloat16
)
{
if
(
use_float4
)
{
minimax_reduce_rms_kernel_launcher_float4
<
__nv_bfloat16
,
NRanks
,
6144
,
1024
>
(
params
);
}
else
{
minimax_reduce_rms_kernel_launcher
<
__nv_bfloat16
,
NRanks
>
(
params
);
}
}
else
if
(
params
.
dtype
==
at
::
ScalarType
::
Float
)
{
if
(
use_float4
)
{
minimax_reduce_rms_kernel_launcher_float4
<
float
,
NRanks
,
6144
,
1024
>
(
params
);
}
else
{
minimax_reduce_rms_kernel_launcher
<
float
,
NRanks
>
(
params
);
}
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type for minimax_reduce_rms_op"
);
}
}
void
minimax_reduce_rms_op
(
MiniMaxReduceRMSParams
const
&
params
)
{
if
(
params
.
nranks
==
2
)
{
dispatch_dtype
<
2
>
(
params
);
}
else
if
(
params
.
nranks
==
4
)
{
dispatch_dtype
<
4
>
(
params
);
}
else
if
(
params
.
nranks
==
8
)
{
dispatch_dtype
<
8
>
(
params
);
}
else
if
(
params
.
nranks
==
16
)
{
dispatch_dtype
<
16
>
(
params
);
}
else
{
TORCH_CHECK
(
false
,
"minimax_reduce_rms_op: unsupported ranks number!"
);
}
}
}
// namespace tensorrt_llm
}
// namespace vllm
torch
::
Tensor
minimax_allreduce_rms
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
norm_weight
,
torch
::
Tensor
workspace
,
int64_t
const
rank
,
int64_t
const
nranks
,
double
const
eps
)
{
auto
allreduce_params
=
vllm
::
tensorrt_llm
::
MiniMaxReduceRMSParams
();
allreduce_params
.
nranks
=
static_cast
<
int
>
(
nranks
);
allreduce_params
.
rank
=
static_cast
<
int
>
(
rank
);
allreduce_params
.
dtype
=
input
.
scalar_type
();
allreduce_params
.
size_q
=
static_cast
<
int
>
(
input
.
numel
());
allreduce_params
.
hidden_dim
=
static_cast
<
int
>
(
input
.
size
(
-
1
));
allreduce_params
.
stride_q
=
allreduce_params
.
hidden_dim
;
allreduce_params
.
workspace
=
reinterpret_cast
<
void
**>
(
workspace
.
mutable_data_ptr
());
allreduce_params
.
allreduce_in
=
input
.
data_ptr
();
allreduce_params
.
rms_gamma
=
norm_weight
.
data_ptr
();
allreduce_params
.
rms_eps
=
static_cast
<
float
>
(
eps
);
allreduce_params
.
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
input
.
get_device
());
torch
::
Tensor
rms_norm_out
=
torch
::
empty_like
(
input
);
allreduce_params
.
rms_norm_out
=
rms_norm_out
.
mutable_data_ptr
();
vllm
::
tensorrt_llm
::
minimax_reduce_rms_op
(
allreduce_params
);
return
rms_norm_out
;
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
minimax_allreduce_rms_qk
(
torch
::
Tensor
qkv
,
torch
::
Tensor
const
&
norm_weight_q
,
torch
::
Tensor
const
&
norm_weight_k
,
torch
::
Tensor
workspace
,
int64_t
const
q_size
,
int64_t
const
kv_size
,
int64_t
const
rank
,
int64_t
const
nranks
,
double
const
eps
)
{
TORCH_CHECK
(
qkv
.
dim
()
==
2
,
"minimax_allreduce_rms_qk: qkv must be 2D"
);
TORCH_CHECK
(
qkv
.
is_contiguous
(),
"minimax_allreduce_rms_qk: qkv must be contiguous"
);
int64_t
qkv_dim
=
qkv
.
size
(
-
1
);
TORCH_CHECK
(
qkv_dim
==
q_size
+
2
*
kv_size
,
"minimax_allreduce_rms_qk: qkv last dim must equal "
"q_size + 2 * kv_size"
);
TORCH_CHECK
(
rank
<
nranks
,
"minimax_allreduce_rms_qk: rank must be less than nranks"
);
int64_t
num_tokens
=
qkv
.
size
(
0
);
int
elem_bytes
=
qkv
.
element_size
();
torch
::
Tensor
q_out
=
torch
::
empty
({
num_tokens
,
q_size
},
qkv
.
options
());
torch
::
Tensor
k_out
=
torch
::
empty
({
num_tokens
,
kv_size
},
qkv
.
options
());
auto
params
=
vllm
::
tensorrt_llm
::
MiniMaxReduceRMSParams
();
params
.
nranks
=
static_cast
<
int
>
(
nranks
);
params
.
rank
=
static_cast
<
int
>
(
rank
);
params
.
dtype
=
qkv
.
scalar_type
();
params
.
size_q
=
static_cast
<
int
>
(
num_tokens
*
q_size
);
params
.
hidden_dim
=
static_cast
<
int
>
(
q_size
);
params
.
size_k
=
static_cast
<
int
>
(
num_tokens
*
kv_size
);
params
.
hidden_dim_k
=
static_cast
<
int
>
(
kv_size
);
params
.
stride_q
=
static_cast
<
int
>
(
qkv_dim
);
params
.
stride_k
=
static_cast
<
int
>
(
qkv_dim
);
params
.
stride_q_out
=
0
;
// q_out is contiguous; kernel uses hidden_dim
params
.
stride_k_out
=
0
;
// k_out is contiguous; kernel uses hidden_dim_k
params
.
workspace
=
reinterpret_cast
<
void
**>
(
workspace
.
mutable_data_ptr
());
uint8_t
*
base
=
static_cast
<
uint8_t
*>
(
qkv
.
data_ptr
());
params
.
allreduce_in
=
base
;
params
.
allreduce_in_k
=
base
+
q_size
*
elem_bytes
;
params
.
rms_gamma
=
norm_weight_q
.
data_ptr
();
params
.
rms_gamma_k
=
norm_weight_k
.
data_ptr
();
params
.
rms_eps
=
static_cast
<
float
>
(
eps
);
params
.
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
qkv
.
get_device
());
params
.
rms_norm_out
=
q_out
.
mutable_data_ptr
();
params
.
rms_norm_out_k
=
k_out
.
mutable_data_ptr
();
vllm
::
tensorrt_llm
::
minimax_reduce_rms_op
(
params
);
return
{
q_out
,
k_out
};
}
csrc/minimax_reduce_rms_kernel.h
0 → 100644
View file @
ecd1ea13
/*
* Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <torch/types.h>
namespace
vllm
{
namespace
tensorrt_llm
{
template
<
typename
DType
>
struct
ElemsPerAccess
;
template
<
>
struct
ElemsPerAccess
<
half
>
{
static
constexpr
int
value
=
8
;
using
vec_type
=
float4
;
};
template
<
>
struct
ElemsPerAccess
<
nv_bfloat16
>
{
static
constexpr
int
value
=
8
;
using
vec_type
=
float4
;
};
template
<
>
struct
ElemsPerAccess
<
float
>
{
static
constexpr
int
value
=
4
;
using
vec_type
=
float4
;
};
template
<
typename
DType
>
static
constexpr
int
kElemsPerAccess
=
ElemsPerAccess
<
DType
>::
value
;
struct
MiniMaxReduceRMSParams
{
int
nranks
{};
int
rank
{};
at
::
ScalarType
dtype
{
at
::
ScalarType
::
Undefined
};
int
size_q
{};
int
hidden_dim
{};
int
size_k
{};
int
hidden_dim_k
{};
int
stride_q
{};
// row stride for q input (elements); when > hidden_dim,
// q is part of a wider qkv tensor
int
stride_k
{};
// row stride for k input (elements); when > hidden_dim_k,
// k is part of a wider qkv tensor
int
stride_q_out
{};
// row stride for q output (elements); 0 = contiguous
int
stride_k_out
{};
// row stride for k output (elements); 0 = contiguous
void
**
workspace
{};
void
*
allreduce_in
{};
void
*
rms_norm_out
{};
void
*
rms_gamma
{};
void
*
allreduce_in_k
{};
void
*
rms_norm_out_k
{};
void
*
rms_gamma_k
{};
float
rms_eps
{};
cudaStream_t
stream
{};
};
void
minimax_reduce_rms_op
(
MiniMaxReduceRMSParams
const
&
params
);
}
// namespace tensorrt_llm
}
// namespace vllm
csrc/ops.h
View file @
ecd1ea13
...
@@ -309,3 +309,15 @@ int64_t qr_max_size();
...
@@ -309,3 +309,15 @@ int64_t qr_max_size();
void
dsv3_fused_a_gemm
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
mat_a
,
void
dsv3_fused_a_gemm
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
mat_a
,
torch
::
Tensor
const
&
mat_b
);
torch
::
Tensor
const
&
mat_b
);
#endif
#endif
#ifndef USE_ROCM
torch
::
Tensor
minimax_allreduce_rms
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
norm_weight
,
torch
::
Tensor
workspace
,
int64_t
const
rank
,
int64_t
const
nranks
,
double
const
eps
);
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
minimax_allreduce_rms_qk
(
torch
::
Tensor
qkv
,
torch
::
Tensor
const
&
norm_weight_q
,
torch
::
Tensor
const
&
norm_weight_k
,
torch
::
Tensor
workspace
,
int64_t
const
q_size
,
int64_t
const
kv_size
,
int64_t
const
rank
,
int64_t
const
nranks
,
double
const
eps
);
#endif
\ No newline at end of file
csrc/torch_bindings.cpp
View file @
ecd1ea13
...
@@ -496,6 +496,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -496,6 +496,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? b_qzeros, "
"Tensor? b_qzeros, "
"SymInt n, SymInt group_size, SymInt sm_count, SymInt sm_version, SymInt "
"SymInt n, SymInt group_size, SymInt sm_count, SymInt sm_version, SymInt "
"CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) -> Tensor"
);
"CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) -> Tensor"
);
ops
.
def
(
"minimax_allreduce_rms("
"Tensor input,"
"Tensor norm_weight,"
"Tensor workspace,"
"int rank,"
"int nranks,"
"float eps) -> Tensor"
);
ops
.
impl
(
"minimax_allreduce_rms"
,
torch
::
kCUDA
,
&
minimax_allreduce_rms
);
ops
.
def
(
"minimax_allreduce_rms_qk("
"Tensor qkv,"
"Tensor norm_weight_q,"
"Tensor norm_weight_k,"
"Tensor workspace,"
"int q_size,"
"int kv_size,"
"int rank,"
"int nranks,"
"float eps) -> (Tensor, Tensor)"
);
ops
.
impl
(
"minimax_allreduce_rms_qk"
,
torch
::
kCUDA
,
&
minimax_allreduce_rms_qk
);
// conditionally compiled so impl in source file
// conditionally compiled so impl in source file
#endif
#endif
}
}
...
...
tests/kernels/core/test_minimax_reduce_rms.py
0 → 100644
View file @
ecd1ea13
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for MiniMax QK RMS-norm: NCCL reference vs Lamport fused kernel."""
import
pytest
import
torch
import
torch.nn
as
nn
from
torch.multiprocessing
import
spawn
from
tests.kernels.utils
import
opcheck
from
tests.utils
import
ensure_current_vllm_config
,
init_test_distributed_environment
from
vllm.distributed
import
cleanup_dist_env_and_memory
from
vllm.model_executor.layers.mamba.linear_attn
import
MiniMaxText01RMSNormTP
from
vllm.platforms
import
current_platform
from
vllm.utils.network_utils
import
get_open_port
from
vllm.utils.torch_utils
import
set_random_seed
@
ensure_current_vllm_config
()
def
_worker_forward_qk
(
local_rank
,
world_size
,
port
,
num_tokens
,
hidden_q_full
,
hidden_k_full
,
dtype
,
seed
,
eps
,
):
"""Per-rank worker: compare NCCL allreduce path vs Lamport fused kernel."""
if
not
hasattr
(
torch
.
ops
.
_C
,
"minimax_allreduce_rms_qk"
):
cleanup_dist_env_and_memory
()
return
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
torch
.
accelerator
.
set_device_index
(
device
)
init_test_distributed_environment
(
world_size
,
1
,
local_rank
,
port
,
local_rank
=
local_rank
)
hq
=
hidden_q_full
//
world_size
hk
=
hidden_k_full
//
world_size
q_norm
=
MiniMaxText01RMSNormTP
(
hidden_q_full
,
eps
=
eps
).
cuda
()
k_norm
=
MiniMaxText01RMSNormTP
(
hidden_k_full
,
eps
=
eps
).
cuda
()
set_random_seed
(
seed
)
qw
=
torch
.
randn
(
hidden_q_full
,
dtype
=
dtype
,
device
=
"cuda"
)
kw
=
torch
.
randn
(
hidden_k_full
,
dtype
=
dtype
,
device
=
"cuda"
)
q_norm
.
weight
=
nn
.
Parameter
(
qw
[
local_rank
*
hq
:
(
local_rank
+
1
)
*
hq
])
k_norm
.
weight
=
nn
.
Parameter
(
kw
[
local_rank
*
hk
:
(
local_rank
+
1
)
*
hk
])
torch
.
manual_seed
(
seed
+
1000
+
local_rank
)
qkv
=
torch
.
randn
(
num_tokens
,
hq
+
hk
+
hk
,
dtype
=
dtype
,
device
=
"cuda"
)
q_ref
,
k_ref
,
v_ref
=
qkv
.
clone
().
split
([
hq
,
hk
,
hk
],
dim
=-
1
)
ref_q
,
ref_k
=
MiniMaxText01RMSNormTP
.
forward_qk
(
q_norm
,
k_norm
,
q_ref
,
k_ref
)
# Set up Lamport workspace.
from
vllm.distributed.parallel_state
import
get_tp_group
from
vllm.model_executor.layers.mamba.lamport_workspace
import
(
get_allreduce_workspace
,
)
workspace
=
get_allreduce_workspace
(
rank
=
local_rank
,
world_size
=
world_size
,
max_tokens
=
num_tokens
,
process_group
=
get_tp_group
().
cpu_group
,
)
opcheck
(
torch
.
ops
.
_C
.
minimax_allreduce_rms_qk
,
(
qkv
.
clone
(),
q_norm
.
weight
,
k_norm
.
weight
,
workspace
,
hq
,
hk
,
local_rank
,
world_size
,
eps
,
),
)
fused_q
,
fused_k
=
torch
.
ops
.
_C
.
minimax_allreduce_rms_qk
(
qkv
.
clone
(),
q_norm
.
weight
,
k_norm
.
weight
,
workspace
,
hq
,
hk
,
local_rank
,
world_size
,
eps
,
)
_
,
_
,
fused_v
=
qkv
.
split
([
hq
,
hk
,
hk
],
dim
=-
1
)
torch
.
accelerator
.
synchronize
()
torch
.
testing
.
assert_close
(
fused_q
,
ref_q
,
atol
=
3e-2
,
rtol
=
3e-2
,
)
torch
.
testing
.
assert_close
(
fused_k
,
ref_k
,
atol
=
3e-2
,
rtol
=
3e-2
)
cleanup_dist_env_and_memory
()
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"CUDA required"
,
)
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
1
,
128
,
333
])
@
pytest
.
mark
.
parametrize
(
"hidden_dims"
,
[(
6144
,
1024
)],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"eps"
,
[
1e-6
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
42
])
def
test_minimax_reduce_rms_qk
(
world_size
,
num_tokens
,
hidden_dims
,
dtype
,
eps
,
seed
,
):
num_gpus
=
current_platform
.
device_count
()
if
num_gpus
<
world_size
:
pytest
.
skip
(
f
"Need >=
{
world_size
}
GPUs, have
{
num_gpus
}
"
)
hidden_q_full
,
hidden_k_full
=
hidden_dims
port
=
str
(
get_open_port
())
spawn
(
_worker_forward_qk
,
args
=
(
world_size
,
port
,
num_tokens
,
hidden_q_full
,
hidden_k_full
,
dtype
,
seed
,
eps
,
),
nprocs
=
world_size
,
join
=
True
,
)
vllm/_custom_ops.py
View file @
ecd1ea13
...
@@ -3491,3 +3491,38 @@ if hasattr(torch.ops._C, "hadacore_transform"):
...
@@ -3491,3 +3491,38 @@ if hasattr(torch.ops._C, "hadacore_transform"):
@
register_fake
(
"_C::hadacore_transform"
)
@
register_fake
(
"_C::hadacore_transform"
)
def
_hadacore_transform_fake
(
x
:
torch
.
Tensor
,
inplace
:
bool
)
->
torch
.
Tensor
:
def
_hadacore_transform_fake
(
x
:
torch
.
Tensor
,
inplace
:
bool
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
x
)
if
not
inplace
else
x
return
torch
.
empty_like
(
x
)
if
not
inplace
else
x
if
hasattr
(
torch
.
ops
.
_C
,
"minimax_allreduce_rms"
):
@
register_fake
(
"_C::minimax_allreduce_rms"
)
def
_minimax_allreduce_rms_fake
(
input
:
torch
.
Tensor
,
norm_weight
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
rank
:
int
,
nranks
:
int
,
eps
:
float
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
input
)
if
hasattr
(
torch
.
ops
.
_C
,
"minimax_allreduce_rms_qk"
):
@
register_fake
(
"_C::minimax_allreduce_rms_qk"
)
def
_minimax_allreduce_rms_qk_fake
(
qkv
:
torch
.
Tensor
,
norm_weight_q
:
torch
.
Tensor
,
norm_weight_k
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
q_size
:
int
,
kv_size
:
int
,
rank
:
int
,
nranks
:
int
,
eps
:
float
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
token_num
=
qkv
.
shape
[
0
]
return
(
torch
.
empty
([
token_num
,
q_size
],
dtype
=
qkv
.
dtype
,
device
=
qkv
.
device
),
torch
.
empty
([
token_num
,
kv_size
],
dtype
=
qkv
.
dtype
,
device
=
qkv
.
device
),
)
vllm/compilation/passes/fusion/minimax_qk_norm_fusion.py
0 → 100644
View file @
ecd1ea13
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Fusion pass: replace MiniMax QK allreduce + RMS norm with the Lamport
fused kernel (minimax_allreduce_rms_qk) for decode-size batches.
Pattern (inlined forward_qk in compiled graph):
q, k, v = qkv.split([q_size, kv_size, kv_size], -1)
q_fp32 = q.to(float32); k_fp32 = k.to(float32)
q_var = q_fp32.pow(2).mean(-1, keepdim=True)
k_var = k_fp32.pow(2).mean(-1, keepdim=True)
qk_var = cat([q_var, k_var], -1)
qk_var = allreduce(qk_var) / tp_world
q_var, k_var = qk_var.chunk(2, -1)
q_out = (q_fp32 * rsqrt(q_var + eps) * q_weight).to(orig_dtype)
k_out = (k_fp32 * rsqrt(k_var + eps) * k_weight).to(orig_dtype)
return q_out, k_out, v
Replacement (pure, no in-place on qkv/q/k):
q_out, k_out = minimax_qk_norm_fused(qkv, q_weight, k_weight, workspace, ...)
v = qkv.split([q_size, kv_size, kv_size], -1)[2]
return q_out, k_out, v
is_applicable_for_range: only fires for compile_range.end <= max_decode_tokens
so that large prefill batches fall through to the original forward_qk (= main).
"""
import
torch
import
torch._inductor.pattern_matcher
as
pm
import
torch.fx
as
fx
from
torch._inductor.pattern_matcher
import
PatternMatcherPass
from
vllm.config
import
VllmConfig
from
vllm.config.utils
import
Range
from
vllm.distributed
import
tensor_model_parallel_all_reduce
from
vllm.distributed.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
vllm.logger
import
init_logger
from
vllm.utils.torch_utils
import
direct_register_custom_op
from
..inductor_pass
import
enable_fake_mode
from
..vllm_inductor_pass
import
VllmInductorPass
,
VllmPatternMatcherPass
logger
=
init_logger
(
__name__
)
MAX_TOKEN_NUM
=
2048
_MINIMAX_QK_NORM_FUSED_OP
=
None
if
hasattr
(
torch
.
ops
.
_C
,
"minimax_allreduce_rms_qk"
):
def
_minimax_qk_norm_fused
(
qkv
:
torch
.
Tensor
,
norm_weight_q
:
torch
.
Tensor
,
norm_weight_k
:
torch
.
Tensor
,
q_size
:
int
,
kv_size
:
int
,
rank
:
int
,
nranks
:
int
,
eps
:
float
,
max_tokens
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
vllm.distributed.parallel_state
import
get_tp_group
from
vllm.model_executor.layers.mamba.lamport_workspace
import
(
get_allreduce_workspace
,
)
workspace
=
get_allreduce_workspace
(
rank
=
rank
,
world_size
=
nranks
,
max_tokens
=
max_tokens
,
process_group
=
get_tp_group
().
cpu_group
,
)
return
torch
.
ops
.
_C
.
minimax_allreduce_rms_qk
(
qkv
,
norm_weight_q
,
norm_weight_k
,
workspace
,
q_size
,
kv_size
,
rank
,
nranks
,
eps
,
)
def
_minimax_qk_norm_fused_fake
(
qkv
:
torch
.
Tensor
,
norm_weight_q
:
torch
.
Tensor
,
norm_weight_k
:
torch
.
Tensor
,
q_size
:
int
,
kv_size
:
int
,
rank
:
int
,
nranks
:
int
,
eps
:
float
,
max_tokens
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
T
=
qkv
.
shape
[
0
]
return
(
torch
.
empty
([
T
,
q_size
],
dtype
=
qkv
.
dtype
,
device
=
qkv
.
device
),
torch
.
empty
([
T
,
kv_size
],
dtype
=
qkv
.
dtype
,
device
=
qkv
.
device
),
)
direct_register_custom_op
(
op_name
=
"minimax_qk_norm_fused"
,
op_func
=
_minimax_qk_norm_fused
,
fake_impl
=
_minimax_qk_norm_fused_fake
,
mutates_args
=
[],
)
_MINIMAX_QK_NORM_FUSED_OP
=
torch
.
ops
.
vllm
.
minimax_qk_norm_fused
.
default
class
MiniMaxQKNormPattern
:
"""
Match the forward_qk allreduce+rms pattern and replace with Lamport kernel.
"""
def
__init__
(
self
,
q_size
:
int
,
kv_size
:
int
,
eps
:
float
,
tp_world
:
int
,
tp_rank
:
int
,
max_tokens
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
|
None
,
)
->
None
:
self
.
q_size
=
q_size
self
.
kv_size
=
kv_size
self
.
eps
=
eps
self
.
tp_world
=
tp_world
self
.
tp_rank
=
tp_rank
self
.
max_tokens
=
max_tokens
self
.
dtype
=
dtype
self
.
device
=
device
def
get_inputs
(
self
)
->
list
[
torch
.
Tensor
]:
T
=
4
qkv
=
torch
.
empty
(
[
T
,
self
.
q_size
+
2
*
self
.
kv_size
],
device
=
self
.
device
,
dtype
=
self
.
dtype
,
)
q_weight
=
torch
.
empty
([
self
.
q_size
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
k_weight
=
torch
.
empty
([
self
.
kv_size
],
device
=
self
.
device
,
dtype
=
self
.
dtype
)
return
[
qkv
,
q_weight
,
k_weight
]
def
register
(
self
,
pm_pass
:
PatternMatcherPass
)
->
None
:
q_size
=
self
.
q_size
kv_size
=
self
.
kv_size
eps
=
self
.
eps
tp_world
=
self
.
tp_world
max_tokens
=
self
.
max_tokens
tp_rank
=
self
.
tp_rank
dtype
=
self
.
dtype
def
pattern
(
qkv
:
torch
.
Tensor
,
q_weight
:
torch
.
Tensor
,
k_weight
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
q
,
k
,
v
=
qkv
.
split
([
q_size
,
kv_size
,
kv_size
],
dim
=-
1
)
q_fp32
=
q
.
to
(
torch
.
float32
)
k_fp32
=
k
.
to
(
torch
.
float32
)
q_var
=
q_fp32
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
k_var
=
k_fp32
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
qk_var
=
torch
.
cat
([
q_var
,
k_var
],
dim
=-
1
)
qk_var
=
tensor_model_parallel_all_reduce
(
qk_var
)
/
tp_world
q_var
,
k_var
=
qk_var
.
chunk
(
2
,
dim
=-
1
)
q_out
=
(
q_fp32
*
torch
.
rsqrt
(
q_var
+
eps
)
*
q_weight
).
to
(
dtype
)
k_out
=
(
k_fp32
*
torch
.
rsqrt
(
k_var
+
eps
)
*
k_weight
).
to
(
dtype
)
return
q_out
,
k_out
,
v
def
replacement
(
qkv
:
torch
.
Tensor
,
q_weight
:
torch
.
Tensor
,
k_weight
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
assert
_MINIMAX_QK_NORM_FUSED_OP
is
not
None
q_out
,
k_out
=
torch
.
ops
.
vllm
.
minimax_qk_norm_fused
(
qkv
,
q_weight
,
k_weight
,
q_size
,
kv_size
,
tp_rank
,
tp_world
,
eps
,
max_tokens
,
)
_
,
_
,
v
=
qkv
.
split
([
q_size
,
kv_size
,
kv_size
],
dim
=-
1
)
return
q_out
,
k_out
,
v
pm
.
register_replacement
(
pattern
,
replacement
,
self
.
get_inputs
(),
pm
.
fwd_only
,
pm_pass
)
# Second pattern: three separate split_with_sizes nodes (one per output),
# each with _users=1. This occurs when the QKV projection uses a
# functional GEMM kernel (e.g. cutlass_scaled_mm via auto_functionalized),
# which causes inductor to generate one split per consumer.
def
pattern_split3
(
qkv
:
torch
.
Tensor
,
q_weight
:
torch
.
Tensor
,
k_weight
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
q
=
qkv
.
split
([
q_size
,
kv_size
,
kv_size
],
dim
=-
1
)[
0
]
k
=
qkv
.
split
([
q_size
,
kv_size
,
kv_size
],
dim
=-
1
)[
1
]
v
=
qkv
.
split
([
q_size
,
kv_size
,
kv_size
],
dim
=-
1
)[
2
]
q_fp32
=
q
.
to
(
torch
.
float32
)
k_fp32
=
k
.
to
(
torch
.
float32
)
q_var
=
q_fp32
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
k_var
=
k_fp32
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
qk_var
=
torch
.
cat
([
q_var
,
k_var
],
dim
=-
1
)
qk_var
=
tensor_model_parallel_all_reduce
(
qk_var
)
/
tp_world
q_var
,
k_var
=
qk_var
.
chunk
(
2
,
dim
=-
1
)
q_out
=
(
q_fp32
*
torch
.
rsqrt
(
q_var
+
eps
)
*
q_weight
).
to
(
dtype
)
k_out
=
(
k_fp32
*
torch
.
rsqrt
(
k_var
+
eps
)
*
k_weight
).
to
(
dtype
)
return
q_out
,
k_out
,
v
pm
.
register_replacement
(
pattern_split3
,
replacement
,
self
.
get_inputs
(),
pm
.
fwd_only
,
pm_pass
)
class
MiniMaxQKNormPass
(
VllmPatternMatcherPass
):
"""
Replace forward_qk allreduce+norm with the Lamport fused kernel.
Only applied for decode-size compile ranges (small token counts).
"""
def
__init__
(
self
,
config
:
VllmConfig
)
->
None
:
super
().
__init__
(
config
)
self
.
disabled
=
True
if
_MINIMAX_QK_NORM_FUSED_OP
is
None
:
logger
.
warning_once
(
"minimax_allreduce_rms_qk op not found, MiniMaxQKNormPass disabled."
)
return
tp_world
=
get_tensor_model_parallel_world_size
()
if
tp_world
<=
1
:
logger
.
warning_once
(
"MiniMaxQKNormPass disabled: tp_size <= 1."
)
return
if
config
.
model_config
is
None
:
logger
.
warning_once
(
"MiniMaxQKNormPass disabled: no model_config."
)
return
hf_cfg
=
config
.
model_config
.
hf_config
model_name
=
getattr
(
hf_cfg
,
"architectures"
,
""
)[
0
]
if
model_name
!=
"MiniMaxM2ForCausalLM"
:
return
num_attention_heads
=
getattr
(
hf_cfg
,
"num_attention_heads"
,
0
)
num_key_value_heads
=
getattr
(
hf_cfg
,
"num_key_value_heads"
,
0
)
hidden_size
=
getattr
(
hf_cfg
,
"hidden_size"
,
0
)
head_dim
=
getattr
(
hf_cfg
,
"head_dim"
,
0
)
eps
:
float
=
getattr
(
hf_cfg
,
"rms_norm_eps"
,
1e-6
)
if
(
num_attention_heads
!=
48
or
num_key_value_heads
!=
8
or
hidden_size
!=
3072
or
head_dim
!=
128
):
logger
.
warning_once
(
"MiniMaxQKNormPass disabled: cannot infer model info from hf_config."
)
return
num_heads_per_rank
=
num_attention_heads
//
tp_world
num_kv_heads_per_rank
=
max
(
1
,
num_key_value_heads
//
tp_world
)
q_size
=
num_heads_per_rank
*
head_dim
kv_size
=
num_kv_heads_per_rank
*
head_dim
self
.
max_token_num
=
min
(
MAX_TOKEN_NUM
,
config
.
scheduler_config
.
max_num_batched_tokens
)
tp_rank
=
get_tensor_model_parallel_rank
()
# Allocate Lamport workspace first.
from
vllm.distributed.parallel_state
import
get_tp_group
from
vllm.model_executor.layers.mamba.lamport_workspace
import
(
get_allreduce_workspace
,
)
get_allreduce_workspace
(
rank
=
tp_rank
,
world_size
=
tp_world
,
max_tokens
=
self
.
max_token_num
,
process_group
=
get_tp_group
().
cpu_group
,
)
self
.
patterns
:
PatternMatcherPass
=
PatternMatcherPass
(
pass_name
=
"minimax_qk_norm_pass"
)
self
.
_register_patterns
(
q_size
,
kv_size
,
eps
,
tp_world
,
tp_rank
)
self
.
dump_patterns
(
config
,
self
.
patterns
)
self
.
disabled
=
False
@
enable_fake_mode
def
_register_patterns
(
self
,
q_size
:
int
,
kv_size
:
int
,
eps
:
float
,
tp_world
:
int
,
tp_rank
:
int
,
)
->
None
:
MiniMaxQKNormPattern
(
q_size
=
q_size
,
kv_size
=
kv_size
,
eps
=
eps
,
tp_world
=
tp_world
,
tp_rank
=
tp_rank
,
max_tokens
=
self
.
max_token_num
,
dtype
=
self
.
model_dtype
,
device
=
self
.
device
,
).
register
(
self
.
patterns
)
def
is_applicable_for_range
(
self
,
compile_range
:
Range
)
->
bool
:
if
self
.
disabled
:
return
False
return
bool
(
compile_range
.
end
<=
self
.
max_token_num
)
@
VllmInductorPass
.
time_and_log
def
__call__
(
self
,
graph
:
fx
.
Graph
)
->
None
:
if
self
.
disabled
:
return
self
.
matched_count
=
self
.
patterns
.
apply
(
graph
)
logger
.
debug
(
"MiniMaxQKNormPass replaced %s patterns"
,
self
.
matched_count
)
def
uuid
(
self
)
->
str
:
return
VllmInductorPass
.
hash_source
(
self
,
MiniMaxQKNormPattern
)
vllm/compilation/passes/pass_manager.py
View file @
ecd1ea13
...
@@ -38,6 +38,7 @@ if current_platform.is_cuda_alike():
...
@@ -38,6 +38,7 @@ if current_platform.is_cuda_alike():
if
current_platform
.
is_cuda
():
if
current_platform
.
is_cuda
():
from
.fusion.allreduce_rms_fusion
import
AllReduceFusionPass
from
.fusion.allreduce_rms_fusion
import
AllReduceFusionPass
from
.fusion.collective_fusion
import
AsyncTPPass
from
.fusion.collective_fusion
import
AsyncTPPass
from
.fusion.minimax_qk_norm_fusion
import
MiniMaxQKNormPass
from
.inductor_pass
import
(
from
.inductor_pass
import
(
CustomGraphPass
,
CustomGraphPass
,
...
@@ -137,6 +138,9 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
...
@@ -137,6 +138,9 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
if
self
.
pass_config
.
fuse_allreduce_rms
:
if
self
.
pass_config
.
fuse_allreduce_rms
:
self
.
passes
+=
[
AllReduceFusionPass
(
config
)]
self
.
passes
+=
[
AllReduceFusionPass
(
config
)]
if
self
.
pass_config
.
fuse_minimax_qk_norm
:
self
.
passes
+=
[
MiniMaxQKNormPass
(
config
)]
if
self
.
pass_config
.
fuse_norm_quant
:
if
self
.
pass_config
.
fuse_norm_quant
:
self
.
passes
+=
[
RMSNormQuantFusionPass
(
config
)]
self
.
passes
+=
[
RMSNormQuantFusionPass
(
config
)]
if
rocm_aiter_ops
.
is_enabled
():
if
rocm_aiter_ops
.
is_enabled
():
...
...
vllm/config/compilation.py
View file @
ecd1ea13
...
@@ -134,6 +134,8 @@ class PassConfig:
...
@@ -134,6 +134,8 @@ class PassConfig:
"""Enable async TP."""
"""Enable async TP."""
fuse_allreduce_rms
:
bool
=
None
# type: ignore[assignment]
fuse_allreduce_rms
:
bool
=
None
# type: ignore[assignment]
"""Enable flashinfer allreduce fusion."""
"""Enable flashinfer allreduce fusion."""
fuse_minimax_qk_norm
:
bool
=
None
# type: ignore[assignment]
"""Enable fused allreduce+RMSNorm for MiniMax QK norm."""
enable_qk_norm_rope_fusion
:
bool
=
False
enable_qk_norm_rope_fusion
:
bool
=
False
"""Enable fused Q/K RMSNorm + RoPE pass."""
"""Enable fused Q/K RMSNorm + RoPE pass."""
...
...
vllm/config/vllm.py
View file @
ecd1ea13
...
@@ -1627,6 +1627,22 @@ class VllmConfig:
...
@@ -1627,6 +1627,22 @@ class VllmConfig:
compile_range_end
,
compile_range_end
,
)
)
if
compilation_config
.
pass_config
.
fuse_minimax_qk_norm
:
from
vllm.compilation.passes.fusion.minimax_qk_norm_fusion
import
(
MAX_TOKEN_NUM
,
)
max_token_num
=
min
(
MAX_TOKEN_NUM
,
self
.
scheduler_config
.
max_num_batched_tokens
)
if
compile_range_end
is
not
None
and
max_token_num
<
compile_range_end
:
computed_compile_ranges_endpoints
.
append
(
max_token_num
)
else
:
logger
.
debug
(
"Max num batched tokens below MiniMax QK norm fusion threshold, "
"MiniMax QK norm fusion enabled for all num_tokens."
)
if
compilation_config
.
compile_ranges_endpoints
is
not
None
:
if
compilation_config
.
compile_ranges_endpoints
is
not
None
:
for
x
in
compilation_config
.
compile_ranges_endpoints
:
for
x
in
compilation_config
.
compile_ranges_endpoints
:
assert
isinstance
(
x
,
int
)
assert
isinstance
(
x
,
int
)
...
...
vllm/model_executor/layers/mamba/lamport_workspace.py
0 → 100644
View file @
ecd1ea13
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
array
import
contextlib
import
struct
import
sys
import
threading
import
torch
try
:
from
cuda.bindings
import
runtime
as
cudart
except
ImportError
:
from
cuda
import
cudart
_ALIGN
=
1
<<
21
# 2 MiB — CUDA IPC allocation alignment
# ---------------------------------------------------------------------------
# CUDA helpers
# ---------------------------------------------------------------------------
def
_check
(
error
):
"""Raise on CUDA runtime error."""
success
=
getattr
(
cudart
.
cudaError_t
,
"cudaSuccess"
,
None
)
or
cudart
.
cudaError_t
(
0
)
if
error
!=
success
:
raise
RuntimeError
(
f
"CUDA runtime error:
{
error
}
"
)
def
_cuda_malloc
(
size
:
int
):
aligned
=
((
size
+
_ALIGN
-
1
)
>>
21
)
<<
21
err
,
ptr
=
cudart
.
cudaMalloc
(
aligned
)
_check
(
err
)
return
ptr
,
aligned
def
_cuda_free
(
ptr
:
int
):
if
ptr
:
_check
(
cudart
.
cudaFree
(
ptr
)[
0
])
def
_cuda_memset_zero
(
ptr
:
int
,
size
:
int
):
_check
(
cudart
.
cudaMemset
(
ptr
,
0
,
size
)[
0
])
def
_cuda_memcpy_d2d
(
dst
:
int
,
src
:
int
,
size
:
int
):
_check
(
cudart
.
cudaMemcpy
(
dst
,
src
,
size
,
cudart
.
cudaMemcpyKind
.
cudaMemcpyDeviceToDevice
)[
0
]
)
# ---------------------------------------------------------------------------
# IPC buffer
# ---------------------------------------------------------------------------
class
IpcBuffer
:
"""
Allocates CUDA device memory and exchanges IPC handles with all ranks
so that every rank holds a valid device pointer to every other rank's buffer.
"""
def
__init__
(
self
,
rank
:
int
,
world_size
:
int
,
size
:
int
,
process_group
=
None
):
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
peer_ptrs
:
list
[
int
]
=
[
0
]
*
world_size
self
.
local_ptr
:
int
=
0
self
.
_alive
=
False
if
size
<=
0
:
return
self
.
local_ptr
,
_
=
_cuda_malloc
(
size
)
_cuda_memset_zero
(
self
.
local_ptr
,
size
)
self
.
_alive
=
True
# --- exchange IPC handles via torch.distributed ---
err
,
local_handle
=
cudart
.
cudaIpcGetMemHandle
(
self
.
local_ptr
)
_check
(
err
)
all_handles
:
list
[
bytes
|
None
]
=
[
None
]
*
world_size
torch
.
distributed
.
all_gather_object
(
all_handles
,
bytes
(
local_handle
.
reserved
),
group
=
process_group
)
for
r
in
range
(
world_size
):
if
r
==
rank
:
self
.
peer_ptrs
[
r
]
=
self
.
local_ptr
else
:
handle
=
cudart
.
cudaIpcMemHandle_t
()
handle
.
reserved
=
all_handles
[
r
]
err
,
ptr
=
cudart
.
cudaIpcOpenMemHandle
(
handle
,
cudart
.
cudaIpcMemLazyEnablePeerAccess
)
_check
(
err
)
self
.
peer_ptrs
[
r
]
=
ptr
def
serialize
(
self
)
->
list
[
int
]:
"""Return peer pointers as a list of int64 values (one per rank)."""
raw
=
b
""
for
ptr
in
self
.
peer_ptrs
:
raw
+=
struct
.
pack
(
"P"
,
ptr
)
return
array
.
array
(
"Q"
,
raw
).
tolist
()
def
cleanup
(
self
):
if
not
self
.
_alive
:
return
self
.
_alive
=
False
for
r
in
range
(
self
.
world_size
):
if
self
.
peer_ptrs
[
r
]
==
0
:
continue
if
r
==
self
.
rank
:
_cuda_free
(
self
.
peer_ptrs
[
r
])
else
:
with
contextlib
.
suppress
(
RuntimeError
):
_check
(
cudart
.
cudaIpcCloseMemHandle
(
self
.
peer_ptrs
[
r
])[
0
])
self
.
peer_ptrs
[
r
]
=
0
self
.
local_ptr
=
0
def
__del__
(
self
):
if
not
sys
.
is_finalizing
():
self
.
cleanup
()
# ---------------------------------------------------------------------------
# Lamport negative-zero initialization
# ---------------------------------------------------------------------------
def
_lamport_fill_neg_zero
(
device_ptr
:
int
,
size_bytes
:
int
):
"""
Fill device memory with IEEE-754 negative zero (-0.0f = 0x80000000).
This is the "slot empty" sentinel for the Lamport protocol: the kernel
spin-waits until a value is *not* negative zero.
"""
if
size_bytes
==
0
or
device_ptr
==
0
:
return
n_floats
=
size_bytes
//
4
# torch preserves -0.0 in IEEE-754
fill
=
torch
.
full
((
n_floats
,),
-
0.0
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
_cuda_memcpy_d2d
(
device_ptr
,
fill
.
data_ptr
(),
size_bytes
)
del
fill
# ---------------------------------------------------------------------------
# LamportWorkspace — the main class
# ---------------------------------------------------------------------------
class
LamportWorkspace
:
"""
Self-contained workspace for Lamport-based cross-GPU AllReduce.
Parameters
----------
rank : int
Local rank (0-based).
world_size : int
Total number of ranks in the TP group.
comm_size : int
Size in bytes of *one* Lamport buffer slot. The total IPC allocation
per rank is ``3 * comm_size`` (triple-buffering). Must be large enough
to hold the per-slot data written by the kernel. Use
``compute_comm_size_for_minimax()`` for a safe default.
process_group : optional
``torch.distributed`` process group for IPC handle exchange.
``None`` uses the default group.
"""
def
__init__
(
self
,
rank
:
int
,
world_size
:
int
,
comm_size
:
int
,
process_group
=
None
):
assert
world_size
>=
2
,
"Lamport workspace requires at least 2 ranks"
assert
comm_size
>
0
,
"comm_size must be positive"
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
comm_size
=
comm_size
# 1) Lamport triple-buffer (the only IPC memory the kernel reads/writes)
lamport_total
=
3
*
comm_size
self
.
_lamport
=
IpcBuffer
(
rank
,
world_size
,
lamport_total
,
process_group
)
_lamport_fill_neg_zero
(
self
.
_lamport
.
local_ptr
,
lamport_total
)
# 2) flag_buffer on device: int32[3] = {counter, unused, lamport_flag}
# counter — used for block-level sync inside the kernel
# unused — reserved (index 1)
# lamport_flag — triple-buffer rotation index (0 → 1 → 2 → 0 …)
self
.
_flag_buf
=
torch
.
zeros
(
3
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# 3) layout_buffer on device: int64[2] = {clear_size, comm_size}
# clear_size — bytes to clear from *previous* slot (set by kernel)
# comm_size — size of one triple-buffer slot
self
.
_layout_buf
=
torch
.
tensor
(
[
0
,
comm_size
],
dtype
=
torch
.
int64
,
device
=
"cuda"
)
# 4) Assemble device-side void* pointer array
N
=
world_size
ptrs
:
list
[
int
]
=
[]
ptrs
+=
[
0
]
*
N
# [0 .. N-1] ipc_buffers (placeholder)
ptrs
+=
[
0
]
*
N
# [N .. 2N-1] ipc_barriers (placeholder)
ptrs
+=
self
.
_lamport
.
serialize
()
# [2N .. 3N-1] lamport peer ptrs
ptrs
.
append
(
self
.
_flag_buf
.
data_ptr
())
# [3N] flag_buffer
ptrs
.
append
(
self
.
_layout_buf
.
data_ptr
())
# [3N+1] layout_buffer
self
.
_workspace
=
torch
.
tensor
(
ptrs
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
@
property
def
workspace
(
self
)
->
torch
.
Tensor
:
"""Device tensor (int64) that can be passed to the kernel
as ``void** workspace``."""
return
self
.
_workspace
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
@
staticmethod
def
compute_comm_size_for_minimax
(
max_tokens
:
int
,
world_size
:
int
,
fused_qk
:
bool
=
True
,
)
->
int
:
"""
Return a safe ``comm_size`` (in bytes) for MiniMaxReduceRMSKernel.
The kernel stores per-token variance scalars in the Lamport buffer:
- single-matrix path: ``world_size × max_tokens × 4`` bytes per slot
- fused Q+K path: ``world_size × 2 × ceil(max_tokens/4) × 16`` bytes per slot
The returned value is rounded up to 2 MiB alignment.
"""
if
fused_qk
:
groups
=
(
max_tokens
+
3
)
//
4
slot_bytes
=
world_size
*
2
*
groups
*
16
# 16 = sizeof(float4)
else
:
slot_bytes
=
world_size
*
max_tokens
*
4
# 4 = sizeof(float)
return
((
slot_bytes
+
_ALIGN
-
1
)
>>
21
)
<<
21
def
cleanup
(
self
):
if
hasattr
(
self
,
"_lamport"
):
self
.
_lamport
.
cleanup
()
def
__del__
(
self
):
if
not
sys
.
is_finalizing
():
self
.
cleanup
()
def
__repr__
(
self
):
return
(
f
"LamportWorkspace(rank=
{
self
.
rank
}
, world_size=
{
self
.
world_size
}
, "
f
"comm_size=
{
self
.
comm_size
}
)"
)
# ---------------------------------------------------------------------------
# Cached convenience function (mirrors TRT-LLM's get_allreduce_workspace)
# ---------------------------------------------------------------------------
_cache_lock
=
threading
.
Lock
()
_workspace_cache
:
dict
=
{}
def
get_allreduce_workspace
(
rank
:
int
,
world_size
:
int
,
comm_size
:
int
|
None
=
None
,
max_tokens
:
int
=
16384
,
process_group
=
None
,
)
->
torch
.
Tensor
:
"""
Return a cached workspace tensor for the given (rank, world_size) pair.
On first call the workspace is allocated and IPC handles are exchanged;
subsequent calls with the same arguments return the cached tensor.
Parameters
----------
rank, world_size : int
TP rank and TP size.
comm_size : int, optional
Explicit slot size in bytes. If ``None``, computed automatically
from ``max_tokens`` and ``world_size`` (fused Q+K path).
max_tokens : int
Maximum number of tokens per batch (used when ``comm_size is None``).
process_group : optional
``torch.distributed`` process group.
"""
if
comm_size
is
None
:
comm_size
=
LamportWorkspace
.
compute_comm_size_for_minimax
(
max_tokens
,
world_size
,
fused_qk
=
True
)
pg_id
=
id
(
process_group
)
if
process_group
is
not
None
else
0
key
=
(
rank
,
world_size
,
comm_size
,
pg_id
)
with
_cache_lock
:
if
key
not
in
_workspace_cache
:
ws
=
LamportWorkspace
(
rank
,
world_size
,
comm_size
,
process_group
)
_workspace_cache
[
key
]
=
ws
return
_workspace_cache
[
key
].
workspace
vllm/model_executor/models/minimax_m2.py
View file @
ecd1ea13
...
@@ -233,9 +233,7 @@ class MiniMaxM2Attention(nn.Module):
...
@@ -233,9 +233,7 @@ class MiniMaxM2Attention(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
MiniMaxText01RMSNormTP
.
forward_qk
(
q
,
k
=
MiniMaxText01RMSNormTP
.
forward_qk
(
self
.
q_norm
,
self
.
k_norm
,
q
,
k
)
self
.
q_norm
,
self
.
k_norm
,
q
.
contiguous
(),
k
.
contiguous
()
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment