Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3c283de3
Commit
3c283de3
authored
Feb 02, 2026
by
SAC_fanth
Browse files
fuse_moe_fp8接入marlin算子
parent
a3fb334b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
497 additions
and
2 deletions
+497
-2
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
...ation/compressed_tensors/compressed_tensors_moe_marlin.py
+497
-2
No files found.
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
View file @
3c283de3
...
...
@@ -29,6 +29,7 @@ from vllm.utils import round_up
try
:
from
lightop
import
m_grouped_w8a8_gemm_nt_masked
,
fuse_silu_mul_quant_ep
,
fuse_silu_mul_quant
from
lmslim.layers.fused_moe.fuse_moe_int8_marlin
import
fused_experts_impl_int8_marlin
from
lmslim.layers.fused_moe.fuse_moe_fp8_marlin
import
fused_experts_impl_fp8_marlin
from
lightop
import
m_grouped_w8a8_gemm_nt_contig_asm
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer the quantitative model of moe.
\n
"
)
...
...
@@ -37,8 +38,27 @@ logger = init_logger(__name__)
__all__
=
[
"CompressedTensorsW8A8Int8MarlinMoEMethod"
,
"CompressedTensorsW8A8FP8MarlinMoEMethod"
,
]
def
fp32_to_fp8_e4m3fn
(
t
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""更合理的FP32到Float8_e4m3fn转换,使用最近值而不是简单舍弃尾数"""
# torch.float8_e4m3fn的数值范围约[-448, 448]
fp8_min
,
fp8_max
=
-
448.0
,
448.0
t_clamped
=
t
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
)
# 保证不会下溢到0
# 转换前到float16再转fp8可能提升精度(float8实现本身通常通过float16做rounding)
t_fp16
=
t_clamped
.
to
(
torch
.
float16
)
return
t_fp16
.
to
(
torch
.
float8_e4m3fn
)
def
w8a8_fp8_nt_kpack2_marlin_weight
(
w8a8_w
,
# [size_n, size_k// 2 ]
k_tile
=
16
,
n_tile
=
16
,
):
size_n
,
size_k
=
w8a8_w
.
shape
assert
size_n
%
k_tile
==
0
and
size_k
%
n_tile
==
0
,
"k_tile / n_tile 必须能整除对应维度"
w8a8_w
=
w8a8_w
.
reshape
((
size_n
//
n_tile
,
n_tile
,
size_k
//
k_tile
,
k_tile
))
w8a8_w
=
w8a8_w
.
permute
((
0
,
2
,
1
,
3
)).
contiguous
()
w8a8_w
=
w8a8_w
.
reshape
((
size_n
//
k_tile
,
size_k
*
k_tile
))
return
w8a8_w
class
CompressedTensorsMarlinMoEMethod
(
FusedMoEMethodBase
):
@
staticmethod
...
...
@@ -46,17 +66,492 @@ class CompressedTensorsMarlinMoEMethod(FusedMoEMethodBase):
quant_config
:
"SlimQuantCompressedTensorsMarlinConfig"
,
# type: ignore # noqa E501
layer
:
torch
.
nn
.
Module
,
)
->
"CompressedTensorsMarlinMoEMethod"
:
# are supported + check if the layer is being ignored.
weight_quant
=
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
input_quant
=
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"input_activations"
)
if
quant_config
.
_is_dynamic_token_w8a8
(
weight_quant
,
input_quant
):
if
quant_config
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8FP8MarlinMoEMethod
(
quant_config
)
elif
quant_config
.
_is_dynamic_token_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Int8MarlinMoEMethod
(
quant_config
)
else
:
raise
RuntimeError
(
f
"Slimquant_marlin does not support the FusedMoe scheme:
{
weight_quant
}
,
{
input_quant
}
"
)
class
CompressedTensorsW8A8FP8MarlinMoEMethod
(
CompressedTensorsMarlinMoEMethod
):
def
__init__
(
self
,
quant_config
:
"CompressedTensorsMarlinConfig"
# type: ignore # noqa E501
):
self
.
quant_config
=
quant_config
self
.
weight_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
self
.
input_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"input_activations"
)
per_channel
=
(
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
and
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
)
if
not
per_channel
:
raise
ValueError
(
"For FP8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found "
f
"
{
self
.
weight_quant
}
,
{
self
.
input_quant
}
"
)
self
.
static_input_scales
=
not
self
.
input_quant
.
dynamic
if
self
.
static_input_scales
:
raise
ValueError
(
"For FP8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales."
)
self
.
fused_experts
=
self
.
fused_moe_forward
vllm_config
=
get_current_vllm_config
()
parallel_config
=
vllm_config
.
parallel_config
self
.
dp_size
=
get_dp_group
().
world_size
self
.
ep_size
=
get_ep_group
().
world_size
self
.
use_deepep
=
self
.
dp_size
>
1
and
parallel_config
.
enable_expert_parallel
and
\
(
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_high_throughput"
or
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
or
\
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_auto"
)
self
.
use_deepgemm
=
False
if
self
.
use_deepep
:
all2all_manager
=
get_ep_group
().
device_communicator
.
all2all_manager
assert
all2all_manager
is
not
None
self
.
num_dispatchers
=
all2all_manager
.
world_size
self
.
block_shape
=
[
256
,
256
]
self
.
use_deepgemm
=
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
or
envs
.
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
or
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_auto"
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
if
self
.
use_deepep
:
self
.
N
=
2
*
intermediate_size_per_partition
self
.
K
=
hidden_size
params_dtype
=
torch
.
float8_e4m3fn
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# WEIGHT_SCALES
assert
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
intermediate_size_per_partition
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
hidden_size
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
})
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
# INPUT_SCALES
assert
not
self
.
static_input_scales
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
w1_marlin_list
=
[]
for
ii
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
if
not
self
.
use_deepgemm
:
w1_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w13_weight
[
ii
])
else
:
w1_marlin_in
=
w8a8_fp8_nt_kpack2_marlin_weight
(
layer
.
w13_weight
[
ii
])
w1_marlin_list
.
append
(
w1_marlin_in
.
float
()
if
w1_marlin_in
.
dtype
==
torch
.
float8_e4m3fn
else
w1_marlin_in
)
w1_marlin
=
torch
.
stack
(
w1_marlin_list
,
dim
=
0
)
w1_marlin
=
fp32_to_fp8_e4m3fn
(
w1_marlin
)
del
w1_marlin_list
w2_marlin_list
=
[]
for
ii
in
range
(
layer
.
w2_weight
.
shape
[
0
]):
if
not
self
.
use_deepgemm
:
w2_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w2_weight
[
ii
])
else
:
w2_marlin_in
=
w8a8_fp8_nt_kpack2_marlin_weight
(
layer
.
w2_weight
[
ii
])
w2_marlin_list
.
append
(
w2_marlin_in
.
float
()
if
w2_marlin_in
.
dtype
==
torch
.
float8_e4m3fn
else
w2_marlin_in
)
w2_marlin
=
torch
.
stack
(
w2_marlin_list
,
dim
=
0
)
w2_marlin
=
fp32_to_fp8_e4m3fn
(
w2_marlin
)
layer
.
w13_weight
=
Parameter
(
w1_marlin
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
w2_marlin
,
requires_grad
=
False
)
def
masked_groupgemm_workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
):
assert
a
.
dim
()
==
2
# FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks
# end up sending their tokens. This needs to be fixed.
num_dispatchers
=
self
.
num_dispatchers
num_experts
=
local_num_experts
max_num_tokens
=
a
.
size
(
0
)
if
self
.
max_num_tokens_per_rank
is
None
else
self
.
max_num_tokens_per_rank
workspace13
=
(
num_experts
,
max_num_tokens
*
num_dispatchers
,
max
(
K
,
N
))
workspace2
=
(
num_experts
,
max_num_tokens
*
num_dispatchers
,
(
N
//
2
))
output
=
(
num_experts
,
max_num_tokens
*
num_dispatchers
,
K
)
return
(
workspace13
,
workspace2
,
output
,
a
.
dtype
)
def
contiguous_groupgemm_workspace_shapes
(
self
,
a
:
torch
.
Tensor
,
aq
:
torch
.
Tensor
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_num_tokens_cpu
:
torch
.
Tensor
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
assert
self
.
block_shape
is
not
None
# We use global_num_experts due to how moe_align_block_size handles
# expert_maps.
block_m
=
self
.
block_shape
[
0
]
M_sum
=
compute_aligned_M
(
M
,
topk
,
local_num_experts
,
block_m
,
expert_num_tokens_cpu
)
assert
M_sum
%
block_m
==
0
workspace1
=
(
M_sum
,
max
(
N
,
K
))
workspace2
=
(
M_sum
,
max
(
N
//
2
,
K
))
output
=
(
M
,
K
)
return
(
workspace1
,
workspace2
,
output
,
a
.
dtype
,
M_sum
)
def
w8a8_groupgemm_masked_forward
(
self
,
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
q_x
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_num_tokens_cpu
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
):
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
local_num_experts
=
w1
.
size
(
0
)
E
,
max_num_tokens
,
_
,
_
,
top_k
=
mk
.
_moe_problem_size
(
q_x
,
w1
,
w2
,
topk_ids
)
N
,
K
=
self
.
N
,
self
.
K
(
workspace13_shape
,
workspace2_shape
,
fused_out_shape
,
workspace_dtype
)
=
self
.
masked_groupgemm_workspace_shapes
(
x
,
q_x
,
max_num_tokens
,
N
,
K
,
top_k
,
global_num_experts
,
local_num_experts
)
workspace13
=
torch
.
empty
(
prod
(
workspace13_shape
),
device
=
x
.
device
,
dtype
=
workspace_dtype
)
workspace1
=
_resize_cache
(
workspace13
,
(
E
,
max_num_tokens
,
N
))
fused_out
=
_resize_cache
(
workspace13
,
fused_out_shape
)
# (from deepgemm docs) : A value hint (which is a value on CPU)
# for the M expectation of each batch, correctly setting this value
# may lead to better performance.
# expected_m = max_num_tokens
ori_bs
=
x
.
shape
[
0
]
expected_m
=
ori_bs
*
self
.
ep_size
# expected_m = (
# x.shape[0] * self.dp_size * topk_ids.shape[1]
# + global_num_experts
# ) // global_num_experts
m_grouped_w8a8_gemm_nt_masked
((
q_x
,
a1_scale
),
(
w1
,
w1_scale
),
workspace1
,
expert_num_tokens
,
expected_m
,
)
assert
expert_num_tokens
is
not
None
a2q
,
a2q_scale
=
fuse_silu_mul_quant_ep
(
workspace1
,
expert_num_tokens
)
m_grouped_w8a8_gemm_nt_masked
((
a2q
,
a2q_scale
),
(
w2
,
w2_scale
),
fused_out
,
expert_num_tokens
,
expected_m
)
return
fused_out
def
w8a8_groupgemm_contiguous_forward
(
self
,
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
q_x
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_num_tokens_cpu
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
):
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
local_num_experts
=
w1
.
size
(
0
)
a1q
=
q_x
N
,
K
=
self
.
N
,
self
.
K
(
workspace13_shape
,
workspace2_shape
,
fused_out_shape
,
workspace_dtype
,
M_sum
)
=
self
.
contiguous_groupgemm_workspace_shapes
(
x
,
q_x
,
topk_ids
.
size
(
0
),
N
,
K
,
topk_ids
.
size
(
1
),
global_num_experts
,
local_num_experts
,
expert_num_tokens_cpu
)
workspace13
=
torch
.
empty
(
prod
(
workspace13_shape
),
device
=
x
.
device
,
dtype
=
workspace_dtype
)
workspace2
=
torch
.
empty
(
prod
(
workspace2_shape
),
device
=
x
.
device
,
dtype
=
workspace_dtype
)
mm1_out
=
_resize_cache
(
workspace13
,
(
M_sum
,
N
))
mm2_out
=
_resize_cache
(
workspace2
,
(
M_sum
,
K
))
act_out
=
_resize_cache
(
workspace2
,
(
M_sum
,
N
//
2
))
quant_out
=
_resize_cache
(
workspace13
.
view
(
dtype
=
a1q
.
dtype
),
(
M_sum
,
N
//
2
)
)
fused_out
=
_resize_cache
(
workspace13
,
fused_out_shape
)
a1q_perm
=
_resize_cache
(
workspace2
.
view
(
dtype
=
a1q
.
dtype
),
(
M_sum
,
K
))
a1q
,
a1q_scale
,
expert_ids
,
inv_perm
=
deepgemm_moe_permute
(
aq
=
a1q
,
aq_scale
=
a1_scale
,
topk_ids
=
topk_ids
,
local_num_experts
=
local_num_experts
,
expert_map
=
expert_map
,
block_shape
=
self
.
block_shape
,
expert_num_tokens
=
expert_num_tokens
,
expert_num_tokens_cpu
=
expert_num_tokens_cpu
,
aq_out
=
a1q_perm
,
M_sum
=
M_sum
)
m_grouped_w8a8_gemm_nt_contig_asm
(
(
a1q
,
a1q_scale
),
(
w1
,
w1_scale
),
mm1_out
,
expert_ids
)
a2q
,
a2q_scale
=
fuse_silu_mul_quant
(
mm1_out
)
# a2q, a2q_scale = fuse_silu_mul_quant(input=mm1_out, expert_ids=expert_ids)
m_grouped_w8a8_gemm_nt_contig_asm
(
(
a2q
,
a2q_scale
),
(
w2
,
w2_scale
),
mm2_out
,
expert_ids
)
if
apply_router_weight_on_input
:
topk_weights
=
torch
.
ones_like
(
topk_weights
)
deepgemm_unpermute_and_reduce
(
a
=
mm2_out
,
topk_ids
=
topk_ids
,
topk_weights
=
topk_weights
,
inv_perm
=
inv_perm
,
expert_map
=
expert_map
,
output
=
fused_out
,
)
return
fused_out
def
fused_moe_forward
(
self
,
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
i_q
:
Optional
[
torch
.
Tensor
]
=
None
,
i_s
:
Optional
[
torch
.
Tensor
]
=
None
,
q_x
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_num_tokens_cpu
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
):
return
fused_experts_impl_fp8_marlin
(
hidden_states
=
x
if
q_x
is
None
else
q_x
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
True
,
per_channel_quant
=
True
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
use_nn_moe
=
False
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
,
i_q
=
i_q
,
i_s
=
i_s
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
i_q
:
Optional
[
torch
.
Tensor
]
=
None
,
i_s
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
)
->
torch
.
Tensor
:
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for "
"`CompressedTensorsW8A8FP8MoEMethod` yet."
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
routed_scaling_factor
=
routed_scaling_factor
,
use_fused_gate
=
use_fused_gate
,
e_score_correction_bias
=
e_score_correction_bias
,
indices_type
=
torch
.
int64
if
self
.
use_deepep
else
None
,
)
return
self
.
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
(
layer
.
w13_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale
),
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
False
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
,
i_q
=
i_q
,
i_s
=
i_s
)
def
select_gemm_impl
(
self
,
prepare_finalize
:
FusedMoEPrepareAndFinalize
,
moe
:
FusedMoEConfig
,
)
->
FusedMoEPermuteExpertsUnpermute
:
from
vllm.model_executor.layers.fused_moe
import
(
TritonOrGroupGemmExperts
)
if
(
prepare_finalize
.
activation_format
==
FusedMoEActivationFormat
.
BatchedExperts
):
max_num_tokens_per_rank
=
(
prepare_finalize
.
max_num_tokens_per_rank
())
assert
max_num_tokens_per_rank
is
not
None
self
.
max_num_tokens_per_rank
=
max_num_tokens_per_rank
logger
.
debug
(
"TritonOrGroupGemmExperts(%s): "
"max_tokens_per_rank=%s, block_size=%s, per_act_token=%s"
,
self
.
__class__
.
__name__
,
max_num_tokens_per_rank
,
None
,
True
)
return
TritonOrGroupGemmExperts
(
use_fp8_w8a8
=
True
,
per_act_token_quant
=
True
,
fused_experts
=
self
.
w8a8_groupgemm_masked_forward
)
else
:
logger
.
debug
(
"TritonOrGroupGemmExperts(%s): block_size=%s, per_act_token=%s"
,
self
.
__class__
.
__name__
,
None
,
False
)
return
TritonOrGroupGemmExperts
(
use_fp8_w8a8
=
True
if
envs
.
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
else
False
,
per_act_token_quant
=
True
if
envs
.
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
else
False
,
fused_experts
=
self
.
w8a8_groupgemm_contiguous_forward
if
envs
.
VLLM_ENABLE_DEEPEP_HT_DEEPGEMM
else
self
.
fused_moe_forward
)
class
CompressedTensorsW8A8Int8MarlinMoEMethod
(
CompressedTensorsMarlinMoEMethod
):
def
__init__
(
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment