Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a8bffaa1
Unverified
Commit
a8bffaa1
authored
Apr 17, 2026
by
Michael Goin
Committed by
GitHub
Apr 17, 2026
Browse files
[Kernel] Add MXFP4 W4A4 CUTLASS MoE kernel for SM100 (#37463)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
5cdddddd
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1701 additions
and
14 deletions
+1701
-14
.buildkite/test_areas/kernels.yaml
.buildkite/test_areas/kernels.yaml
+1
-0
CMakeLists.txt
CMakeLists.txt
+3
-1
csrc/libtorch_stable/ops.h
csrc/libtorch_stable/ops.h
+23
-0
csrc/libtorch_stable/quantization/fp4/mxfp4_blockwise_moe_kernel.cu
...rch_stable/quantization/fp4/mxfp4_blockwise_moe_kernel.cu
+468
-0
csrc/libtorch_stable/quantization/fp4/mxfp4_experts_quant.cu
csrc/libtorch_stable/quantization/fp4/mxfp4_experts_quant.cu
+422
-0
csrc/libtorch_stable/torch_bindings.cpp
csrc/libtorch_stable/torch_bindings.cpp
+22
-0
tests/kernels/moe/test_mxfp4_moe.py
tests/kernels/moe/test_mxfp4_moe.py
+248
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+135
-0
vllm/model_executor/layers/fused_moe/config.py
vllm/model_executor/layers/fused_moe/config.py
+19
-0
vllm/model_executor/layers/fused_moe/cutlass_moe.py
vllm/model_executor/layers/fused_moe/cutlass_moe.py
+295
-0
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py
...mpressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py
+65
-13
No files found.
.buildkite/test_areas/kernels.yaml
View file @
a8bffaa1
...
@@ -141,6 +141,7 @@ steps:
...
@@ -141,6 +141,7 @@ steps:
-
pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py
-
pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py
-
pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py
-
pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py
-
pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
-
pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
-
pytest -v -s tests/kernels/moe/test_mxfp4_moe.py
-
pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
-
pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
-
pytest -v -s tests/kernels/moe/test_flashinfer.py
-
pytest -v -s tests/kernels/moe/test_flashinfer.py
-
pytest -v -s tests/kernels/moe/test_flashinfer_moe.py
-
pytest -v -s tests/kernels/moe/test_flashinfer_moe.py
...
...
CMakeLists.txt
View file @
a8bffaa1
...
@@ -952,7 +952,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
...
@@ -952,7 +952,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/libtorch_stable/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu"
)
"csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu"
"csrc/libtorch_stable/quantization/fp4/mxfp4_experts_quant.cu"
"csrc/libtorch_stable/quantization/fp4/mxfp4_blockwise_moe_kernel.cu"
)
set_gencode_flags_for_srcs
(
set_gencode_flags_for_srcs
(
SRCS
"
${
SRCS
}
"
SRCS
"
${
SRCS
}
"
CUDA_ARCHS
"
${
FP4_ARCHS
}
"
)
CUDA_ARCHS
"
${
FP4_ARCHS
}
"
)
...
...
csrc/libtorch_stable/ops.h
View file @
a8bffaa1
...
@@ -134,4 +134,27 @@ void silu_and_mul_nvfp4_quant(torch::stable::Tensor& out,
...
@@ -134,4 +134,27 @@ void silu_and_mul_nvfp4_quant(torch::stable::Tensor& out,
torch
::
stable
::
Tensor
&
input
,
torch
::
stable
::
Tensor
&
input
,
torch
::
stable
::
Tensor
&
input_global_scale
);
torch
::
stable
::
Tensor
&
input_global_scale
);
void
mxfp4_experts_quant
(
torch
::
stable
::
Tensor
&
output
,
torch
::
stable
::
Tensor
&
output_scale
,
torch
::
stable
::
Tensor
const
&
input
,
torch
::
stable
::
Tensor
const
&
input_offset_by_experts
,
torch
::
stable
::
Tensor
const
&
output_scale_offset_by_experts
,
int64_t
n_experts
);
void
silu_and_mul_mxfp4_experts_quant
(
torch
::
stable
::
Tensor
&
output
,
torch
::
stable
::
Tensor
&
output_scale
,
torch
::
stable
::
Tensor
const
&
input
,
torch
::
stable
::
Tensor
const
&
input_offset_by_experts
,
torch
::
stable
::
Tensor
const
&
output_scale_offset_by_experts
,
int64_t
n_experts
);
void
cutlass_mxfp4_group_mm
(
torch
::
stable
::
Tensor
&
output
,
const
torch
::
stable
::
Tensor
&
a
,
const
torch
::
stable
::
Tensor
&
b
,
const
torch
::
stable
::
Tensor
&
a_blockscale
,
const
torch
::
stable
::
Tensor
&
b_blockscales
,
const
torch
::
stable
::
Tensor
&
problem_sizes
,
const
torch
::
stable
::
Tensor
&
expert_offsets
,
const
torch
::
stable
::
Tensor
&
sf_offsets
);
#endif
#endif
csrc/libtorch_stable/quantization/fp4/mxfp4_blockwise_moe_kernel.cu
0 → 100644
View file @
a8bffaa1
This diff is collapsed.
Click to expand it.
csrc/libtorch_stable/quantization/fp4/mxfp4_experts_quant.cu
0 → 100644
View file @
a8bffaa1
/*
* SPDX-License-Identifier: Apache-2.0
* SPDX-FileCopyrightText: Copyright contributors to the vLLM project
*
* MXFP4 activation quantization kernel for MoE experts.
* Quantizes BF16/FP16 activations to MXFP4: E2M1 values with E8M0 block scales
* over 32-element groups.
*
* Uses PACK16 E2M1 conversion helpers (nvfp4_utils.cuh) configured for:
* - Block size 32 (2 threads per SF in PACK16 mode)
* - E8M0 (power-of-two) scale factors
* - SF layout: [numMTiles, numKTiles, 32, 4, 4] where numKTiles=ceil(K/128)
*/
// MXFP4 requires PACK16 mode (16 elements per thread) so that
// 2 threads cover 32-element blocks. This requires CUDA >= 12.9.
// Must be defined before any header that (transitively) includes
// nvfp4_utils.cuh.
#define NVFP4_ENABLE_ELTS16 1
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <cuda_runtime.h>
#include <cuda_fp8.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include "cuda_vec_utils.cuh"
#include "cuda_utils.h"
#include "nvfp4_utils.cuh"
static_assert
(
CVT_FP4_ELTS_PER_THREAD
==
16
,
"MXFP4 experts quant requires PACK16 mode (CUDA >= 12.9)"
);
#include "launch_bounds_utils.h"
namespace
vllm
{
// MXFP4 block size constants
static
constexpr
int
MXFP4_SF_VEC_SIZE
=
32
;
// For PACK16 mode (CVT_FP4_ELTS_PER_THREAD=16): 2 threads per SF
// For PACK8 mode (CVT_FP4_ELTS_PER_THREAD=8): 4 threads per SF
static
constexpr
int
MXFP4_NUM_THREADS_PER_SF
=
MXFP4_SF_VEC_SIZE
/
CVT_FP4_ELTS_PER_THREAD
;
// MXFP4 quantization kernel for experts.
// Uses 32-element blocks with E8M0 (UE8M0) scale factors.
// When FUSE_SILU_MUL=true, expects input with gate||up layout and fuses
// SiLU(gate)*up before quantization.
template
<
class
Type
,
bool
FUSE_SILU_MUL
=
false
,
bool
SMALL_NUM_EXPERTS
=
false
>
__global__
void
__launch_bounds__
(
512
,
VLLM_BLOCKS_PER_SM
(
512
))
mxfp4_cvt_fp16_to_fp4
(
int32_t
numRows
,
int32_t
numCols
,
Type
const
*
in
,
fp4_packed_t
*
out
,
uint32_t
*
SFout
,
uint32_t
*
input_offset_by_experts
,
uint32_t
*
output_scale_offset_by_experts
,
int
n_experts
,
bool
low_latency
)
{
using
PackedVec
=
PackedVec
<
Type
,
CVT_FP4_PACK16
>
;
static_assert
(
sizeof
(
PackedVec
)
==
sizeof
(
Type
)
*
CVT_FP4_ELTS_PER_THREAD
,
"Vec size is not matched."
);
// MXFP4: numKTiles = ceil(numCols / 128) since block_size=32, 4 SFs/tile
int32_t
const
numKTiles
=
(
numCols
+
127
)
/
128
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
colsPerRow
=
numCols
/
CVT_FP4_ELTS_PER_THREAD
;
int
inColsPerRow
=
FUSE_SILU_MUL
?
colsPerRow
*
2
:
colsPerRow
;
for
(
int
globalIdx
=
tid
;
globalIdx
<
numRows
*
colsPerRow
;
globalIdx
+=
gridDim
.
x
*
blockDim
.
x
)
{
int
rowIdx
=
globalIdx
/
colsPerRow
;
int
colIdx
=
globalIdx
%
colsPerRow
;
int
rowIdx_in_expert
=
0
;
int
expert_idx
=
0
;
if
constexpr
(
SMALL_NUM_EXPERTS
)
{
for
(
int
i
=
0
;
i
<
n_experts
;
i
++
)
{
uint32_t
current_offset
=
__ldca
(
&
input_offset_by_experts
[
i
]);
uint32_t
next_offset
=
__ldca
(
&
input_offset_by_experts
[
i
+
1
]);
if
(
rowIdx
>=
current_offset
&&
rowIdx
<
next_offset
)
{
rowIdx_in_expert
=
rowIdx
-
current_offset
;
expert_idx
=
i
;
break
;
}
}
}
else
{
uint32_t
local_offsets
[
17
];
for
(
int
chunk_start
=
0
;
chunk_start
<
n_experts
;
chunk_start
+=
16
)
{
*
reinterpret_cast
<
int4
*>
(
local_offsets
)
=
__ldca
(
reinterpret_cast
<
const
int4
*>
(
&
input_offset_by_experts
[
chunk_start
]));
*
reinterpret_cast
<
int4
*>
(
local_offsets
+
4
)
=
__ldca
(
reinterpret_cast
<
const
int4
*>
(
&
input_offset_by_experts
[
chunk_start
+
4
]));
*
reinterpret_cast
<
int4
*>
(
local_offsets
+
8
)
=
__ldca
(
reinterpret_cast
<
const
int4
*>
(
&
input_offset_by_experts
[
chunk_start
+
8
]));
*
reinterpret_cast
<
int4
*>
(
local_offsets
+
12
)
=
__ldca
(
reinterpret_cast
<
const
int4
*>
(
&
input_offset_by_experts
[
chunk_start
+
12
]));
local_offsets
[
16
]
=
__ldca
(
&
input_offset_by_experts
[
chunk_start
+
16
]);
#pragma unroll
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
if
(
rowIdx
>=
local_offsets
[
i
]
&&
rowIdx
<
local_offsets
[
i
+
1
])
{
rowIdx_in_expert
=
rowIdx
-
local_offsets
[
i
];
expert_idx
=
chunk_start
+
i
;
break
;
}
}
}
}
// Load input and optionally apply fused SiLU+Mul
int64_t
inOffset
=
rowIdx
*
inColsPerRow
+
colIdx
;
PackedVec
in_vec
=
reinterpret_cast
<
PackedVec
const
*>
(
in
)[
inOffset
];
PackedVec
quant_input
;
if
constexpr
(
FUSE_SILU_MUL
)
{
PackedVec
in_vec_up
=
reinterpret_cast
<
PackedVec
const
*>
(
in
)[
inOffset
+
colsPerRow
];
quant_input
=
compute_silu_mul
(
in_vec
,
in_vec_up
);
}
else
{
quant_input
=
in_vec
;
}
// In PACK16 mode, each thread outputs 16 E2M1 values = u32x2
int64_t
outOffset
=
rowIdx
*
colsPerRow
+
colIdx
;
auto
&
out_pos
=
out
[
outOffset
];
uint32_t
*
SFout_in_expert
=
SFout
+
output_scale_offset_by_experts
[
expert_idx
]
*
numKTiles
;
// Use MXFP4_NUM_THREADS_PER_SF (2 for PACK16) for 32-element blocks
auto
sf_out
=
cvt_quant_to_fp4_get_sf_out_offset
<
uint32_t
,
MXFP4_NUM_THREADS_PER_SF
>
(
rowIdx_in_expert
,
colIdx
,
numKTiles
,
SFout_in_expert
);
// Block E8M0 scales only; no extra tensor-level scale in this path
constexpr
float
SFScaleVal
=
1.0
f
;
// UE8M0_SF=true for MXFP4 E8M0 scale factors
out_pos
=
cvt_warp_fp16_to_fp4
<
Type
,
MXFP4_NUM_THREADS_PER_SF
,
/*UE8M0_SF=*/
true
>
(
quant_input
,
SFScaleVal
,
sf_out
);
}
}
// Large M_topk variant using shared memory for expert offsets
template
<
class
Type
,
bool
FUSE_SILU_MUL
=
false
,
bool
SMALL_NUM_EXPERTS
=
false
>
__global__
void
__launch_bounds__
(
1024
,
VLLM_BLOCKS_PER_SM
(
1024
))
mxfp4_cvt_fp16_to_fp4
(
int32_t
numRows
,
int32_t
numCols
,
Type
const
*
in
,
fp4_packed_t
*
out
,
uint32_t
*
SFout
,
uint32_t
*
input_offset_by_experts
,
uint32_t
*
output_scale_offset_by_experts
,
int
n_experts
)
{
using
PackedVec
=
PackedVec
<
Type
,
CVT_FP4_PACK16
>
;
static_assert
(
sizeof
(
PackedVec
)
==
sizeof
(
Type
)
*
CVT_FP4_ELTS_PER_THREAD
,
"Vec size is not matched."
);
// MXFP4: numKTiles = ceil(numCols / 128)
int32_t
const
numKTiles
=
(
numCols
+
127
)
/
128
;
extern
__shared__
uint32_t
shared_input_offsets
[];
if
constexpr
(
SMALL_NUM_EXPERTS
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
n_experts
+
1
;
i
+=
blockDim
.
x
)
{
shared_input_offsets
[
i
]
=
input_offset_by_experts
[
i
];
}
}
else
{
for
(
int
i
=
threadIdx
.
x
*
4
;
i
<
n_experts
;
i
+=
blockDim
.
x
*
4
)
{
*
reinterpret_cast
<
int4
*>
(
&
shared_input_offsets
[
i
])
=
*
reinterpret_cast
<
const
int4
*>
(
&
input_offset_by_experts
[
i
]);
}
if
(
threadIdx
.
x
==
0
)
{
shared_input_offsets
[
n_experts
]
=
input_offset_by_experts
[
n_experts
];
}
}
__syncthreads
();
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
colsPerRow
=
numCols
/
CVT_FP4_ELTS_PER_THREAD
;
int
inColsPerRow
=
FUSE_SILU_MUL
?
colsPerRow
*
2
:
colsPerRow
;
for
(
int
globalIdx
=
tid
;
globalIdx
<
numRows
*
colsPerRow
;
globalIdx
+=
gridDim
.
x
*
blockDim
.
x
)
{
int
rowIdx
=
globalIdx
/
colsPerRow
;
int
colIdx
=
globalIdx
%
colsPerRow
;
int
rowIdx_in_expert
=
0
;
int
expert_idx
=
0
;
// Binary search through experts using shared memory
int
left
=
0
,
right
=
n_experts
-
1
;
while
(
left
<=
right
)
{
int
mid
=
(
left
+
right
)
/
2
;
uint32_t
mid_offset
=
shared_input_offsets
[
mid
];
uint32_t
next_offset
=
shared_input_offsets
[
mid
+
1
];
if
(
rowIdx
>=
mid_offset
&&
rowIdx
<
next_offset
)
{
rowIdx_in_expert
=
rowIdx
-
mid_offset
;
expert_idx
=
mid
;
break
;
}
else
if
(
rowIdx
<
mid_offset
)
{
right
=
mid
-
1
;
}
else
{
left
=
mid
+
1
;
}
}
int64_t
inOffset
=
rowIdx
*
inColsPerRow
+
colIdx
;
PackedVec
in_vec
=
reinterpret_cast
<
PackedVec
const
*>
(
in
)[
inOffset
];
PackedVec
quant_input
;
if
constexpr
(
FUSE_SILU_MUL
)
{
PackedVec
in_vec_up
=
reinterpret_cast
<
PackedVec
const
*>
(
in
)[
inOffset
+
colsPerRow
];
quant_input
=
compute_silu_mul
(
in_vec
,
in_vec_up
);
}
else
{
quant_input
=
in_vec
;
}
int64_t
outOffset
=
rowIdx
*
colsPerRow
+
colIdx
;
auto
&
out_pos
=
out
[
outOffset
];
// MXFP4 has no global scale - only block-level E8M0 scale factors
constexpr
float
SFScaleVal
=
1.0
f
;
uint32_t
*
SFout_in_expert
=
SFout
+
output_scale_offset_by_experts
[
expert_idx
]
*
numKTiles
;
auto
sf_out
=
cvt_quant_to_fp4_get_sf_out_offset
<
uint32_t
,
MXFP4_NUM_THREADS_PER_SF
>
(
rowIdx_in_expert
,
colIdx
,
numKTiles
,
SFout_in_expert
);
out_pos
=
cvt_warp_fp16_to_fp4
<
Type
,
MXFP4_NUM_THREADS_PER_SF
,
/*UE8M0_SF=*/
true
>
(
quant_input
,
SFScaleVal
,
sf_out
);
}
}
template
<
typename
T
,
bool
FUSE_SILU_MUL
=
false
>
void
mxfp4_quant_impl
(
void
*
output
,
void
*
output_scale
,
void
*
input
,
void
*
input_offset_by_experts
,
void
*
output_scale_offset_by_experts
,
int
m_topk
,
int
k
,
int
n_experts
,
cudaStream_t
stream
)
{
int
multiProcessorCount
=
get_device_attribute
(
cudaDevAttrMultiProcessorCount
,
-
1
);
int
const
workSizePerRow
=
k
/
ELTS_PER_THREAD
;
int
const
totalWorkSize
=
m_topk
*
workSizePerRow
;
dim3
block
(
std
::
min
(
workSizePerRow
,
512
));
int
const
numBlocksPerSM
=
vllm_runtime_blocks_per_sm
(
static_cast
<
int
>
(
block
.
x
));
dim3
grid
(
std
::
min
(
static_cast
<
int
>
((
totalWorkSize
+
block
.
x
-
1
)
/
block
.
x
),
multiProcessorCount
*
numBlocksPerSM
));
while
(
grid
.
x
<=
multiProcessorCount
&&
block
.
x
>
64
)
{
grid
.
x
*=
2
;
block
.
x
=
(
block
.
x
+
1
)
/
2
;
}
int
const
blockRepeat
=
(
totalWorkSize
+
block
.
x
*
grid
.
x
-
1
)
/
(
block
.
x
*
grid
.
x
);
if
(
blockRepeat
>
1
)
{
size_t
shared_mem_size
=
(
n_experts
+
1
)
*
sizeof
(
uint32_t
);
if
(
n_experts
>=
4
)
{
mxfp4_cvt_fp16_to_fp4
<
T
,
FUSE_SILU_MUL
,
false
>
<<<
grid
,
block
,
shared_mem_size
,
stream
>>>
(
m_topk
,
k
,
reinterpret_cast
<
T
*>
(
input
),
reinterpret_cast
<
fp4_packed_t
*>
(
output
),
reinterpret_cast
<
uint32_t
*>
(
output_scale
),
reinterpret_cast
<
uint32_t
*>
(
input_offset_by_experts
),
reinterpret_cast
<
uint32_t
*>
(
output_scale_offset_by_experts
),
n_experts
);
}
else
{
mxfp4_cvt_fp16_to_fp4
<
T
,
FUSE_SILU_MUL
,
true
>
<<<
grid
,
block
,
shared_mem_size
,
stream
>>>
(
m_topk
,
k
,
reinterpret_cast
<
T
*>
(
input
),
reinterpret_cast
<
fp4_packed_t
*>
(
output
),
reinterpret_cast
<
uint32_t
*>
(
output_scale
),
reinterpret_cast
<
uint32_t
*>
(
input_offset_by_experts
),
reinterpret_cast
<
uint32_t
*>
(
output_scale_offset_by_experts
),
n_experts
);
}
}
else
{
if
(
n_experts
>=
16
)
{
mxfp4_cvt_fp16_to_fp4
<
T
,
FUSE_SILU_MUL
,
false
>
<<<
grid
,
block
,
0
,
stream
>>>
(
m_topk
,
k
,
reinterpret_cast
<
T
*>
(
input
),
reinterpret_cast
<
fp4_packed_t
*>
(
output
),
reinterpret_cast
<
uint32_t
*>
(
output_scale
),
reinterpret_cast
<
uint32_t
*>
(
input_offset_by_experts
),
reinterpret_cast
<
uint32_t
*>
(
output_scale_offset_by_experts
),
n_experts
,
/* bool low_latency */
true
);
}
else
{
mxfp4_cvt_fp16_to_fp4
<
T
,
FUSE_SILU_MUL
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
m_topk
,
k
,
reinterpret_cast
<
T
*>
(
input
),
reinterpret_cast
<
fp4_packed_t
*>
(
output
),
reinterpret_cast
<
uint32_t
*>
(
output_scale
),
reinterpret_cast
<
uint32_t
*>
(
input_offset_by_experts
),
reinterpret_cast
<
uint32_t
*>
(
output_scale_offset_by_experts
),
n_experts
,
/* bool low_latency */
true
);
}
}
}
}
// namespace vllm
/*Quantization entry for mxfp4 experts quantization*/
#define CHECK_TH_CUDA(x, m) \
STD_TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \
STD_TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
#define CHECK_INPUT(x, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m);
constexpr
auto
HALF
=
torch
::
headeronly
::
ScalarType
::
Half
;
constexpr
auto
BF16
=
torch
::
headeronly
::
ScalarType
::
BFloat16
;
constexpr
auto
INT
=
torch
::
headeronly
::
ScalarType
::
Int
;
constexpr
auto
UINT8
=
torch
::
headeronly
::
ScalarType
::
Byte
;
static
constexpr
int
MXFP4_BLOCK_SIZE
=
32
;
static
void
validate_mxfp4_experts_quant_inputs
(
torch
::
stable
::
Tensor
const
&
output
,
torch
::
stable
::
Tensor
const
&
output_scale
,
torch
::
stable
::
Tensor
const
&
input
,
torch
::
stable
::
Tensor
const
&
input_offset_by_experts
,
torch
::
stable
::
Tensor
const
&
output_scale_offset_by_experts
,
int64_t
n_experts
,
int64_t
m_topk
,
int64_t
k
)
{
CHECK_INPUT
(
output
,
"output"
);
CHECK_INPUT
(
output_scale
,
"output_scale"
);
CHECK_INPUT
(
input
,
"input"
);
CHECK_INPUT
(
input_offset_by_experts
,
"input_offset_by_experts"
);
CHECK_INPUT
(
output_scale_offset_by_experts
,
"output_scale_offset_by_experts"
);
STD_TORCH_CHECK
(
output
.
dim
()
==
2
);
STD_TORCH_CHECK
(
output_scale
.
dim
()
==
2
);
STD_TORCH_CHECK
(
input
.
dim
()
==
2
);
STD_TORCH_CHECK
(
input_offset_by_experts
.
dim
()
==
1
);
STD_TORCH_CHECK
(
output_scale_offset_by_experts
.
dim
()
==
1
);
STD_TORCH_CHECK
(
input
.
scalar_type
()
==
HALF
||
input
.
scalar_type
()
==
BF16
);
STD_TORCH_CHECK
(
input_offset_by_experts
.
scalar_type
()
==
INT
);
STD_TORCH_CHECK
(
output_scale_offset_by_experts
.
scalar_type
()
==
INT
);
// output is uint8 (two mxfp4 values packed into one uint8)
// output_scale is int32 (four E8M0 values packed into one int32)
STD_TORCH_CHECK
(
output
.
scalar_type
()
==
UINT8
);
STD_TORCH_CHECK
(
output_scale
.
scalar_type
()
==
INT
);
STD_TORCH_CHECK
(
k
%
MXFP4_BLOCK_SIZE
==
0
,
"k must be a multiple of 32"
);
STD_TORCH_CHECK
(
input_offset_by_experts
.
size
(
0
)
==
n_experts
+
1
);
STD_TORCH_CHECK
(
output_scale_offset_by_experts
.
size
(
0
)
==
n_experts
+
1
);
STD_TORCH_CHECK
(
output
.
size
(
0
)
==
m_topk
);
STD_TORCH_CHECK
(
output
.
size
(
1
)
==
k
/
2
);
int
scales_k
=
k
/
MXFP4_BLOCK_SIZE
;
// K-dimension scale columns padded to a multiple of 4 for swizzle layout
int
padded_k
=
(
scales_k
+
(
4
-
1
))
/
4
*
4
;
// 4 = 4 E8M0 values packed into one int32
STD_TORCH_CHECK
(
output_scale
.
size
(
1
)
*
4
==
padded_k
);
}
void
mxfp4_experts_quant
(
torch
::
stable
::
Tensor
&
output
,
torch
::
stable
::
Tensor
&
output_scale
,
torch
::
stable
::
Tensor
const
&
input
,
torch
::
stable
::
Tensor
const
&
input_offset_by_experts
,
torch
::
stable
::
Tensor
const
&
output_scale_offset_by_experts
,
int64_t
n_experts
)
{
auto
m_topk
=
input
.
size
(
0
);
auto
k
=
input
.
size
(
1
);
validate_mxfp4_experts_quant_inputs
(
output
,
output_scale
,
input
,
input_offset_by_experts
,
output_scale_offset_by_experts
,
n_experts
,
m_topk
,
k
);
const
torch
::
stable
::
accelerator
::
DeviceGuard
device_guard
(
input
.
get_device_index
());
const
cudaStream_t
stream
=
get_current_cuda_stream
(
input
.
get_device_index
());
VLLM_STABLE_DISPATCH_HALF_TYPES
(
input
.
scalar_type
(),
"mxfp4_experts_quant_kernel"
,
[
&
]
{
using
cuda_type
=
vllm
::
CUDATypeConverter
<
scalar_t
>::
Type
;
vllm
::
mxfp4_quant_impl
<
cuda_type
,
/*FUSE_SILU_MUL=*/
false
>
(
output
.
data_ptr
(),
output_scale
.
data_ptr
(),
input
.
data_ptr
(),
input_offset_by_experts
.
data_ptr
(),
output_scale_offset_by_experts
.
data_ptr
(),
m_topk
,
k
,
n_experts
,
stream
);
});
}
void
silu_and_mul_mxfp4_experts_quant
(
torch
::
stable
::
Tensor
&
output
,
torch
::
stable
::
Tensor
&
output_scale
,
torch
::
stable
::
Tensor
const
&
input
,
torch
::
stable
::
Tensor
const
&
input_offset_by_experts
,
torch
::
stable
::
Tensor
const
&
output_scale_offset_by_experts
,
int64_t
n_experts
)
{
auto
m_topk
=
input
.
size
(
0
);
auto
k_times_2
=
input
.
size
(
1
);
STD_TORCH_CHECK
(
k_times_2
%
2
==
0
,
"input width must be even (gate || up)"
);
auto
k
=
k_times_2
/
2
;
validate_mxfp4_experts_quant_inputs
(
output
,
output_scale
,
input
,
input_offset_by_experts
,
output_scale_offset_by_experts
,
n_experts
,
m_topk
,
k
);
const
torch
::
stable
::
accelerator
::
DeviceGuard
device_guard
(
input
.
get_device_index
());
const
cudaStream_t
stream
=
get_current_cuda_stream
(
input
.
get_device_index
());
VLLM_STABLE_DISPATCH_HALF_TYPES
(
input
.
scalar_type
(),
"silu_mul_mxfp4_experts_quant_kernel"
,
[
&
]
{
using
cuda_type
=
vllm
::
CUDATypeConverter
<
scalar_t
>::
Type
;
vllm
::
mxfp4_quant_impl
<
cuda_type
,
/*FUSE_SILU_MUL=*/
true
>
(
output
.
data_ptr
(),
output_scale
.
data_ptr
(),
input
.
data_ptr
(),
input_offset_by_experts
.
data_ptr
(),
output_scale_offset_by_experts
.
data_ptr
(),
m_topk
,
k
,
n_experts
,
stream
);
});
}
csrc/libtorch_stable/torch_bindings.cpp
View file @
a8bffaa1
...
@@ -116,6 +116,12 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
...
@@ -116,6 +116,12 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
" Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
" Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()"
);
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()"
);
// cutlass mxfp4 block scaled group GEMM (MXFP4 x MXFP4 MoE)
ops
.
def
(
"cutlass_mxfp4_group_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor a_blockscale, Tensor b_blockscales,"
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()"
);
// Compute NVFP4 block quantized tensor.
// Compute NVFP4 block quantized tensor.
ops
.
def
(
ops
.
def
(
"scaled_fp4_quant(Tensor input,"
"scaled_fp4_quant(Tensor input,"
...
@@ -149,6 +155,19 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
...
@@ -149,6 +155,19 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()"
);
"Tensor output_scale_offset_by_experts) -> ()"
);
// Compute MXFP4 experts quantization (32-element blocks, E8M0 SFs).
ops
.
def
(
"mxfp4_experts_quant(Tensor! output, Tensor! output_scale,"
"Tensor input, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts, int n_experts) -> ()"
);
// Fused SiLU+Mul+MXFP4 experts quantization.
ops
.
def
(
"silu_and_mul_mxfp4_experts_quant(Tensor! output, Tensor! "
"output_scale,"
"Tensor input, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts, int n_experts) -> ()"
);
// Fused SiLU+Mul+NVFP4 quantization.
// Fused SiLU+Mul+NVFP4 quantization.
ops
.
def
(
ops
.
def
(
"silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, "
"silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, "
...
@@ -233,6 +252,9 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
...
@@ -233,6 +252,9 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
ops
.
impl
(
"silu_and_mul_scaled_fp4_experts_quant"
,
ops
.
impl
(
"silu_and_mul_scaled_fp4_experts_quant"
,
TORCH_BOX
(
&
silu_and_mul_scaled_fp4_experts_quant
));
TORCH_BOX
(
&
silu_and_mul_scaled_fp4_experts_quant
));
ops
.
impl
(
"silu_and_mul_nvfp4_quant"
,
TORCH_BOX
(
&
silu_and_mul_nvfp4_quant
));
ops
.
impl
(
"silu_and_mul_nvfp4_quant"
,
TORCH_BOX
(
&
silu_and_mul_nvfp4_quant
));
ops
.
impl
(
"mxfp4_experts_quant"
,
TORCH_BOX
(
&
mxfp4_experts_quant
));
ops
.
impl
(
"silu_and_mul_mxfp4_experts_quant"
,
TORCH_BOX
(
&
silu_and_mul_mxfp4_experts_quant
));
// W4A8 ops: impl registrations are in the source files
// W4A8 ops: impl registrations are in the source files
// (w4a8_mm_entry.cu and w4a8_grouped_mm_entry.cu)
// (w4a8_mm_entry.cu and w4a8_grouped_mm_entry.cu)
...
...
tests/kernels/moe/test_mxfp4_moe.py
0 → 100644
View file @
a8bffaa1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for SM100 CUTLASS MXFP4 x MXFP4 grouped MoE kernels."""
import
random
import
pytest
import
torch
from
tests.kernels.utils
import
torch_moe_single
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
set_random_seed
random
.
seed
(
42
)
set_random_seed
(
42
)
MXFP4_BLOCK_SIZE
=
32
def
align
(
val
:
int
,
alignment
:
int
=
128
)
->
int
:
return
int
((
val
+
alignment
-
1
)
//
alignment
*
alignment
)
def
calc_diff
(
x
,
y
):
x
,
y
=
x
.
double
(),
y
.
double
()
denominator
=
(
x
*
x
+
y
*
y
).
sum
()
sim
=
2
*
(
x
*
y
).
sum
()
/
denominator
return
1
-
sim
def
is_sm100_supported
()
->
bool
:
return
current_platform
.
is_cuda
()
and
current_platform
.
is_device_capability_family
(
100
)
def
compute_ref_output
(
input_tensor
:
torch
.
Tensor
,
weight_list
:
list
[
torch
.
Tensor
],
expert_offsets
:
list
[
int
],
expert_offset
:
int
,
num_experts
:
int
,
)
->
torch
.
Tensor
:
"""Reference output using torch_moe_single with top-1 routing."""
score
=
torch
.
full
(
(
expert_offset
,
num_experts
),
-
1e9
,
device
=
input_tensor
.
device
,
dtype
=
torch
.
float32
,
)
for
g
in
range
(
num_experts
):
start
=
expert_offsets
[
g
]
end
=
expert_offsets
[
g
+
1
]
if
g
+
1
<
num_experts
else
expert_offset
score
[
start
:
end
,
g
]
=
0.0
return
torch_moe_single
(
input_tensor
,
torch
.
stack
(
weight_list
,
dim
=
0
),
score
,
topk
=
1
)
@
pytest
.
mark
.
skipif
(
not
is_sm100_supported
(),
reason
=
"cutlass_mxfp4_group_mm requires CUDA SM100"
,
)
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
8
,
16
,
32
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
])
def
test_cutlass_mxfp4_grouped_mm
(
num_experts
,
out_dtype
):
"""
Test the MXFP4 grouped GEMM kernel by:
1. Creating random per-expert inputs and weights
2. Quantizing both to MXFP4 using the CUDA kernel
3. Running the CUTLASS grouped GEMM
4. Comparing against BF16 reference
"""
device
=
"cuda"
alignment
=
128
# N and K must be multiples of 128 for clean swizzle layout
n_g
=
random
.
randint
(
1
,
16
)
*
alignment
k_g
=
random
.
randint
(
1
,
16
)
*
alignment
expert_offset
=
0
expert_offsets_input
=
[]
problem_sizes
=
[]
input_list
=
[]
weight_list
=
[]
for
g
in
range
(
num_experts
):
m_g
=
random
.
randint
(
1
,
256
)
expert_offsets_input
.
append
(
expert_offset
)
expert_offset
+=
m_g
problem_sizes
.
append
([
m_g
,
n_g
,
k_g
])
input_list
.
append
(
torch
.
normal
(
0.0
,
std
=
0.5
,
size
=
(
m_g
,
k_g
),
device
=
device
,
dtype
=
out_dtype
)
)
weight_list
.
append
(
torch
.
normal
(
0.0
,
std
=
0.5
,
size
=
(
n_g
,
k_g
),
device
=
device
,
dtype
=
out_dtype
)
)
input_tensor
=
torch
.
concat
(
input_list
,
dim
=
0
)
# [M_total, K]
# --- Quantize INPUTS via mxfp4_experts_quant ---
input_bs_offsets
=
[]
tot
=
0
for
g
in
range
(
num_experts
):
input_bs_offsets
.
append
(
tot
)
tot
+=
align
(
problem_sizes
[
g
][
0
],
128
)
input_bs_offsets
.
append
(
tot
)
_inp_expert_offsets
=
torch
.
tensor
(
expert_offsets_input
+
[
expert_offset
],
device
=
device
,
dtype
=
torch
.
int32
)
_inp_bs_offsets
=
torch
.
tensor
(
input_bs_offsets
,
device
=
device
,
dtype
=
torch
.
int32
)
input_quant
,
input_sf
=
ops
.
mxfp4_experts_quant
(
input_tensor
,
_inp_expert_offsets
,
_inp_bs_offsets
,
num_experts
,
topk
=
1
,
)
# --- Quantize WEIGHTS via mxfp4_experts_quant ---
# Treat each expert's N weight rows as an "expert" with N tokens
weight_tensor
=
torch
.
concat
(
weight_list
,
dim
=
0
)
# [E*N, K]
weight_expert_offsets
=
[
g
*
n_g
for
g
in
range
(
num_experts
)]
+
[
num_experts
*
n_g
]
# N is always multiple of 128, so blockscale offsets are clean
weight_bs_offsets
=
[
g
*
n_g
for
g
in
range
(
num_experts
)]
+
[
num_experts
*
n_g
]
_wt_expert_offsets
=
torch
.
tensor
(
weight_expert_offsets
,
device
=
device
,
dtype
=
torch
.
int32
)
_wt_bs_offsets
=
torch
.
tensor
(
weight_bs_offsets
,
device
=
device
,
dtype
=
torch
.
int32
)
weight_quant
,
weight_sf
=
ops
.
mxfp4_experts_quant
(
weight_tensor
,
_wt_expert_offsets
,
_wt_bs_offsets
,
num_experts
,
topk
=
1
,
)
# Reshape weight quantized data to [E, N, K//2]
weight_quant
=
weight_quant
[:
num_experts
*
n_g
].
view
(
num_experts
,
n_g
,
k_g
//
2
)
# Reshape weight scale factors to [E, N, K//32]
# The quant kernel produces uint8 SF buffer. Each row has K//32 SFs.
scales_per_row
=
k_g
//
MXFP4_BLOCK_SIZE
weight_sf_flat
=
weight_sf
.
view
(
-
1
)[:
num_experts
*
n_g
*
scales_per_row
]
weight_sf_3d
=
weight_sf_flat
.
view
(
num_experts
,
n_g
,
scales_per_row
)
# Output
output
=
torch
.
empty
((
expert_offset
,
n_g
),
device
=
device
,
dtype
=
out_dtype
)
_problem_sizes
=
torch
.
tensor
(
problem_sizes
,
device
=
device
,
dtype
=
torch
.
int32
)
_expert_offsets
=
torch
.
tensor
(
expert_offsets_input
,
device
=
device
,
dtype
=
torch
.
int32
)
_input_bs
=
torch
.
tensor
(
input_bs_offsets
[:
-
1
],
device
=
device
,
dtype
=
torch
.
int32
)
# Run the MXFP4 grouped GEMM
ops
.
cutlass_mxfp4_moe_mm
(
output
,
input_quant
,
weight_quant
,
input_sf
,
weight_sf_3d
,
_problem_sizes
,
_expert_offsets
,
_input_bs
,
)
# Reference: BF16 matmul
ref_output
=
compute_ref_output
(
input_tensor
=
input_tensor
,
weight_list
=
weight_list
,
expert_offsets
=
expert_offsets_input
,
expert_offset
=
expert_offset
,
num_experts
=
num_experts
,
)
# Compare per-expert
for
g
in
range
(
num_experts
):
start
=
expert_offsets_input
[
g
]
end
=
expert_offsets_input
[
g
+
1
]
if
g
+
1
<
num_experts
else
expert_offset
if
start
==
end
:
continue
baseline
=
ref_output
[
start
:
end
]
actual
=
output
[
start
:
end
]
diff
=
calc_diff
(
actual
,
baseline
)
print
(
f
"m_g=
{
end
-
start
}
n_g=
{
n_g
}
k_g=
{
k_g
}
"
f
"num_experts=
{
num_experts
}
, "
f
"out_dtype=
{
out_dtype
}
, diff=
{
diff
:.
5
f
}
"
)
# FP4 quantization is very lossy (~4 bits precision)
# Comparing quantized vs full-precision gives cosine diff of 0.05-0.15
assert
diff
<
0.15
,
f
"Expert
{
g
}
: diff=
{
diff
:.
5
f
}
exceeds threshold"
@
pytest
.
mark
.
skipif
(
not
is_sm100_supported
(),
reason
=
"mxfp4_experts_quant requires CUDA SM100"
,
)
def
test_mxfp4_experts_quant_basic
():
"""
Basic smoke test for the MXFP4 experts quantization kernel.
"""
device
=
"cuda"
num_experts
=
4
k
=
256
tokens_per_expert
=
16
total_tokens
=
tokens_per_expert
*
num_experts
input_tensor
=
torch
.
randn
(
total_tokens
,
k
,
device
=
device
,
dtype
=
torch
.
bfloat16
)
/
5
expert_offsets
=
[
i
*
tokens_per_expert
for
i
in
range
(
num_experts
+
1
)]
blockscale_offsets
=
[
align
(
i
*
tokens_per_expert
,
128
)
for
i
in
range
(
num_experts
+
1
)
]
_expert_offsets
=
torch
.
tensor
(
expert_offsets
,
device
=
device
,
dtype
=
torch
.
int32
)
_blockscale_offsets
=
torch
.
tensor
(
blockscale_offsets
,
device
=
device
,
dtype
=
torch
.
int32
)
output
,
output_sf
=
ops
.
mxfp4_experts_quant
(
input_tensor
,
_expert_offsets
,
_blockscale_offsets
,
num_experts
,
topk
=
1
,
)
assert
output
.
shape
==
(
total_tokens
,
k
//
2
)
assert
output
.
dtype
==
torch
.
uint8
assert
output_sf
.
dtype
==
torch
.
uint8
assert
output
.
any
(),
"Quantized output is all zeros"
print
(
f
"MXFP4 experts quant: output shape=
{
output
.
shape
}
, sf shape=
{
output_sf
.
shape
}
"
)
print
(
"PASSED"
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
,
"-v"
,
"-s"
])
vllm/_custom_ops.py
View file @
a8bffaa1
...
@@ -1150,6 +1150,38 @@ def cutlass_fp4_moe_mm(
...
@@ -1150,6 +1150,38 @@ def cutlass_fp4_moe_mm(
)
)
def
cutlass_mxfp4_moe_mm
(
out_tensors
:
torch
.
Tensor
,
a_tensors
:
torch
.
Tensor
,
b_tensors
:
torch
.
Tensor
,
a_scales
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
problem_sizes
:
torch
.
Tensor
,
expert_offsets
:
torch
.
Tensor
,
sf_offsets
:
torch
.
Tensor
,
):
"""
An MXFP4 Blockscaled Group Gemm for MoE (MXFP4 x MXFP4).
Uses mx_float4_t types with E8M0 scale factors and 32-element blocks.
- a/b_tensors: MXFP4 packed activations/weights (uint8, 2 E2M1 per byte)
- a_/b_scales: E8M0 blockscales (uint8, stored in swizzled layout)
- Epilogue uses scalar alpha=1, beta=0 inside the CUDA op (no global scales).
- expert_offsets/sf_offsets: expert boundary indices
- problem_sizes: (num_experts, 3) with (M, N, K) per expert
"""
return
torch
.
ops
.
_C
.
cutlass_mxfp4_group_mm
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
problem_sizes
,
expert_offsets
,
sf_offsets
,
)
def
mxfp8_experts_quant
(
def
mxfp8_experts_quant
(
input_tensor
:
torch
.
Tensor
,
input_tensor
:
torch
.
Tensor
,
problem_sizes
:
torch
.
Tensor
,
problem_sizes
:
torch
.
Tensor
,
...
@@ -1848,6 +1880,109 @@ def silu_and_mul_scaled_fp4_experts_quant(
...
@@ -1848,6 +1880,109 @@ def silu_and_mul_scaled_fp4_experts_quant(
return
output
,
output_scales
return
output
,
output_scales
def
mxfp4_experts_quant
(
input_tensor
:
torch
.
Tensor
,
expert_offsets
:
torch
.
Tensor
,
blockscale_offsets
:
torch
.
Tensor
,
n_experts
:
int
,
topk
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Quantize input tensor to MXFP4 for packed MoE inputs.
Uses 32-element blocks with E8M0 (power-of-two) scale factors.
MXFP4 has no global scale - only block-level E8M0 scale factors.
Args:
input_tensor: [m_topk, k] BF16/FP16 activations
expert_offsets: [n_experts+1] token boundaries per expert
blockscale_offsets: [n_experts+1] SF row boundaries per expert
n_experts: number of experts
topk: number of top-k experts
Returns:
output: [m_topk, k//2] packed E2M1 values (uint8)
output_scales: E8M0 blockscales in swizzled layout (uint8 view)
"""
assert
not
current_platform
.
is_rocm
()
assert
input_tensor
.
ndim
==
2
MAX_TOKENS_PER_EXPERT
=
envs
.
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE
m_numtopk
,
k
=
input_tensor
.
shape
assert
m_numtopk
<=
MAX_TOKENS_PER_EXPERT
*
topk
,
(
f
"m_numtopk must be less than MAX_TOKENS_PER_EXPERT("
f
"
{
MAX_TOKENS_PER_EXPERT
}
)"
f
" for cutlass_moe_mxfp4, observed m_numtopk =
{
m_numtopk
}
. Use"
f
" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE to set this value."
)
scales_k
=
k
//
32
padded_k
=
(
scales_k
+
(
4
-
1
))
//
4
output
=
torch
.
empty
(
m_numtopk
,
k
//
2
,
device
=
input_tensor
.
device
,
dtype
=
torch
.
uint8
)
output_scales
=
torch
.
empty
(
MAX_TOKENS_PER_EXPERT
*
topk
,
padded_k
,
dtype
=
torch
.
int32
,
device
=
input_tensor
.
device
,
)
torch
.
ops
.
_C
.
mxfp4_experts_quant
(
output
,
output_scales
,
input_tensor
,
expert_offsets
,
blockscale_offsets
,
n_experts
,
)
# E8M0 SFs are stored as uint8
output_scales
=
output_scales
.
view
(
torch
.
uint8
)
return
output
,
output_scales
def
silu_and_mul_mxfp4_experts_quant
(
input_tensor
:
torch
.
Tensor
,
expert_offsets
:
torch
.
Tensor
,
blockscale_offsets
:
torch
.
Tensor
,
n_experts
:
int
,
topk
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Fused SiLU+Mul+MXFP4 quantization for MoE intermediate activations.
MXFP4 has no global scale - only block-level E8M0 scale factors.
"""
assert
not
current_platform
.
is_rocm
()
assert
input_tensor
.
ndim
==
2
MAX_TOKENS_PER_EXPERT
=
envs
.
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE
m_numtopk
,
k_times_2
=
input_tensor
.
shape
assert
k_times_2
%
2
==
0
,
"input width must be even (gate || up layout)"
k
=
k_times_2
//
2
assert
m_numtopk
<=
MAX_TOKENS_PER_EXPERT
*
topk
scales_k
=
k
//
32
padded_k
=
(
scales_k
+
(
4
-
1
))
//
4
output
=
torch
.
empty
(
m_numtopk
,
k
//
2
,
device
=
input_tensor
.
device
,
dtype
=
torch
.
uint8
)
output_scales
=
torch
.
empty
(
MAX_TOKENS_PER_EXPERT
*
topk
,
padded_k
,
dtype
=
torch
.
int32
,
device
=
input_tensor
.
device
,
)
torch
.
ops
.
_C
.
silu_and_mul_mxfp4_experts_quant
(
output
,
output_scales
,
input_tensor
,
expert_offsets
,
blockscale_offsets
,
n_experts
,
)
output_scales
=
output_scales
.
view
(
torch
.
uint8
)
return
output
,
output_scales
# fp8
# fp8
def
scaled_fp8_quant
(
def
scaled_fp8_quant
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/fused_moe/config.py
View file @
a8bffaa1
...
@@ -762,6 +762,25 @@ def nvfp4_moe_quant_config(
...
@@ -762,6 +762,25 @@ def nvfp4_moe_quant_config(
)
)
def
mxfp4_moe_quant_config
(
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
)
->
FusedMoEQuantConfig
:
"""
Construct a quant config for MXFP4 x MXFP4 MoE.
MXFP4 uses block scaling only (E8M0 scales, 32-element groups), with no
separate alphas / global activation scales in this config.
"""
return
FusedMoEQuantConfig
.
make
(
"mxfp4"
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
per_act_token_quant
=
False
,
per_out_ch_quant
=
False
,
block_shape
=
None
,
)
def
nvfp4_w4a16_moe_quant_config
(
def
nvfp4_w4a16_moe_quant_config
(
g1_alphas
:
torch
.
Tensor
,
g1_alphas
:
torch
.
Tensor
,
g2_alphas
:
torch
.
Tensor
,
g2_alphas
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/fused_moe/cutlass_moe.py
View file @
a8bffaa1
...
@@ -36,6 +36,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
...
@@ -36,6 +36,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8DynamicTokenSym
,
kFp8DynamicTokenSym
,
kFp8StaticChannelSym
,
kFp8StaticChannelSym
,
kFp8StaticTensorSym
,
kFp8StaticTensorSym
,
kMxfp4Dynamic
,
kMxfp4Static
,
kNvfp4Dynamic
,
kNvfp4Dynamic
,
kNvfp4Static
,
kNvfp4Static
,
)
)
...
@@ -795,6 +797,299 @@ class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
...
@@ -795,6 +797,299 @@ class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
)
)
def
run_cutlass_moe_mxfp4
(
output
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
w1_fp4
:
torch
.
Tensor
,
w1_blockscale
:
torch
.
Tensor
,
w2_fp4
:
torch
.
Tensor
,
w2_blockscale
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
MoEActivation
,
workspace13
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
device
:
torch
.
device
,
apply_router_weight_on_input
:
bool
=
False
,
)
->
None
:
"""MXFP4 x MXFP4 MoE implementation using CUTLASS grouped GEMM."""
is_gated
=
activation
.
is_gated
w1_n
=
n
*
2
if
is_gated
else
n
assert
topk_weights
.
shape
==
topk_ids
.
shape
,
"topk shape mismatch"
assert
w1_fp4
.
dtype
==
torch
.
uint8
,
"weight 1 must be uint8"
assert
w2_fp4
.
dtype
==
torch
.
uint8
,
"weight 2 must be uint8"
assert
(
w1_fp4
.
ndim
==
3
and
w2_fp4
.
ndim
==
3
and
w1_blockscale
.
ndim
==
3
and
w2_blockscale
.
ndim
==
3
),
"All Weights must be of rank 3 for cutlass_moe_mxfp4"
m_a
,
k_a
=
a
.
shape
e_w1
,
w1_n_actual
,
half_k_w1
=
w1_fp4
.
shape
e_w2
,
k_w2
,
half_n_w2
=
w2_fp4
.
shape
assert
e_w1
==
e_w2
and
e_w1
==
e
assert
k_a
==
half_k_w1
*
2
and
k
==
k_w2
assert
w1_n_actual
==
w1_n
and
half_n_w2
*
2
==
n
assert
m
==
m_a
assert
2
*
half_k_w1
==
k_w2
assert
a
.
dtype
in
[
torch
.
half
,
torch
.
bfloat16
],
"Invalid input dtype"
assert
topk_weights
.
size
(
0
)
==
m
and
topk_ids
.
size
(
0
)
==
m
topk
=
topk_ids
.
size
(
1
)
out_dtype
=
a
.
dtype
num_topk
=
topk_ids
.
size
(
1
)
expert_offsets
=
torch
.
empty
((
e
+
1
),
dtype
=
torch
.
int32
,
device
=
device
)
blockscale_offsets
=
torch
.
empty
((
e
+
1
),
dtype
=
torch
.
int32
,
device
=
device
)
problem_sizes1
=
torch
.
empty
((
e
,
3
),
dtype
=
torch
.
int32
,
device
=
device
)
problem_sizes2
=
torch
.
empty
((
e
,
3
),
dtype
=
torch
.
int32
,
device
=
device
)
a_map
=
torch
.
empty
((
topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
c_map
=
torch
.
empty
((
topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
if
apply_router_weight_on_input
:
assert
num_topk
==
1
,
(
"apply_router_weight_on_input is only implemented for topk=1"
)
a
.
mul_
(
topk_weights
.
to
(
out_dtype
))
ops
.
get_cutlass_moe_mm_data
(
topk_ids
,
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
a_map
,
c_map
,
e
,
n
,
k
,
blockscale_offsets
,
is_gated
=
is_gated
,
)
a
=
ops
.
shuffle_rows
(
a
,
a_map
)
rep_a_fp4
,
rep_a_blockscale
=
ops
.
mxfp4_experts_quant
(
a
,
expert_offsets
,
blockscale_offsets
,
e
,
num_topk
,
)
c1
=
_resize_cache
(
workspace13
,
(
m
*
topk
,
w1_n
))
c2
=
_resize_cache
(
workspace2
,
(
m
*
topk
,
n
))
c3
=
_resize_cache
(
workspace13
,
(
m
*
topk
,
k
))
ops
.
cutlass_mxfp4_moe_mm
(
c1
,
rep_a_fp4
,
w1_fp4
,
rep_a_blockscale
,
w1_blockscale
,
problem_sizes1
,
expert_offsets
[:
-
1
],
blockscale_offsets
[:
-
1
],
)
del
rep_a_fp4
,
rep_a_blockscale
if
activation
==
MoEActivation
.
SILU
:
int_fp4
,
int_blockscale
=
ops
.
silu_and_mul_mxfp4_experts_quant
(
c1
,
expert_offsets
,
blockscale_offsets
,
e
,
num_topk
)
else
:
apply_moe_activation
(
activation
,
c2
,
c1
)
int_fp4
,
int_blockscale
=
ops
.
mxfp4_experts_quant
(
c2
,
expert_offsets
,
blockscale_offsets
,
e
,
num_topk
)
ops
.
cutlass_mxfp4_moe_mm
(
c3
,
int_fp4
,
w2_fp4
,
int_blockscale
,
w2_blockscale
,
problem_sizes2
,
expert_offsets
[:
-
1
],
blockscale_offsets
[:
-
1
],
)
del
int_fp4
,
int_blockscale
c3
=
ops
.
shuffle_rows
(
c3
,
c_map
)
assert
output
.
dtype
==
out_dtype
if
not
apply_router_weight_on_input
:
output
.
copy_
(
(
c3
.
view
(
m
,
num_topk
,
k
)
*
topk_weights
.
view
(
m
,
num_topk
,
1
).
to
(
out_dtype
)
).
sum
(
dim
=
1
),
non_blocking
=
True
,
)
else
:
output
.
copy_
(
c3
.
view
(
m
,
num_topk
,
k
).
sum
(
dim
=
1
),
non_blocking
=
True
)
return
def
swizzle_mxfp4_scales
(
scales
:
torch
.
Tensor
,
N
:
int
,
K
:
int
,
)
->
torch
.
Tensor
:
"""Swizzle flat [N, K//32] E8M0 scales to CUTLASS tiled layout.
CUTLASS expects MX scale factors in a tiled layout:
[numMTiles, numKTiles, 32, 4, 4]
where numMTiles = ceil(N/128), numKTiles = ceil(K/128),
and the inner dimensions correspond to the swizzle pattern:
mTileIdx = mIdx / 128
outerMIdx = mIdx % 32
innerMIdx = (mIdx / 32) % 4
kTileIdx = kIdx / 4
innerKIdx = kIdx % 4
with kIdx = col_in_scale_space (i.e., index into K//32).
"""
assert
scales
.
dtype
==
torch
.
uint8
num_scale_cols
=
K
//
32
# number of E8M0 scale values per row
num_m_tiles
=
(
N
+
127
)
//
128
num_k_tiles
=
(
num_scale_cols
+
3
)
//
4
# Pad N to multiple of 128 and scale_cols to multiple of 4
padded_N
=
num_m_tiles
*
128
padded_scale_cols
=
num_k_tiles
*
4
# Start with flat scales, pad if needed
padded
=
torch
.
zeros
(
padded_N
,
padded_scale_cols
,
dtype
=
torch
.
uint8
,
device
=
scales
.
device
)
padded
[:
N
,
:
num_scale_cols
]
=
scales
# Reshape to tile structure:
# [numMTiles, 4, 32, numKTiles, 4]
# mTileIdx, innerMIdx, outerMIdx, kTileIdx, innerKIdx
tiled
=
padded
.
reshape
(
num_m_tiles
,
4
,
32
,
num_k_tiles
,
4
)
# Permute to [numMTiles, numKTiles, 32, 4, 4]
# (outerMIdx, innerMIdx, innerKIdx)
tiled
=
tiled
.
permute
(
0
,
3
,
2
,
1
,
4
).
contiguous
()
return
tiled
.
reshape
(
-
1
)
class
CutlassExpertsMxfp4
(
mk
.
FusedMoEExpertsModular
):
"""CUTLASS MXFP4 x MXFP4 fused MoE expert implementation."""
@
property
def
expects_unquantized_inputs
(
self
)
->
bool
:
return
True
@
staticmethod
def
_supports_current_device
()
->
bool
:
p
=
current_platform
return
p
.
is_cuda
()
and
p
.
is_device_capability_family
(
100
)
@
staticmethod
def
_supports_no_act_and_mul
()
->
bool
:
return
True
@
staticmethod
def
_supports_quant_scheme
(
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
return
(
weight_key
,
activation_key
)
==
(
kMxfp4Static
,
kMxfp4Dynamic
)
@
staticmethod
def
_supports_activation
(
activation
:
MoEActivation
)
->
bool
:
return
activation
in
[
MoEActivation
.
SILU
,
MoEActivation
.
GELU
,
MoEActivation
.
SWIGLUOAI
,
MoEActivation
.
SWIGLUSTEP
,
MoEActivation
.
SILU_NO_MUL
,
MoEActivation
.
GELU_NO_MUL
,
MoEActivation
.
RELU2_NO_MUL
,
]
@
staticmethod
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
,
)
->
bool
:
return
moe_parallel_config
.
ep_size
==
1
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
Standard
def
supports_expert_map
(
self
)
->
bool
:
return
False
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
return
TopKWeightAndReduceNoOP
()
def
workspace_dtype
(
self
,
act_dtype
:
torch
.
dtype
)
->
torch
.
dtype
:
return
act_dtype
def
workspace_shapes
(
self
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
activation
:
MoEActivation
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
workspace1
=
(
M
*
topk
,
max
(
2
*
N
,
K
))
workspace2
=
(
M
*
topk
,
N
)
output
=
(
M
,
K
)
return
(
workspace1
,
workspace2
,
output
)
def
apply
(
self
,
output
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
MoEActivation
,
global_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
a1q_scale
:
torch
.
Tensor
|
None
,
a2_scale
:
torch
.
Tensor
|
None
,
workspace13
:
torch
.
Tensor
|
None
,
workspace2
:
torch
.
Tensor
|
None
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
apply_router_weight_on_input
:
bool
,
):
e
,
m
,
n
,
k
,
_
=
self
.
moe_problem_size
(
hidden_states
,
w1
,
w2
,
topk_ids
)
n
=
w2
.
shape
[
2
]
*
2
run_cutlass_moe_mxfp4
(
output
=
output
,
a
=
hidden_states
,
w1_fp4
=
w1
,
w1_blockscale
=
self
.
w1_scale
,
w2_fp4
=
w2
,
w2_blockscale
=
self
.
w2_scale
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
activation
=
activation
,
workspace13
=
workspace13
,
workspace2
=
workspace2
,
m
=
m
,
n
=
n
,
k
=
k
,
e
=
e
,
device
=
hidden_states
.
device
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
# W4A8
# W4A8
def
run_cutlass_moe_w4a8_fp8
(
def
run_cutlass_moe_w4a8_fp8
(
output
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe/compressed_tensors_moe_w4a4_mxfp4.py
View file @
a8bffaa1
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
import
torch
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoE
,
...
@@ -11,6 +12,10 @@ from vllm.model_executor.layers.fused_moe import (
...
@@ -11,6 +12,10 @@ from vllm.model_executor.layers.fused_moe import (
)
)
from
vllm.model_executor.layers.fused_moe.config
import
(
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEQuantConfig
,
FusedMoEQuantConfig
,
mxfp4_moe_quant_config
,
)
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
(
CutlassExpertsMxfp4
,
)
)
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
MarlinExperts
,
MarlinExperts
,
...
@@ -36,7 +41,14 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
...
@@ -36,7 +41,14 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
super
().
__init__
(
moe
)
super
().
__init__
(
moe
)
self
.
group_size
=
32
self
.
group_size
=
32
self
.
mxfp4_backend
=
Mxfp4MoeBackend
.
MARLIN
self
.
mxfp4_backend
=
Mxfp4MoeBackend
.
MARLIN
self
.
experts_cls
=
MarlinExperts
self
.
use_cutlass_mxfp4
=
CutlassExpertsMxfp4
.
_supports_current_device
()
self
.
experts_cls
:
type
[
mk
.
FusedMoEExperts
]
if
self
.
use_cutlass_mxfp4
:
logger
.
info_once
(
"Using CutlassExpertsMxfp4 for MXFP4 MoE"
,
scope
=
"local"
)
self
.
experts_cls
=
CutlassExpertsMxfp4
else
:
logger
.
info_once
(
"Using MarlinExperts for MXFP4 MoE"
,
scope
=
"local"
)
self
.
experts_cls
=
MarlinExperts
def
create_weights
(
def
create_weights
(
self
,
self
,
...
@@ -109,11 +121,19 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
...
@@ -109,11 +121,19 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
def
get_fused_moe_quant_config
(
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
self
,
layer
:
torch
.
nn
.
Module
)
->
FusedMoEQuantConfig
|
None
:
)
->
FusedMoEQuantConfig
|
None
:
return
make_mxfp4_moe_quant_config
(
if
self
.
use_cutlass_mxfp4
:
mxfp4_backend
=
self
.
mxfp4_backend
,
# W4A4: both weights and activations quantized to MXFP4
w1_scale
=
layer
.
w13_weight_scale
,
return
mxfp4_moe_quant_config
(
w2_scale
=
layer
.
w2_weight_scale
,
w1_scale
=
layer
.
w13_weight_scale
,
)
w2_scale
=
layer
.
w2_weight_scale
,
)
else
:
# W4A16: weight-only via Marlin
return
make_mxfp4_moe_quant_config
(
mxfp4_backend
=
self
.
mxfp4_backend
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
)
def
process_weights_after_loading
(
self
,
layer
:
FusedMoE
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
FusedMoE
)
->
None
:
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
...
@@ -126,13 +146,45 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
...
@@ -126,13 +146,45 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
)
)
delattr
(
layer
,
"w2_weight_packed"
)
delattr
(
layer
,
"w2_weight_packed"
)
logger
.
warning_once
(
if
self
.
use_cutlass_mxfp4
:
"Your GPU does not have native support for FP4 computation but "
# Swizzle weight scales from flat checkpoint layout [E, N, K//32]
"FP4 quantization is being used. Weight-only FP4 compression "
# to CUTLASS tiled layout [E, numMTiles*numKTiles*512].
"will be used leveraging the Marlin kernel. This may degrade "
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
(
"performance for compute-heavy workloads."
swizzle_mxfp4_scales
,
)
)
prepare_moe_fp4_layer_for_marlin
(
layer
)
E
=
layer
.
w13_weight_scale
.
shape
[
0
]
w13_N
=
layer
.
w13_weight_scale
.
shape
[
1
]
w13_scale_K
=
layer
.
w13_weight_scale
.
shape
[
2
]
w13_K
=
w13_scale_K
*
32
w2_M
=
layer
.
w2_weight_scale
.
shape
[
1
]
w2_scale_N
=
layer
.
w2_weight_scale
.
shape
[
2
]
w2_N
=
w2_scale_N
*
32
swizzled_w13
=
[]
swizzled_w2
=
[]
for
e_idx
in
range
(
E
):
s13
=
layer
.
w13_weight_scale
[
e_idx
]
sw13
=
swizzle_mxfp4_scales
(
s13
,
w13_N
,
w13_K
)
swizzled_w13
.
append
(
sw13
.
reshape
(
w13_N
,
w13_scale_K
))
s2
=
layer
.
w2_weight_scale
[
e_idx
]
sw2
=
swizzle_mxfp4_scales
(
s2
,
w2_M
,
w2_N
)
swizzled_w2
.
append
(
sw2
.
reshape
(
w2_M
,
w2_scale_N
))
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
stack
(
swizzled_w13
),
requires_grad
=
False
)
layer
.
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
stack
(
swizzled_w2
),
requires_grad
=
False
)
else
:
logger
.
warning_once
(
"Your GPU does not have native support for FP4 computation "
"but FP4 quantization is being used. Weight-only FP4 "
"compression will be used leveraging the Marlin kernel. "
"This may degrade performance for compute-heavy workloads."
)
prepare_moe_fp4_layer_for_marlin
(
layer
)
self
.
moe_quant_config
=
self
.
get_fused_moe_quant_config
(
layer
)
self
.
moe_quant_config
=
self
.
get_fused_moe_quant_config
(
layer
)
if
self
.
moe_quant_config
is
not
None
:
if
self
.
moe_quant_config
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment