Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
07ec07ad
Unverified
Commit
07ec07ad
authored
Dec 03, 2024
by
Lianmin Zheng
Committed by
GitHub
Dec 03, 2024
Browse files
Improve torch compile for fused moe (#2327)
parent
83b340e3
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
45 additions
and
24 deletions
+45
-24
benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py
...els/fused_moe_triton/benchmark_torch_compile_fused_moe.py
+5
-2
python/sglang/srt/layers/fused_moe_patch.py
python/sglang/srt/layers/fused_moe_patch.py
+20
-11
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+16
-7
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-1
test/srt/test_srt_engine.py
test/srt/test_srt_engine.py
+1
-1
test/srt/test_torch_compile_moe.py
test/srt/test_torch_compile_moe.py
+2
-2
No files found.
benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py
View file @
07ec07ad
...
@@ -6,6 +6,7 @@ from torch.nn import functional as F
...
@@ -6,6 +6,7 @@ from torch.nn import functional as F
from
transformers
import
AutoConfig
from
transformers
import
AutoConfig
from
sglang.srt.layers.fused_moe_triton.fused_moe
import
fused_moe
as
fused_moe_triton
from
sglang.srt.layers.fused_moe_triton.fused_moe
import
fused_moe
as
fused_moe_triton
from
sglang.srt.model_executor.cuda_graph_runner
import
set_torch_compile_config
def
get_model_config
(
model_name
:
str
,
tp_size
:
int
):
def
get_model_config
(
model_name
:
str
,
tp_size
:
int
):
...
@@ -64,7 +65,7 @@ def fused_topk_native(
...
@@ -64,7 +65,7 @@ def fused_topk_native(
return
topk_weights
,
topk_ids
return
topk_weights
,
topk_ids
@
torch
.
compile
@
torch
.
compile
(
dynamic
=
False
)
def
fused_moe_torch
(
def
fused_moe_torch
(
x
,
x
,
w1
,
w1
,
...
@@ -88,7 +89,8 @@ def fused_moe_torch(
...
@@ -88,7 +89,8 @@ def fused_moe_torch(
w13_weights
=
w1
[
topk_ids
]
w13_weights
=
w1
[
topk_ids
]
w1_weights
,
w3_weights
=
torch
.
chunk
(
w13_weights
,
2
,
dim
=
2
)
w1_weights
,
w3_weights
=
torch
.
chunk
(
w13_weights
,
2
,
dim
=
2
)
w2_weights
=
w2
[
topk_ids
]
w2_weights
=
w2
[
topk_ids
]
x1
=
F
.
gelu
(
torch
.
einsum
(
"ti,taoi -> tao"
,
x
,
w1_weights
))
x1
=
torch
.
einsum
(
"ti,taoi -> tao"
,
x
,
w1_weights
)
x1
=
F
.
silu
(
x1
)
x3
=
torch
.
einsum
(
"ti, taoi -> tao"
,
x
,
w3_weights
)
x3
=
torch
.
einsum
(
"ti, taoi -> tao"
,
x
,
w3_weights
)
expert_outs
=
torch
.
einsum
(
"tao, taio -> tai"
,
(
x1
*
x3
),
w2_weights
)
expert_outs
=
torch
.
einsum
(
"tao, taio -> tai"
,
(
x1
*
x3
),
w2_weights
)
return
torch
.
einsum
(
"tai,ta -> ti"
,
expert_outs
,
topk_weights
.
to
(
expert_outs
.
dtype
))
return
torch
.
einsum
(
"tai,ta -> ti"
,
expert_outs
,
topk_weights
.
to
(
expert_outs
.
dtype
))
...
@@ -174,6 +176,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
...
@@ -174,6 +176,7 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
print
(
f
"benchmark
{
provider
}
with batch_size=
{
batch_size
}
"
)
print
(
f
"benchmark
{
provider
}
with batch_size=
{
batch_size
}
"
)
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
0
)
torch
.
cuda
.
manual_seed_all
(
0
)
set_torch_compile_config
()
num_tokens
=
batch_size
num_tokens
=
batch_size
num_experts
=
model_config
[
"num_experts"
]
num_experts
=
model_config
[
"num_experts"
]
...
...
python/sglang/srt/layers/fused_moe_patch.py
View file @
07ec07ad
...
@@ -105,20 +105,29 @@ def fused_moe_forward_native(
...
@@ -105,20 +105,29 @@ def fused_moe_forward_native(
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
custom_routing_function
is
None
topk_weights
,
topk_ids
=
select_experts_native
(
if
use_grouped_topk
:
hidden_states
=
x
,
assert
num_expert_group
is
not
None
and
topk_group
is
not
None
router_logits
=
router_logits
,
topk_weights
,
topk_ids
=
grouped_topk
(
use_grouped_topk
=
use_grouped_topk
,
x
,
top_k
=
top_k
,
router_logits
,
renormalize
=
renormalize
,
top_k
,
topk_group
=
topk_group
,
renormalize
,
num_expert_group
=
num_expert_group
,
num_expert_group
,
)
topk_group
,
)
elif
custom_routing_function
is
None
:
topk_weights
,
topk_ids
=
fused_topk_native
(
x
,
router_logits
,
top_k
,
renormalize
)
else
:
topk_weights
,
topk_ids
=
custom_routing_function
(
x
,
router_logits
,
top_k
,
renormalize
)
w13_weights
=
layer
.
w13_weight
[
topk_ids
]
w13_weights
=
layer
.
w13_weight
[
topk_ids
]
w1_weights
,
w3_weights
=
torch
.
chunk
(
w13_weights
,
2
,
dim
=
2
)
w1_weights
,
w3_weights
=
torch
.
chunk
(
w13_weights
,
2
,
dim
=
2
)
w2_weights
=
layer
.
w2_weight
[
topk_ids
]
w2_weights
=
layer
.
w2_weight
[
topk_ids
]
x1
=
F
.
silu
(
torch
.
einsum
(
"ti,taoi -> tao"
,
x
,
w1_weights
))
x1
=
torch
.
einsum
(
"ti,taoi -> tao"
,
x
,
w1_weights
)
x1
=
F
.
silu
(
x1
)
x3
=
torch
.
einsum
(
"ti, taoi -> tao"
,
x
,
w3_weights
)
x3
=
torch
.
einsum
(
"ti, taoi -> tao"
,
x
,
w3_weights
)
expert_outs
=
torch
.
einsum
(
"tao, taio -> tai"
,
(
x1
*
x3
),
w2_weights
)
expert_outs
=
torch
.
einsum
(
"tao, taio -> tai"
,
(
x1
*
x3
),
w2_weights
)
return
torch
.
einsum
(
"tai,ta -> ti"
,
expert_outs
,
topk_weights
.
to
(
expert_outs
.
dtype
))
return
torch
.
einsum
(
"tai,ta -> ti"
,
expert_outs
,
topk_weights
.
to
(
expert_outs
.
dtype
))
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
07ec07ad
...
@@ -36,7 +36,7 @@ if TYPE_CHECKING:
...
@@ -36,7 +36,7 @@ if TYPE_CHECKING:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
def
_to_torch
(
model
:
torch
.
nn
.
Module
,
reverse
:
bool
=
False
):
def
_to_torch
(
model
:
torch
.
nn
.
Module
,
reverse
:
bool
,
batch_size
:
int
):
for
sub
in
model
.
_modules
.
values
():
for
sub
in
model
.
_modules
.
values
():
if
isinstance
(
sub
,
CustomOp
):
if
isinstance
(
sub
,
CustomOp
):
if
reverse
:
if
reverse
:
...
@@ -45,24 +45,30 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
...
@@ -45,24 +45,30 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
else
:
else
:
# NOTE: Temporarily workaround MoE
# NOTE: Temporarily workaround MoE
if
"FusedMoE"
in
sub
.
__class__
.
__name__
:
if
"FusedMoE"
in
sub
.
__class__
.
__name__
:
sub
.
_forward_method
=
fused_moe_forward_native
if
batch_size
==
1
:
# The performance of torch.compile on this layer is not always good when bs > 1,
# so we decide to skip it for now.
sub
.
_forward_method
=
fused_moe_forward_native
else
:
else
:
sub
.
_forward_method
=
sub
.
forward_native
sub
.
_forward_method
=
sub
.
forward_native
setattr
(
sub
,
"is_torch_compile"
,
True
)
setattr
(
sub
,
"is_torch_compile"
,
True
)
if
isinstance
(
sub
,
torch
.
nn
.
Module
):
if
isinstance
(
sub
,
torch
.
nn
.
Module
):
_to_torch
(
sub
,
reverse
)
_to_torch
(
sub
,
reverse
,
batch_size
)
@
contextmanager
@
contextmanager
def
patch_model
(
def
patch_model
(
model
:
torch
.
nn
.
Module
,
enable_compile
:
bool
,
tp_group
:
"GroupCoordinator"
model
:
torch
.
nn
.
Module
,
enable_compile
:
bool
,
batch_size
:
int
,
tp_group
:
"GroupCoordinator"
,
):
):
"""Patch the model to make it compatible with with torch.compile"""
"""Patch the model to make it compatible with with torch.compile"""
backup_ca_comm
=
None
backup_ca_comm
=
None
try
:
try
:
if
enable_compile
:
if
enable_compile
:
_to_torch
(
model
)
_to_torch
(
model
,
reverse
=
False
,
batch_size
=
batch_size
)
monkey_patch_vllm_all_gather
()
monkey_patch_vllm_all_gather
()
backup_ca_comm
=
tp_group
.
ca_comm
backup_ca_comm
=
tp_group
.
ca_comm
# Use custom-allreduce here.
# Use custom-allreduce here.
...
@@ -70,13 +76,15 @@ def patch_model(
...
@@ -70,13 +76,15 @@ def patch_model(
# even with ENABLE_INTRA_NODE_COMM=1.
# even with ENABLE_INTRA_NODE_COMM=1.
# tp_group.ca_comm = None
# tp_group.ca_comm = None
yield
torch
.
compile
(
yield
torch
.
compile
(
torch
.
no_grad
()(
model
.
forward
),
mode
=
"max-autotune-no-cudagraphs"
torch
.
no_grad
()(
model
.
forward
),
mode
=
"max-autotune-no-cudagraphs"
,
dynamic
=
False
,
)
)
else
:
else
:
yield
model
.
forward
yield
model
.
forward
finally
:
finally
:
if
enable_compile
:
if
enable_compile
:
_to_torch
(
model
,
reverse
=
True
)
_to_torch
(
model
,
reverse
=
True
,
batch_size
=
batch_size
)
monkey_patch_vllm_all_gather
(
reverse
=
True
)
monkey_patch_vllm_all_gather
(
reverse
=
True
)
tp_group
.
ca_comm
=
backup_ca_comm
tp_group
.
ca_comm
=
backup_ca_comm
...
@@ -237,6 +245,7 @@ class CudaGraphRunner:
...
@@ -237,6 +245,7 @@ class CudaGraphRunner:
with
patch_model
(
with
patch_model
(
self
.
model_runner
.
model
,
self
.
model_runner
.
model
,
bs
in
self
.
compile_bs
,
bs
in
self
.
compile_bs
,
bs
,
self
.
model_runner
.
tp_group
,
self
.
model_runner
.
tp_group
,
)
as
forward
:
)
as
forward
:
(
(
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
07ec07ad
...
@@ -622,7 +622,7 @@ class ModelRunner:
...
@@ -622,7 +622,7 @@ class ModelRunner:
tic
=
time
.
time
()
tic
=
time
.
time
()
logger
.
info
(
"Capture cuda graph begin. This can take up to several minutes."
)
logger
.
info
(
"Capture cuda graph begin. This can take up to several minutes."
)
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
)
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
)
logger
.
info
(
f
"Capture cuda graph end. Time elapsed:
{
time
.
time
()
-
tic
:.
2
f
}
s"
)
logger
.
info
(
f
"Capture cuda graph end. Time elapsed:
{
time
.
time
()
-
tic
:.
2
f
}
s"
)
def
apply_torch_tp
(
self
):
def
apply_torch_tp
(
self
):
logger
.
info
(
f
"Enabling torch tensor parallelism on
{
self
.
tp_size
}
devices."
)
logger
.
info
(
f
"Enabling torch tensor parallelism on
{
self
.
tp_size
}
devices."
)
...
...
test/srt/test_srt_engine.py
View file @
07ec07ad
...
@@ -188,7 +188,7 @@ class TestSRTEngine(unittest.TestCase):
...
@@ -188,7 +188,7 @@ class TestSRTEngine(unittest.TestCase):
)
)
bench_args
=
BenchArgs
(
num_prompts
=
10
)
bench_args
=
BenchArgs
(
num_prompts
=
10
)
result
=
throughput_test
(
server_args
=
server_args
,
bench_args
=
bench_args
)
result
=
throughput_test
(
server_args
=
server_args
,
bench_args
=
bench_args
)
self
.
assertGreater
(
result
[
"total_throughput"
],
3
5
00
)
self
.
assertGreater
(
result
[
"total_throughput"
],
3
0
00
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
test/srt/test_torch_compile_moe.py
View file @
07ec07ad
...
@@ -14,7 +14,7 @@ from sglang.test.test_utils import (
...
@@ -14,7 +14,7 @@ from sglang.test.test_utils import (
)
)
class
TestTorchCompile
(
unittest
.
TestCase
):
class
TestTorchCompile
Moe
(
unittest
.
TestCase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST
...
@@ -23,7 +23,7 @@ class TestTorchCompile(unittest.TestCase):
...
@@ -23,7 +23,7 @@ class TestTorchCompile(unittest.TestCase):
cls
.
model
,
cls
.
model
,
cls
.
base_url
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--enable-torch-compile"
,
"--torch-compile-max-bs"
,
"
1
"
],
other_args
=
[
"--enable-torch-compile"
,
"--torch-compile-max-bs"
,
"
8
"
],
)
)
@
classmethod
@
classmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment