Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
729b7edf
"vscode:/vscode.git/clone" did not exist on "83d55ac51fbc4b29b666223c87f650b8ffd7b38c"
Unverified
Commit
729b7edf
authored
Oct 16, 2025
by
Huaiyu, Zheng
Committed by
GitHub
Oct 15, 2025
Browse files
enable rmsnorm on XPU (#10248)
parent
4c03dbaa
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
54 additions
and
17 deletions
+54
-17
python/sglang/bench_one_batch.py
python/sglang/bench_one_batch.py
+13
-8
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+41
-9
No files found.
python/sglang/bench_one_batch.py
View file @
729b7edf
...
...
@@ -72,6 +72,8 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from
sglang.srt.utils
import
(
configure_logger
,
get_bool_env_var
,
is_cuda_alike
,
is_xpu
,
kill_process_tree
,
require_mlp_sync
,
require_mlp_tp_gather
,
...
...
@@ -80,6 +82,15 @@ from sglang.srt.utils import (
)
from
sglang.srt.utils.hf_transformers_utils
import
get_tokenizer
profile_activities
=
[
torch
.
profiler
.
ProfilerActivity
.
CPU
]
+
[
profiler_activity
for
available
,
profiler_activity
in
[
(
is_cuda_alike
(),
torch
.
profiler
.
ProfilerActivity
.
CUDA
),
(
is_xpu
(),
torch
.
profiler
.
ProfilerActivity
.
XPU
),
]
if
available
]
@
dataclasses
.
dataclass
class
BenchArgs
:
...
...
@@ -424,10 +435,7 @@ def latency_test_run_once(
profiler
=
None
if
profile
:
profiler
=
torch
.
profiler
.
profile
(
activities
=
[
torch
.
profiler
.
ProfilerActivity
.
CPU
,
torch
.
profiler
.
ProfilerActivity
.
CUDA
,
],
activities
=
profile_activities
,
with_stack
=
True
,
record_shapes
=
profile_record_shapes
,
)
...
...
@@ -460,10 +468,7 @@ def latency_test_run_once(
if
profile
and
i
==
output_len
/
2
:
profiler
=
None
profiler
=
torch
.
profiler
.
profile
(
activities
=
[
torch
.
profiler
.
ProfilerActivity
.
CPU
,
torch
.
profiler
.
ProfilerActivity
.
CUDA
,
],
activities
=
profile_activities
,
with_stack
=
True
,
record_shapes
=
profile_record_shapes
,
)
...
...
python/sglang/srt/layers/layernorm.py
View file @
729b7edf
...
...
@@ -52,8 +52,13 @@ if _is_cuda:
gemma_rmsnorm
,
rmsnorm
,
)
elif
_is_xpu
:
from
sgl_kernel
import
(
fused_add_rmsnorm
,
gemma_fused_add_rmsnorm
,
gemma_rmsnorm
,
rmsnorm
,
)
if
_use_aiter
:
from
aiter
import
rmsnorm2d_fwd
as
rms_norm
from
aiter
import
rmsnorm2d_fwd_with_add
as
fused_add_rms_norm
...
...
@@ -216,6 +221,19 @@ class RMSNorm(CustomOp):
else
:
return
self
.
forward_native
(
x
,
residual
)
def
forward_xpu
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
if
self
.
variance_size_override
is
not
None
:
return
self
.
forward_native
(
x
,
residual
)
if
residual
is
not
None
:
fused_add_rmsnorm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
x
,
residual
out
=
rmsnorm
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
out
def
forward_with_allreduce_fusion
(
self
,
x
:
torch
.
Tensor
,
...
...
@@ -263,6 +281,19 @@ class GemmaRMSNorm(CustomOp):
if
_is_hip
:
self
.
_forward_method
=
self
.
forward_native
def
_forward_impl
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
if
residual
is
not
None
:
gemma_fused_add_rmsnorm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
x
,
residual
out
=
gemma_rmsnorm
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
out
def
forward_native
(
self
,
x
:
torch
.
Tensor
,
...
...
@@ -285,13 +316,7 @@ class GemmaRMSNorm(CustomOp):
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
if
residual
is
not
None
:
gemma_fused_add_rmsnorm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
x
,
residual
out
=
gemma_rmsnorm
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
out
return
self
.
_forward_impl
(
x
,
residual
)
def
forward_npu
(
self
,
...
...
@@ -305,6 +330,13 @@ class GemmaRMSNorm(CustomOp):
x
,
_
=
torch_npu
.
npu_gemma_rms_norm
(
x
,
self
.
weight
,
self
.
variance_epsilon
)
return
x
if
residual
is
None
else
(
x
,
residual
)
def
forward_xpu
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
return
self
.
_forward_impl
(
x
,
residual
)
class
Gemma3RMSNorm
(
CustomOp
):
def
__init__
(
self
,
dim
:
int
,
eps
:
float
=
1e-6
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment