Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
506a2d59
Unverified
Commit
506a2d59
authored
Jun 25, 2025
by
ll819214
Committed by
GitHub
Jun 25, 2025
Browse files
npu fused op (#7386)
Co-authored-by:
Li Junwen
<
lijunwen13@hisilicon.com
>
parent
a07f8ae4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
70 additions
and
2 deletions
+70
-2
python/sglang/srt/custom_op.py
python/sglang/srt/custom_op.py
+7
-1
python/sglang/srt/layers/activation.py
python/sglang/srt/layers/activation.py
+7
-0
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+15
-0
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+41
-1
No files found.
python/sglang/srt/custom_op.py
View file @
506a2d59
from
torch
import
nn
from
torch
import
nn
from
sglang.srt.utils
import
cpu_has_amx_support
,
is_cpu
,
is_cuda
,
is_hip
from
sglang.srt.utils
import
cpu_has_amx_support
,
is_cpu
,
is_cuda
,
is_hip
,
is_npu
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
_is_cpu
=
is_cpu
()
_is_cpu
=
is_cpu
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_npu
=
is_npu
()
class
CustomOp
(
nn
.
Module
):
class
CustomOp
(
nn
.
Module
):
...
@@ -60,6 +61,9 @@ class CustomOp(nn.Module):
...
@@ -60,6 +61,9 @@ class CustomOp(nn.Module):
def
forward_cuda
(
self
,
*
args
,
**
kwargs
):
def
forward_cuda
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
raise
NotImplementedError
def
forward_npu
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
def
forward_hip
(
self
,
*
args
,
**
kwargs
):
def
forward_hip
(
self
,
*
args
,
**
kwargs
):
return
self
.
forward_cuda
(
*
args
,
**
kwargs
)
return
self
.
forward_cuda
(
*
args
,
**
kwargs
)
...
@@ -79,5 +83,7 @@ class CustomOp(nn.Module):
...
@@ -79,5 +83,7 @@ class CustomOp(nn.Module):
return
self
.
forward_hip
return
self
.
forward_hip
elif
_is_cpu
and
_is_cpu_amx_available
:
elif
_is_cpu
and
_is_cpu_amx_available
:
return
self
.
forward_cpu
return
self
.
forward_cpu
elif
_is_npu
:
return
self
.
forward_npu
else
:
else
:
return
self
.
forward_native
return
self
.
forward_native
python/sglang/srt/layers/activation.py
View file @
506a2d59
...
@@ -48,6 +48,9 @@ if _is_cuda:
...
@@ -48,6 +48,9 @@ if _is_cuda:
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
if
is_npu
():
import
torch_npu
class
SiluAndMul
(
CustomOp
):
class
SiluAndMul
(
CustomOp
):
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -70,6 +73,10 @@ class SiluAndMul(CustomOp):
...
@@ -70,6 +73,10 @@ class SiluAndMul(CustomOp):
else
:
else
:
return
self
.
forward_native
(
x
)
return
self
.
forward_native
(
x
)
def
forward_npu
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
out
=
torch_npu
.
npu_swiglu
(
x
)
return
out
class
GeluAndMul
(
CustomOp
):
class
GeluAndMul
(
CustomOp
):
def
__init__
(
self
,
approximate
=
"tanh"
):
def
__init__
(
self
,
approximate
=
"tanh"
):
...
...
python/sglang/srt/layers/layernorm.py
View file @
506a2d59
...
@@ -52,6 +52,9 @@ elif _is_hip:
...
@@ -52,6 +52,9 @@ elif _is_hip:
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
if
is_npu
():
import
torch_npu
class
RMSNorm
(
CustomOp
):
class
RMSNorm
(
CustomOp
):
def
__init__
(
def
__init__
(
...
@@ -76,6 +79,18 @@ class RMSNorm(CustomOp):
...
@@ -76,6 +79,18 @@ class RMSNorm(CustomOp):
out
=
rmsnorm
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
out
=
rmsnorm
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
out
return
out
def
forward_npu
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
if
residual
is
not
None
:
out
,
_
,
residual_out
=
torch_npu
.
npu_add_rms_norm
(
residual
,
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
out
,
residual_out
return
torch_npu
.
npu_rms_norm
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)[
0
]
def
forward_aiter
(
def
forward_aiter
(
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
...
python/sglang/srt/layers/rotary_embedding.py
View file @
506a2d59
...
@@ -8,7 +8,14 @@ import torch
...
@@ -8,7 +8,14 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.utils
import
cpu_has_amx_support
,
is_cpu
,
is_cuda
,
is_hip
,
is_npu
from
sglang.srt.utils
import
(
cpu_has_amx_support
,
get_bool_env_var
,
is_cpu
,
is_cuda
,
is_hip
,
is_npu
,
)
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
...
@@ -19,6 +26,9 @@ _is_cpu = is_cpu()
...
@@ -19,6 +26,9 @@ _is_cpu = is_cpu()
if
_is_cuda
:
if
_is_cuda
:
from
sgl_kernel
import
apply_rope_with_cos_sin_cache_inplace
from
sgl_kernel
import
apply_rope_with_cos_sin_cache_inplace
if
is_npu
():
import
torch_npu
def
_rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
...
@@ -152,6 +162,36 @@ class RotaryEmbedding(CustomOp):
...
@@ -152,6 +162,36 @@ class RotaryEmbedding(CustomOp):
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
return
query
,
key
def
forward_npu
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""A PyTorch-npu implementation of forward()."""
import
os
if
get_bool_env_var
(
"SGLANG_ENABLE_TORCH_COMPILE"
):
return
self
.
forward_native
(
positions
,
query
,
key
,
offsets
)
else
:
rotary_mode
=
"half"
if
self
.
is_neox_style
:
rotary_mode
=
"half"
else
:
rotary_mode
=
"interleave"
mrope_section
=
[
0
,
0
,
0
]
query_out
,
key_out
=
torch_npu
.
npu_mrope
(
positions
,
query
,
key
,
self
.
cos_sin_cache
,
self
.
head_size
,
mrope_section
=
mrope_section
,
rotary_mode
=
rotary_mode
,
)
return
query_out
,
key_out
def
forward_cpu
(
def
forward_cpu
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment