Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
07ec07ad
"docs/source/en/using-diffusers/other-formats.md" did not exist on "8bf80fc8d8aade3bd3fca5054d05b65488fbbf8f"
Unverified
Commit
07ec07ad
authored
Dec 03, 2024
by
Lianmin Zheng
Committed by
GitHub
Dec 03, 2024
Browse files
Improve torch compile for fused moe (#2327)
parent
83b340e3
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
45 additions
and
24 deletions
+45
-24
benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py
...els/fused_moe_triton/benchmark_torch_compile_fused_moe.py
+5
-2
python/sglang/srt/layers/fused_moe_patch.py
python/sglang/srt/layers/fused_moe_patch.py
+20
-11
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+16
-7
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-1
test/srt/test_srt_engine.py
test/srt/test_srt_engine.py
+1
-1
test/srt/test_torch_compile_moe.py
test/srt/test_torch_compile_moe.py
+2
-2
No files found.
benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py
View file @
07ec07ad
...
...
@@ -6,6 +6,7 @@ from torch.nn import functional as F
from
transformers
import
AutoConfig
from
sglang.srt.layers.fused_moe_triton.fused_moe
import
fused_moe
as
fused_moe_triton
from
sglang.srt.model_executor.cuda_graph_runner
import
set_torch_compile_config
def
get_model_config
(
model_name
:
str
,
tp_size
:
int
):
...
...
@@ -64,7 +65,7 @@ def fused_topk_native(
return
topk_weights
,
topk_ids
@
torch
.
compile
@
torch
.
compile
(
dynamic
=
False
)
def
fused_moe_torch
(
x
,
w1
,
...
...
@@ -88,7 +89,8 @@ def fused_moe_torch(
w13_weights
=
w1
[
topk_ids
]
w1_weights
,
w3_weights
=
torch
.
chunk
(
w13_weights
,
2
,
dim
=
2
)
w2_weights
=
w2
[
topk_ids
]
x1
=
F
.
gelu
(
torch
.
einsum
(
"ti,taoi -> tao"
,
x
,
w1_weights
))
x1
=
torch
.
einsum
(
"ti,taoi -> tao"
,
x
,
w1_weights
)
x1
=
F
.
silu
(
x1
)
x3
=
torch
.
einsum
(
"ti, taoi -> tao"
,
x
,
w3_weights
)
expert_outs
=
torch
.
einsum
(
"tao, taio -> tai"
,
(
x1
*
x3
),
w2_weights
)
return
torch
.
einsum
(
"tai,ta -> ti"
,
expert_outs
,
topk_weights
.
to
(
expert_outs
.
dtype
))
...
...
@@ -174,6 +176,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
print
(
f
"benchmark
{
provider
}
with batch_size=
{
batch_size
}
"
)
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
0
)
set_torch_compile_config
()
num_tokens
=
batch_size
num_experts
=
model_config
[
"num_experts"
]
...
...
python/sglang/srt/layers/fused_moe_patch.py
View file @
07ec07ad
...
...
@@ -105,20 +105,29 @@ def fused_moe_forward_native(
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
)
->
torch
.
Tensor
:
assert
custom_routing_function
is
None
topk_weights
,
topk_ids
=
select_experts_native
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
)
if
use_grouped_topk
:
assert
num_expert_group
is
not
None
and
topk_group
is
not
None
topk_weights
,
topk_ids
=
grouped_topk
(
x
,
router_logits
,
top_k
,
renormalize
,
num_expert_group
,
topk_group
,
)
elif
custom_routing_function
is
None
:
topk_weights
,
topk_ids
=
fused_topk_native
(
x
,
router_logits
,
top_k
,
renormalize
)
else
:
topk_weights
,
topk_ids
=
custom_routing_function
(
x
,
router_logits
,
top_k
,
renormalize
)
w13_weights
=
layer
.
w13_weight
[
topk_ids
]
w1_weights
,
w3_weights
=
torch
.
chunk
(
w13_weights
,
2
,
dim
=
2
)
w2_weights
=
layer
.
w2_weight
[
topk_ids
]
x1
=
F
.
silu
(
torch
.
einsum
(
"ti,taoi -> tao"
,
x
,
w1_weights
))
x1
=
torch
.
einsum
(
"ti,taoi -> tao"
,
x
,
w1_weights
)
x1
=
F
.
silu
(
x1
)
x3
=
torch
.
einsum
(
"ti, taoi -> tao"
,
x
,
w3_weights
)
expert_outs
=
torch
.
einsum
(
"tao, taio -> tai"
,
(
x1
*
x3
),
w2_weights
)
return
torch
.
einsum
(
"tai,ta -> ti"
,
expert_outs
,
topk_weights
.
to
(
expert_outs
.
dtype
))
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
07ec07ad
...
...
@@ -36,7 +36,7 @@ if TYPE_CHECKING:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
def
_to_torch
(
model
:
torch
.
nn
.
Module
,
reverse
:
bool
=
False
):
def
_to_torch
(
model
:
torch
.
nn
.
Module
,
reverse
:
bool
,
batch_size
:
int
):
for
sub
in
model
.
_modules
.
values
():
if
isinstance
(
sub
,
CustomOp
):
if
reverse
:
...
...
@@ -45,24 +45,30 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
else
:
# NOTE: Temporarily workaround MoE
if
"FusedMoE"
in
sub
.
__class__
.
__name__
:
sub
.
_forward_method
=
fused_moe_forward_native
if
batch_size
==
1
:
# The performance of torch.compile on this layer is not always good when bs > 1,
# so we decide to skip it for now.
sub
.
_forward_method
=
fused_moe_forward_native
else
:
sub
.
_forward_method
=
sub
.
forward_native
setattr
(
sub
,
"is_torch_compile"
,
True
)
if
isinstance
(
sub
,
torch
.
nn
.
Module
):
_to_torch
(
sub
,
reverse
)
_to_torch
(
sub
,
reverse
,
batch_size
)
@
contextmanager
def
patch_model
(
model
:
torch
.
nn
.
Module
,
enable_compile
:
bool
,
tp_group
:
"GroupCoordinator"
model
:
torch
.
nn
.
Module
,
enable_compile
:
bool
,
batch_size
:
int
,
tp_group
:
"GroupCoordinator"
,
):
"""Patch the model to make it compatible with with torch.compile"""
backup_ca_comm
=
None
try
:
if
enable_compile
:
_to_torch
(
model
)
_to_torch
(
model
,
reverse
=
False
,
batch_size
=
batch_size
)
monkey_patch_vllm_all_gather
()
backup_ca_comm
=
tp_group
.
ca_comm
# Use custom-allreduce here.
...
...
@@ -70,13 +76,15 @@ def patch_model(
# even with ENABLE_INTRA_NODE_COMM=1.
# tp_group.ca_comm = None
yield
torch
.
compile
(
torch
.
no_grad
()(
model
.
forward
),
mode
=
"max-autotune-no-cudagraphs"
torch
.
no_grad
()(
model
.
forward
),
mode
=
"max-autotune-no-cudagraphs"
,
dynamic
=
False
,
)
else
:
yield
model
.
forward
finally
:
if
enable_compile
:
_to_torch
(
model
,
reverse
=
True
)
_to_torch
(
model
,
reverse
=
True
,
batch_size
=
batch_size
)
monkey_patch_vllm_all_gather
(
reverse
=
True
)
tp_group
.
ca_comm
=
backup_ca_comm
...
...
@@ -237,6 +245,7 @@ class CudaGraphRunner:
with
patch_model
(
self
.
model_runner
.
model
,
bs
in
self
.
compile_bs
,
bs
,
self
.
model_runner
.
tp_group
,
)
as
forward
:
(
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
07ec07ad
...
...
@@ -622,7 +622,7 @@ class ModelRunner:
tic
=
time
.
time
()
logger
.
info
(
"Capture cuda graph begin. This can take up to several minutes."
)
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
)
logger
.
info
(
f
"Capture cuda graph end. Time elapsed:
{
time
.
time
()
-
tic
:.
2
f
}
s"
)
logger
.
info
(
f
"Capture cuda graph end. Time elapsed:
{
time
.
time
()
-
tic
:.
2
f
}
s"
)
def
apply_torch_tp
(
self
):
logger
.
info
(
f
"Enabling torch tensor parallelism on
{
self
.
tp_size
}
devices."
)
...
...
test/srt/test_srt_engine.py
View file @
07ec07ad
...
...
@@ -188,7 +188,7 @@ class TestSRTEngine(unittest.TestCase):
)
bench_args
=
BenchArgs
(
num_prompts
=
10
)
result
=
throughput_test
(
server_args
=
server_args
,
bench_args
=
bench_args
)
self
.
assertGreater
(
result
[
"total_throughput"
],
3
5
00
)
self
.
assertGreater
(
result
[
"total_throughput"
],
3
0
00
)
if
__name__
==
"__main__"
:
...
...
test/srt/test_torch_compile_moe.py
View file @
07ec07ad
...
...
@@ -14,7 +14,7 @@ from sglang.test.test_utils import (
)
class
TestTorchCompile
(
unittest
.
TestCase
):
class
TestTorchCompile
Moe
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST
...
...
@@ -23,7 +23,7 @@ class TestTorchCompile(unittest.TestCase):
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--enable-torch-compile"
,
"--torch-compile-max-bs"
,
"
1
"
],
other_args
=
[
"--enable-torch-compile"
,
"--torch-compile-max-bs"
,
"
8
"
],
)
@
classmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment