Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
76524b70
"vscode:/vscode.git/clone" did not exist on "1d3b429f40888d935e15608b2c7707f5b028564e"
Unverified
Commit
76524b70
authored
Sep 17, 2024
by
Ke Bao
Committed by
GitHub
Sep 17, 2024
Browse files
Fix torch compile for deepseek-v2 (#1442)
parent
3a6e0418
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
1 deletion
+20
-1
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+12
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+1
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
No files found.
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
76524b70
...
...
@@ -41,6 +41,9 @@ if TYPE_CHECKING:
def
_to_torch
(
model
:
torch
.
nn
.
Module
,
reverse
:
bool
=
False
):
for
sub
in
model
.
_modules
.
values
():
if
isinstance
(
sub
,
CustomOp
):
# NOTE: FusedMoE torch native implementaiton is not efficient
if
"FusedMoE"
in
sub
.
__class__
.
__name__
:
continue
if
reverse
:
sub
.
_forward_method
=
sub
.
forward_cuda
setattr
(
sub
,
"is_torch_compile"
,
False
)
...
...
@@ -105,7 +108,15 @@ class CudaGraphRunner:
self
.
capture_bs
=
list
(
range
(
1
,
32
))
+
[
64
,
128
]
else
:
self
.
capture_bs
=
[
1
,
2
,
4
]
+
[
i
*
8
for
i
in
range
(
1
,
21
)]
self
.
compile_bs
=
[
1
,
2
,
4
,
8
,
16
,
24
,
32
]
if
self
.
use_torch_compile
else
[]
self
.
compile_bs
=
(
[
bs
for
bs
in
self
.
capture_bs
if
bs
<=
self
.
model_runner
.
server_args
.
max_torch_compile_bs
]
if
self
.
use_torch_compile
else
[]
)
# Common inputs
self
.
max_bs
=
max
(
self
.
capture_bs
)
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
76524b70
...
...
@@ -653,6 +653,7 @@ class DeepseekV2ForCausalLM(nn.Module):
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
@
torch
.
no_grad
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
...
...
python/sglang/srt/server_args.py
View file @
76524b70
...
...
@@ -110,6 +110,7 @@ class ServerArgs:
disable_custom_all_reduce
:
bool
=
False
enable_mixed_chunk
:
bool
=
False
enable_torch_compile
:
bool
=
False
max_torch_compile_bs
:
int
=
32
torchao_config
:
str
=
""
enable_p2p_check
:
bool
=
False
enable_mla
:
bool
=
False
...
...
@@ -523,6 +524,12 @@ class ServerArgs:
action
=
"store_true"
,
help
=
"Optimize the model with torch.compile. Experimental feature."
,
)
parser
.
add_argument
(
"--max-torch-compile-bs"
,
type
=
int
,
default
=
ServerArgs
.
max_torch_compile_bs
,
help
=
"Set the maximum batch size when using torch compile."
,
)
parser
.
add_argument
(
"--torchao-config"
,
type
=
str
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment