Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b471092d
Unverified
Commit
b471092d
authored
Dec 21, 2025
by
Robert Shaw
Committed by
GitHub
Dec 21, 2025
Browse files
[MoE Refactor][4/N] Marlin Fp8 Mk (#31036)
parent
93cabc41
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
85 additions
and
63 deletions
+85
-63
tests/quantization/test_fp8.py
tests/quantization/test_fp8.py
+4
-0
vllm/model_executor/layers/fused_moe/config.py
vllm/model_executor/layers/fused_moe/config.py
+30
-2
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+17
-21
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+34
-40
No files found.
tests/quantization/test_fp8.py
View file @
b471092d
...
@@ -15,6 +15,7 @@ from vllm.model_executor.layers.quantization.fp8 import (
...
@@ -15,6 +15,7 @@ from vllm.model_executor.layers.quantization.fp8 import (
Fp8Config
,
Fp8Config
,
Fp8KVCacheMethod
,
Fp8KVCacheMethod
,
Fp8LinearMethod
,
Fp8LinearMethod
,
Fp8MoeBackend
,
Fp8MoEMethod
,
Fp8MoEMethod
,
)
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
@@ -324,7 +325,10 @@ def test_fp8_reloading(
...
@@ -324,7 +325,10 @@ def test_fp8_reloading(
weight_loader
=
default_weight_loader
,
weight_loader
=
default_weight_loader
,
)
)
# Fp8LinearMethod uses use_marlin
# Fp8MoEMethod uses fp8_backend
method
.
use_marlin
=
use_marlin
method
.
use_marlin
=
use_marlin
method
.
fp8_backend
=
Fp8MoeBackend
.
MARLIN
if
use_marlin
else
None
# capture weights format during loading
# capture weights format during loading
original_metadata
=
[
original_metadata
=
[
...
...
vllm/model_executor/layers/fused_moe/config.py
View file @
b471092d
...
@@ -19,6 +19,7 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
...
@@ -19,6 +19,7 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_Scheme
,
OCP_MX_Scheme
,
)
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
from
vllm.platforms
import
current_platform
from
vllm.utils.flashinfer
import
has_flashinfer_cutlass_fused_moe
from
vllm.utils.flashinfer
import
has_flashinfer_cutlass_fused_moe
from
vllm.utils.import_utils
import
has_triton_kernels
from
vllm.utils.import_utils
import
has_triton_kernels
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.math_utils
import
cdiv
...
@@ -39,6 +40,7 @@ if has_triton_kernels():
...
@@ -39,6 +40,7 @@ if has_triton_kernels():
def
_get_config_dtype_str
(
def
_get_config_dtype_str
(
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
use_int4_w4a16
:
bool
=
False
,
ocp_mx_scheme
:
str
|
None
=
None
,
ocp_mx_scheme
:
str
|
None
=
None
,
...
@@ -50,6 +52,8 @@ def _get_config_dtype_str(
...
@@ -50,6 +52,8 @@ def _get_config_dtype_str(
"""
"""
if
use_fp8_w8a8
:
if
use_fp8_w8a8
:
return
"fp8_w8a8"
return
"fp8_w8a8"
elif
use_fp8_w8a16
:
return
"fp8_w8a16"
elif
use_int8_w8a16
:
elif
use_int8_w8a16
:
return
"int8_w8a16"
return
"int8_w8a16"
elif
use_int4_w4a16
:
elif
use_int4_w4a16
:
...
@@ -319,6 +323,10 @@ class FusedMoEQuantConfig:
...
@@ -319,6 +323,10 @@ class FusedMoEQuantConfig:
def
use_int8_w8a16
(
self
)
->
bool
:
def
use_int8_w8a16
(
self
)
->
bool
:
return
self
.
_a1
.
dtype
is
None
and
self
.
_w1
.
dtype
==
torch
.
int8
return
self
.
_a1
.
dtype
is
None
and
self
.
_w1
.
dtype
==
torch
.
int8
@
property
def
use_fp8_w8a16
(
self
)
->
bool
:
return
self
.
_a1
.
dtype
is
None
and
self
.
_w1
.
dtype
==
current_platform
.
fp8_dtype
()
@
property
@
property
def
use_int4_w4a16
(
self
)
->
bool
:
def
use_int4_w4a16
(
self
)
->
bool
:
return
self
.
_a1
.
dtype
is
None
and
self
.
_w1
.
dtype
==
"int4"
return
self
.
_a1
.
dtype
is
None
and
self
.
_w1
.
dtype
==
"int4"
...
@@ -362,6 +370,7 @@ class FusedMoEQuantConfig:
...
@@ -362,6 +370,7 @@ class FusedMoEQuantConfig:
"""
"""
return
_get_config_dtype_str
(
return
_get_config_dtype_str
(
use_fp8_w8a8
=
self
.
use_fp8_w8a8
,
use_fp8_w8a8
=
self
.
use_fp8_w8a8
,
use_fp8_w8a16
=
self
.
use_fp8_w8a16
,
use_int8_w8a16
=
self
.
use_int8_w8a16
,
use_int8_w8a16
=
self
.
use_int8_w8a16
,
use_int4_w4a16
=
self
.
use_int4_w4a16
,
use_int4_w4a16
=
self
.
use_int4_w4a16
,
ocp_mx_scheme
=
self
.
ocp_mx_scheme
,
ocp_mx_scheme
=
self
.
ocp_mx_scheme
,
...
@@ -680,7 +689,6 @@ def int4_w4a16_moe_quant_config(
...
@@ -680,7 +689,6 @@ def int4_w4a16_moe_quant_config(
)
->
FusedMoEQuantConfig
:
)
->
FusedMoEQuantConfig
:
"""
"""
Construct a quant config for 16-bit float activations and int4 weights.
Construct a quant config for 16-bit float activations and int4 weights.
Note: Activations are pre-quantized.
"""
"""
group_shape
=
GroupShape
(
*
block_shape
)
if
block_shape
is
not
None
else
None
group_shape
=
GroupShape
(
*
block_shape
)
if
block_shape
is
not
None
else
None
return
FusedMoEQuantConfig
(
return
FusedMoEQuantConfig
(
...
@@ -691,6 +699,27 @@ def int4_w4a16_moe_quant_config(
...
@@ -691,6 +699,27 @@ def int4_w4a16_moe_quant_config(
)
)
def
fp8_w8a16_moe_quant_config
(
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
block_shape
:
list
[
int
]
|
None
=
None
,
)
->
FusedMoEQuantConfig
:
"""
Construct a quant config for 16-bit float activations and fp8 weights.
"""
group_shape
=
GroupShape
(
*
block_shape
)
if
block_shape
is
not
None
else
None
return
FusedMoEQuantConfig
(
_a1
=
FusedMoEQuantDesc
(),
_a2
=
FusedMoEQuantDesc
(),
_w1
=
FusedMoEQuantDesc
(
current_platform
.
fp8_dtype
(),
group_shape
,
w1_scale
,
None
,
None
),
_w2
=
FusedMoEQuantDesc
(
current_platform
.
fp8_dtype
(),
group_shape
,
w2_scale
,
None
,
None
),
)
def
int8_w8a16_moe_quant_config
(
def
int8_w8a16_moe_quant_config
(
w1_scale
:
torch
.
Tensor
,
w1_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
w2_scale
:
torch
.
Tensor
,
...
@@ -700,7 +729,6 @@ def int8_w8a16_moe_quant_config(
...
@@ -700,7 +729,6 @@ def int8_w8a16_moe_quant_config(
)
->
FusedMoEQuantConfig
:
)
->
FusedMoEQuantConfig
:
"""
"""
Construct a quant config for 16-bit float activations and int8 weights.
Construct a quant config for 16-bit float activations and int8 weights.
Note: Activations are pre-quantized.
"""
"""
group_shape
=
GroupShape
(
*
block_shape
)
if
block_shape
is
not
None
else
None
group_shape
=
GroupShape
(
*
block_shape
)
if
block_shape
is
not
None
else
None
return
FusedMoEQuantConfig
(
return
FusedMoEQuantConfig
(
...
...
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
View file @
b471092d
...
@@ -13,9 +13,6 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
...
@@ -13,9 +13,6 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
batched_moe_align_block_size
,
batched_moe_align_block_size
,
moe_align_block_size
,
moe_align_block_size
,
)
)
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
,
)
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
TopKWeightAndReduceDelegate
,
TopKWeightAndReduceDelegate
,
TopKWeightAndReduceNoOP
,
TopKWeightAndReduceNoOP
,
...
@@ -26,6 +23,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
...
@@ -26,6 +23,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_moe_intermediate_size
,
marlin_moe_intermediate_size
,
marlin_quant_input
,
marlin_quant_input
,
)
)
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.scalar_type
import
ScalarType
,
scalar_types
...
@@ -542,9 +540,11 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -542,9 +540,11 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
is_k_full
:
bool
=
True
,
is_k_full
:
bool
=
True
,
):
):
# TODO (varun) : Enable activation quantization
# TODO (varun) : Enable activation quantization
assert
quant_config
.
use_mxfp4_w4a16
or
quant_config
.
use_int4_w4a16
,
(
assert
(
"Supports only mxfp4_w4a16 or int4_w4a16"
quant_config
.
use_mxfp4_w4a16
)
or
quant_config
.
use_int4_w4a16
or
quant_config
.
use_fp8_w8a16
),
"Supports only mxfp4_w4a16, int4_w4a16 or fp8_w8a16"
self
.
w13_g_idx
=
w13_g_idx
self
.
w13_g_idx
=
w13_g_idx
self
.
w2_g_idx
=
w2_g_idx
self
.
w2_g_idx
=
w2_g_idx
self
.
w13_g_idx_sort_indices
=
w13_g_idx_sort_indices
self
.
w13_g_idx_sort_indices
=
w13_g_idx_sort_indices
...
@@ -555,11 +555,17 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -555,11 +555,17 @@ class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
@
property
@
property
def
quant_type_id
(
self
)
->
int
:
def
quant_type_id
(
self
)
->
int
:
# uint4b8 will be set for int4 weight and float4_e2m1f will be used for mxfp4
# uint4b8 will be set for int4 weight and float4_e2m1f will be used for mxfp4
return
(
if
self
.
quant_config
.
use_int4_w4a16
:
scalar_types
.
uint4b8
.
id
return
scalar_types
.
uint4b8
.
id
if
self
.
quant_config
.
use_int4_w4a16
elif
self
.
quant_config
.
use_mxfp4_w4a16
:
else
scalar_types
.
float4_e2m1f
.
id
return
scalar_types
.
float4_e2m1f
.
id
)
elif
(
self
.
quant_config
.
use_fp8_w8a16
and
current_platform
.
fp8_dtype
()
==
torch
.
float8_e4m3fn
):
return
scalar_types
.
float8_e4m3fn
.
id
else
:
raise
NotImplementedError
(
"Unsupported quantization type."
)
def
moe_problem_size
(
def
moe_problem_size
(
self
,
self
,
...
@@ -711,16 +717,6 @@ class MarlinExperts(MarlinExpertsBase):
...
@@ -711,16 +717,6 @@ class MarlinExperts(MarlinExpertsBase):
ops
.
moe_sum
(
input
,
output
)
ops
.
moe_sum
(
input
,
output
)
def
modular_marlin_fused_moe
(
quant_config
:
FusedMoEQuantConfig
,
shared_experts
:
torch
.
nn
.
Module
|
None
=
None
)
->
mk
.
FusedMoEModularKernel
:
return
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
MarlinExperts
(
quant_config
),
shared_experts
,
)
class
BatchedMarlinExperts
(
MarlinExpertsBase
):
class
BatchedMarlinExperts
(
MarlinExpertsBase
):
def
__init__
(
def
__init__
(
self
,
self
,
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
b471092d
...
@@ -32,8 +32,8 @@ from vllm.model_executor.layers.fused_moe.config import (
...
@@ -32,8 +32,8 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig
,
FusedMoEQuantConfig
,
RoutingMethodType
,
RoutingMethodType
,
fp8_w8a8_moe_quant_config
,
fp8_w8a8_moe_quant_config
,
fp8_w8a16_moe_quant_config
,
)
)
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
fused_marlin_moe
from
vllm.model_executor.layers.fused_moe.layer
import
UnquantizedFusedMoEMethod
from
vllm.model_executor.layers.fused_moe.layer
import
UnquantizedFusedMoEMethod
from
vllm.model_executor.layers.linear
import
(
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearBase
,
...
@@ -95,7 +95,6 @@ from vllm.model_executor.parameter import (
...
@@ -95,7 +95,6 @@ from vllm.model_executor.parameter import (
)
)
from
vllm.model_executor.utils
import
replace_parameter
,
set_weight_attrs
from
vllm.model_executor.utils
import
replace_parameter
,
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.utils.deep_gemm
import
(
from
vllm.utils.deep_gemm
import
(
is_deep_gemm_e8m0_used
,
is_deep_gemm_e8m0_used
,
is_deep_gemm_supported
,
is_deep_gemm_supported
,
...
@@ -729,7 +728,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -729,7 +728,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
)
self
.
marlin_input_dtype
=
None
self
.
marlin_input_dtype
=
None
self
.
use_marlin
=
self
.
fp8_backend
==
Fp8MoeBackend
.
MARLIN
self
.
flashinfer_moe_backend
:
FlashinferMoeBackend
|
None
=
None
self
.
flashinfer_moe_backend
:
FlashinferMoeBackend
|
None
=
None
if
self
.
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_TRTLLM
:
if
self
.
fp8_backend
==
Fp8MoeBackend
.
FLASHINFER_TRTLLM
:
self
.
flashinfer_moe_backend
=
FlashinferMoeBackend
.
TENSORRT_LLM
self
.
flashinfer_moe_backend
=
FlashinferMoeBackend
.
TENSORRT_LLM
...
@@ -1048,7 +1046,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -1048,7 +1046,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
rotate_flashinfer_fp8_moe_weights
(
w13_weight
,
w2_weight
)
rotate_flashinfer_fp8_moe_weights
(
w13_weight
,
w2_weight
)
layer
.
w13_weight
.
data
=
w13_weight
.
data
layer
.
w13_weight
.
data
=
w13_weight
.
data
if
self
.
use_marlin
:
if
self
.
fp8_backend
==
Fp8MoeBackend
.
MARLIN
:
prepare_moe_fp8_layer_for_marlin
(
prepare_moe_fp8_layer_for_marlin
(
layer
,
False
,
input_dtype
=
self
.
marlin_input_dtype
layer
,
False
,
input_dtype
=
self
.
marlin_input_dtype
)
)
...
@@ -1091,10 +1089,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -1091,10 +1089,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
)
self
.
use_inplace
=
False
self
.
use_inplace
=
False
elif
self
.
fp8_backend
in
[
Fp8MoeBackend
.
DEEPGEMM
,
Fp8MoeBackend
.
TRITON
]:
elif
self
.
fp8_backend
in
[
Fp8MoeBackend
.
DEEPGEMM
,
Fp8MoeBackend
.
TRITON
,
Fp8MoeBackend
.
MARLIN
,
]:
from
vllm.model_executor.layers.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe
import
(
TritonOrDeepGemmExperts
,
TritonOrDeepGemmExperts
,
)
)
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
MarlinExperts
,
)
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
from
vllm.model_executor.layers.fused_moe.prepare_finalize
import
(
MoEPrepareAndFinalizeNoEP
,
MoEPrepareAndFinalizeNoEP
,
)
)
...
@@ -1102,12 +1107,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -1102,12 +1107,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
config
=
self
.
get_fused_moe_quant_config
(
layer
)
config
=
self
.
get_fused_moe_quant_config
(
layer
)
assert
config
is
not
None
assert
config
is
not
None
self
.
moe_quant_config
=
config
self
.
moe_quant_config
=
config
self
.
kernel
=
mk
.
FusedMoEModularKernel
(
use_marlin
=
self
.
fp8_backend
==
Fp8MoeBackend
.
MARLIN
MoEPrepareAndFinalizeNoEP
(),
allow_deep_gemm
=
self
.
fp8_backend
==
Fp8MoeBackend
.
DEEPGEMM
TritonOrDeepGemmExperts
(
moe_kernel
=
(
MarlinExperts
(
quant_config
=
self
.
moe_quant_config
)
if
use_marlin
else
TritonOrDeepGemmExperts
(
quant_config
=
self
.
moe_quant_config
,
quant_config
=
self
.
moe_quant_config
,
allow_deep_gemm
=
(
self
.
fp8_backend
==
Fp8MoeBackend
.
DEEPGEMM
),
allow_deep_gemm
=
allow_deep_gemm
,
),
)
)
self
.
kernel
=
mk
.
FusedMoEModularKernel
(
MoEPrepareAndFinalizeNoEP
(),
moe_kernel
)
)
self
.
use_inplace
=
True
self
.
use_inplace
=
True
...
@@ -1116,9 +1128,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -1116,9 +1128,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
routing_tables
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
,
)
->
mk
.
FusedMoEPrepareAndFinalize
|
None
:
)
->
mk
.
FusedMoEPrepareAndFinalize
|
None
:
if
(
if
(
current_platform
.
is_xpu
()
self
.
rocm_aiter_moe_enabled
or
self
.
rocm_aiter_moe_enabled
or
self
.
fp8_backend
==
Fp8MoeBackend
.
MARLIN
or
self
.
use_marlin
or
self
.
flashinfer_moe_backend
==
FlashinferMoeBackend
.
TENSORRT_LLM
or
self
.
flashinfer_moe_backend
==
FlashinferMoeBackend
.
TENSORRT_LLM
):
):
return
None
return
None
...
@@ -1150,7 +1161,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -1150,7 +1161,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
TritonOrDeepGemmExperts
,
TritonOrDeepGemmExperts
,
)
)
assert
not
self
.
use_marlin
and
not
self
.
rocm_aiter_moe_enabled
,
(
assert
(
self
.
fp8_backend
!=
Fp8MoeBackend
.
MARLIN
)
and
not
self
.
rocm_aiter_moe_enabled
,
(
"Marlin and ROCm AITER are not supported with all2all yet."
"Marlin and ROCm AITER are not supported with all2all yet."
)
)
...
@@ -1207,8 +1220,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -1207,8 +1220,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def
get_fused_moe_quant_config
(
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
self
,
layer
:
torch
.
nn
.
Module
)
->
FusedMoEQuantConfig
|
None
:
)
->
FusedMoEQuantConfig
|
None
:
if
self
.
use_marlin
:
if
self
.
fp8_backend
==
Fp8MoeBackend
.
MARLIN
:
return
None
return
fp8_w8a16_moe_quant_config
(
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
block_shape
=
self
.
weight_block_size
,
)
return
fp8_w8a8_moe_quant_config
(
return
fp8_w8a8_moe_quant_config
(
w1_scale
=
(
w1_scale
=
(
...
@@ -1314,29 +1331,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -1314,29 +1331,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
expert_map
=
layer
.
expert_map
,
expert_map
=
layer
.
expert_map
,
quant_config
=
self
.
moe_quant_config
,
quant_config
=
self
.
moe_quant_config
,
)
)
elif
self
.
use_marlin
:
# TODO(rob): convert this to MK.
assert
layer
.
activation
==
"silu"
,
(
f
"
{
layer
.
activation
}
not supported for Marlin MoE."
)
result
=
fused_marlin_moe
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
None
,
None
,
layer
.
w13_weight_scale
,
layer
.
w2_weight_scale
,
router_logits
,
topk_weights
,
topk_ids
,
quant_type_id
=
scalar_types
.
float8_e4m3fn
.
id
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
global_num_experts
=
layer
.
global_num_experts
,
expert_map
=
layer
.
expert_map
,
input_dtype
=
self
.
marlin_input_dtype
,
workspace
=
layer
.
workspace
,
)
else
:
else
:
result
=
self
.
kernel
(
result
=
self
.
kernel
(
x
,
x
,
...
@@ -1495,7 +1489,7 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
...
@@ -1495,7 +1489,7 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
replace_parameter
(
layer
,
"w2_weight"
,
shuffled_w2
)
replace_parameter
(
layer
,
"w2_weight"
,
shuffled_w2
)
# Rushuffle weights for MARLIN if needed.
# Rushuffle weights for MARLIN if needed.
if
self
.
use_marlin
:
if
self
.
fp8_backend
==
Fp8MoeBackend
.
MARLIN
:
prepare_moe_fp8_layer_for_marlin
(
prepare_moe_fp8_layer_for_marlin
(
layer
,
False
,
input_dtype
=
self
.
marlin_input_dtype
layer
,
False
,
input_dtype
=
self
.
marlin_input_dtype
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment