Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e05e29d1
Unverified
Commit
e05e29d1
authored
Jun 03, 2025
by
fzyzcjy
Committed by
GitHub
Jun 02, 2025
Browse files
Refactor CustomOp to avoid confusing bugs (#5382)
parent
a2cb5913
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
15 deletions
+22
-15
python/sglang/srt/custom_op.py
python/sglang/srt/custom_op.py
+20
-3
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+2
-12
No files found.
python/sglang/srt/custom_op.py
View file @
e05e29d1
from
typing
import
Optional
import
torch
from
torch
import
nn
from
torch
import
nn
from
sglang.srt.utils
import
is_cuda
,
is_hip
from
sglang.srt.utils
import
is_cuda
,
is_hip
...
@@ -14,6 +11,26 @@ class CustomOp(nn.Module):
...
@@ -14,6 +11,26 @@ class CustomOp(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
_forward_method
=
self
.
dispatch_forward
()
self
.
_forward_method
=
self
.
dispatch_forward
()
def
enter_torch_compile
(
self
,
num_tokens
:
int
):
# NOTE: Temporarily workaround MoE
if
"FusedMoE"
in
self
.
__class__
.
__name__
:
if
num_tokens
==
1
:
from
sglang.srt.layers.moe.fused_moe_native
import
(
fused_moe_forward_native
,
)
# The performance of torch.compile on this layer is not always good when bs > 1,
# so we decide to only use torch.compile when bs =1
self
.
_forward_method
=
fused_moe_forward_native
else
:
self
.
_forward_method
=
self
.
forward_native
self
.
is_torch_compile
=
True
def
leave_torch_compile
(
self
):
self
.
_forward_method
=
self
.
forward_cuda
self
.
is_torch_compile
=
False
# Please do not override this method, because `self._forward_method` can change when in torch compile mode
def
forward
(
self
,
*
args
,
**
kwargs
):
def
forward
(
self
,
*
args
,
**
kwargs
):
return
self
.
_forward_method
(
*
args
,
**
kwargs
)
return
self
.
_forward_method
(
*
args
,
**
kwargs
)
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
e05e29d1
...
@@ -28,7 +28,6 @@ from sglang.srt.custom_op import CustomOp
...
@@ -28,7 +28,6 @@ from sglang.srt.custom_op import CustomOp
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed.parallel_state
import
GroupCoordinator
,
graph_capture
from
sglang.srt.distributed.parallel_state
import
GroupCoordinator
,
graph_capture
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.moe.fused_moe_native
import
fused_moe_forward_native
from
sglang.srt.layers.torchao_utils
import
save_gemlite_cache
from
sglang.srt.layers.torchao_utils
import
save_gemlite_cache
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
(
from
sglang.srt.model_executor.forward_batch_info
import
(
...
@@ -60,18 +59,9 @@ def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
...
@@ -60,18 +59,9 @@ def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
for
sub
in
model
.
_modules
.
values
():
for
sub
in
model
.
_modules
.
values
():
if
isinstance
(
sub
,
CustomOp
):
if
isinstance
(
sub
,
CustomOp
):
if
reverse
:
if
reverse
:
sub
.
_forward_method
=
sub
.
forward_cuda
sub
.
leave_torch_compile
()
setattr
(
sub
,
"is_torch_compile"
,
False
)
else
:
else
:
# NOTE: Temporarily workaround MoE
sub
.
enter_torch_compile
(
num_tokens
=
num_tokens
)
if
"FusedMoE"
in
sub
.
__class__
.
__name__
:
if
num_tokens
==
1
:
# The performance of torch.compile on this layer is not always good when bs > 1,
# so we decide to only use torch.compile when bs =1
sub
.
_forward_method
=
fused_moe_forward_native
else
:
sub
.
_forward_method
=
sub
.
forward_native
setattr
(
sub
,
"is_torch_compile"
,
True
)
if
isinstance
(
sub
,
torch
.
nn
.
Module
):
if
isinstance
(
sub
,
torch
.
nn
.
Module
):
_to_torch
(
sub
,
reverse
,
num_tokens
)
_to_torch
(
sub
,
reverse
,
num_tokens
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment