Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
25e8b412
Commit
25e8b412
authored
Dec 23, 2025
by
yangql
Browse files
适配gptq/awq的triton moe算子
parent
10349d37
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
109 additions
and
76 deletions
+109
-76
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+108
-76
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+1
-0
No files found.
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
25e8b412
...
...
@@ -414,7 +414,8 @@ def fused_moe_kernel_gptq_awq(
compute_type
:
tl
.
constexpr
,
has_zp
:
tl
.
constexpr
,
use_int4_w4a16
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
):
use_int8_w8a16
:
tl
.
constexpr
,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
...
...
@@ -463,26 +464,50 @@ def fused_moe_kernel_gptq_awq(
num_tokens_post_padded
=
tl
.
load
(
num_tokens_post_padded_ptr
)
if
pid_m
*
BLOCK_SIZE_M
>=
num_tokens_post_padded
:
return
offs_token_id
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
).
to
(
tl
.
int64
)
offs_token_id
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
).
to
(
tl
.
int64
)
offs_token
=
tl
.
load
(
sorted_token_ids_ptr
+
offs_token_id
)
token_mask
=
offs_token
<
num_valid_tokens
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
).
to
(
tl
.
int64
))
%
N
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
).
to
(
tl
.
int64
)
if
off_experts
==
-
1
:
# -----------------------------------------------------------
# Write back zeros to the output when the expert is not
# in the current expert parallel rank.
write_zeros_to_output
(
c_ptr
,
stride_cm
,
stride_cn
,
pid_n
,
N
,
offs_token
,
token_mask
,
BLOCK_SIZE_M
,
BLOCK_SIZE_N
,
compute_type
,
)
return
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
)
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
).
to
(
tl
.
int64
))
%
N
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
if
use_int4_w4a16
:
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
\
(
offs_k
[:,
None
]
//
2
)
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
b_ptrs
=
(
b_ptr
+
off_experts
*
stride_be
+
(
offs_k
[:,
None
]
//
2
)
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
b_shifter
=
(
offs_k
[:,
None
]
%
2
)
*
4
elif
use_int8_w8a16
:
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
\
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
b_ptrs
=
(
b_ptr
+
off_experts
*
stride_be
+
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
if
not
has_zp
and
use_int4_w4a16
:
b_zp_num
=
8
...
...
@@ -508,33 +533,43 @@ def fused_moe_kernel_gptq_awq(
k_mask
=
None
k_other
=
None
a
=
tl
.
load
(
a_ptrs
,
mask
=
token_mask
[:,
None
]
&
(
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
),
other
=
0.0
)
a
=
tl
.
load
(
a_ptrs
,
mask
=
token_mask
[:,
None
]
&
(
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
),
other
=
0.0
,
)
b
=
tl
.
load
(
b_ptrs
)
if
use_int4_w4a16
:
b
=
(
b
>>
b_shifter
)
&
0xF
b_scale_ptrs
=
b_scale_ptr
+
off_experts
*
stride_bse
+
\
offs_bn
[
None
,
:]
*
stride_bsn
+
\
((
offs_k
[:,
None
]
+
BLOCK_SIZE_K
*
k
)
//
group_size
)
*
stride_bsk
b_scale_ptrs
=
(
b_scale_ptr
+
off_experts
*
stride_bse
+
offs_bn
[
None
,
:]
*
stride_bsn
+
((
offs_k
[:,
None
]
+
BLOCK_SIZE_K
*
k
)
//
group_size
)
*
stride_bsk
)
b_scale
=
tl
.
load
(
b_scale_ptrs
,
mask
=
k_mask
,
other
=
k_other
)
b_scale
=
b_scale
.
to
(
tl
.
float32
)
if
has_zp
and
use_int4_w4a16
:
offs_k_true
=
(
offs_k
[:,
None
]
+
BLOCK_SIZE_K
*
k
)
//
group_size
b_zp_ptrs
=
b_zp_ptr
+
off_experts
*
stride_bze
+
\
(
offs_bn
[
None
,
:]
//
2
)
*
stride_bzn
+
\
offs_k_true
*
stride_bzk
b_zp_ptrs
=
(
b_zp_ptr
+
off_experts
*
stride_bze
+
(
offs_bn
[
None
,
:]
//
2
)
*
stride_bzn
+
offs_k_true
*
stride_bzk
)
b_zp
=
tl
.
load
(
b_zp_ptrs
,
mask
=
k_mask
,
other
=
k_other
)
b_zp
=
(
(
b_zp
>>
b_zp_shifter
)
&
0xF
)
b_zp
=
(
b_zp
>>
b_zp_shifter
)
&
0xF
b_zp
=
b_zp
.
to
(
tl
.
float32
)
elif
has_zp
and
use_int8_w8a16
:
offs_k_true
=
(
offs_k
[:,
None
]
+
BLOCK_SIZE_K
*
k
)
//
group_size
b_zp_ptrs
=
b_zp_ptr
+
off_experts
*
stride_bze
+
\
offs_bn
[
None
,
:]
*
stride_bzn
+
\
offs_k_true
*
stride_bzk
b_zp_ptrs
=
(
b_zp_ptr
+
off_experts
*
stride_bze
+
offs_bn
[
None
,
:]
*
stride_bzn
+
offs_k_true
*
stride_bzk
)
b_zp
=
tl
.
load
(
b_zp_ptrs
,
mask
=
k_mask
,
other
=
k_other
)
b_zp
=
b_zp
.
to
(
tl
.
float32
)
...
...
@@ -553,17 +588,14 @@ def fused_moe_kernel_gptq_awq(
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
if
MUL_ROUTED_WEIGHT
:
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
mask
=
token_mask
,
other
=
0
)
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
mask
=
token_mask
,
other
=
0
)
accumulator
=
accumulator
*
moe_weight
[:,
None
]
accumulator
=
accumulator
.
to
(
compute_type
)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
c_ptr
+
stride_cm
*
offs_token
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_ptrs
=
c_ptr
+
stride_cm
*
offs_token
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_mask
=
token_mask
[:,
None
]
&
(
offs_cn
[
None
,
:]
<
N
)
tl
.
store
(
c_ptrs
,
accumulator
,
mask
=
c_mask
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
25e8b412
...
...
@@ -1787,6 +1787,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
**
_
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
assert
self
.
fused_experts
is
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment