Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8d4ed42a
Unverified
Commit
8d4ed42a
authored
Sep 24, 2024
by
Ke Bao
Committed by
GitHub
Sep 24, 2024
Browse files
MoE torch compile (#1497)
parent
2854a5ea
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
126 additions
and
5 deletions
+126
-5
python/sglang/srt/layers/fused_moe/patch.py
python/sglang/srt/layers/fused_moe/patch.py
+117
-0
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+9
-5
No files found.
python/sglang/srt/layers/fused_moe/patch.py
0 → 100644
View file @
8d4ed42a
from
typing
import
Optional
import
torch
from
torch.nn
import
functional
as
F
def
fused_topk_native
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
):
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
M
,
_
=
hidden_states
.
shape
topk_weights
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
topk_ids
=
torch
.
empty
(
M
,
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
topk_weights
=
F
.
softmax
(
gating_output
.
float
(),
dim
=-
1
)
topk_weights
,
topk_ids
=
torch
.
topk
(
topk_weights
,
topk
,
dim
=-
1
)
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
return
topk_weights
,
topk_ids
# This is used by the Deepseek-V2 model
def
grouped_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
num_expert_group
:
int
=
0
,
topk_group
:
int
=
0
,
):
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
"Number of tokens mismatch"
scores
=
torch
.
softmax
(
gating_output
,
dim
=-
1
)
num_token
=
scores
.
shape
[
0
]
group_scores
=
(
scores
.
view
(
num_token
,
num_expert_group
,
-
1
).
max
(
dim
=-
1
).
values
)
# [n, n_group]
group_idx
=
torch
.
topk
(
group_scores
,
k
=
topk_group
,
dim
=-
1
,
sorted
=
False
)[
1
]
# [n, top_k_group]
group_mask
=
torch
.
zeros_like
(
group_scores
)
# [n, n_group]
group_mask
.
scatter_
(
1
,
group_idx
,
1
)
# [n, n_group]
score_mask
=
(
group_mask
.
unsqueeze
(
-
1
)
.
expand
(
num_token
,
num_expert_group
,
scores
.
shape
[
-
1
]
//
num_expert_group
)
.
reshape
(
num_token
,
-
1
)
)
# [n, e]
tmp_scores
=
scores
.
masked_fill
(
~
score_mask
.
bool
(),
0.0
)
# [n, e]
topk_weights
,
topk_ids
=
torch
.
topk
(
tmp_scores
,
k
=
topk
,
dim
=-
1
,
sorted
=
False
)
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
return
topk_weights
,
topk_ids
def
select_experts_native
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
use_grouped_topk
:
bool
,
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
):
# DeekSeekv2 uses grouped_top_k
if
use_grouped_topk
:
assert
topk_group
is
not
None
assert
num_expert_group
is
not
None
topk_weights
,
topk_ids
=
grouped_topk
(
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
topk
=
top_k
,
renormalize
=
renormalize
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
,
)
else
:
topk_weights
,
topk_ids
=
fused_topk_native
(
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
topk
=
top_k
,
renormalize
=
renormalize
,
)
return
topk_weights
,
topk_ids
def
fused_moe_forward_native
(
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
use_grouped_topk
:
bool
,
top_k
:
int
,
router_logits
:
torch
.
Tensor
,
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
topk_weights
,
topk_ids
=
select_experts_native
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
)
w13_weights
=
layer
.
w13_weight
[
topk_ids
]
w1_weights
,
w3_weights
=
torch
.
chunk
(
w13_weights
,
2
,
dim
=
2
)
w2_weights
=
layer
.
w2_weight
[
topk_ids
]
x1
=
F
.
silu
(
torch
.
einsum
(
"ti,taoi -> tao"
,
x
,
w1_weights
))
x3
=
torch
.
einsum
(
"ti, taoi -> tao"
,
x
,
w3_weights
)
expert_outs
=
torch
.
einsum
(
"tao, taio -> tai"
,
(
x1
*
x3
),
w2_weights
)
return
torch
.
einsum
(
"tai,ta -> ti"
,
expert_outs
,
topk_weights
)
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
8d4ed42a
...
...
@@ -25,6 +25,7 @@ import torch
from
vllm.distributed.parallel_state
import
graph_capture
from
vllm.model_executor.custom_op
import
CustomOp
from
sglang.srt.layers.fused_moe.patch
import
fused_moe_forward_native
from
sglang.srt.layers.logits_processor
import
(
LogitsMetadata
,
LogitsProcessor
,
...
...
@@ -41,14 +42,15 @@ if TYPE_CHECKING:
def
_to_torch
(
model
:
torch
.
nn
.
Module
,
reverse
:
bool
=
False
):
for
sub
in
model
.
_modules
.
values
():
if
isinstance
(
sub
,
CustomOp
):
# NOTE: FusedMoE torch native implementaiton is not efficient
if
"FusedMoE"
in
sub
.
__class__
.
__name__
:
continue
if
reverse
:
sub
.
_forward_method
=
sub
.
forward_cuda
setattr
(
sub
,
"is_torch_compile"
,
False
)
else
:
sub
.
_forward_method
=
sub
.
forward_native
# NOTE: Temporarily workaround MoE
if
"FusedMoE"
in
sub
.
__class__
.
__name__
:
sub
.
_forward_method
=
fused_moe_forward_native
else
:
sub
.
_forward_method
=
sub
.
forward_native
setattr
(
sub
,
"is_torch_compile"
,
True
)
if
isinstance
(
sub
,
torch
.
nn
.
Module
):
_to_torch
(
sub
,
reverse
)
...
...
@@ -67,7 +69,9 @@ def patch_model(
monkey_patch_vllm_all_gather
()
backup_ca_comm
=
tp_group
.
ca_comm
tp_group
.
ca_comm
=
None
yield
torch
.
compile
(
model
.
forward
,
mode
=
"max-autotune-no-cudagraphs"
)
yield
torch
.
compile
(
torch
.
no_grad
()(
model
.
forward
),
mode
=
"max-autotune-no-cudagraphs"
)
else
:
yield
model
.
forward
finally
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment