Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f9a784a7
Commit
f9a784a7
authored
Apr 23, 2025
by
yangql
Browse files
更新curr_topk_ids
parent
cd87548a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
14 deletions
+7
-14
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+7
-14
No files found.
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
f9a784a7
...
...
@@ -384,26 +384,17 @@ def fused_moe_kernel_gptq_awq(
offs_token
=
tl
.
load
(
sorted_token_ids_ptr
+
offs_token_id
)
token_mask
=
offs_token
<
num_valid_tokens
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
).
to
(
tl
.
int64
)
if
off_experts
==
-
1
:
# -----------------------------------------------------------
# Write back zeros to the output when the expert is not
# in the current expert parallel rank.
write_zeros_to_output
(
c_ptr
,
stride_cm
,
stride_cn
,
pid_n
,
N
,
offs_token
,
token_mask
,
BLOCK_SIZE_M
,
BLOCK_SIZE_N
,
compute_type
)
return
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
).
to
(
tl
.
int64
))
%
N
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
).
to
(
tl
.
int64
)
if
use_int4_w4a16
:
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
\
(
offs_k
[:,
None
]
//
2
)
*
stride_bk
+
offs_bn
[
None
,
:]
*
\
stride_bn
(
offs_k
[:,
None
]
//
2
)
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
b_shifter
=
(
offs_k
[:,
None
]
%
2
)
*
4
elif
use_int8_w8a16
:
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
\
...
...
@@ -443,8 +434,7 @@ def fused_moe_kernel_gptq_awq(
b_scale_ptrs
=
b_scale_ptr
+
off_experts
*
stride_bse
+
\
offs_bn
[
None
,
:]
*
stride_bsn
+
\
((
offs_k
[:,
None
]
+
BLOCK_SIZE_K
*
k
)
//
group_size
)
*
\
stride_bsk
((
offs_k
[:,
None
]
+
BLOCK_SIZE_K
*
k
)
//
group_size
)
*
stride_bsk
b_scale
=
tl
.
load
(
b_scale_ptrs
,
mask
=
k_mask
,
other
=
k_other
)
b_scale
=
b_scale
.
to
(
tl
.
float32
)
...
...
@@ -716,6 +706,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
B_scale
:
Optional
[
torch
.
Tensor
],
B_zp
:
Optional
[
torch
.
Tensor
],
topk_weights
:
Optional
[
torch
.
Tensor
],
topk_ids
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
...
...
@@ -1709,6 +1700,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
w1_scale
,
w1_zp
,
curr_topk_weights
,
curr_topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
...
...
@@ -1769,6 +1761,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
w2_scale
,
w2_zp
,
curr_topk_weights
,
curr_topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment