Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
07ec07ad
Unverified
Commit
07ec07ad
authored
Dec 03, 2024
by
Lianmin Zheng
Committed by
GitHub
Dec 03, 2024
Browse files
Improve torch compile for fused moe (#2327)
parent
83b340e3
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
45 additions
and
24 deletions
+45
-24
benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py
...els/fused_moe_triton/benchmark_torch_compile_fused_moe.py
+5
-2
python/sglang/srt/layers/fused_moe_patch.py
python/sglang/srt/layers/fused_moe_patch.py
+20
-11
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+16
-7
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-1
test/srt/test_srt_engine.py
test/srt/test_srt_engine.py
+1
-1
test/srt/test_torch_compile_moe.py
test/srt/test_torch_compile_moe.py
+2
-2
No files found.
benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py
View file @
07ec07ad
...
...
@@ -6,6 +6,7 @@ from torch.nn import functional as F
from
transformers
import
AutoConfig
from
sglang.srt.layers.fused_moe_triton.fused_moe
import
fused_moe
as
fused_moe_triton
from
sglang.srt.model_executor.cuda_graph_runner
import
set_torch_compile_config
def
get_model_config
(
model_name
:
str
,
tp_size
:
int
):
...
...
@@ -64,7 +65,7 @@ def fused_topk_native(
return
topk_weights
,
topk_ids
@
torch
.
compile
@
torch
.
compile
(
dynamic
=
False
)
def
fused_moe_torch
(
x
,
w1
,
...
...
@@ -88,7 +89,8 @@ def fused_moe_torch(
w13_weights
=
w1
[
topk_ids
]
w1_weights
,
w3_weights
=
torch
.
chunk
(
w13_weights
,
2
,
dim
=
2
)
w2_weights
=
w2
[
topk_ids
]
x1
=
F
.
gelu
(
torch
.
einsum
(
"ti,taoi -> tao"
,
x
,
w1_weights
))
x1
=
torch
.
einsum
(
"ti,taoi -> tao"
,
x
,
w1_weights
)
x1
=
F
.
silu
(
x1
)
x3
=
torch
.
einsum
(
"ti, taoi -> tao"
,
x
,
w3_weights
)
expert_outs
=
torch
.
einsum
(
"tao, taio -> tai"
,
(
x1
*
x3
),
w2_weights
)
return
torch
.
einsum
(
"tai,ta -> ti"
,
expert_outs
,
topk_weights
.
to
(
expert_outs
.
dtype
))
...
...
@@ -174,6 +176,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
print
(
f
"benchmark
{
provider
}
with batch_size=
{
batch_size
}
"
)
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
0
)
set_torch_compile_config
()
num_tokens
=
batch_size
num_experts
=
model_config
[
"num_experts"
]
...
...
python/sglang/srt/layers/fused_moe_patch.py
View file @
07ec07ad
...
...
@@ -105,20 +105,29 @@ def fused_moe_forward_native(
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
)
->
torch
.
Tensor
:
assert
custom_routing_function
is
None
topk_weights
,
topk_ids
=
select_experts_native
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
if
use_grouped_topk
:
assert
num_expert_group
is
not
None
and
topk_group
is
not
None
topk_weights
,
topk_ids
=
grouped_topk
(
x
,
router_logits
,
top_k
,
renormalize
,
num_expert_group
,
topk_group
,
)
elif
custom_routing_function
is
None
:
topk_weights
,
topk_ids
=
fused_topk_native
(
x
,
router_logits
,
top_k
,
renormalize
)
else
:
topk_weights
,
topk_ids
=
custom_routing_function
(
x
,
router_logits
,
top_k
,
renormalize
)
w13_weights
=
layer
.
w13_weight
[
topk_ids
]
w1_weights
,
w3_weights
=
torch
.
chunk
(
w13_weights
,
2
,
dim
=
2
)
w2_weights
=
layer
.
w2_weight
[
topk_ids
]
x1
=
F
.
silu
(
torch
.
einsum
(
"ti,taoi -> tao"
,
x
,
w1_weights
))
x1
=
torch
.
einsum
(
"ti,taoi -> tao"
,
x
,
w1_weights
)
x1
=
F
.
silu
(
x1
)
x3
=
torch
.
einsum
(
"ti, taoi -> tao"
,
x
,
w3_weights
)
expert_outs
=
torch
.
einsum
(
"tao, taio -> tai"
,
(
x1
*
x3
),
w2_weights
)
return
torch
.
einsum
(
"tai,ta -> ti"
,
expert_outs
,
topk_weights
.
to
(
expert_outs
.
dtype
))
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
07ec07ad
...
...
@@ -36,7 +36,7 @@ if TYPE_CHECKING:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
def
_to_torch
(
model
:
torch
.
nn
.
Module
,
reverse
:
bool
=
False
):
def
_to_torch
(
model
:
torch
.
nn
.
Module
,
reverse
:
bool
,
batch_size
:
int
):
for
sub
in
model
.
_modules
.
values
():
if
isinstance
(
sub
,
CustomOp
):
if
reverse
:
...
...
@@ -45,24 +45,30 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
else
:
# NOTE: Temporarily workaround MoE
if
"FusedMoE"
in
sub
.
__class__
.
__name__
:
if
batch_size
==
1
:
# The performance of torch.compile on this layer is not always good when bs > 1,
# so we decide to skip it for now.
sub
.
_forward_method
=
fused_moe_forward_native
else
:
sub
.
_forward_method
=
sub
.
forward_native
setattr
(
sub
,
"is_torch_compile"
,
True
)
if
isinstance
(
sub
,
torch
.
nn
.
Module
):
_to_torch
(
sub
,
reverse
)
_to_torch
(
sub
,
reverse
,
batch_size
)
@
contextmanager
def
patch_model
(
model
:
torch
.
nn
.
Module
,
enable_compile
:
bool
,
tp_group
:
"GroupCoordinator"
model
:
torch
.
nn
.
Module
,
enable_compile
:
bool
,
batch_size
:
int
,
tp_group
:
"GroupCoordinator"
,
):
"""Patch the model to make it compatible with with torch.compile"""
backup_ca_comm
=
None
try
:
if
enable_compile
:
_to_torch
(
model
)
_to_torch
(
model
,
reverse
=
False
,
batch_size
=
batch_size
)
monkey_patch_vllm_all_gather
()
backup_ca_comm
=
tp_group
.
ca_comm
# Use custom-allreduce here.
...
...
@@ -70,13 +76,15 @@ def patch_model(
# even with ENABLE_INTRA_NODE_COMM=1.
# tp_group.ca_comm = None
yield
torch
.
compile
(
torch
.
no_grad
()(
model
.
forward
),
mode
=
"max-autotune-no-cudagraphs"
torch
.
no_grad
()(
model
.
forward
),
mode
=
"max-autotune-no-cudagraphs"
,
dynamic
=
False
,
)
else
:
yield
model
.
forward
finally
:
if
enable_compile
:
_to_torch
(
model
,
reverse
=
True
)
_to_torch
(
model
,
reverse
=
True
,
batch_size
=
batch_size
)
monkey_patch_vllm_all_gather
(
reverse
=
True
)
tp_group
.
ca_comm
=
backup_ca_comm
...
...
@@ -237,6 +245,7 @@ class CudaGraphRunner:
with
patch_model
(
self
.
model_runner
.
model
,
bs
in
self
.
compile_bs
,
bs
,
self
.
model_runner
.
tp_group
,
)
as
forward
:
(
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
07ec07ad
...
...
@@ -622,7 +622,7 @@ class ModelRunner:
tic
=
time
.
time
()
logger
.
info
(
"Capture cuda graph begin. This can take up to several minutes."
)
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
)
logger
.
info
(
f
"Capture cuda graph end. Time elapsed:
{
time
.
time
()
-
tic
:.
2
f
}
s"
)
logger
.
info
(
f
"Capture cuda graph end. Time elapsed:
{
time
.
time
()
-
tic
:.
2
f
}
s"
)
def
apply_torch_tp
(
self
):
logger
.
info
(
f
"Enabling torch tensor parallelism on
{
self
.
tp_size
}
devices."
)
...
...
test/srt/test_srt_engine.py
View file @
07ec07ad
...
...
@@ -188,7 +188,7 @@ class TestSRTEngine(unittest.TestCase):
)
bench_args
=
BenchArgs
(
num_prompts
=
10
)
result
=
throughput_test
(
server_args
=
server_args
,
bench_args
=
bench_args
)
self
.
assertGreater
(
result
[
"total_throughput"
],
3
5
00
)
self
.
assertGreater
(
result
[
"total_throughput"
],
3
0
00
)
if
__name__
==
"__main__"
:
...
...
test/srt/test_torch_compile_moe.py
View file @
07ec07ad
...
...
@@ -14,7 +14,7 @@ from sglang.test.test_utils import (
)
class
TestTorchCompile
(
unittest
.
TestCase
):
class
TestTorchCompile
Moe
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST
...
...
@@ -23,7 +23,7 @@ class TestTorchCompile(unittest.TestCase):
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--enable-torch-compile"
,
"--torch-compile-max-bs"
,
"
1
"
],
other_args
=
[
"--enable-torch-compile"
,
"--torch-compile-max-bs"
,
"
8
"
],
)
@
classmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment