Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f453578b
Commit
f453578b
authored
Nov 10, 2025
by
renzhc
Browse files
添加环境变量SGLANG_USE_LIGHTOP 控制 lightop的融合rotaty_emb和moe_gated算子,默认禁用;修复RMSNorm:forward_hip中的错误逻辑
parent
a5156371
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
146 additions
and
3 deletions
+146
-3
python/sglang/srt/environ.py
python/sglang/srt/environ.py
+3
-0
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+2
-2
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+17
-0
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+124
-1
No files found.
python/sglang/srt/environ.py
View file @
f453578b
...
@@ -163,6 +163,9 @@ class Envs:
...
@@ -163,6 +163,9 @@ class Envs:
SGLANG_USE_AITER
=
EnvBool
(
False
)
SGLANG_USE_AITER
=
EnvBool
(
False
)
SGLANG_ROCM_FUSED_DECODE_MLA
=
EnvBool
(
False
)
SGLANG_ROCM_FUSED_DECODE_MLA
=
EnvBool
(
False
)
SGLANG_ROCM_DISABLE_LINEARQUANT
=
EnvBool
(
False
)
SGLANG_ROCM_DISABLE_LINEARQUANT
=
EnvBool
(
False
)
# DCU Lightop
SGLANG_USE_LIGHTOP
=
EnvBool
(
False
)
# Quantization
# Quantization
SGLANG_INT4_WEIGHT
=
EnvBool
(
False
)
SGLANG_INT4_WEIGHT
=
EnvBool
(
False
)
...
...
python/sglang/srt/layers/layernorm.py
View file @
f453578b
...
@@ -167,8 +167,6 @@ class RMSNorm(CustomOp):
...
@@ -167,8 +167,6 @@ class RMSNorm(CustomOp):
if
residual
is
not
None
:
if
residual
is
not
None
:
try
:
try
:
output
=
torch
.
empty_like
(
x
)
residual_out
=
torch
.
empty_like
(
x
)
fused_add_rms_norm
(
fused_add_rms_norm
(
x
,
x
,
residual
,
residual
,
...
@@ -177,6 +175,8 @@ class RMSNorm(CustomOp):
...
@@ -177,6 +175,8 @@ class RMSNorm(CustomOp):
)
)
return
x
,
residual
return
x
,
residual
except
TypeError
:
except
TypeError
:
output
=
torch
.
empty_like
(
x
)
residual_out
=
torch
.
empty_like
(
x
)
fused_add_rms_norm
(
fused_add_rms_norm
(
output
,
output
,
x
,
x
,
...
...
python/sglang/srt/layers/moe/topk.py
View file @
f453578b
...
@@ -28,6 +28,8 @@ from typing import (
...
@@ -28,6 +28,8 @@ from typing import (
runtime_checkable
,
runtime_checkable
,
)
)
from
numpy
import
dtype
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -68,6 +70,7 @@ _is_cpu = is_cpu()
...
@@ -68,6 +70,7 @@ _is_cpu = is_cpu()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_npu
=
is_npu
()
_is_npu
=
is_npu
()
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
_use_lightop
=
get_bool_env_var
(
"SGLANG_USE_LIGHTOP"
)
if
_is_cuda
:
if
_is_cuda
:
from
sgl_kernel
import
moe_fused_gate
from
sgl_kernel
import
moe_fused_gate
...
@@ -79,6 +82,8 @@ if _use_aiter:
...
@@ -79,6 +82,8 @@ if _use_aiter:
from
aiter
import
biased_grouped_topk
as
aiter_biased_grouped_topk
from
aiter
import
biased_grouped_topk
as
aiter_biased_grouped_topk
except
ImportError
:
except
ImportError
:
raise
ImportError
(
"aiter is required when SGLANG_USE_AITER is set to True"
)
raise
ImportError
(
"aiter is required when SGLANG_USE_AITER is set to True"
)
if
_use_lightop
:
from
lightop
import
op
as
op
if
_is_npu
:
if
_is_npu
:
import
torch_npu
import
torch_npu
...
@@ -725,6 +730,18 @@ def biased_grouped_topk_gpu(
...
@@ -725,6 +730,18 @@ def biased_grouped_topk_gpu(
routed_scaling_factor
,
routed_scaling_factor
,
)
)
return
topk_weights
,
topk_ids
return
topk_weights
,
topk_ids
elif
_use_lightop
:
assert
not
apply_routed_scaling_factor_on_output
,
"Not implemented"
topk_weights
,
topk_ids
=
op
.
moe_fused_gate
(
gating_output
.
to
(
dtype
=
torch
.
float32
),
# or bfloat16
correction_bias
,
num_expert_group
,
topk_group
,
topk
,
0
,
# 0 in vllm
routed_scaling_factor
,
)
return
topk_weights
,
topk_ids
else
:
else
:
return
biased_grouped_topk_impl
(
return
biased_grouped_topk_impl
(
hidden_states
,
hidden_states
,
...
...
python/sglang/srt/layers/rotary_embedding.py
View file @
f453578b
...
@@ -22,6 +22,8 @@ from sglang.srt.utils import (
...
@@ -22,6 +22,8 @@ from sglang.srt.utils import (
is_xpu
,
is_xpu
,
)
)
from
sglang.srt.utils
import
direct_register_custom_op
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
...
@@ -29,6 +31,7 @@ _is_npu = is_npu()
...
@@ -29,6 +31,7 @@ _is_npu = is_npu()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu
=
is_cpu
()
_is_cpu
=
is_cpu
()
_is_xpu
=
is_xpu
()
_is_xpu
=
is_xpu
()
_use_lightop
=
get_bool_env_var
(
"SGLANG_USE_LIGHTOP"
)
if
_is_cuda
:
if
_is_cuda
:
from
sgl_kernel
import
FusedSetKVBufferArg
,
apply_rope_with_cos_sin_cache_inplace
from
sgl_kernel
import
FusedSetKVBufferArg
,
apply_rope_with_cos_sin_cache_inplace
...
@@ -57,6 +60,34 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
...
@@ -57,6 +60,34 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
x
=
torch
.
stack
((
-
x2
,
x1
),
dim
=-
1
)
x
=
torch
.
stack
((
-
x2
,
x1
),
dim
=-
1
)
return
x
.
flatten
(
-
2
)
return
x
.
flatten
(
-
2
)
# for dcu
@
triton
.
jit
def
deepseek_scaling_rotary_emb_kernel_gptj
(
cos_sin
,
q
,
stride1
:
int
,
stride2
:
int
,
stride_cs
:
int
,
dim1
:
int
,
dim2
:
int
,
dim3
:
int
,
BLOCK_SIZE
:
tl
.
constexpr
):
pid0
=
tl
.
program_id
(
0
)
pid1
=
tl
.
program_id
(
1
)
pid2
=
tl
.
program_id
(
2
)
offsets_cs
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
+
pid2
*
BLOCK_SIZE
offsets_q
=
tl
.
arange
(
0
,
BLOCK_SIZE
*
2
)
+
pid2
*
BLOCK_SIZE
*
2
offsets
=
pid0
*
stride1
+
pid1
*
stride2
+
offsets_q
mask
=
offsets_cs
<
dim3
mask2
=
offsets_q
<
dim3
*
2
v_cos
=
tl
.
load
(
cos_sin
+
pid0
*
stride_cs
+
offsets_cs
,
mask
=
mask
)
v_cos2
=
tl
.
interleave
(
v_cos
,
v_cos
)
v_sin
=
tl
.
load
(
cos_sin
+
pid0
*
stride_cs
+
dim3
+
offsets_cs
,
mask
=
mask
)
v_sin2
=
tl
.
interleave
(
v_sin
,
v_sin
)
x12
=
tl
.
load
(
q
+
offsets
,
mask
=
mask2
)
x1
,
x2
=
tl
.
split
(
x12
.
reshape
([
BLOCK_SIZE
,
2
]))
# we are both reading and writing 'q'; make sure all warps are in sync
tl
.
debug_barrier
()
x12_
=
tl
.
ravel
(
tl
.
join
(
-
x2
,
x1
))
x12
=
x12
*
v_cos2
+
x12_
*
v_sin2
tl
.
store
(
q
+
offsets
,
x12
,
mask
=
mask2
)
def
_apply_rotary_emb
(
def
_apply_rotary_emb
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
@@ -736,7 +767,10 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
...
@@ -736,7 +767,10 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
# Re-dispatch
# Re-dispatch
if
_is_hip
:
if
_is_hip
:
self
.
_forward_method
=
self
.
forward_native
if
_use_lightop
:
self
.
_forward_method
=
self
.
forward_dcu
else
:
self
.
_forward_method
=
self
.
forward_native
def
_compute_inv_freq
(
self
,
scaling_factor
:
float
)
->
torch
.
Tensor
:
def
_compute_inv_freq
(
self
,
scaling_factor
:
float
)
->
torch
.
Tensor
:
pos_freqs
=
self
.
base
**
(
pos_freqs
=
self
.
base
**
(
...
@@ -778,6 +812,24 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
...
@@ -778,6 +812,24 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
sin
=
freqs
.
sin
()
*
self
.
mscale
sin
=
freqs
.
sin
()
*
self
.
mscale
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
return
cache
def
rotary_embedding_deepseek_fuse
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox_style
:
bool
)
->
None
:
from
lightop
import
op
op
.
rotary_embedding_deepseek_fuse
(
positions
,
query
,
key
,
head_size
,
cos_sin_cache
,
is_neox_style
)
def
rotary_embedding_deepseek_fuse_fake
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox_style
:
bool
)
->
None
:
pass
direct_register_custom_op
(
op_name
=
"rotary_embedding_deepseek_fuse"
,
op_func
=
rotary_embedding_deepseek_fuse
,
mutates_args
=
[
"query"
,
"key"
],
fake_impl
=
rotary_embedding_deepseek_fuse_fake
,
)
def
forward_native
(
def
forward_native
(
self
,
self
,
...
@@ -819,6 +871,77 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
...
@@ -819,6 +871,77 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
query
=
query_rot
query
=
query_rot
key
=
key_rot
key
=
key_rot
return
query
.
to
(
dtype
),
key
.
to
(
dtype
)
return
query
.
to
(
dtype
),
key
.
to
(
dtype
)
def
forward_dcu
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
key
is
not
None
if
self
.
cos_sin_cache
.
device
!=
positions
.
device
:
self
.
cos_sin_cache
:
torch
.
Tensor
=
self
.
cos_sin_cache
.
to
(
positions
.
device
)
cos_sin
=
self
.
cos_sin_cache
[
torch
.
add
(
positions
,
offsets
)
if
offsets
is
not
None
else
positions
]
if
query
.
device
.
type
==
'cuda'
and
not
self
.
is_neox_style
:
# not self.reference ?
assert
len
(
query
.
shape
)
==
3
def
call
(
q
):
BLOCK_SIZE
=
64
grid
=
(
q
.
shape
[
-
3
],
q
.
shape
[
-
2
],
triton
.
cdiv
(
self
.
rotary_dim
//
2
,
BLOCK_SIZE
),
)
deepseek_scaling_rotary_emb_kernel_gptj
[
grid
](
cos_sin
,
q
,
stride1
=
q
.
stride
()[
-
3
],
stride2
=
q
.
stride
()[
-
2
],
stride_cs
=
cos_sin
.
stride
()[
-
2
],
dim1
=
q
.
shape
[
0
],
dim2
=
q
.
shape
[
1
],
dim3
=
self
.
rotary_dim
//
2
,
BLOCK_SIZE
=
BLOCK_SIZE
,
num_warps
=
1
)
if
_use_lightop
:
torch
.
ops
.
sglang
.
rotary_embedding_deepseek_fuse
(
positions
,
query
,
key
,
self
.
head_size
,
self
.
cos_sin_cache
,
self
.
is_neox_style
)
else
:
call
(
query
)
call
(
key
)
return
query
,
key
else
:
query_rot
=
query
[...,
:
self
.
rotary_dim
]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
if
self
.
rotary_dim
<
self
.
head_size
:
query_pass
=
query
[...,
self
.
rotary_dim
:]
key_pass
=
key
[...,
self
.
rotary_dim
:]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
if
self
.
is_neox_style
:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos
=
cos
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
else
:
cos
=
cos
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
rotate_fn
=
_rotate_neox
if
self
.
is_neox_style
else
_rotate_gptj
query_rot
=
query_rot
*
cos
+
rotate_fn
(
query_rot
)
*
sin
key_rot
=
key_rot
*
cos
+
rotate_fn
(
key_rot
)
*
sin
if
self
.
rotary_dim
<
self
.
head_size
:
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
)
else
:
query
=
query_rot
key
=
key_rot
return
query
,
key
def
forward_npu
(
def
forward_npu
(
self
,
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment