Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
180ff5ee
Unverified
Commit
180ff5ee
authored
Jun 04, 2025
by
JieXin Liang
Committed by
GitHub
Jun 03, 2025
Browse files
[fix] recover auto-dispatch for rmsnorm and rope (#6745)
parent
37f15475
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
30 deletions
+28
-30
python/sglang/srt/custom_op.py
python/sglang/srt/custom_op.py
+19
-1
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+3
-17
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+6
-12
No files found.
python/sglang/srt/custom_op.py
View file @
180ff5ee
...
@@ -11,7 +11,20 @@ class CustomOp(nn.Module):
...
@@ -11,7 +11,20 @@ class CustomOp(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
_forward_method
=
self
.
dispatch_forward
()
self
.
_forward_method
=
self
.
dispatch_forward
()
# States for torch.compile
self
.
_original_forward_method
=
None
self
.
is_torch_compile
=
False
def
enter_torch_compile
(
self
,
num_tokens
:
int
):
def
enter_torch_compile
(
self
,
num_tokens
:
int
):
# Skip if Op is already entered compile mode.
# NOTE(alcanderian): Some Ops(for example RotaryEmbedding) will be reused
# among layers and `enter_torch_compile` will be called many times.
# We should prevent `self._original_forward_method` from being overridden when
# it is not the first time `enter_torch_compile` called.
if
self
.
is_torch_compile
:
return
self
.
_original_forward_method
=
self
.
_forward_method
# NOTE: Temporarily workaround MoE
# NOTE: Temporarily workaround MoE
if
"FusedMoE"
in
self
.
__class__
.
__name__
:
if
"FusedMoE"
in
self
.
__class__
.
__name__
:
if
num_tokens
==
1
:
if
num_tokens
==
1
:
...
@@ -27,7 +40,12 @@ class CustomOp(nn.Module):
...
@@ -27,7 +40,12 @@ class CustomOp(nn.Module):
self
.
is_torch_compile
=
True
self
.
is_torch_compile
=
True
def
leave_torch_compile
(
self
):
def
leave_torch_compile
(
self
):
self
.
_forward_method
=
self
.
forward_cuda
# Skip if Op is already exited compile mode.
if
not
self
.
is_torch_compile
:
return
self
.
_forward_method
=
self
.
_original_forward_method
self
.
_original_forward_method
=
None
self
.
is_torch_compile
=
False
self
.
is_torch_compile
=
False
# Please do not override this method, because `self._forward_method` can change when in torch compile mode
# Please do not override this method, because `self._forward_method` can change when in torch compile mode
...
...
python/sglang/srt/layers/layernorm.py
View file @
180ff5ee
...
@@ -49,16 +49,6 @@ class RMSNorm(CustomOp):
...
@@ -49,16 +49,6 @@ class RMSNorm(CustomOp):
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
variance_epsilon
=
eps
self
.
variance_epsilon
=
eps
def
forward
(
self
,
*
args
,
**
kwargs
):
if
torch
.
compiler
.
is_compiling
():
return
self
.
forward_native
(
*
args
,
**
kwargs
)
if
_is_cuda
:
return
self
.
forward_cuda
(
*
args
,
**
kwargs
)
elif
_is_hip
:
return
self
.
forward_hip
(
*
args
,
**
kwargs
)
else
:
return
self
.
forward_native
(
*
args
,
**
kwargs
)
def
forward_cuda
(
def
forward_cuda
(
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
@@ -117,13 +107,9 @@ class GemmaRMSNorm(CustomOp):
...
@@ -117,13 +107,9 @@ class GemmaRMSNorm(CustomOp):
self
.
weight
=
nn
.
Parameter
(
torch
.
zeros
(
hidden_size
))
self
.
weight
=
nn
.
Parameter
(
torch
.
zeros
(
hidden_size
))
self
.
variance_epsilon
=
eps
self
.
variance_epsilon
=
eps
def
forward
(
self
,
*
args
,
**
kwargs
):
# Re-dispatch
if
torch
.
compiler
.
is_compiling
():
if
_is_hip
:
return
self
.
forward_native
(
*
args
,
**
kwargs
)
self
.
_forward_method
=
self
.
forward_native
if
_is_cuda
:
return
self
.
forward_cuda
(
*
args
,
**
kwargs
)
else
:
return
self
.
forward_native
(
*
args
,
**
kwargs
)
def
forward_native
(
def
forward_native
(
self
,
self
,
...
...
python/sglang/srt/layers/rotary_embedding.py
View file @
180ff5ee
...
@@ -8,9 +8,10 @@ import torch
...
@@ -8,9 +8,10 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.utils
import
is_cuda
from
sglang.srt.utils
import
is_cuda
,
is_hip
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
_is_hip
=
is_hip
()
if
_is_cuda
:
if
_is_cuda
:
from
sgl_kernel
import
apply_rope_with_cos_sin_cache_inplace
from
sgl_kernel
import
apply_rope_with_cos_sin_cache_inplace
...
@@ -609,6 +610,10 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
...
@@ -609,6 +610,10 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
)
# Re-dispatch
if
_is_hip
:
self
.
_forward_method
=
self
.
forward_native
def
_compute_inv_freq
(
self
,
scaling_factor
:
float
)
->
torch
.
Tensor
:
def
_compute_inv_freq
(
self
,
scaling_factor
:
float
)
->
torch
.
Tensor
:
pos_freqs
=
self
.
base
**
(
pos_freqs
=
self
.
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
,
device
=
self
.
device
)
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
,
device
=
self
.
device
)
...
@@ -650,17 +655,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
...
@@ -650,17 +655,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
return
cache
def
forward_hip
(
self
,
*
args
,
**
kwargs
):
return
self
.
forward_native
(
*
args
,
**
kwargs
)
def
forward
(
self
,
*
args
,
**
kwargs
):
if
torch
.
compiler
.
is_compiling
():
return
self
.
forward_native
(
*
args
,
**
kwargs
)
if
_is_cuda
:
return
self
.
forward_cuda
(
*
args
,
**
kwargs
)
else
:
return
self
.
forward_native
(
*
args
,
**
kwargs
)
def
forward_native
(
def
forward_native
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment