Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f347ac6c
Unverified
Commit
f347ac6c
authored
Jan 07, 2026
by
Michael Goin
Committed by
GitHub
Jan 07, 2026
Browse files
[Perf] Fuse stride preparation for NVFP4 cutlass_moe (#31837)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
05f47bd8
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
30 additions
and
18 deletions
+30
-18
csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
+30
-18
No files found.
csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu
View file @
f347ac6c
...
...
@@ -62,7 +62,9 @@ __global__ void __get_group_gemm_starts(
ElementSF
*
a_scales_base_as_int
,
ElementSF
*
b_scales_base_as_int
,
ElementAccumulator
*
alphas_base_as_int
,
const
int32_t
*
expert_offsets
,
const
int32_t
*
sf_offsets
,
const
int32_t
*
problem_sizes_as_shapes
,
const
int
K
,
const
int
N
)
{
int64_t
*
a_strides
,
int64_t
*
b_strides
,
int64_t
*
c_strides
,
const
int64_t
a_stride_val
,
const
int64_t
b_stride_val
,
const
int64_t
c_stride_val
,
const
int
K
,
const
int
N
)
{
int64_t
expert_id
=
threadIdx
.
x
;
if
(
expert_id
>=
gridDim
.
x
*
blockDim
.
x
)
{
return
;
...
...
@@ -103,6 +105,11 @@ __global__ void __get_group_gemm_starts(
// Shape of alpha = [E]
alpha_offsets
[
expert_id
]
=
alphas_base_as_int
+
expert_id
;
// Initialize strides (constant across all experts, avoids separate kernels)
a_strides
[
expert_id
]
=
a_stride_val
;
b_strides
[
expert_id
]
=
b_stride_val
;
c_strides
[
expert_id
]
=
c_stride_val
;
LayoutSFA
*
layout_sfa_ptr
=
layout_sfa_base_as_int
+
expert_id
;
LayoutSFB
*
layout_sfb_ptr
=
layout_sfb_base_as_int
+
expert_id
;
...
...
@@ -135,7 +142,11 @@ __global__ void __get_group_gemm_starts(
static_cast<float*>(alphas.data_ptr()), \
static_cast<int32_t*>(expert_offsets.data_ptr()), \
static_cast<int32_t*>(sf_offsets.data_ptr()), \
static_cast<int32_t*>(problem_sizes.data_ptr()), K, N); \
static_cast<int32_t*>(problem_sizes.data_ptr()), \
static_cast<int64_t*>(a_strides.data_ptr()), \
static_cast<int64_t*>(b_strides.data_ptr()), \
static_cast<int64_t*>(c_strides.data_ptr()), a_stride_val, \
b_stride_val, c_stride_val, K, N); \
}
template
<
typename
LayoutSFA
,
typename
LayoutSFB
,
typename
ScaleConfig
>
...
...
@@ -144,6 +155,9 @@ void run_get_group_gemm_starts(
const
torch
::
Tensor
&
out_starts
,
const
torch
::
Tensor
&
a_scales_starts
,
const
torch
::
Tensor
&
b_scales_starts
,
const
torch
::
Tensor
&
alpha_starts
,
const
torch
::
Tensor
&
layout_sfa
,
const
torch
::
Tensor
&
layout_sfb
,
const
torch
::
Tensor
&
a_strides
,
const
torch
::
Tensor
&
b_strides
,
const
torch
::
Tensor
&
c_strides
,
int64_t
a_stride_val
,
int64_t
b_stride_val
,
int64_t
c_stride_val
,
/*these are used for their base addresses*/
torch
::
Tensor
const
&
a_tensors
,
torch
::
Tensor
const
&
b_tensors
,
torch
::
Tensor
const
&
out_tensors
,
torch
::
Tensor
const
&
a_scales
,
...
...
@@ -269,17 +283,16 @@ void run_fp4_blockwise_scaled_group_mm_sm100(
torch
::
Tensor
alpha_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
layout_sfa
=
torch
::
empty
({
num_experts
,
5
},
options_int
);
torch
::
Tensor
layout_sfb
=
torch
::
empty
({
num_experts
,
5
},
options_int
);
torch
::
Tensor
c_strides1
=
torch
::
full
({
num_experts
},
output
.
stride
(
0
),
options_int
);
torch
::
Tensor
a_strides1
=
torch
::
full
({
num_experts
},
a
.
stride
(
0
)
*
2
,
options_int
);
torch
::
Tensor
b_strides1
=
torch
::
full
({
num_experts
},
b
.
stride
(
1
)
*
2
,
options_int
);
torch
::
Tensor
a_strides1
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
b_strides1
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
c_strides1
=
torch
::
empty
(
num_experts
,
options_int
);
run_get_group_gemm_starts
<
LayoutSFA
,
LayoutSFB
,
ScaleConfig
>
(
a_ptrs
,
b_ptrs
,
out_ptrs
,
a_scales_ptrs
,
b_scales_ptrs
,
alpha_ptrs
,
layout_sfa
,
layout_sfb
,
a
,
b
,
output
,
a_blockscale
,
b_blockscales
,
alphas
,
expert_offsets
,
sf_offsets
,
problem_sizes
,
M
,
N
,
K
);
layout_sfa
,
layout_sfb
,
a_strides1
,
b_strides1
,
c_strides1
,
a
.
stride
(
0
)
*
2
,
b
.
stride
(
1
)
*
2
,
output
.
stride
(
0
),
a
,
b
,
output
,
a_blockscale
,
b_blockscales
,
alphas
,
expert_offsets
,
sf_offsets
,
problem_sizes
,
M
,
N
,
K
);
// Create an instance of the GEMM
Gemm
gemm_op
;
...
...
@@ -444,17 +457,16 @@ void run_fp4_blockwise_scaled_group_mm_sm120(
torch
::
Tensor
alpha_ptrs
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
layout_sfa
=
torch
::
empty
({
num_experts
,
5
},
options_int
);
torch
::
Tensor
layout_sfb
=
torch
::
empty
({
num_experts
,
5
},
options_int
);
torch
::
Tensor
c_strides1
=
torch
::
full
({
num_experts
},
output
.
stride
(
0
),
options_int
);
torch
::
Tensor
a_strides1
=
torch
::
full
({
num_experts
},
a
.
stride
(
0
)
*
2
,
options_int
);
torch
::
Tensor
b_strides1
=
torch
::
full
({
num_experts
},
b
.
stride
(
1
)
*
2
,
options_int
);
torch
::
Tensor
a_strides1
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
b_strides1
=
torch
::
empty
(
num_experts
,
options_int
);
torch
::
Tensor
c_strides1
=
torch
::
empty
(
num_experts
,
options_int
);
run_get_group_gemm_starts
<
LayoutSFA
,
LayoutSFB
,
ScaleConfig
>
(
a_ptrs
,
b_ptrs
,
out_ptrs
,
a_scales_ptrs
,
b_scales_ptrs
,
alpha_ptrs
,
layout_sfa
,
layout_sfb
,
a
,
b
,
output
,
a_blockscale
,
b_blockscales
,
alphas
,
expert_offsets
,
sf_offsets
,
problem_sizes
,
M
,
N
,
K
);
layout_sfa
,
layout_sfb
,
a_strides1
,
b_strides1
,
c_strides1
,
a
.
stride
(
0
)
*
2
,
b
.
stride
(
1
)
*
2
,
output
.
stride
(
0
),
a
,
b
,
output
,
a_blockscale
,
b_blockscales
,
alphas
,
expert_offsets
,
sf_offsets
,
problem_sizes
,
M
,
N
,
K
);
// Create an instance of the GEMM
Gemm
gemm_op
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment