Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e05e29d1
"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "454f82e6fc4f932747cf7c2062805289fde2672b"
Unverified
Commit
e05e29d1
authored
Jun 03, 2025
by
fzyzcjy
Committed by
GitHub
Jun 02, 2025
Browse files
Refactor CustomOp to avoid confusing bugs (#5382)
parent
a2cb5913
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
15 deletions
+22
-15
python/sglang/srt/custom_op.py
python/sglang/srt/custom_op.py
+20
-3
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+2
-12
No files found.
python/sglang/srt/custom_op.py
View file @
e05e29d1
from
typing
import
Optional
import
torch
from
torch
import
nn
from
torch
import
nn
from
sglang.srt.utils
import
is_cuda
,
is_hip
from
sglang.srt.utils
import
is_cuda
,
is_hip
...
@@ -14,6 +11,26 @@ class CustomOp(nn.Module):
...
@@ -14,6 +11,26 @@ class CustomOp(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
_forward_method
=
self
.
dispatch_forward
()
self
.
_forward_method
=
self
.
dispatch_forward
()
def
enter_torch_compile
(
self
,
num_tokens
:
int
):
# NOTE: Temporarily workaround MoE
if
"FusedMoE"
in
self
.
__class__
.
__name__
:
if
num_tokens
==
1
:
from
sglang.srt.layers.moe.fused_moe_native
import
(
fused_moe_forward_native
,
)
# The performance of torch.compile on this layer is not always good when bs > 1,
# so we decide to only use torch.compile when bs =1
self
.
_forward_method
=
fused_moe_forward_native
else
:
self
.
_forward_method
=
self
.
forward_native
self
.
is_torch_compile
=
True
def
leave_torch_compile
(
self
):
self
.
_forward_method
=
self
.
forward_cuda
self
.
is_torch_compile
=
False
# Please do not override this method, because `self._forward_method` can change when in torch compile mode
def
forward
(
self
,
*
args
,
**
kwargs
):
def
forward
(
self
,
*
args
,
**
kwargs
):
return
self
.
_forward_method
(
*
args
,
**
kwargs
)
return
self
.
_forward_method
(
*
args
,
**
kwargs
)
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
e05e29d1
...
@@ -28,7 +28,6 @@ from sglang.srt.custom_op import CustomOp
...
@@ -28,7 +28,6 @@ from sglang.srt.custom_op import CustomOp
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed.parallel_state
import
GroupCoordinator
,
graph_capture
from
sglang.srt.distributed.parallel_state
import
GroupCoordinator
,
graph_capture
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.moe.fused_moe_native
import
fused_moe_forward_native
from
sglang.srt.layers.torchao_utils
import
save_gemlite_cache
from
sglang.srt.layers.torchao_utils
import
save_gemlite_cache
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
(
from
sglang.srt.model_executor.forward_batch_info
import
(
...
@@ -60,18 +59,9 @@ def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
...
@@ -60,18 +59,9 @@ def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
for
sub
in
model
.
_modules
.
values
():
for
sub
in
model
.
_modules
.
values
():
if
isinstance
(
sub
,
CustomOp
):
if
isinstance
(
sub
,
CustomOp
):
if
reverse
:
if
reverse
:
sub
.
_forward_method
=
sub
.
forward_cuda
sub
.
leave_torch_compile
()
setattr
(
sub
,
"is_torch_compile"
,
False
)
else
:
else
:
# NOTE: Temporarily workaround MoE
sub
.
enter_torch_compile
(
num_tokens
=
num_tokens
)
if
"FusedMoE"
in
sub
.
__class__
.
__name__
:
if
num_tokens
==
1
:
# The performance of torch.compile on this layer is not always good when bs > 1,
# so we decide to only use torch.compile when bs =1
sub
.
_forward_method
=
fused_moe_forward_native
else
:
sub
.
_forward_method
=
sub
.
forward_native
setattr
(
sub
,
"is_torch_compile"
,
True
)
if
isinstance
(
sub
,
torch
.
nn
.
Module
):
if
isinstance
(
sub
,
torch
.
nn
.
Module
):
_to_torch
(
sub
,
reverse
,
num_tokens
)
_to_torch
(
sub
,
reverse
,
num_tokens
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment