Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
082cc07e
Unverified
Commit
082cc07e
authored
Aug 27, 2025
by
Yongye Zhu
Committed by
GitHub
Aug 27, 2025
Browse files
DP/EP Support for gpt-oss with deepep-ht comm kernel on SM100 (#23608)
parent
853c371f
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
365 additions
and
12 deletions
+365
-12
vllm/distributed/device_communicators/base_device_communicator.py
...tributed/device_communicators/base_device_communicator.py
+1
-1
vllm/model_executor/layers/fused_moe/config.py
vllm/model_executor/layers/fused_moe/config.py
+6
-0
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+4
-2
vllm/model_executor/layers/fused_moe/trtllm_moe.py
vllm/model_executor/layers/fused_moe/trtllm_moe.py
+197
-0
vllm/model_executor/layers/fused_moe/utils.py
vllm/model_executor/layers/fused_moe/utils.py
+16
-0
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+4
-4
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+1
-0
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+2
-0
vllm/model_executor/layers/quantization/mxfp4.py
vllm/model_executor/layers/quantization/mxfp4.py
+110
-0
vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
+4
-5
vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
+20
-0
No files found.
vllm/distributed/device_communicators/base_device_communicator.py
View file @
082cc07e
...
...
@@ -255,7 +255,7 @@ class DeviceCommunicatorBase:
if
module
.
__class__
.
__name__
==
"FusedMoE"
]
for
module
in
moe_modules
:
module
.
quant_method
.
init_prepare_finalize
()
module
.
quant_method
.
init_prepare_finalize
(
module
)
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/fused_moe/config.py
View file @
082cc07e
...
...
@@ -450,6 +450,12 @@ class FusedMoEConfig:
if
quant_dtype
is
None
and
isinstance
(
quant_config
,
Fp8Config
):
quant_dtype
=
torch
.
float8_e4m3fn
from
vllm.model_executor.layers.quantization.mxfp4
import
(
Mxfp4Config
)
if
(
quant_dtype
is
None
and
isinstance
(
quant_config
,
Mxfp4Config
)
and
envs
.
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
):
quant_dtype
=
"mxfp8"
from
vllm.model_executor.layers.quantization.modelopt
import
(
ModelOptNvFp4Config
)
if
quant_dtype
is
None
and
isinstance
(
quant_config
,
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
082cc07e
...
...
@@ -200,7 +200,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
# Note: init_prepare_finalize should only be called by
# prepare_communication_buffer_for_model.
def
init_prepare_finalize
(
self
):
def
init_prepare_finalize
(
self
,
layer
:
torch
.
nn
.
Module
):
assert
self
.
moe
is
not
None
prepare_finalize
=
self
.
maybe_make_prepare_finalize
(
self
.
moe
)
...
...
@@ -211,7 +211,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
assert
self
.
fused_experts
is
None
,
\
f
"Attempt to override experts for
{
id
(
self
)
}
!"
self
.
topk_indices_dtype
=
prepare_finalize
.
topk_indices_dtype
()
experts
=
self
.
select_gemm_impl
(
prepare_finalize
,
self
.
moe
)
experts
=
self
.
select_gemm_impl
(
prepare_finalize
,
self
.
moe
,
layer
)
self
.
fused_experts
=
FusedMoEModularKernel
(
prepare_finalize
,
experts
,
...
...
@@ -221,6 +221,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
self
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
moe
:
FusedMoEConfig
,
layer
:
torch
.
nn
.
Module
,
)
->
FusedMoEPermuteExpertsUnpermute
:
# based on the all2all implementation, select the appropriate
# gemm implementation
...
...
@@ -273,6 +274,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
# TODO(bnell): Remove. Every layer should have an moe config object.
moe
:
FusedMoEConfig
,
layer
:
torch
.
nn
.
Module
,
)
->
FusedMoEPermuteExpertsUnpermute
:
if
(
prepare_finalize
.
activation_format
==
FusedMoEActivationFormat
.
BatchedExperts
):
...
...
vllm/model_executor/layers/fused_moe/trtllm_moe.py
0 → 100644
View file @
082cc07e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEConfig
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
TopKWeightAndReduceNoOP
)
from
vllm.utils
import
next_power_of_2
class
TrtLlmGenExperts
(
mk
.
FusedMoEPermuteExpertsUnpermute
):
def
__init__
(
self
,
moe
:
FusedMoEConfig
,
gemm1_alpha
,
gemm1_beta
,
gemm1_clamp_limit
,
w13_bias
,
w2_bias
,
max_capture_size
,
):
super
().
__init__
(
moe
.
quant_config
)
self
.
moe
=
moe
self
.
gemm1_alpha
=
gemm1_alpha
self
.
gemm1_beta
=
gemm1_beta
self
.
gemm1_clamp_limit
=
gemm1_clamp_limit
self
.
w13_bias
=
w13_bias
self
.
w2_bias
=
w2_bias
self
.
max_capture_size
=
max_capture_size
@
property
def
activation_formats
(
self
)
->
tuple
[
mk
.
FusedMoEActivationFormat
,
mk
.
FusedMoEActivationFormat
]:
return
(
mk
.
FusedMoEActivationFormat
.
Standard
,
mk
.
FusedMoEActivationFormat
.
Standard
)
def
supports_chunking
(
self
)
->
bool
:
return
True
def
supports_expert_map
(
self
)
->
bool
:
return
True
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
return
TopKWeightAndReduceNoOP
()
def
workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
# The workspaces for this implementation are managed by flashinfer.
# TODO(varun) : workspace1 is could be used as the output tensor. This
# is error-prone. Allow the `workspace_shapes` to return None workspaces
workspace1
=
(
M
,
K
)
workspace2
=
(
0
,
0
)
output
=
(
M
,
K
)
return
(
workspace1
,
workspace2
,
output
,
a
.
dtype
)
def
_get_tile_tokens_dim
(
self
,
x
:
torch
.
Tensor
,
top_k
:
int
,
local_num_experts
:
int
):
# Number of tokens in the input tensor.
num_tokens
=
x
.
shape
[
0
]
# Factor to account for the imbalance of the experts.
# factor equals to the
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
# 1.0 means perfect expert distribution.
# > 1.0 means some experts have more tokens than the perfect
# distribution.
# < 1.0 does not make sense.
imbalance_factor
=
1.3
# Calculate the number of tokens per expert assuming perfect
# distribution.
num_tokens_per_expert
=
(
num_tokens
*
top_k
)
//
local_num_experts
# Apply the imbalance factor.
num_tokens_per_expert
=
int
(
num_tokens_per_expert
*
imbalance_factor
)
# And pad the number to the next power of 2.
tile_tokens_dim
=
next_power_of_2
(
num_tokens_per_expert
)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the
# kernel.
tile_tokens_dim
=
min
(
max
(
tile_tokens_dim
,
8
),
64
)
return
tile_tokens_dim
def
apply
(
self
,
output
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
,
global_num_experts
:
int
,
expert_map
:
Optional
[
torch
.
Tensor
],
w1_scale
:
Optional
[
torch
.
Tensor
],
w2_scale
:
Optional
[
torch
.
Tensor
],
w1_zp
:
Optional
[
torch
.
Tensor
],
w2_zp
:
Optional
[
torch
.
Tensor
],
a1q_scale
:
Optional
[
torch
.
Tensor
],
a2_scale
:
Optional
[
torch
.
Tensor
],
workspace13
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
expert_tokens_meta
:
Optional
[
mk
.
ExpertTokensMetadata
],
apply_router_weight_on_input
:
bool
,
):
topk
=
topk_ids
.
size
(
-
1
)
local_num_experts
=
w1
.
size
(
0
)
intermediate_size
=
w2
.
size
(
1
)
local_expert_offset
=
self
.
moe
.
ep_rank
*
local_num_experts
x_quant
=
hidden_states
x_scale
=
a1q_scale
if
x_scale
is
not
None
:
x_scale
=
x_scale
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
*
x_quant
.
shape
[:
-
1
],
-
1
)
packed_tensor
=
(
topk_ids
.
to
(
torch
.
int32
)
<<
16
)
|
topk_weights
.
to
(
torch
.
bfloat16
).
view
(
torch
.
int16
)
assert
w1_scale
is
not
None
assert
w2_scale
is
not
None
kwargs
=
{
"topk_ids"
:
packed_tensor
,
"routing_bias"
:
None
,
"hidden_states"
:
x_quant
,
"hidden_states_scale"
:
x_scale
,
"gemm1_weights"
:
w1
,
"gemm1_weights_scale"
:
w1_scale
,
"gemm1_bias"
:
self
.
w13_bias
,
"gemm1_alpha"
:
self
.
gemm1_alpha
,
"gemm1_beta"
:
self
.
gemm1_beta
,
"gemm1_clamp_limit"
:
self
.
gemm1_clamp_limit
,
"gemm2_weights"
:
w2
,
"gemm2_weights_scale"
:
w2_scale
,
"gemm2_bias"
:
self
.
w2_bias
,
"output1_scale_scalar"
:
None
,
"output1_scale_gate_scalar"
:
None
,
"output2_scale_scalar"
:
None
,
"num_experts"
:
global_num_experts
,
"top_k"
:
topk
,
"n_group"
:
None
,
"topk_group"
:
None
,
"intermediate_size"
:
intermediate_size
,
"local_expert_offset"
:
local_expert_offset
,
"local_num_experts"
:
local_num_experts
,
"routed_scaling_factor"
:
None
,
"tile_tokens_dim"
:
self
.
_get_tile_tokens_dim
(
x_quant
,
topk
,
local_num_experts
),
"routing_method_type"
:
1
,
"do_finalize"
:
True
,
"output"
:
output
,
"tune_max_num_tokens"
:
self
.
max_capture_size
,
}
from
flashinfer
import
trtllm_fp4_block_scale_routed_moe
trtllm_fp4_block_scale_routed_moe
(
**
kwargs
)
return
output
vllm/model_executor/layers/fused_moe/utils.py
View file @
082cc07e
...
...
@@ -12,6 +12,8 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8
,
per_token_quant_int8
)
from
vllm.model_executor.layers.quantization.utils.mxfp4_utils
import
(
quant_dequant_mxfp4
)
from
vllm.model_executor.layers.quantization.utils.mxfp8_utils
import
(
mxfp8_quantize
)
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils
import
cdiv
...
...
@@ -177,6 +179,18 @@ def _mxfp4_quantize(
return
A
,
None
def
_mxfp8_quantize
(
A
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
per_act_token_quant
:
bool
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
A_scale
is
None
assert
not
per_act_token_quant
assert
block_shape
is
None
return
mxfp8_quantize
(
A
)
def
moe_kernel_quantize_input
(
A
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
...
...
@@ -195,6 +209,8 @@ def moe_kernel_quantize_input(
is_sf_swizzled_layout
=
is_fp4_scale_swizzled
)
elif
quant_dtype
==
"mxfp4"
:
return
_mxfp4_quantize
(
A
,
A_scale
,
per_act_token_quant
,
block_shape
)
elif
quant_dtype
==
"mxfp8"
:
return
_mxfp8_quantize
(
A
,
A_scale
,
per_act_token_quant
,
block_shape
)
else
:
return
A
,
A_scale
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
082cc07e
...
...
@@ -322,6 +322,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
self
,
prepare_finalize
:
mk
.
FusedMoEPrepareAndFinalize
,
moe
:
FusedMoEConfig
,
layer
:
torch
.
nn
.
Module
,
)
->
mk
.
FusedMoEPermuteExpertsUnpermute
:
"""Return the appropriate GEMM experts implementation."""
experts
=
select_nvfp4_gemm_impl
(
...
...
@@ -719,10 +720,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
dtype
=
torch
.
int64
)
def
select_gemm_impl
(
self
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
moe
:
FusedMoEConfig
,
)
->
FusedMoEPermuteExpertsUnpermute
:
self
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
moe
:
FusedMoEConfig
,
layer
:
torch
.
nn
.
Module
)
->
FusedMoEPermuteExpertsUnpermute
:
# cutlass path
if
self
.
use_cutlass
:
from
vllm.model_executor.layers.fused_moe
import
(
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
082cc07e
...
...
@@ -897,6 +897,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
moe
:
FusedMoEConfig
,
layer
:
torch
.
nn
.
Module
,
)
->
FusedMoEPermuteExpertsUnpermute
:
from
vllm.model_executor.layers.fused_moe
import
(
BatchedTritonOrDeepGemmExperts
,
TritonOrDeepGemmExperts
)
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
082cc07e
...
...
@@ -311,6 +311,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
self
,
prepare_finalize
:
mk
.
FusedMoEPrepareAndFinalize
,
moe
:
FusedMoEConfig
,
layer
:
torch
.
nn
.
Module
,
)
->
mk
.
FusedMoEPermuteExpertsUnpermute
:
experts
=
select_cutlass_fp8_gemm_impl
(
moe
,
...
...
@@ -1032,6 +1033,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
self
,
prepare_finalize
:
mk
.
FusedMoEPrepareAndFinalize
,
moe
:
FusedMoEConfig
,
layer
:
torch
.
nn
.
Module
,
)
->
mk
.
FusedMoEPermuteExpertsUnpermute
:
experts
=
select_nvfp4_gemm_impl
(
moe
,
...
...
vllm/model_executor/layers/quantization/mxfp4.py
View file @
082cc07e
...
...
@@ -10,6 +10,8 @@ from vllm.config import get_current_vllm_config
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEConfig
,
FusedMoEMethodBase
)
from
vllm.model_executor.layers.fused_moe
import
modular_kernel
as
mk
from
vllm.model_executor.layers.fused_moe.trtllm_moe
import
TrtLlmGenExperts
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
...
...
@@ -445,6 +447,91 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
return
tile_tokens_dim
def
select_gemm_impl
(
self
,
prepare_finalize
:
mk
.
FusedMoEPrepareAndFinalize
,
moe
:
FusedMoEConfig
,
layer
:
torch
.
nn
.
Module
,
)
->
mk
.
FusedMoEPermuteExpertsUnpermute
:
if
(
prepare_finalize
.
activation_format
==
mk
.
FusedMoEActivationFormat
.
BatchedExperts
):
raise
NotImplementedError
(
"Mxfp4 does not support batched experts format for EP"
)
else
:
if
should_use_flashinfer_mxfp4
():
# B200 code-path
kwargs
=
{
"gemm1_alpha"
:
layer
.
gemm1_alpha
,
"gemm1_beta"
:
layer
.
gemm1_beta
,
"gemm1_clamp_limit"
:
layer
.
gemm1_clamp_limit
,
"w13_bias"
:
layer
.
w13_bias
,
"w2_bias"
:
layer
.
w2_bias
,
"max_capture_size"
:
self
.
max_capture_size
,
}
return
TrtLlmGenExperts
(
moe
,
**
kwargs
)
else
:
# Use matmul_ogs from triton_kernels here!
raise
NotImplementedError
(
"Mxfp4 does not support non-batched experts format for EP"
)
def
_route_and_experts
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
assert
isinstance
(
self
.
fused_experts
,
mk
.
FusedMoEModularKernel
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
,
indices_type
=
self
.
topk_indices_dtype
,
enable_eplb
=
enable_eplb
,
expert_map
=
expert_map
,
expert_load_view
=
expert_load_view
,
logical_to_physical_map
=
logical_to_physical_map
,
logical_replica_count
=
logical_replica_count
)
return
self
.
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -503,6 +590,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
activation
=
activation
,
expert_map
=
expert_map
)
if
self
.
fused_experts
is
not
None
:
return
self
.
_route_and_experts
(
layer
,
x
,
router_logits
,
top_k
,
renormalize
,
use_grouped_topk
,
topk_group
,
num_expert_group
,
global_num_experts
,
expert_map
,
custom_routing_function
,
scoring_func
,
e_score_correction_bias
,
apply_router_weight_on_input
,
activation
,
enable_eplb
,
expert_load_view
,
logical_to_physical_map
,
logical_replica_count
,
)
assert
_can_support_mxfp4
(
use_grouped_topk
,
topk_group
,
num_expert_group
,
expert_map
,
custom_routing_function
,
e_score_correction_bias
,
...
...
vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
View file @
082cc07e
...
...
@@ -66,11 +66,10 @@ def _can_support_mxfp4(use_grouped_topk: bool = False,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
):
return
not
(
use_grouped_topk
or
topk_group
or
num_expert_group
or
expert_map
or
custom_routing_function
or
e_score_correction_bias
or
apply_router_weight_on_input
or
scoring_func
!=
"softmax"
or
activation
!=
"swigluoai"
or
expert_load_view
or
logical_to_physical_map
or
logical_replica_count
)
or
custom_routing_function
or
e_score_correction_bias
or
apply_router_weight_on_input
or
scoring_func
!=
"softmax"
or
activation
!=
"swigluoai"
or
expert_load_view
or
logical_to_physical_map
or
logical_replica_count
)
def
_dequant_mxfp4
(
x
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
0 → 100644
View file @
082cc07e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
def
mxfp8_quantize
(
x
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
try
:
from
flashinfer
import
mxfp8_quantize
except
ImportError
as
err
:
raise
ImportError
(
"The package `flashinfer` is required to do "
"MX-FP8 quantization. Please install it with"
\
"`pip install flashinfer`"
)
from
err
return
mxfp8_quantize
(
x
,
is_sf_swizzled_layout
=
False
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment