Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c3eac1b0
"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "1d1e1a2888bd65b51f13272de2f709fd91e0beb1"
Unverified
Commit
c3eac1b0
authored
Nov 14, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 14, 2024
Browse files
Fix torch.compile for MoE (#2033)
parent
b275ce00
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
89 additions
and
12 deletions
+89
-12
python/sglang/srt/layers/fused_moe/patch.py
python/sglang/srt/layers/fused_moe/patch.py
+4
-2
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+3
-2
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_data_parallelism.py
test/srt/test_data_parallelism.py
+1
-1
test/srt/test_double_sparsity.py
test/srt/test_double_sparsity.py
+1
-1
test/srt/test_eval_accuracy_mini.py
test/srt/test_eval_accuracy_mini.py
+1
-1
test/srt/test_retract_decode.py
test/srt/test_retract_decode.py
+1
-1
test/srt/test_torch_compile.py
test/srt/test_torch_compile.py
+3
-3
test/srt/test_torch_compile_moe.py
test/srt/test_torch_compile_moe.py
+73
-0
test/srt/test_triton_attention_backend.py
test/srt/test_triton_attention_backend.py
+1
-1
No files found.
python/sglang/srt/layers/fused_moe/patch.py
View file @
c3eac1b0
from
typing
import
Optional
from
typing
import
Callable
,
Optional
import
torch
import
torch
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
...
@@ -98,7 +98,9 @@ def fused_moe_forward_native(
...
@@ -98,7 +98,9 @@ def fused_moe_forward_native(
renormalize
:
bool
,
renormalize
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
custom_routing_function
is
None
topk_weights
,
topk_ids
=
select_experts_native
(
topk_weights
,
topk_ids
=
select_experts_native
(
hidden_states
=
x
,
hidden_states
=
x
,
router_logits
=
router_logits
,
router_logits
=
router_logits
,
...
@@ -114,4 +116,4 @@ def fused_moe_forward_native(
...
@@ -114,4 +116,4 @@ def fused_moe_forward_native(
x1
=
F
.
silu
(
torch
.
einsum
(
"ti,taoi -> tao"
,
x
,
w1_weights
))
x1
=
F
.
silu
(
torch
.
einsum
(
"ti,taoi -> tao"
,
x
,
w1_weights
))
x3
=
torch
.
einsum
(
"ti, taoi -> tao"
,
x
,
w3_weights
)
x3
=
torch
.
einsum
(
"ti, taoi -> tao"
,
x
,
w3_weights
)
expert_outs
=
torch
.
einsum
(
"tao, taio -> tai"
,
(
x1
*
x3
),
w2_weights
)
expert_outs
=
torch
.
einsum
(
"tao, taio -> tai"
,
(
x1
*
x3
),
w2_weights
)
return
torch
.
einsum
(
"tai,ta -> ti"
,
expert_outs
,
topk_weights
)
return
torch
.
einsum
(
"tai,ta -> ti"
,
expert_outs
,
topk_weights
.
to
(
expert_outs
.
dtype
)
)
python/sglang/test/test_utils.py
View file @
c3eac1b0
...
@@ -28,8 +28,9 @@ from sglang.utils import get_exception_traceback
...
@@ -28,8 +28,9 @@ from sglang.utils import get_exception_traceback
DEFAULT_FP8_MODEL_NAME_FOR_TEST
=
"neuralmagic/Meta-Llama-3.1-8B-FP8"
DEFAULT_FP8_MODEL_NAME_FOR_TEST
=
"neuralmagic/Meta-Llama-3.1-8B-FP8"
DEFAULT_MODEL_NAME_FOR_TEST
=
"meta-llama/Llama-3.1-8B-Instruct"
DEFAULT_MODEL_NAME_FOR_TEST
=
"meta-llama/Llama-3.1-8B-Instruct"
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
=
"meta-llama/Llama-3.2-1B-Instruct"
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
=
"meta-llama/Llama-3.2-1B-Instruct"
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST
=
"Alibaba-NLP/gte-Qwen2-1.5B-instruct"
DEFAULT_MOE_MODEL_NAME_FOR_TEST
=
"mistralai/Mixtral-8x7B-Instruct-v0.1"
DEFAULT_MOE_MODEL_NAME_FOR_TEST
=
"mistralai/Mixtral-8x7B-Instruct-v0.1"
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST
=
"Qwen/Qwen1.5-MoE-A2.7B"
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST
=
"Alibaba-NLP/gte-Qwen2-1.5B-instruct"
DEFAULT_MLA_MODEL_NAME_FOR_TEST
=
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
DEFAULT_MLA_MODEL_NAME_FOR_TEST
=
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST
=
"neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST
=
"neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
=
600
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
=
600
...
@@ -740,7 +741,7 @@ def run_mmlu_test(
...
@@ -740,7 +741,7 @@ def run_mmlu_test(
try
:
try
:
metrics
=
run_eval
(
args
)
metrics
=
run_eval
(
args
)
print
(
f
"
{
metrics
=
}
"
)
print
(
f
"
{
metrics
=
}
"
)
assert
metrics
[
"score"
]
>=
0.65
self
.
assert
GreaterEqual
(
metrics
[
"score"
]
,
0.65
)
finally
:
finally
:
pass
pass
...
...
test/srt/run_suite.py
View file @
c3eac1b0
...
@@ -27,6 +27,7 @@ suites = {
...
@@ -27,6 +27,7 @@ suites = {
"test_srt_engine.py"
,
"test_srt_engine.py"
,
"test_srt_endpoint.py"
,
"test_srt_endpoint.py"
,
"test_torch_compile.py"
,
"test_torch_compile.py"
,
"test_torch_compile_moe.py"
,
"test_torchao.py"
,
"test_torchao.py"
,
"test_triton_attention_kernels.py"
,
"test_triton_attention_kernels.py"
,
"test_triton_attention_backend.py"
,
"test_triton_attention_backend.py"
,
...
...
test/srt/test_data_parallelism.py
View file @
c3eac1b0
...
@@ -40,7 +40,7 @@ class TestDataParallelism(unittest.TestCase):
...
@@ -40,7 +40,7 @@ class TestDataParallelism(unittest.TestCase):
)
)
metrics
=
run_eval
(
args
)
metrics
=
run_eval
(
args
)
assert
metrics
[
"score"
]
>=
0.65
self
.
assert
GreaterEqual
(
metrics
[
"score"
]
,
0.65
)
def
test_update_weight
(
self
):
def
test_update_weight
(
self
):
response
=
requests
.
post
(
response
=
requests
.
post
(
...
...
test/srt/test_double_sparsity.py
View file @
c3eac1b0
...
@@ -55,7 +55,7 @@ class TestDoubleSparsity(unittest.TestCase):
...
@@ -55,7 +55,7 @@ class TestDoubleSparsity(unittest.TestCase):
)
)
metrics
=
run_eval
(
args
)
metrics
=
run_eval
(
args
)
assert
metrics
[
"score"
]
>=
0.65
self
.
assert
GreaterEqual
(
metrics
[
"score"
]
,
0.65
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
test/srt/test_eval_accuracy_mini.py
View file @
c3eac1b0
...
@@ -35,7 +35,7 @@ class TestEvalAccuracyMini(unittest.TestCase):
...
@@ -35,7 +35,7 @@ class TestEvalAccuracyMini(unittest.TestCase):
)
)
metrics
=
run_eval
(
args
)
metrics
=
run_eval
(
args
)
assert
metrics
[
"score"
]
>=
0.65
self
.
assert
GreaterEqual
(
metrics
[
"score"
]
,
0.65
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
test/srt/test_retract_decode.py
View file @
c3eac1b0
...
@@ -34,7 +34,7 @@ class TestRetractDecode(unittest.TestCase):
...
@@ -34,7 +34,7 @@ class TestRetractDecode(unittest.TestCase):
)
)
metrics
=
run_eval
(
args
)
metrics
=
run_eval
(
args
)
assert
metrics
[
"score"
]
>=
0.65
self
.
assert
GreaterEqual
(
metrics
[
"score"
]
,
0.65
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
test/srt/test_torch_compile.py
View file @
c3eac1b0
...
@@ -39,7 +39,7 @@ class TestTorchCompile(unittest.TestCase):
...
@@ -39,7 +39,7 @@ class TestTorchCompile(unittest.TestCase):
)
)
metrics
=
run_eval
(
args
)
metrics
=
run_eval
(
args
)
assert
metrics
[
"score"
]
>=
0.65
self
.
assert
GreaterEqual
(
metrics
[
"score"
]
,
0.65
)
def
run_decode
(
self
,
max_new_tokens
):
def
run_decode
(
self
,
max_new_tokens
):
response
=
requests
.
post
(
response
=
requests
.
post
(
...
@@ -49,8 +49,8 @@ class TestTorchCompile(unittest.TestCase):
...
@@ -49,8 +49,8 @@ class TestTorchCompile(unittest.TestCase):
"sampling_params"
:
{
"sampling_params"
:
{
"temperature"
:
0
,
"temperature"
:
0
,
"max_new_tokens"
:
max_new_tokens
,
"max_new_tokens"
:
max_new_tokens
,
"ignore_eos"
:
True
,
},
},
"ignore_eos"
:
True
,
},
},
)
)
return
response
.
json
()
return
response
.
json
()
...
@@ -66,7 +66,7 @@ class TestTorchCompile(unittest.TestCase):
...
@@ -66,7 +66,7 @@ class TestTorchCompile(unittest.TestCase):
print
(
res
[
"text"
])
print
(
res
[
"text"
])
throughput
=
max_tokens
/
(
tok
-
tic
)
throughput
=
max_tokens
/
(
tok
-
tic
)
print
(
f
"Throughput:
{
throughput
}
tokens/s"
)
print
(
f
"Throughput:
{
throughput
}
tokens/s"
)
assert
throughput
>=
152
self
.
assert
GreaterEqual
(
throughput
,
152
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
test/srt/test_torch_compile_moe.py
0 → 100644
View file @
c3eac1b0
import
unittest
from
types
import
SimpleNamespace
import
requests
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
)
class
TestTorchCompile
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--enable-torch-compile"
,
"--torch-compile-max-bs"
,
"1"
],
)
@
classmethod
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
,
include_self
=
True
)
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
)
metrics
=
run_eval
(
args
)
self
.
assertGreaterEqual
(
metrics
[
"score"
],
0.50
)
def
run_decode
(
self
,
max_new_tokens
):
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
"The capital of France is"
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
max_new_tokens
,
"ignore_eos"
:
True
,
},
},
)
return
response
.
json
()
def
test_throughput
(
self
):
import
time
max_tokens
=
256
tic
=
time
.
time
()
res
=
self
.
run_decode
(
max_tokens
)
tok
=
time
.
time
()
print
(
f
"
{
res
=
}
"
)
throughput
=
max_tokens
/
(
tok
-
tic
)
print
(
f
"Throughput:
{
throughput
}
tokens/s"
)
self
.
assertGreaterEqual
(
throughput
,
290
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/test_triton_attention_backend.py
View file @
c3eac1b0
...
@@ -48,7 +48,7 @@ class TestTritonAttnBackend(unittest.TestCase):
...
@@ -48,7 +48,7 @@ class TestTritonAttnBackend(unittest.TestCase):
)
)
metrics
=
run_eval
(
args
)
metrics
=
run_eval
(
args
)
assert
metrics
[
"score"
]
>=
0.65
self
.
assert
GreaterEqual
(
metrics
[
"score"
]
,
0.65
)
finally
:
finally
:
kill_child_process
(
process
.
pid
,
include_self
=
True
)
kill_child_process
(
process
.
pid
,
include_self
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment