Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
729b7edf
Unverified
Commit
729b7edf
authored
Oct 16, 2025
by
Huaiyu, Zheng
Committed by
GitHub
Oct 15, 2025
Browse files
enable rmsnorm on XPU (#10248)
parent
4c03dbaa
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
54 additions
and
17 deletions
+54
-17
python/sglang/bench_one_batch.py
python/sglang/bench_one_batch.py
+13
-8
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+41
-9
No files found.
python/sglang/bench_one_batch.py
View file @
729b7edf
...
@@ -72,6 +72,8 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
...
@@ -72,6 +72,8 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
configure_logger
,
configure_logger
,
get_bool_env_var
,
get_bool_env_var
,
is_cuda_alike
,
is_xpu
,
kill_process_tree
,
kill_process_tree
,
require_mlp_sync
,
require_mlp_sync
,
require_mlp_tp_gather
,
require_mlp_tp_gather
,
...
@@ -80,6 +82,15 @@ from sglang.srt.utils import (
...
@@ -80,6 +82,15 @@ from sglang.srt.utils import (
)
)
from
sglang.srt.utils.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.utils.hf_transformers_utils
import
get_tokenizer
profile_activities
=
[
torch
.
profiler
.
ProfilerActivity
.
CPU
]
+
[
profiler_activity
for
available
,
profiler_activity
in
[
(
is_cuda_alike
(),
torch
.
profiler
.
ProfilerActivity
.
CUDA
),
(
is_xpu
(),
torch
.
profiler
.
ProfilerActivity
.
XPU
),
]
if
available
]
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
BenchArgs
:
class
BenchArgs
:
...
@@ -424,10 +435,7 @@ def latency_test_run_once(
...
@@ -424,10 +435,7 @@ def latency_test_run_once(
profiler
=
None
profiler
=
None
if
profile
:
if
profile
:
profiler
=
torch
.
profiler
.
profile
(
profiler
=
torch
.
profiler
.
profile
(
activities
=
[
activities
=
profile_activities
,
torch
.
profiler
.
ProfilerActivity
.
CPU
,
torch
.
profiler
.
ProfilerActivity
.
CUDA
,
],
with_stack
=
True
,
with_stack
=
True
,
record_shapes
=
profile_record_shapes
,
record_shapes
=
profile_record_shapes
,
)
)
...
@@ -460,10 +468,7 @@ def latency_test_run_once(
...
@@ -460,10 +468,7 @@ def latency_test_run_once(
if
profile
and
i
==
output_len
/
2
:
if
profile
and
i
==
output_len
/
2
:
profiler
=
None
profiler
=
None
profiler
=
torch
.
profiler
.
profile
(
profiler
=
torch
.
profiler
.
profile
(
activities
=
[
activities
=
profile_activities
,
torch
.
profiler
.
ProfilerActivity
.
CPU
,
torch
.
profiler
.
ProfilerActivity
.
CUDA
,
],
with_stack
=
True
,
with_stack
=
True
,
record_shapes
=
profile_record_shapes
,
record_shapes
=
profile_record_shapes
,
)
)
...
...
python/sglang/srt/layers/layernorm.py
View file @
729b7edf
...
@@ -52,8 +52,13 @@ if _is_cuda:
...
@@ -52,8 +52,13 @@ if _is_cuda:
gemma_rmsnorm
,
gemma_rmsnorm
,
rmsnorm
,
rmsnorm
,
)
)
elif
_is_xpu
:
from
sgl_kernel
import
(
fused_add_rmsnorm
,
gemma_fused_add_rmsnorm
,
gemma_rmsnorm
,
rmsnorm
,
)
if
_use_aiter
:
if
_use_aiter
:
from
aiter
import
rmsnorm2d_fwd
as
rms_norm
from
aiter
import
rmsnorm2d_fwd
as
rms_norm
from
aiter
import
rmsnorm2d_fwd_with_add
as
fused_add_rms_norm
from
aiter
import
rmsnorm2d_fwd_with_add
as
fused_add_rms_norm
...
@@ -216,6 +221,19 @@ class RMSNorm(CustomOp):
...
@@ -216,6 +221,19 @@ class RMSNorm(CustomOp):
else
:
else
:
return
self
.
forward_native
(
x
,
residual
)
return
self
.
forward_native
(
x
,
residual
)
def
forward_xpu
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
if
self
.
variance_size_override
is
not
None
:
return
self
.
forward_native
(
x
,
residual
)
if
residual
is
not
None
:
fused_add_rmsnorm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
x
,
residual
out
=
rmsnorm
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
out
def
forward_with_allreduce_fusion
(
def
forward_with_allreduce_fusion
(
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
@@ -263,6 +281,19 @@ class GemmaRMSNorm(CustomOp):
...
@@ -263,6 +281,19 @@ class GemmaRMSNorm(CustomOp):
if
_is_hip
:
if
_is_hip
:
self
.
_forward_method
=
self
.
forward_native
self
.
_forward_method
=
self
.
forward_native
def
_forward_impl
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
if
residual
is
not
None
:
gemma_fused_add_rmsnorm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
x
,
residual
out
=
gemma_rmsnorm
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
out
def
forward_native
(
def
forward_native
(
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
@@ -285,13 +316,7 @@ class GemmaRMSNorm(CustomOp):
...
@@ -285,13 +316,7 @@ class GemmaRMSNorm(CustomOp):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
if
residual
is
not
None
:
return
self
.
_forward_impl
(
x
,
residual
)
gemma_fused_add_rmsnorm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
x
,
residual
out
=
gemma_rmsnorm
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
out
def
forward_npu
(
def
forward_npu
(
self
,
self
,
...
@@ -305,6 +330,13 @@ class GemmaRMSNorm(CustomOp):
...
@@ -305,6 +330,13 @@ class GemmaRMSNorm(CustomOp):
x
,
_
=
torch_npu
.
npu_gemma_rms_norm
(
x
,
self
.
weight
,
self
.
variance_epsilon
)
x
,
_
=
torch_npu
.
npu_gemma_rms_norm
(
x
,
self
.
weight
,
self
.
variance_epsilon
)
return
x
if
residual
is
None
else
(
x
,
residual
)
return
x
if
residual
is
None
else
(
x
,
residual
)
def
forward_xpu
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
return
self
.
_forward_impl
(
x
,
residual
)
class
Gemma3RMSNorm
(
CustomOp
):
class
Gemma3RMSNorm
(
CustomOp
):
def
__init__
(
self
,
dim
:
int
,
eps
:
float
=
1e-6
):
def
__init__
(
self
,
dim
:
int
,
eps
:
float
=
1e-6
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment