Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
db5d0719
Unverified
Commit
db5d0719
authored
Apr 01, 2026
by
Michael Goin
Committed by
GitHub
Apr 01, 2026
Browse files
[Kernel] Add MXFP8 to Marlin GEMM/MoE and refactor Mxfp8LinearOp (#34664)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
dc0428eb
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
481 additions
and
129 deletions
+481
-129
csrc/moe/marlin_moe_wna16/generate_kernels.py
csrc/moe/marlin_moe_wna16/generate_kernels.py
+9
-0
csrc/moe/marlin_moe_wna16/marlin_template.h
csrc/moe/marlin_moe_wna16/marlin_template.h
+10
-8
csrc/moe/marlin_moe_wna16/ops.cu
csrc/moe/marlin_moe_wna16/ops.cu
+3
-0
csrc/quantization/marlin/generate_kernels.py
csrc/quantization/marlin/generate_kernels.py
+9
-0
csrc/quantization/marlin/marlin.cu
csrc/quantization/marlin/marlin.cu
+3
-0
csrc/quantization/marlin/marlin_template.h
csrc/quantization/marlin/marlin_template.h
+10
-7
tests/kernels/moe/test_moe.py
tests/kernels/moe/test_moe.py
+6
-0
tests/models/quantization/test_mxfp8.py
tests/models/quantization/test_mxfp8.py
+1
-1
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+2
-1
vllm/model_executor/layers/fused_moe/oracle/fp8.py
vllm/model_executor/layers/fused_moe/oracle/fp8.py
+21
-7
vllm/model_executor/layers/fused_moe/oracle/mxfp8.py
vllm/model_executor/layers/fused_moe/oracle/mxfp8.py
+9
-5
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+7
-55
vllm/model_executor/layers/quantization/mxfp8.py
vllm/model_executor/layers/quantization/mxfp8.py
+23
-33
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
...el_executor/layers/quantization/utils/marlin_utils_fp8.py
+196
-0
vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
+172
-12
No files found.
csrc/moe/marlin_moe_wna16/generate_kernels.py
View file @
db5d0719
...
@@ -108,6 +108,15 @@ QUANT_CONFIGS = [
...
@@ -108,6 +108,15 @@ QUANT_CONFIGS = [
"thread_m_blocks"
:
THREAD_M_BLOCKS
,
"thread_m_blocks"
:
THREAD_M_BLOCKS
,
"group_blocks"
:
[
2
],
"group_blocks"
:
[
2
],
},
},
# MXFP8
{
"a_type"
:
[
"kBFloat16"
],
"b_type"
:
"kFE4M3fn"
,
"s_type"
:
"kFE8M0fnu"
,
"thread_configs"
:
THREAD_CONFIGS
,
"thread_m_blocks"
:
THREAD_M_BLOCKS
,
"group_blocks"
:
[
2
],
},
# AWQ-INT4 with INT8 activation
# AWQ-INT4 with INT8 activation
{
{
"a_type"
:
[
"kS8"
],
"a_type"
:
[
"kS8"
],
...
...
csrc/moe/marlin_moe_wna16/marlin_template.h
View file @
db5d0719
...
@@ -343,6 +343,8 @@ __global__ void Marlin(
...
@@ -343,6 +343,8 @@ __global__ void Marlin(
if
constexpr
(
b_type
==
vllm
::
kFE2M1f
)
{
if
constexpr
(
b_type
==
vllm
::
kFE2M1f
)
{
static_assert
(
s_type
==
vllm
::
kFE4M3fn
&&
group_blocks
==
1
||
static_assert
(
s_type
==
vllm
::
kFE4M3fn
&&
group_blocks
==
1
||
s_type
==
vllm
::
kFE8M0fnu
&&
group_blocks
==
2
);
s_type
==
vllm
::
kFE8M0fnu
&&
group_blocks
==
2
);
}
else
if
constexpr
(
b_type
==
vllm
::
kFE4M3fn
&&
s_type
==
vllm
::
kFE8M0fnu
)
{
static_assert
(
group_blocks
==
2
);
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
nv_bfloat16
>::
value
)
{
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
nv_bfloat16
>::
value
)
{
static_assert
(
s_type
==
vllm
::
kBFloat16
);
static_assert
(
s_type
==
vllm
::
kBFloat16
);
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
half
>::
value
)
{
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
half
>::
value
)
{
...
@@ -357,9 +359,10 @@ __global__ void Marlin(
...
@@ -357,9 +359,10 @@ __global__ void Marlin(
constexpr
bool
is_int_type
=
b_type
==
vllm
::
kU4
||
b_type
==
vllm
::
kU8
||
constexpr
bool
is_int_type
=
b_type
==
vllm
::
kU4
||
b_type
==
vllm
::
kU8
||
b_type
==
vllm
::
kS4
||
b_type
==
vllm
::
kS8
||
b_type
==
vllm
::
kS4
||
b_type
==
vllm
::
kS8
||
b_type
==
vllm
::
kU4B8
||
b_type
==
vllm
::
kU8B128
;
b_type
==
vllm
::
kU4B8
||
b_type
==
vllm
::
kU8B128
;
constexpr
bool
is_8bit_scale
=
s_type
.
size_bits
()
==
8
;
// see comments of dequant.h for more details
// see comments of dequant.h for more details
constexpr
bool
dequant_skip_flop
=
constexpr
bool
dequant_skip_flop
=
is_a_8bit
||
b_type
==
vllm
::
kFE4M3fn
||
is_a_8bit
||
(
b_type
==
vllm
::
kFE4M3fn
&&
!
(
s_type
==
vllm
::
kFE8M0fnu
))
||
b_type
==
vllm
::
kFE2M1f
&&
s_type
==
vllm
::
kFE4M3fn
||
b_type
==
vllm
::
kFE2M1f
&&
s_type
==
vllm
::
kFE4M3fn
||
has_zp
&&
!
is_zp_float
&&
!
std
::
is_same
<
scalar_t
,
nv_bfloat16
>::
value
||
has_zp
&&
!
is_zp_float
&&
!
std
::
is_same
<
scalar_t
,
nv_bfloat16
>::
value
||
has_zp
&&
!
is_zp_float
&&
!
(
b_type
==
vllm
::
kU8
);
has_zp
&&
!
is_zp_float
&&
!
(
b_type
==
vllm
::
kU8
);
...
@@ -373,7 +376,7 @@ __global__ void Marlin(
...
@@ -373,7 +376,7 @@ __global__ void Marlin(
const
int
group_size
=
const
int
group_size
=
(
!
has_act_order
&&
group_blocks
==
-
1
)
?
prob_k
:
prob_k
/
num_groups
;
(
!
has_act_order
&&
group_blocks
==
-
1
)
?
prob_k
:
prob_k
/
num_groups
;
const
int
scales_expert_stride
=
const
int
scales_expert_stride
=
prob_n
*
prob_k
/
group_size
/
(
b_type
==
vllm
::
kFE2M1f
?
16
:
8
);
prob_n
*
prob_k
/
group_size
/
(
is_8bit_scale
?
16
:
8
);
const
int
zp_expert_stride
=
const
int
zp_expert_stride
=
is_zp_float
?
prob_n
*
prob_k
/
group_size
/
8
is_zp_float
?
prob_n
*
prob_k
/
group_size
/
8
:
prob_n
*
prob_k
/
group_size
/
(
pack_factor
*
4
);
:
prob_n
*
prob_k
/
group_size
/
(
pack_factor
*
4
);
...
@@ -692,9 +695,8 @@ __global__ void Marlin(
...
@@ -692,9 +695,8 @@ __global__ void Marlin(
constexpr
int
b_sh_wr_iters
=
b_sh_stage
/
b_sh_wr_delta
;
constexpr
int
b_sh_wr_iters
=
b_sh_stage
/
b_sh_wr_delta
;
// Scale sizes/strides without act_order
// Scale sizes/strides without act_order
int
s_gl_stride
=
prob_n
/
(
b_type
==
vllm
::
kFE2M1f
?
16
:
8
);
int
s_gl_stride
=
prob_n
/
(
is_8bit_scale
?
16
:
8
);
constexpr
int
s_sh_stride
=
constexpr
int
s_sh_stride
=
16
*
thread_n_blocks
/
(
is_8bit_scale
?
16
:
8
);
16
*
thread_n_blocks
/
(
b_type
==
vllm
::
kFE2M1f
?
16
:
8
);
constexpr
int
s_tb_groups
=
constexpr
int
s_tb_groups
=
!
has_act_order
&&
group_blocks
!=
-
1
&&
group_blocks
<
thread_k_blocks
!
has_act_order
&&
group_blocks
!=
-
1
&&
group_blocks
<
thread_k_blocks
?
thread_k_blocks
/
group_blocks
?
thread_k_blocks
/
group_blocks
...
@@ -1131,7 +1133,7 @@ __global__ void Marlin(
...
@@ -1131,7 +1133,7 @@ __global__ void Marlin(
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
if
constexpr
(
b_type_id
!=
vllm
::
kFE2M1f
.
id
()
)
{
if
constexpr
(
!
is_8bit_scale
)
{
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
+
cur_group_id
*
s_sh_stride
];
sh_s_stage
[
s_sh_rd
+
cur_group_id
*
s_sh_stride
];
}
else
{
}
else
{
...
@@ -1140,7 +1142,7 @@ __global__ void Marlin(
...
@@ -1140,7 +1142,7 @@ __global__ void Marlin(
sh_s_stage
)[
s_sh_rd
+
cur_group_id
*
(
2
*
s_sh_stride
)];
sh_s_stage
)[
s_sh_rd
+
cur_group_id
*
(
2
*
s_sh_stride
)];
}
}
}
else
if
(
group_blocks
>=
b_sh_wr_iters
)
{
}
else
if
(
group_blocks
>=
b_sh_wr_iters
)
{
if
constexpr
(
b_type_id
!=
vllm
::
kFE2M1f
.
id
()
)
{
if
constexpr
(
!
is_8bit_scale
)
{
reinterpret_cast
<
int4
*>
(
&
frag_s
[
1
])[
0
]
=
reinterpret_cast
<
int4
*>
(
&
frag_s
[
1
])[
0
]
=
reinterpret_cast
<
int4
*>
(
&
frag_s
[
0
])[
0
];
reinterpret_cast
<
int4
*>
(
&
frag_s
[
0
])[
0
];
}
else
{
}
else
{
...
@@ -1341,7 +1343,7 @@ __global__ void Marlin(
...
@@ -1341,7 +1343,7 @@ __global__ void Marlin(
}
}
}
}
if
constexpr
(
b
_type
==
vllm
::
kFE
2M1f
)
{
if
constexpr
(
s
_type
==
vllm
::
kFE
4M3fn
||
s_type
==
vllm
::
kFE8M0fnu
)
{
int
s_quant_0
=
reinterpret_cast
<
int
*>
(
frag_s
[
k2
])[
0
];
int
s_quant_0
=
reinterpret_cast
<
int
*>
(
frag_s
[
k2
])[
0
];
int
s_quant_1
=
reinterpret_cast
<
int
*>
(
frag_s
[
k2
])[
1
];
int
s_quant_1
=
reinterpret_cast
<
int
*>
(
frag_s
[
k2
])[
1
];
...
...
csrc/moe/marlin_moe_wna16/ops.cu
View file @
db5d0719
...
@@ -599,6 +599,9 @@ torch::Tensor moe_wna16_marlin_gemm(
...
@@ -599,6 +599,9 @@ torch::Tensor moe_wna16_marlin_gemm(
"When b_type = float4_e2m1f, b_scale scalar type must be"
,
"When b_type = float4_e2m1f, b_scale scalar type must be"
,
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4)."
);
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4)."
);
}
}
}
else
if
(
b_type_id
==
vllm
::
kFE4M3fn
.
id
()
&&
b_scales
.
scalar_type
()
==
at
::
ScalarType
::
Float8_e8m0fnu
)
{
s_type_id
=
vllm
::
kFE8M0fnu
.
id
();
}
}
vllm
::
ScalarType
a_type
=
vllm
::
ScalarType
::
from_id
(
a_type_id
);
vllm
::
ScalarType
a_type
=
vllm
::
ScalarType
::
from_id
(
a_type_id
);
...
...
csrc/quantization/marlin/generate_kernels.py
View file @
db5d0719
...
@@ -108,6 +108,15 @@ QUANT_CONFIGS = [
...
@@ -108,6 +108,15 @@ QUANT_CONFIGS = [
"thread_m_blocks"
:
THREAD_M_BLOCKS
,
"thread_m_blocks"
:
THREAD_M_BLOCKS
,
"group_blocks"
:
[
2
],
"group_blocks"
:
[
2
],
},
},
# MXFP8
{
"a_type"
:
[
"kBFloat16"
],
"b_type"
:
"kFE4M3fn"
,
"s_type"
:
"kFE8M0fnu"
,
"thread_configs"
:
THREAD_CONFIGS
,
"thread_m_blocks"
:
THREAD_M_BLOCKS
,
"group_blocks"
:
[
2
],
},
# AWQ-INT4 with INT8 activation
# AWQ-INT4 with INT8 activation
{
{
"a_type"
:
[
"kS8"
],
"a_type"
:
[
"kS8"
],
...
...
csrc/quantization/marlin/marlin.cu
View file @
db5d0719
...
@@ -591,6 +591,9 @@ torch::Tensor marlin_gemm(
...
@@ -591,6 +591,9 @@ torch::Tensor marlin_gemm(
"When b_type = float4_e2m1f, b_scale scalar type must be"
,
"When b_type = float4_e2m1f, b_scale scalar type must be"
,
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4)."
);
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4)."
);
}
}
}
else
if
(
b_type_id
==
vllm
::
kFE4M3fn
.
id
()
&&
b_scales
.
scalar_type
()
==
at
::
ScalarType
::
Float8_e8m0fnu
)
{
s_type_id
=
vllm
::
kFE8M0fnu
.
id
();
}
}
vllm
::
ScalarType
a_type
=
vllm
::
ScalarType
::
from_id
(
a_type_id
);
vllm
::
ScalarType
a_type
=
vllm
::
ScalarType
::
from_id
(
a_type_id
);
...
...
csrc/quantization/marlin/marlin_template.h
View file @
db5d0719
...
@@ -327,6 +327,9 @@ __global__ void Marlin(
...
@@ -327,6 +327,9 @@ __global__ void Marlin(
if
constexpr
(
b_type
==
vllm
::
kFE2M1f
)
{
if
constexpr
(
b_type
==
vllm
::
kFE2M1f
)
{
static_assert
(
s_type
==
vllm
::
kFE4M3fn
&&
group_blocks
==
1
||
static_assert
(
s_type
==
vllm
::
kFE4M3fn
&&
group_blocks
==
1
||
s_type
==
vllm
::
kFE8M0fnu
&&
group_blocks
==
2
);
s_type
==
vllm
::
kFE8M0fnu
&&
group_blocks
==
2
);
}
else
if
constexpr
(
s_type
==
vllm
::
kFE8M0fnu
)
{
// MXFP8: FP8 weights with e8m0 microscaling block scales
static_assert
(
b_type
==
vllm
::
kFE4M3fn
&&
group_blocks
==
2
);
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
nv_bfloat16
>::
value
)
{
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
nv_bfloat16
>::
value
)
{
static_assert
(
s_type
==
vllm
::
kBFloat16
);
static_assert
(
s_type
==
vllm
::
kBFloat16
);
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
half
>::
value
)
{
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
half
>::
value
)
{
...
@@ -334,6 +337,7 @@ __global__ void Marlin(
...
@@ -334,6 +337,7 @@ __global__ void Marlin(
}
}
constexpr
bool
is_a_8bit
=
a_type
.
size_bits
()
==
8
;
constexpr
bool
is_a_8bit
=
a_type
.
size_bits
()
==
8
;
constexpr
bool
is_8bit_scale
=
s_type
.
size_bits
()
==
8
;
if
constexpr
(
!
is_a_8bit
)
{
if
constexpr
(
!
is_a_8bit
)
{
static_assert
(
std
::
is_same
<
scalar_t
,
c_scalar_t
>::
value
);
static_assert
(
std
::
is_same
<
scalar_t
,
c_scalar_t
>::
value
);
}
}
...
@@ -343,7 +347,7 @@ __global__ void Marlin(
...
@@ -343,7 +347,7 @@ __global__ void Marlin(
b_type
==
vllm
::
kU4B8
||
b_type
==
vllm
::
kU8B128
;
b_type
==
vllm
::
kU4B8
||
b_type
==
vllm
::
kU8B128
;
// see comments of dequant.h for more details
// see comments of dequant.h for more details
constexpr
bool
dequant_skip_flop
=
constexpr
bool
dequant_skip_flop
=
is_a_8bit
||
b_type
==
vllm
::
kFE4M3fn
||
is_a_8bit
||
(
b_type
==
vllm
::
kFE4M3fn
&&
!
(
s_type
==
vllm
::
kFE8M0fnu
))
||
b_type
==
vllm
::
kFE2M1f
&&
s_type
==
vllm
::
kFE4M3fn
||
b_type
==
vllm
::
kFE2M1f
&&
s_type
==
vllm
::
kFE4M3fn
||
has_zp
&&
!
is_zp_float
&&
!
std
::
is_same
<
scalar_t
,
nv_bfloat16
>::
value
||
has_zp
&&
!
is_zp_float
&&
!
std
::
is_same
<
scalar_t
,
nv_bfloat16
>::
value
||
has_zp
&&
!
is_zp_float
&&
!
(
b_type
==
vllm
::
kU8
);
has_zp
&&
!
is_zp_float
&&
!
(
b_type
==
vllm
::
kU8
);
...
@@ -555,9 +559,8 @@ __global__ void Marlin(
...
@@ -555,9 +559,8 @@ __global__ void Marlin(
constexpr
int
b_sh_wr_iters
=
b_sh_stage
/
b_sh_wr_delta
;
constexpr
int
b_sh_wr_iters
=
b_sh_stage
/
b_sh_wr_delta
;
// Scale sizes/strides without act_order
// Scale sizes/strides without act_order
int
s_gl_stride
=
prob_n
/
(
b_type
==
vllm
::
kFE2M1f
?
16
:
8
);
int
s_gl_stride
=
prob_n
/
(
is_8bit_scale
?
16
:
8
);
constexpr
int
s_sh_stride
=
constexpr
int
s_sh_stride
=
16
*
thread_n_blocks
/
(
is_8bit_scale
?
16
:
8
);
16
*
thread_n_blocks
/
(
b_type
==
vllm
::
kFE2M1f
?
16
:
8
);
constexpr
int
s_tb_groups
=
constexpr
int
s_tb_groups
=
!
has_act_order
&&
group_blocks
!=
-
1
&&
group_blocks
<
thread_k_blocks
!
has_act_order
&&
group_blocks
!=
-
1
&&
group_blocks
<
thread_k_blocks
?
thread_k_blocks
/
group_blocks
?
thread_k_blocks
/
group_blocks
...
@@ -997,7 +1000,7 @@ __global__ void Marlin(
...
@@ -997,7 +1000,7 @@ __global__ void Marlin(
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
if
constexpr
(
b_type_id
!=
vllm
::
kFE2M1f
.
id
()
)
{
if
constexpr
(
!
is_8bit_scale
)
{
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
+
cur_group_id
*
s_sh_stride
];
sh_s_stage
[
s_sh_rd
+
cur_group_id
*
s_sh_stride
];
}
else
{
}
else
{
...
@@ -1006,7 +1009,7 @@ __global__ void Marlin(
...
@@ -1006,7 +1009,7 @@ __global__ void Marlin(
sh_s_stage
)[
s_sh_rd
+
cur_group_id
*
(
2
*
s_sh_stride
)];
sh_s_stage
)[
s_sh_rd
+
cur_group_id
*
(
2
*
s_sh_stride
)];
}
}
}
else
if
(
group_blocks
>=
b_sh_wr_iters
)
{
}
else
if
(
group_blocks
>=
b_sh_wr_iters
)
{
if
constexpr
(
b_type_id
!=
vllm
::
kFE2M1f
.
id
()
)
{
if
constexpr
(
!
is_8bit_scale
)
{
reinterpret_cast
<
int4
*>
(
&
frag_s
[
1
])[
0
]
=
reinterpret_cast
<
int4
*>
(
&
frag_s
[
1
])[
0
]
=
reinterpret_cast
<
int4
*>
(
&
frag_s
[
0
])[
0
];
reinterpret_cast
<
int4
*>
(
&
frag_s
[
0
])[
0
];
}
else
{
}
else
{
...
@@ -1207,7 +1210,7 @@ __global__ void Marlin(
...
@@ -1207,7 +1210,7 @@ __global__ void Marlin(
}
}
}
}
if
constexpr
(
b
_type
==
vllm
::
kFE
2M1f
)
{
if
constexpr
(
s
_type
==
vllm
::
kFE
4M3fn
||
s_type
==
vllm
::
kFE8M0fnu
)
{
int
s_quant_0
=
reinterpret_cast
<
int
*>
(
frag_s
[
k2
])[
0
];
int
s_quant_0
=
reinterpret_cast
<
int
*>
(
frag_s
[
k2
])[
0
];
int
s_quant_1
=
reinterpret_cast
<
int
*>
(
frag_s
[
k2
])[
1
];
int
s_quant_1
=
reinterpret_cast
<
int
*>
(
frag_s
[
k2
])[
1
];
...
...
tests/kernels/moe/test_moe.py
View file @
db5d0719
...
@@ -151,6 +151,12 @@ MOE_MARLIN_QUANT_TEST_CONFIGS = [
...
@@ -151,6 +151,12 @@ MOE_MARLIN_QUANT_TEST_CONFIGS = [
"b_type"
:
scalar_types
.
float4_e2m1f
,
"b_type"
:
scalar_types
.
float4_e2m1f
,
"group_blocks"
:
[
2
],
"group_blocks"
:
[
2
],
},
},
# MXFP8
{
"a_type"
:
[
scalar_types
.
bfloat16
],
"b_type"
:
scalar_types
.
float8_e4m3fn
,
"group_blocks"
:
[
2
],
},
# AWQ-INT4 with INT8 activation
# AWQ-INT4 with INT8 activation
{
{
"a_type"
:
[
scalar_types
.
int8
],
"a_type"
:
[
scalar_types
.
int8
],
...
...
tests/models/quantization/test_mxfp8.py
View file @
db5d0719
...
@@ -23,7 +23,7 @@ from tests.quantization.utils import is_quant_method_supported
...
@@ -23,7 +23,7 @@ from tests.quantization.utils import is_quant_method_supported
from
..utils
import
check_logprobs_close
from
..utils
import
check_logprobs_close
# A small MoE model that fits on a single GPU and has both linear + MoE layers.
# A small MoE model that fits on a single GPU and has both linear + MoE layers.
MOE_MODEL
=
"
Qwen/Qwen3-30B-A3B
"
MOE_MODEL
=
"
allenai/OLMoE-1B-7B-0125-Instruct
"
# A small dense model (no MoE) to validate the linear-only path.
# A small dense model (no MoE) to validate the linear-only path.
DENSE_MODEL
=
"Qwen/Qwen3-0.6B"
DENSE_MODEL
=
"Qwen/Qwen3-0.6B"
...
...
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
View file @
db5d0719
...
@@ -41,6 +41,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
...
@@ -41,6 +41,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticChannelSym
,
kFp8StaticChannelSym
,
kFp8StaticTensorSym
,
kFp8StaticTensorSym
,
kMxfp4Static
,
kMxfp4Static
,
kMxfp8Static
,
kNvfp4Static
,
kNvfp4Static
,
)
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
@@ -582,6 +583,7 @@ class MarlinExpertsBase(mk.FusedMoEExpertsModular):
...
@@ -582,6 +583,7 @@ class MarlinExpertsBase(mk.FusedMoEExpertsModular):
kFp8StaticChannelSym
,
kFp8StaticChannelSym
,
kFp8StaticTensorSym
,
kFp8StaticTensorSym
,
kMxfp4Static
,
kMxfp4Static
,
kMxfp8Static
,
kNvfp4Static
,
kNvfp4Static
,
]
]
return
weight_key
in
SUPPORTED_W
return
weight_key
in
SUPPORTED_W
...
@@ -609,7 +611,6 @@ class MarlinExpertsBase(mk.FusedMoEExpertsModular):
...
@@ -609,7 +611,6 @@ class MarlinExpertsBase(mk.FusedMoEExpertsModular):
@
property
@
property
def
quant_type_id
(
self
)
->
int
:
def
quant_type_id
(
self
)
->
int
:
# uint4b8 will be set for int4 weight and float4_e2m1f will be used for mxfp4
if
self
.
quant_config
.
use_int4_w4a16
:
if
self
.
quant_config
.
use_int4_w4a16
:
return
scalar_types
.
uint4b8
.
id
return
scalar_types
.
uint4b8
.
id
elif
self
.
quant_config
.
use_mxfp4_w4a16
or
self
.
quant_config
.
use_nvfp4_w4a16
:
elif
self
.
quant_config
.
use_mxfp4_w4a16
or
self
.
quant_config
.
use_nvfp4_w4a16
:
...
...
vllm/model_executor/layers/fused_moe/oracle/fp8.py
View file @
db5d0719
...
@@ -436,13 +436,27 @@ def convert_to_fp8_moe_kernel_format(
...
@@ -436,13 +436,27 @@ def convert_to_fp8_moe_kernel_format(
elif
fp8_backend
==
Fp8MoeBackend
.
AITER
:
elif
fp8_backend
==
Fp8MoeBackend
.
AITER
:
w13
,
w2
=
rocm_aiter_ops
.
shuffle_weights
(
w13
,
w2
)
w13
,
w2
=
rocm_aiter_ops
.
shuffle_weights
(
w13
,
w2
)
elif
fp8_backend
==
Fp8MoeBackend
.
MARLIN
:
elif
fp8_backend
==
Fp8MoeBackend
.
MARLIN
:
w13
,
w2
,
w13_scale
,
w2_scale
=
prepare_fp8_moe_layer_for_marlin
(
weight_block_size
=
getattr
(
layer
,
"weight_block_size"
,
None
)
layer
,
if
weight_block_size
==
[
1
,
32
]:
w13
,
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
w2
,
prepare_mxfp8_moe_layer_for_marlin
,
w13_scale
,
)
w2_scale
,
)
w13
,
w2
,
w13_scale
,
w2_scale
=
prepare_mxfp8_moe_layer_for_marlin
(
layer
,
w13
,
w2
,
w13_scale
,
w2_scale
,
)
else
:
w13
,
w2
,
w13_scale
,
w2_scale
=
prepare_fp8_moe_layer_for_marlin
(
layer
,
w13
,
w2
,
w13_scale
,
w2_scale
,
)
elif
fp8_backend
in
[
elif
fp8_backend
in
[
Fp8MoeBackend
.
FLASHINFER_CUTLASS
,
Fp8MoeBackend
.
FLASHINFER_CUTLASS
,
Fp8MoeBackend
.
FLASHINFER_TRTLLM
,
Fp8MoeBackend
.
FLASHINFER_TRTLLM
,
...
...
vllm/model_executor/layers/fused_moe/oracle/mxfp8.py
View file @
db5d0719
...
@@ -15,14 +15,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
...
@@ -15,14 +15,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_SUPPORTED_BACKENDS
:
frozenset
[
Fp8MoeBackend
]
=
frozenset
(
_SUPPORTED_BACKENDS
=
(
{
Fp8MoeBackend
.
FLASHINFER_TRTLLM
,
Fp8MoeBackend
.
FLASHINFER_TRTLLM
,
Fp8MoeBackend
.
MARLIN
,
}
)
)
_BACKEND_NAME_MAP
:
dict
[
str
,
Fp8MoeBackend
]
=
{
_BACKEND_NAME_MAP
:
dict
[
str
,
Fp8MoeBackend
]
=
{
"flashinfer_trtllm"
:
Fp8MoeBackend
.
FLASHINFER_TRTLLM
,
"flashinfer_trtllm"
:
Fp8MoeBackend
.
FLASHINFER_TRTLLM
,
"marlin"
:
Fp8MoeBackend
.
MARLIN
,
}
}
...
@@ -81,7 +81,11 @@ def select_mxfp8_moe_backend(
...
@@ -81,7 +81,11 @@ def select_mxfp8_moe_backend(
# Auto-select: pick the first supported backend.
# Auto-select: pick the first supported backend.
for
backend
in
_SUPPORTED_BACKENDS
:
for
backend
in
_SUPPORTED_BACKENDS
:
try
:
experts_cls
=
_select_kernel_cls
(
backend
,
config
)
except
ValueError
:
continue
logger
.
info_once
(
"Using '%s' MxFp8 MoE backend."
,
backend
.
value
)
logger
.
info_once
(
"Using '%s' MxFp8 MoE backend."
,
backend
.
value
)
return
backend
,
_select_kernel_cls
(
backend
,
config
)
return
backend
,
experts_cls
raise
ValueError
(
"No MXFP8 MoE backends available."
)
raise
ValueError
(
"No MXFP8 MoE backends available."
)
vllm/model_executor/layers/quantization/modelopt.py
View file @
db5d0719
...
@@ -67,10 +67,8 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
...
@@ -67,10 +67,8 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
MXFP8_BLOCK_SIZE
,
MXFP8_BLOCK_SIZE
,
MXFP8_SCALE_DTYPE
,
MXFP8_SCALE_DTYPE
,
MXFP8_VALUE_DTYPE
,
MXFP8_VALUE_DTYPE
,
Mxfp8LinearBackend
,
Mxfp8LinearOp
,
Mxfp8LinearOp
,
mxfp8_e4m3_quantize
,
mxfp8_e4m3_quantize
,
swizzle_mxfp8_scale
,
)
)
from
vllm.model_executor.layers.quantization.utils.nvfp4_utils
import
(
from
vllm.model_executor.layers.quantization.utils.nvfp4_utils
import
(
apply_nvfp4_linear
,
apply_nvfp4_linear
,
...
@@ -1499,8 +1497,8 @@ class ModelOptMxFp8Config(ModelOptQuantConfigBase):
...
@@ -1499,8 +1497,8 @@ class ModelOptMxFp8Config(ModelOptQuantConfigBase):
@
classmethod
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
def
get_min_capability
(
cls
)
->
int
:
# M
XFP8 hardware acceleration requires Blackwell (SM100) or newer
# M
arlin kernel supports MXFP8 on SM80+
return
10
0
return
8
0
@
classmethod
@
classmethod
def
override_quantization_method
(
def
override_quantization_method
(
...
@@ -1555,9 +1553,7 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
...
@@ -1555,9 +1553,7 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
"Dynamic quantization is not supported."
"Dynamic quantization is not supported."
)
)
self
.
backend
:
Mxfp8LinearBackend
=
Mxfp8LinearBackend
.
FLASHINFER_CUTLASS
self
.
mxfp8_linear_op
=
Mxfp8LinearOp
()
self
.
mxfp8_linear_op
=
Mxfp8LinearOp
(
backend
=
self
.
backend
)
logger
.
info_once
(
"Using %s backend for MXFP8 GEMM"
,
self
.
backend
.
value
)
def
create_weights
(
def
create_weights
(
self
,
self
,
...
@@ -1615,36 +1611,6 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
...
@@ -1615,36 +1611,6 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
)
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
def
_process_weights_after_loading_scale_2d
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""Not swizzled - MXFP8 GEMM emulation"""
weight
=
layer
.
weight
.
data
# [N, K]
N
,
K
=
weight
.
shape
scale_k
=
K
//
MXFP8_BLOCK_SIZE
# Slice weight_scale to match weight dimensions (handles padding)
weight_scale
=
layer
.
weight_scale
.
data
[:
N
,
:
scale_k
].
contiguous
()
layer
.
weight
=
Parameter
(
weight
.
contiguous
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
def
_process_weights_after_loading_scale_1d
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""Swizzled - MXFP8 GEMM Flashinfer CUTLASS"""
weight
=
layer
.
weight
.
data
# [N, K]
N
,
K
=
weight
.
shape
# 2D weight scale
weight_scale
=
layer
.
weight_scale
.
data
# Swizzle the weight scales
scale_k
=
K
//
MXFP8_BLOCK_SIZE
weight_scale_2d
=
weight_scale
[:
N
,
:
scale_k
].
contiguous
()
weight_scale_swizzled
=
swizzle_mxfp8_scale
(
weight_scale_2d
,
M
=
N
,
K
=
K
)
layer
.
weight
=
Parameter
(
weight
.
contiguous
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale_swizzled
.
contiguous
(),
requires_grad
=
False
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# Validate weight tensor
# Validate weight tensor
if
layer
.
weight
.
ndim
!=
2
:
if
layer
.
weight
.
ndim
!=
2
:
...
@@ -1669,14 +1635,7 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
...
@@ -1669,14 +1635,7 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
f
" got
{
layer
.
weight_scale
.
dtype
}
"
f
" got
{
layer
.
weight_scale
.
dtype
}
"
)
)
if
self
.
backend
==
Mxfp8LinearBackend
.
EMULATION
:
self
.
mxfp8_linear_op
.
process_weights
(
layer
)
# Swizzled layout is not used
self
.
_process_weights_after_loading_scale_2d
(
layer
)
return
assert
self
.
backend
==
Mxfp8LinearBackend
.
FLASHINFER_CUTLASS
# Swizzled layout is required for Flashinfer CUTLASS
self
.
_process_weights_after_loading_scale_1d
(
layer
)
def
apply
(
def
apply
(
self
,
self
,
...
@@ -1684,22 +1643,15 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
...
@@ -1684,22 +1643,15 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
layer
.
weight
.
dtype
!=
MXFP8_VALUE_DTYPE
:
raise
ValueError
(
f
"Weight dtype
{
layer
.
weight
.
dtype
}
!= expected
{
MXFP8_VALUE_DTYPE
}
"
)
if
layer
.
weight_scale
.
dtype
!=
MXFP8_SCALE_DTYPE
:
raise
ValueError
(
f
"Weight scale dtype
{
layer
.
weight_scale
.
dtype
}
!= "
f
"expected
{
MXFP8_SCALE_DTYPE
}
"
)
return
self
.
mxfp8_linear_op
.
apply
(
return
self
.
mxfp8_linear_op
.
apply
(
input
=
x
,
input
=
x
,
weight
=
layer
.
weight
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
weight_scale
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
out_dtype
=
x
.
dtype
,
bias
=
bias
,
bias
=
bias
,
workspace
=
getattr
(
layer
,
"workspace"
,
None
),
size_n
=
layer
.
output_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
)
)
...
...
vllm/model_executor/layers/quantization/mxfp8.py
View file @
db5d0719
...
@@ -34,10 +34,8 @@ from vllm.model_executor.layers.quantization.fp8 import (
...
@@ -34,10 +34,8 @@ from vllm.model_executor.layers.quantization.fp8 import (
)
)
from
vllm.model_executor.layers.quantization.utils.mxfp8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.mxfp8_utils
import
(
MXFP8_BLOCK_SIZE
,
MXFP8_BLOCK_SIZE
,
Mxfp8LinearBackend
,
Mxfp8LinearOp
,
Mxfp8LinearOp
,
mxfp8_e4m3_quantize
,
mxfp8_e4m3_quantize
,
swizzle_mxfp8_scale
,
)
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
is_layer_skipped
,
is_layer_skipped
,
...
@@ -71,7 +69,8 @@ class Mxfp8Config(Fp8Config):
...
@@ -71,7 +69,8 @@ class Mxfp8Config(Fp8Config):
@
classmethod
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
def
get_min_capability
(
cls
)
->
int
:
return
100
# Marlin kernel supports MXFP8 on SM80+
return
80
@
classmethod
@
classmethod
def
from_config
(
cls
,
config
:
dict
[
str
,
Any
])
->
"Mxfp8Config"
:
def
from_config
(
cls
,
config
:
dict
[
str
,
Any
])
->
"Mxfp8Config"
:
...
@@ -128,24 +127,7 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod):
...
@@ -128,24 +127,7 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod):
def
__init__
(
self
,
quant_config
:
"Mxfp8Config"
):
def
__init__
(
self
,
quant_config
:
"Mxfp8Config"
):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
out_dtype
=
torch
.
get_default_dtype
()
self
.
out_dtype
=
torch
.
get_default_dtype
()
self
.
mxfp8_linear
=
Mxfp8LinearOp
(
self
.
_select_backend
())
self
.
mxfp8_linear
=
Mxfp8LinearOp
()
logger
.
info_once
(
"Using %s backend for MXFP8 GEMM"
,
self
.
mxfp8_linear
.
backend
.
value
)
@
staticmethod
def
_select_backend
()
->
Mxfp8LinearBackend
:
try
:
from
vllm.utils
import
flashinfer
as
fi
_
=
fi
.
mm_mxfp8
return
Mxfp8LinearBackend
.
FLASHINFER_CUTLASS
except
Exception
:
logger
.
warning
(
"FlashInfer mm_mxfp8 not available, "
"falling back to MXFP8 emulation backend."
)
return
Mxfp8LinearBackend
.
EMULATION
def
create_weights
(
def
create_weights
(
self
,
self
,
...
@@ -180,14 +162,12 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod):
...
@@ -180,14 +162,12 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod):
weight_fp8
,
weight_scale
=
mxfp8_e4m3_quantize
(
layer
.
weight
.
contiguous
())
weight_fp8
,
weight_scale
=
mxfp8_e4m3_quantize
(
layer
.
weight
.
contiguous
())
if
self
.
mxfp8_linear
.
backend
==
Mxfp8LinearBackend
.
FLASHINFER_CUTLASS
:
N
,
K
=
layer
.
weight
.
shape
[
0
],
layer
.
weight
.
shape
[
1
]
weight_scale
=
swizzle_mxfp8_scale
(
weight_scale
,
N
,
K
)
layer
.
input_scale
=
None
layer
.
input_scale
=
None
replace_parameter
(
layer
,
"weight"
,
weight_fp8
.
data
)
replace_parameter
(
layer
,
"weight"
,
weight_fp8
.
data
)
replace_parameter
(
layer
,
"weight_scale"
,
weight_scale
.
data
)
replace_parameter
(
layer
,
"weight_scale"
,
weight_scale
.
data
)
self
.
mxfp8_linear
.
process_weights
(
layer
)
layer
.
_already_called_process_weights_after_loading
=
True
layer
.
_already_called_process_weights_after_loading
=
True
def
apply
(
def
apply
(
...
@@ -202,6 +182,9 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod):
...
@@ -202,6 +182,9 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod):
weight_scale
=
layer
.
weight_scale
,
weight_scale
=
layer
.
weight_scale
,
out_dtype
=
self
.
out_dtype
,
out_dtype
=
self
.
out_dtype
,
bias
=
bias
,
bias
=
bias
,
workspace
=
getattr
(
layer
,
"workspace"
,
None
),
size_n
=
layer
.
output_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
)
)
...
@@ -255,17 +238,24 @@ class Mxfp8OnlineMoEMethod(Fp8OnlineMoEMethod):
...
@@ -255,17 +238,24 @@ class Mxfp8OnlineMoEMethod(Fp8OnlineMoEMethod):
self
,
weight
:
torch
.
Tensor
self
,
weight
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Batch quantization: bf16/fp16 weights -> MXFP8 (fp8 + uint8 scales)."""
"""Batch quantization: bf16/fp16 weights -> MXFP8 (fp8 + uint8 scales)."""
num_batches
=
weight
.
size
(
0
)
E
=
weight
.
size
(
0
)
w_quant
=
[]
first_q
,
first_s
=
mxfp8_e4m3_quantize
(
weight
[
0
],
is_sf_swizzled_layout
=
False
)
w_scales
=
[]
# Pre-allocate the output tensors rather than stacking.
for
i
in
range
(
num_batches
):
# This is important for consistent memory layout.
mx_fp8_quant
,
mx_fp8_scale
=
mxfp8_e4m3_quantize
(
w_quant
=
torch
.
empty
(
(
E
,
*
first_q
.
shape
),
dtype
=
first_q
.
dtype
,
device
=
weight
.
device
)
w_scales
=
torch
.
empty
(
(
E
,
*
first_s
.
shape
),
dtype
=
first_s
.
dtype
,
device
=
weight
.
device
)
w_quant
[
0
]
=
first_q
w_scales
[
0
]
=
first_s
for
i
in
range
(
1
,
E
):
w_quant
[
i
],
w_scales
[
i
]
=
mxfp8_e4m3_quantize
(
weight
[
i
],
is_sf_swizzled_layout
=
False
weight
[
i
],
is_sf_swizzled_layout
=
False
)
)
w_quant
.
append
(
mx_fp8_quant
)
w_scales
.
append
(
mx_fp8_scale
)
return
torch
.
stack
(
w_quant
)
,
torch
.
stack
(
w_scales
)
return
w_quant
,
w_scales
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
View file @
db5d0719
...
@@ -336,6 +336,202 @@ def pack_fp8_to_int32(
...
@@ -336,6 +336,202 @@ def pack_fp8_to_int32(
return
int32_tensor
.
T
.
contiguous
()
if
size_k_first
else
int32_tensor
return
int32_tensor
.
T
.
contiguous
()
if
size_k_first
else
int32_tensor
def
mxfp8_marlin_process_scales
(
marlin_scales
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Reorder scales for e8m0 kernel layout and convert to float8_e8m0fnu."""
# fit the layout of fp8 dequantization
marlin_scales
=
marlin_scales
.
view
(
-
1
,
4
)[:,
[
0
,
2
,
1
,
3
]].
view
(
marlin_scales
.
size
(
0
),
-
1
)
marlin_scales
=
marlin_scales
.
to
(
torch
.
float8_e8m0fnu
)
return
marlin_scales
def
apply_mxfp8_marlin_linear
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
size_n
:
int
,
size_k
:
int
,
bias
:
torch
.
Tensor
|
None
=
None
,
use_fp32_reduce
:
bool
=
USE_FP32_REDUCE_DEFAULT
,
)
->
torch
.
Tensor
:
reshaped_x
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
out_shape
=
input
.
shape
[:
-
1
]
+
(
size_n
,)
use_atomic_add
=
should_use_atomic_add_reduce
(
m
=
reshaped_x
.
size
(
0
),
n
=
size_n
,
k
=
size_k
,
device
=
input
.
device
,
dtype
=
input
.
dtype
,
)
output
=
ops
.
marlin_gemm
(
a
=
reshaped_x
,
c
=
None
,
b_q_weight
=
weight
,
b_bias
=
bias
,
b_scales
=
weight_scale
,
a_scales
=
None
,
global_scale
=
None
,
b_zeros
=
None
,
g_idx
=
None
,
perm
=
None
,
workspace
=
workspace
,
b_q_type
=
scalar_types
.
float8_e4m3fn
,
size_m
=
reshaped_x
.
size
(
0
),
size_n
=
size_n
,
size_k
=
size_k
,
use_atomic_add
=
use_atomic_add
,
use_fp32_reduce
=
use_fp32_reduce
,
)
return
output
.
reshape
(
out_shape
)
def
prepare_mxfp8_layer_for_marlin
(
layer
:
torch
.
nn
.
Module
)
->
None
:
"""Repack MXFP8 weights and scales into Marlin kernel format.
Expects the layer to have:
- weight: [N, K] float8_e4m3fn
- weight_scale: [N, K//32] uint8 (e8m0 encoded)
- input_size_per_partition / output_size_per_partition
"""
part_size_n
=
layer
.
output_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
group_size
=
32
# MX standard block size
device
=
layer
.
weight
.
device
# WORKSPACE
layer
.
workspace
=
marlin_make_workspace_new
(
device
)
# WEIGHT - repack FP8 weights to Marlin format
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
)
qweight
=
pack_fp8_to_int32
(
layer
.
weight
,
size_k_first
=
False
)
qweight
=
qweight
.
T
.
contiguous
()
marlin_qweight
=
ops
.
gptq_marlin_repack
(
b_q_weight
=
qweight
,
perm
=
perm
,
size_k
=
part_size_k
,
size_n
=
part_size_n
,
num_bits
=
8
,
)
replace_parameter
(
layer
,
"weight"
,
marlin_qweight
)
# WEIGHT SCALES
# Convert uint8 scales -> e8m0fnu -> param_dtype for permutation
# Scales are [N, K//32], need [K//32, N] for marlin_permute_scales
param_dtype
=
torch
.
get_default_dtype
()
scales
=
layer
.
weight_scale
.
data
[:
part_size_n
,
:
part_size_k
//
group_size
]
scales
=
scales
.
contiguous
()
scales
=
scales
.
view
(
torch
.
float8_e8m0fnu
).
to
(
param_dtype
)
scales
=
scales
.
T
.
contiguous
()
# Permute scales to Marlin layout
marlin_scales
=
marlin_permute_scales
(
s
=
scales
,
size_k
=
part_size_k
,
size_n
=
part_size_n
,
group_size
=
group_size
,
)
# Reorder for e8m0 kernel layout and convert back to e8m0fnu
marlin_scales
=
mxfp8_marlin_process_scales
(
marlin_scales
)
replace_parameter
(
layer
,
"weight_scale"
,
marlin_scales
)
# BIAS
if
hasattr
(
layer
,
"bias"
)
and
layer
.
bias
is
not
None
:
assert
layer
.
bias
.
shape
==
(
part_size_n
,)
bias
=
marlin_permute_bias
(
layer
.
bias
)
replace_parameter
(
layer
,
"bias"
,
bias
)
def
prepare_mxfp8_moe_layer_for_marlin
(
layer
:
torch
.
nn
.
Module
,
w13
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w13_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Repack MXFP8 MoE weights and scales into Marlin kernel format.
Args:
layer: MoE layer (used to read params_dtype and attach workspace).
w13: [E, 2*N, K] float8_e4m3fn weights.
w2: [E, K, N] float8_e4m3fn weights.
w13_scale: [E, 2*N, K//32] uint8 e8m0 scales.
w2_scale: [E, K, N//32] uint8 e8m0 scales.
Returns:
(w13, w2, w13_scale, w2_scale) in Marlin format.
"""
group_size
=
32
e
=
w13
.
shape
[
0
]
w13_n
=
w13
.
shape
[
1
]
k
=
w13
.
shape
[
2
]
n
=
w2
.
shape
[
2
]
device
=
w13
.
device
param_dtype
=
torch
.
get_default_dtype
()
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
)
layer
.
workspace
=
marlin_make_workspace_new
(
device
,
4
)
def
repack_weight
(
weight
:
torch
.
Tensor
,
name
:
str
)
->
torch
.
Tensor
:
if
"w13"
in
name
:
size_n
,
size_k
=
w13_n
,
k
else
:
size_n
,
size_k
=
k
,
n
assert
weight
.
shape
==
(
e
,
size_n
,
size_k
)
tensor_list
=
[]
for
i
in
range
(
e
):
qweight
=
pack_fp8_to_int32
(
weight
[
i
],
size_k_first
=
False
)
qweight
=
qweight
.
T
.
contiguous
()
marlin_qweight
=
ops
.
gptq_marlin_repack
(
b_q_weight
=
qweight
,
perm
=
perm
,
size_k
=
size_k
,
size_n
=
size_n
,
num_bits
=
8
,
)
tensor_list
.
append
(
marlin_qweight
)
return
torch
.
cat
([
x
.
unsqueeze
(
0
)
for
x
in
tensor_list
],
0
)
w13
=
repack_weight
(
w13
,
"w13"
)
w2
=
repack_weight
(
w2
,
"w2"
)
def
permute_scales
(
scales
:
torch
.
Tensor
,
name
:
str
)
->
torch
.
Tensor
:
if
"w13"
in
name
:
size_n
,
size_k
=
w13_n
,
k
else
:
size_n
,
size_k
=
k
,
n
tensor_list
=
[]
for
i
in
range
(
e
):
s
=
scales
[
i
][:
size_n
,
:
size_k
//
group_size
].
contiguous
()
s
=
s
.
view
(
torch
.
float8_e8m0fnu
).
to
(
param_dtype
)
s
=
s
.
T
.
contiguous
()
marlin_s
=
marlin_permute_scales
(
s
=
s
,
size_k
=
size_k
,
size_n
=
size_n
,
group_size
=
group_size
,
)
marlin_s
=
mxfp8_marlin_process_scales
(
marlin_s
)
tensor_list
.
append
(
marlin_s
)
return
torch
.
cat
([
x
.
unsqueeze
(
0
)
for
x
in
tensor_list
],
0
)
w13_scale
=
permute_scales
(
w13_scale
,
"w13"
)
w2_scale
=
permute_scales
(
w2_scale
,
"w2"
)
return
w13
,
w2
,
w13_scale
,
w2_scale
def
marlin_quant_fp8_torch
(
weight
,
group_size
,
input_dtype
=
None
):
def
marlin_quant_fp8_torch
(
weight
,
group_size
,
input_dtype
=
None
):
is_a_8bit
=
input_dtype
is
not
None
and
input_dtype
.
itemsize
==
1
is_a_8bit
=
input_dtype
is
not
None
and
input_dtype
.
itemsize
==
1
if
is_a_8bit
:
if
is_a_8bit
:
...
...
vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
View file @
db5d0719
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
from
enum
import
Enum
from
enum
import
Enum
import
torch
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
flashinfer
as
vllm_flashinfer
from
vllm.utils
import
flashinfer
as
vllm_flashinfer
...
@@ -15,6 +16,7 @@ logger = init_logger(__name__)
...
@@ -15,6 +16,7 @@ logger = init_logger(__name__)
class
Mxfp8LinearBackend
(
Enum
):
class
Mxfp8LinearBackend
(
Enum
):
EMULATION
=
"emulation"
EMULATION
=
"emulation"
FLASHINFER_CUTLASS
=
"flashinfer-cutlass"
FLASHINFER_CUTLASS
=
"flashinfer-cutlass"
MARLIN
=
"marlin"
# MXFP8 constants
# MXFP8 constants
...
@@ -23,6 +25,28 @@ MXFP8_SCALE_DTYPE = torch.uint8
...
@@ -23,6 +25,28 @@ MXFP8_SCALE_DTYPE = torch.uint8
MXFP8_BLOCK_SIZE
=
32
MXFP8_BLOCK_SIZE
=
32
def
select_mxfp8_linear_backend
()
->
Mxfp8LinearBackend
:
"""Select the best MXFP8 linear backend for the current device.
- SM100+ (Blackwell): FLASHINFER_CUTLASS (native MXFP8 W8A8 GEMM)
- SM80+ (Ampere/Ada): MARLIN (MXFP8 W8A16 GEMM)
- Otherwise: EMULATION (dequant to BF16 fallback)
"""
from
vllm.platforms
import
current_platform
if
current_platform
.
has_device_capability
(
100
):
return
Mxfp8LinearBackend
.
FLASHINFER_CUTLASS
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
is_fp8_marlin_supported
,
)
if
is_fp8_marlin_supported
():
return
Mxfp8LinearBackend
.
MARLIN
return
Mxfp8LinearBackend
.
EMULATION
def
swizzle_mxfp8_scale
(
sf
:
torch
.
Tensor
,
M
:
int
,
K
:
int
)
->
torch
.
Tensor
:
def
swizzle_mxfp8_scale
(
sf
:
torch
.
Tensor
,
M
:
int
,
K
:
int
)
->
torch
.
Tensor
:
"""Swizzle MXFP8 scales from row-major 2D to F8_128x4 layout."""
"""Swizzle MXFP8 scales from row-major 2D to F8_128x4 layout."""
scaling_vector_size
=
MXFP8_BLOCK_SIZE
# 32 for MXFP8
scaling_vector_size
=
MXFP8_BLOCK_SIZE
# 32 for MXFP8
...
@@ -47,17 +71,71 @@ def swizzle_mxfp8_scale(sf: torch.Tensor, M: int, K: int) -> torch.Tensor:
...
@@ -47,17 +71,71 @@ def swizzle_mxfp8_scale(sf: torch.Tensor, M: int, K: int) -> torch.Tensor:
return
sf_swizzled
.
contiguous
().
view
(
-
1
)
return
sf_swizzled
.
contiguous
().
view
(
-
1
)
def
_mxfp8_e4m3_quantize_torch
(
x
:
torch
.
Tensor
,
is_sf_swizzled_layout
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Naive MXFP8 quantization.
For each block of 32 elements along the last dimension, compute a
shared e8m0 scale (the biased exponent of the block-wise amax)
and quantize each element to float8_e4m3fn.
Returns (quantized_values [same shape, fp8], scales uint8).
Scale shape depends on is_sf_swizzled_layout:
False -> [..., K//32] (row-major 2D)
True -> [flat swizzled 1D]
"""
assert
x
.
shape
[
-
1
]
%
MXFP8_BLOCK_SIZE
==
0
orig_shape
=
x
.
shape
num_blocks
=
x
.
shape
[
-
1
]
//
MXFP8_BLOCK_SIZE
x_fp32
=
x
.
to
(
torch
.
float32
)
x_blocked
=
x_fp32
.
view
(
*
orig_shape
[:
-
1
],
num_blocks
,
MXFP8_BLOCK_SIZE
)
amax
=
x_blocked
.
abs
().
amax
(
dim
=-
1
)
amax
=
amax
.
clamp
(
min
=
torch
.
finfo
(
torch
.
float32
).
tiny
)
scale_biased
=
torch
.
floor
(
torch
.
log2
(
amax
))
+
127.0
scale_biased
=
scale_biased
.
clamp
(
0
,
254
)
scales_uint8
=
scale_biased
.
to
(
torch
.
uint8
)
descale
=
torch
.
exp2
(
scale_biased
-
127.0
)
x_scaled
=
x_blocked
/
descale
.
unsqueeze
(
-
1
)
x_fp8
=
x_scaled
.
view
(
orig_shape
).
to
(
MXFP8_VALUE_DTYPE
)
if
x
.
ndim
==
2
:
M
,
K
=
x
.
shape
scales_uint8
=
scales_uint8
.
view
(
M
,
-
1
)
if
is_sf_swizzled_layout
:
scales_uint8
=
swizzle_mxfp8_scale
(
scales_uint8
,
M
=
M
,
K
=
K
)
elif
x
.
ndim
==
3
:
B
,
M
,
K
=
x
.
shape
scales_uint8
=
scales_uint8
.
view
(
B
,
M
,
-
1
)
if
is_sf_swizzled_layout
:
swizzled
=
[]
for
i
in
range
(
B
):
swizzled
.
append
(
swizzle_mxfp8_scale
(
scales_uint8
[
i
],
M
=
M
,
K
=
K
))
scales_uint8
=
torch
.
cat
(
swizzled
)
return
x_fp8
,
scales_uint8
def
_mxfp8_e4m3_quantize_impl
(
def
_mxfp8_e4m3_quantize_impl
(
x
:
torch
.
Tensor
,
is_sf_swizzled_layout
:
bool
=
False
x
:
torch
.
Tensor
,
is_sf_swizzled_layout
:
bool
=
False
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
flashinfer
import
mxfp8_quantize
as
flashinfer_mxfp8_quantize
from
vllm.platforms
import
current_platform
x_q
,
x_scales
=
flashinfer_mxfp8_quantize
(
if
current_platform
.
has_device_capability
(
100
):
x
,
is_sf_swizzled_layout
=
is_sf_swizzled_layout
from
flashinfer
import
mxfp8_quantize
as
flashinfer_mxfp8_quantize
)
if
x_scales
.
ndim
==
1
and
x
.
ndim
==
2
and
not
is_sf_swizzled_layout
:
x_q
,
x_scales
=
flashinfer_mxfp8_quantize
(
x_scales
=
x_scales
.
view
(
x
.
size
(
0
),
-
1
)
x
,
is_sf_swizzled_layout
=
is_sf_swizzled_layout
return
x_q
,
x_scales
)
if
x_scales
.
ndim
==
1
and
x
.
ndim
==
2
and
not
is_sf_swizzled_layout
:
x_scales
=
x_scales
.
view
(
x
.
size
(
0
),
-
1
)
return
x_q
,
x_scales
return
_mxfp8_e4m3_quantize_torch
(
x
,
is_sf_swizzled_layout
)
def
mxfp8_e4m3_quantize
(
def
mxfp8_e4m3_quantize
(
...
@@ -128,11 +206,51 @@ direct_register_custom_op(
...
@@ -128,11 +206,51 @@ direct_register_custom_op(
class
Mxfp8LinearOp
:
class
Mxfp8LinearOp
:
def
__init__
(
self
,
backend
:
Mxfp8LinearBackend
):
def
__init__
(
self
):
if
backend
not
in
Mxfp8LinearBackend
:
self
.
backend
=
select_mxfp8_linear_backend
()
raise
ValueError
(
f
"Unsupported backend:
{
backend
}
"
)
logger
.
info_once
(
"Using %s backend for MXFP8 GEMM"
,
self
.
backend
)
def
process_weights
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""Process MXFP8 weights after loading into backend-specific format."""
if
self
.
backend
==
Mxfp8LinearBackend
.
MARLIN
:
self
.
_process_weights_marlin
(
layer
)
elif
self
.
backend
==
Mxfp8LinearBackend
.
FLASHINFER_CUTLASS
:
self
.
_process_weights_flashinfer_cutlass
(
layer
)
else
:
self
.
_process_weights_emulation
(
layer
)
def
_process_weights_emulation
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""Keep scales as 2D uint8 for dequant-to-BF16 emulation."""
weight
=
layer
.
weight
.
data
# [N, K]
N
,
K
=
weight
.
shape
scale_k
=
K
//
MXFP8_BLOCK_SIZE
weight_scale
=
layer
.
weight_scale
.
data
[:
N
,
:
scale_k
].
contiguous
()
layer
.
weight
=
Parameter
(
weight
.
contiguous
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
def
_process_weights_flashinfer_cutlass
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""Swizzle scales to F8_128x4 layout for flashinfer CUTLASS."""
weight
=
layer
.
weight
.
data
# [N, K]
N
,
K
=
weight
.
shape
self
.
backend
=
backend
scale_k
=
K
//
MXFP8_BLOCK_SIZE
weight_scale_2d
=
layer
.
weight_scale
.
data
[:
N
,
:
scale_k
].
contiguous
()
weight_scale_swizzled
=
swizzle_mxfp8_scale
(
weight_scale_2d
,
M
=
N
,
K
=
K
)
layer
.
weight
=
Parameter
(
weight
.
contiguous
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale_swizzled
.
contiguous
(),
requires_grad
=
False
)
def
_process_weights_marlin
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""Repack MXFP8 weights and scales into Marlin kernel format."""
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
prepare_mxfp8_layer_for_marlin
,
)
prepare_mxfp8_layer_for_marlin
(
layer
)
def
_apply_emulation
(
def
_apply_emulation
(
self
,
self
,
...
@@ -142,7 +260,6 @@ class Mxfp8LinearOp:
...
@@ -142,7 +260,6 @@ class Mxfp8LinearOp:
out_dtype
:
torch
.
dtype
,
out_dtype
:
torch
.
dtype
,
bias
:
torch
.
Tensor
|
None
=
None
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Validate weight_scale dtype and shape (must be 2D for TORCH backend)
if
weight_scale
.
dtype
!=
MXFP8_SCALE_DTYPE
:
if
weight_scale
.
dtype
!=
MXFP8_SCALE_DTYPE
:
raise
ValueError
(
raise
ValueError
(
f
"TORCH backend requires
{
MXFP8_SCALE_DTYPE
}
weight_scale dtype, "
f
"TORCH backend requires
{
MXFP8_SCALE_DTYPE
}
weight_scale dtype, "
...
@@ -219,6 +336,32 @@ class Mxfp8LinearOp:
...
@@ -219,6 +336,32 @@ class Mxfp8LinearOp:
output_shape
=
(
*
input_shape
[:
-
1
],
N
)
output_shape
=
(
*
input_shape
[:
-
1
],
N
)
return
output
.
view
(
output_shape
)
return
output
.
view
(
output_shape
)
def
_apply_marlin
(
self
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
torch
.
Tensor
|
None
=
None
,
*
,
workspace
:
torch
.
Tensor
,
size_n
:
int
,
size_k
:
int
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_mxfp8_marlin_linear
,
)
return
apply_mxfp8_marlin_linear
(
input
=
input
,
weight
=
weight
,
weight_scale
=
weight_scale
,
workspace
=
workspace
,
size_n
=
size_n
,
size_k
=
size_k
,
bias
=
bias
,
)
def
apply
(
def
apply
(
self
,
self
,
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
...
@@ -226,10 +369,27 @@ class Mxfp8LinearOp:
...
@@ -226,10 +369,27 @@ class Mxfp8LinearOp:
weight_scale
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
out_dtype
:
torch
.
dtype
,
bias
:
torch
.
Tensor
|
None
=
None
,
bias
:
torch
.
Tensor
|
None
=
None
,
*
,
workspace
:
torch
.
Tensor
|
None
=
None
,
size_n
:
int
=
0
,
size_k
:
int
=
0
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
self
.
backend
==
Mxfp8LinearBackend
.
EMULATION
:
if
self
.
backend
==
Mxfp8LinearBackend
.
EMULATION
:
return
self
.
_apply_emulation
(
input
,
weight
,
weight_scale
,
out_dtype
,
bias
)
return
self
.
_apply_emulation
(
input
,
weight
,
weight_scale
,
out_dtype
,
bias
)
if
self
.
backend
==
Mxfp8LinearBackend
.
MARLIN
:
assert
workspace
is
not
None
return
self
.
_apply_marlin
(
input
,
weight
,
weight_scale
,
out_dtype
,
bias
,
workspace
=
workspace
,
size_n
=
size_n
,
size_k
=
size_k
,
)
assert
self
.
backend
==
Mxfp8LinearBackend
.
FLASHINFER_CUTLASS
assert
self
.
backend
==
Mxfp8LinearBackend
.
FLASHINFER_CUTLASS
return
self
.
_apply_flashinfer_cutlass
(
return
self
.
_apply_flashinfer_cutlass
(
input
,
weight
,
weight_scale
,
out_dtype
,
bias
input
,
weight
,
weight_scale
,
out_dtype
,
bias
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment