Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
20bd2271
Unverified
Commit
20bd2271
authored
Oct 25, 2025
by
fzyzcjy
Committed by
GitHub
Oct 25, 2025
Browse files
Support true on-policy (#12058)
parent
64994980
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
151 additions
and
11 deletions
+151
-11
python/sglang/srt/layers/activation.py
python/sglang/srt/layers/activation.py
+6
-0
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+19
-4
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+5
-0
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+13
-1
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+12
-1
python/sglang/srt/models/qwen2.py
python/sglang/srt/models/qwen2.py
+22
-1
python/sglang/srt/models/qwen3.py
python/sglang/srt/models/qwen3.py
+34
-4
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+16
-0
test/srt/cpu/test_activation.py
test/srt/cpu/test_activation.py
+3
-0
test/srt/cpu/test_shared_expert.py
test/srt/cpu/test_shared_expert.py
+3
-0
test/srt/quant/test_block_int8.py
test/srt/quant/test_block_int8.py
+3
-0
test/srt/quant/test_int8_kernel.py
test/srt/quant/test_int8_kernel.py
+3
-0
test/srt/test_fused_moe.py
test/srt/test_fused_moe.py
+3
-0
test/srt/test_triton_fused_moe.py
test/srt/test_triton_fused_moe.py
+3
-0
test/srt/test_triton_moe_channel_fp8_kernel.py
test/srt/test_triton_moe_channel_fp8_kernel.py
+3
-0
test/srt/test_triton_moe_wna16.py
test/srt/test_triton_moe_wna16.py
+3
-0
No files found.
python/sglang/srt/layers/activation.py
View file @
20bd2271
...
@@ -29,6 +29,7 @@ from sglang.srt.distributed import (
...
@@ -29,6 +29,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
)
)
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
cpu_has_amx_support
,
cpu_has_amx_support
,
is_cpu
,
is_cpu
,
...
@@ -59,6 +60,11 @@ logger = logging.getLogger(__name__)
...
@@ -59,6 +60,11 @@ logger = logging.getLogger(__name__)
class
SiluAndMul
(
CustomOp
):
class
SiluAndMul
(
CustomOp
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
if
get_global_server_args
().
rl_on_policy_target
==
"fsdp"
:
self
.
_forward_method
=
self
.
forward_native
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
d
=
x
.
shape
[
-
1
]
//
2
d
=
x
.
shape
[
-
1
]
//
2
return
F
.
silu
(
x
[...,
:
d
])
*
x
[...,
d
:]
return
F
.
silu
(
x
[...,
:
d
])
*
x
[...,
d
:]
...
...
python/sglang/srt/layers/layernorm.py
View file @
20bd2271
...
@@ -73,9 +73,16 @@ class RMSNorm(CustomOp):
...
@@ -73,9 +73,16 @@ class RMSNorm(CustomOp):
hidden_size
:
int
,
hidden_size
:
int
,
eps
:
float
=
1e-6
,
eps
:
float
=
1e-6
,
var_hidden_size
:
Optional
[
int
]
=
None
,
var_hidden_size
:
Optional
[
int
]
=
None
,
cast_x_before_out_mul
:
bool
=
False
,
fp32_residual
:
bool
=
False
,
weight_dtype
:
Optional
=
None
,
override_orig_dtype
:
Optional
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
cast_x_before_out_mul
=
cast_x_before_out_mul
self
.
fp32_residual
=
fp32_residual
self
.
override_orig_dtype
=
override_orig_dtype
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
,
dtype
=
weight_dtype
))
self
.
variance_epsilon
=
eps
self
.
variance_epsilon
=
eps
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
variance_size_override
=
(
self
.
variance_size_override
=
(
...
@@ -165,10 +172,13 @@ class RMSNorm(CustomOp):
...
@@ -165,10 +172,13 @@ class RMSNorm(CustomOp):
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
if
not
x
.
is_contiguous
():
if
not
x
.
is_contiguous
():
x
=
x
.
contiguous
()
x
=
x
.
contiguous
()
orig_dtype
=
x
.
dtype
orig_dtype
=
self
.
override_orig_dtype
or
x
.
dtype
x
=
x
.
to
(
torch
.
float32
)
x
=
x
.
to
(
torch
.
float32
)
if
residual
is
not
None
:
if
residual
is
not
None
:
x
=
x
+
residual
.
to
(
torch
.
float32
)
x
=
x
+
residual
.
to
(
torch
.
float32
)
if
self
.
fp32_residual
:
residual
=
x
.
clone
()
else
:
residual
=
x
.
to
(
orig_dtype
)
residual
=
x
.
to
(
orig_dtype
)
hidden_size
=
x
.
shape
[
-
1
]
hidden_size
=
x
.
shape
[
-
1
]
...
@@ -191,7 +201,12 @@ class RMSNorm(CustomOp):
...
@@ -191,7 +201,12 @@ class RMSNorm(CustomOp):
variance
=
x_var
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
variance
=
x_var
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
x
=
x
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
if
self
.
cast_x_before_out_mul
:
x
=
self
.
weight
*
x
.
to
(
orig_dtype
)
else
:
x
=
(
x
*
self
.
weight
).
to
(
orig_dtype
)
x
=
(
x
*
self
.
weight
).
to
(
orig_dtype
)
if
residual
is
None
:
if
residual
is
None
:
return
x
return
x
else
:
else
:
...
...
python/sglang/srt/layers/logits_processor.py
View file @
20bd2271
...
@@ -593,6 +593,11 @@ class LogitsProcessor(nn.Module):
...
@@ -593,6 +593,11 @@ class LogitsProcessor(nn.Module):
None
,
# bias
None
,
# bias
True
,
# is_vnni
True
,
# is_vnni
)
)
elif
get_global_server_args
().
rl_on_policy_target
==
"fsdp"
:
# Due to tie-weight, we may not be able to change lm_head's weight dtype
logits
=
torch
.
matmul
(
hidden_states
.
bfloat16
(),
lm_head
.
weight
.
T
.
bfloat16
()
)
else
:
else
:
logits
=
torch
.
matmul
(
logits
=
torch
.
matmul
(
hidden_states
.
to
(
lm_head
.
weight
.
dtype
),
lm_head
.
weight
.
T
hidden_states
.
to
(
lm_head
.
weight
.
dtype
),
lm_head
.
weight
.
T
...
...
python/sglang/srt/layers/rotary_embedding.py
View file @
20bd2271
...
@@ -11,6 +11,7 @@ import triton
...
@@ -11,6 +11,7 @@ import triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
cpu_has_amx_support
,
cpu_has_amx_support
,
get_bool_env_var
,
get_bool_env_var
,
...
@@ -124,18 +125,29 @@ class RotaryEmbedding(CustomOp):
...
@@ -124,18 +125,29 @@ class RotaryEmbedding(CustomOp):
self
.
cos_sin_cache
:
torch
.
Tensor
self
.
cos_sin_cache
:
torch
.
Tensor
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
if
get_global_server_args
().
rl_on_policy_target
==
"fsdp"
:
self
.
_forward_method
=
self
.
forward_native
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
"""Compute the inverse frequency."""
"""Compute the inverse frequency."""
# NOTE(woosuk): To exactly match the HF implementation, we need to
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
# a slight numerical difference between the HF implementation and ours.
init_device
=
(
"cpu"
if
get_global_server_args
().
rl_on_policy_target
==
"fsdp"
else
None
)
inv_freq
=
1.0
/
(
inv_freq
=
1.0
/
(
base
base
**
(
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
)
/
self
.
rotary_dim
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
,
device
=
init_device
)
/
self
.
rotary_dim
)
)
)
)
if
get_global_server_args
().
rl_on_policy_target
==
"fsdp"
:
inv_freq
=
inv_freq
.
cuda
()
return
inv_freq
return
inv_freq
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
...
...
python/sglang/srt/layers/sampler.py
View file @
20bd2271
...
@@ -102,6 +102,14 @@ class Sampler(nn.Module):
...
@@ -102,6 +102,14 @@ class Sampler(nn.Module):
if
return_logprob
and
SGLANG_RETURN_ORIGINAL_LOGPROB
:
if
return_logprob
and
SGLANG_RETURN_ORIGINAL_LOGPROB
:
probs_without_temp_scaling
=
torch
.
softmax
(
logits
,
dim
=-
1
)
probs_without_temp_scaling
=
torch
.
softmax
(
logits
,
dim
=-
1
)
if
get_global_server_args
().
rl_on_policy_target
==
"fsdp"
:
logits_div_temperature
=
(
logits
.
bfloat16
().
div
(
sampling_info
.
temperatures
).
bfloat16
()
)
logprobs_via_logsoftmax_kernel
=
torch
.
log_softmax
(
logits_div_temperature
,
dim
=-
1
)
# Post process logits
# Post process logits
logits
.
div_
(
sampling_info
.
temperatures
)
logits
.
div_
(
sampling_info
.
temperatures
)
logits
[:]
=
torch
.
softmax
(
logits
,
dim
=-
1
)
logits
[:]
=
torch
.
softmax
(
logits
,
dim
=-
1
)
...
@@ -148,8 +156,11 @@ class Sampler(nn.Module):
...
@@ -148,8 +156,11 @@ class Sampler(nn.Module):
)
)
if
return_logprob
:
if
return_logprob
:
if
get_global_server_args
().
rl_on_policy_target
==
"fsdp"
:
logprobs
=
logprobs_via_logsoftmax_kernel
del
logprobs_via_logsoftmax_kernel
# clamp to avoid -inf
# clamp to avoid -inf
if
SGLANG_RETURN_ORIGINAL_LOGPROB
:
el
if
SGLANG_RETURN_ORIGINAL_LOGPROB
:
logprobs
=
torch
.
log
(
probs_without_temp_scaling
).
clamp
(
logprobs
=
torch
.
log
(
probs_without_temp_scaling
).
clamp
(
min
=
torch
.
finfo
(
probs_without_temp_scaling
.
dtype
).
min
min
=
torch
.
finfo
(
probs_without_temp_scaling
.
dtype
).
min
)
)
...
...
python/sglang/srt/models/qwen2.py
View file @
20bd2271
...
@@ -49,6 +49,7 @@ from sglang.srt.model_loader.weight_utils import (
...
@@ -49,6 +49,7 @@ from sglang.srt.model_loader.weight_utils import (
default_weight_loader
,
default_weight_loader
,
kv_cache_scales_loader
,
kv_cache_scales_loader
,
)
)
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
add_prefix
,
make_layers
from
sglang.srt.utils
import
add_prefix
,
make_layers
Qwen2Config
=
None
Qwen2Config
=
None
...
@@ -89,6 +90,9 @@ class Qwen2MLP(nn.Module):
...
@@ -89,6 +90,9 @@ class Qwen2MLP(nn.Module):
self
.
act_fn
=
SiluAndMul
()
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
if
get_global_server_args
().
rl_on_policy_target
==
"fsdp"
:
x
=
x
.
bfloat16
()
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
x
,
_
=
self
.
down_proj
(
x
)
...
@@ -275,6 +279,11 @@ class Qwen2Model(nn.Module):
...
@@ -275,6 +279,11 @@ class Qwen2Model(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
enable_tp
=
not
is_dp_attention_enabled
(),
enable_tp
=
not
is_dp_attention_enabled
(),
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
params_dtype
=
(
torch
.
float32
if
get_global_server_args
().
rl_on_policy_target
==
"fsdp"
else
None
),
)
)
else
:
else
:
self
.
embed_tokens
=
PPMissingLayer
()
self
.
embed_tokens
=
PPMissingLayer
()
...
@@ -295,7 +304,19 @@ class Qwen2Model(nn.Module):
...
@@ -295,7 +304,19 @@ class Qwen2Model(nn.Module):
prefix
=
add_prefix
(
"layers"
,
prefix
),
prefix
=
add_prefix
(
"layers"
,
prefix
),
)
)
if
self
.
pp_group
.
is_last_rank
:
if
self
.
pp_group
.
is_last_rank
:
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
norm_kwargs
=
(
dict
(
weight_dtype
=
torch
.
float32
,
cast_x_before_out_mul
=
True
,
override_orig_dtype
=
torch
.
float32
,
fp32_residual
=
True
,
)
if
get_global_server_args
().
rl_on_policy_target
==
"fsdp"
else
{}
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
,
**
norm_kwargs
)
else
:
else
:
self
.
norm
=
PPMissingLayer
(
return_tuple
=
True
)
self
.
norm
=
PPMissingLayer
(
return_tuple
=
True
)
...
...
python/sglang/srt/models/qwen3.py
View file @
20bd2271
...
@@ -29,6 +29,7 @@ from sglang.srt.model_loader.weight_utils import (
...
@@ -29,6 +29,7 @@ from sglang.srt.model_loader.weight_utils import (
)
)
from
sglang.srt.models.qwen2
import
Qwen2MLP
as
Qwen3MLP
from
sglang.srt.models.qwen2
import
Qwen2MLP
as
Qwen3MLP
from
sglang.srt.models.qwen2
import
Qwen2Model
from
sglang.srt.models.qwen2
import
Qwen2Model
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
add_prefix
,
add_prefix
,
get_cmo_stream
,
get_cmo_stream
,
...
@@ -88,8 +89,16 @@ class Qwen3Attention(nn.Module):
...
@@ -88,8 +89,16 @@ class Qwen3Attention(nn.Module):
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
q_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
norm_kwargs
=
(
self
.
k_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
)
dict
(
weight_dtype
=
torch
.
float32
,
cast_x_before_out_mul
=
True
,
)
if
get_global_server_args
().
rl_on_policy_target
==
"fsdp"
else
{}
)
self
.
q_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
,
**
norm_kwargs
)
self
.
k_norm
=
RMSNorm
(
self
.
head_dim
,
eps
=
rms_norm_eps
,
**
norm_kwargs
)
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
hidden_size
,
...
@@ -158,10 +167,18 @@ class Qwen3Attention(nn.Module):
...
@@ -158,10 +167,18 @@ class Qwen3Attention(nn.Module):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
get_global_server_args
().
rl_on_policy_target
==
"fsdp"
:
hidden_states
=
hidden_states
.
bfloat16
()
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
if
get_global_server_args
().
rl_on_policy_target
==
"fsdp"
:
q
=
q
.
to
(
torch
.
bfloat16
)
k
=
k
.
to
(
torch
.
bfloat16
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -204,9 +221,22 @@ class Qwen3DecoderLayer(nn.Module):
...
@@ -204,9 +221,22 @@ class Qwen3DecoderLayer(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"mlp"
,
prefix
),
prefix
=
add_prefix
(
"mlp"
,
prefix
),
)
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
norm_kwargs
=
(
dict
(
weight_dtype
=
torch
.
float32
,
cast_x_before_out_mul
=
True
,
override_orig_dtype
=
torch
.
float32
,
fp32_residual
=
True
,
)
if
get_global_server_args
().
rl_on_policy_target
==
"fsdp"
else
{}
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
,
**
norm_kwargs
)
self
.
post_attention_layernorm
=
RMSNorm
(
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
,
**
norm_kwargs
)
)
self
.
layer_scatter_modes
=
LayerScatterModes
.
init_new
(
self
.
layer_scatter_modes
=
LayerScatterModes
.
init_new
(
...
...
python/sglang/srt/server_args.py
View file @
20bd2271
...
@@ -472,6 +472,7 @@ class ServerArgs:
...
@@ -472,6 +472,7 @@ class ServerArgs:
enable_return_hidden_states
:
bool
=
False
enable_return_hidden_states
:
bool
=
False
scheduler_recv_interval
:
int
=
1
scheduler_recv_interval
:
int
=
1
numa_node
:
Optional
[
List
[
int
]]
=
None
numa_node
:
Optional
[
List
[
int
]]
=
None
rl_on_policy_target
:
Optional
[
str
]
=
None
enable_deterministic_inference
:
bool
=
False
enable_deterministic_inference
:
bool
=
False
# Dynamic batch tokenizer
# Dynamic batch tokenizer
...
@@ -1526,6 +1527,14 @@ class ServerArgs:
...
@@ -1526,6 +1527,14 @@ class ServerArgs:
)
)
def
_handle_deterministic_inference
(
self
):
def
_handle_deterministic_inference
(
self
):
if
self
.
rl_on_policy_target
is
not
None
:
logger
.
warning
(
"Enable deterministic inference because of rl_on_policy_target."
)
self
.
enable_deterministic_inference
=
True
# TODO remove this environment variable as a whole
os
.
environ
[
"SGLANG_ENABLE_DETERMINISTIC_INFERENCE"
]
=
"1"
if
self
.
enable_deterministic_inference
:
if
self
.
enable_deterministic_inference
:
# Check sampling backend
# Check sampling backend
self
.
sampling_backend
=
"pytorch"
self
.
sampling_backend
=
"pytorch"
...
@@ -3300,6 +3309,13 @@ class ServerArgs:
...
@@ -3300,6 +3309,13 @@ class ServerArgs:
)
)
# For deterministic inference
# For deterministic inference
parser
.
add_argument
(
"--rl-on-policy-target"
,
type
=
str
,
default
=
ServerArgs
.
rl_on_policy_target
,
choices
=
[
"fsdp"
],
help
=
"The training system that SGLang needs to match for true on-policy."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--enable-deterministic-inference"
,
"--enable-deterministic-inference"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
...
test/srt/cpu/test_activation.py
View file @
20bd2271
...
@@ -6,6 +6,7 @@ import torch
...
@@ -6,6 +6,7 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
utils
import
GeluAndMul
,
SiluAndMul
,
precision
from
utils
import
GeluAndMul
,
SiluAndMul
,
precision
from
sglang.srt.server_args
import
ServerArgs
,
set_global_server_args_for_scheduler
from
sglang.test.test_utils
import
CustomTestCase
from
sglang.test.test_utils
import
CustomTestCase
torch
.
manual_seed
(
1234
)
torch
.
manual_seed
(
1234
)
...
@@ -17,6 +18,8 @@ class TestActivation(CustomTestCase):
...
@@ -17,6 +18,8 @@ class TestActivation(CustomTestCase):
dtype
=
[
torch
.
float16
,
torch
.
bfloat16
]
dtype
=
[
torch
.
float16
,
torch
.
bfloat16
]
def
_silu_and_mul_test
(
self
,
m
,
n
,
dtype
):
def
_silu_and_mul_test
(
self
,
m
,
n
,
dtype
):
set_global_server_args_for_scheduler
(
ServerArgs
(
model_path
=
"dummy"
))
x
=
torch
.
randn
([
m
,
n
],
dtype
=
dtype
)
x
=
torch
.
randn
([
m
,
n
],
dtype
=
dtype
)
out
=
torch
.
ops
.
sgl_kernel
.
silu_and_mul_cpu
(
x
)
out
=
torch
.
ops
.
sgl_kernel
.
silu_and_mul_cpu
(
x
)
...
...
test/srt/cpu/test_shared_expert.py
View file @
20bd2271
...
@@ -20,6 +20,7 @@ from utils import (
...
@@ -20,6 +20,7 @@ from utils import (
torch_w8a8_per_column_moe
,
torch_w8a8_per_column_moe
,
)
)
from
sglang.srt.server_args
import
ServerArgs
,
set_global_server_args_for_scheduler
from
sglang.test.test_utils
import
CustomTestCase
from
sglang.test.test_utils
import
CustomTestCase
torch
.
manual_seed
(
1234
)
torch
.
manual_seed
(
1234
)
...
@@ -149,6 +150,8 @@ class TestSharedExpert(CustomTestCase):
...
@@ -149,6 +150,8 @@ class TestSharedExpert(CustomTestCase):
self
.
_int8_shared_expert
(
*
params
)
self
.
_int8_shared_expert
(
*
params
)
def
_fp8_shared_expert
(
self
,
M
,
N
,
K
,
routed_scaling_factor
):
def
_fp8_shared_expert
(
self
,
M
,
N
,
K
,
routed_scaling_factor
):
set_global_server_args_for_scheduler
(
ServerArgs
(
model_path
=
"dummy"
))
dtype
=
torch
.
bfloat16
dtype
=
torch
.
bfloat16
prepack
=
True
prepack
=
True
...
...
test/srt/quant/test_block_int8.py
View file @
20bd2271
...
@@ -6,6 +6,7 @@ import torch
...
@@ -6,6 +6,7 @@ import torch
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_moe
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_moe
from
sglang.srt.layers.moe.topk
import
TopKConfig
,
select_experts
from
sglang.srt.layers.moe.topk
import
TopKConfig
,
select_experts
from
sglang.srt.server_args
import
ServerArgs
,
set_global_server_args_for_scheduler
from
sglang.test.test_utils
import
CustomTestCase
from
sglang.test.test_utils
import
CustomTestCase
...
@@ -96,6 +97,8 @@ def native_w8a8_block_int8_matmul(A, B, As, Bs, block_size, output_dtype=torch.f
...
@@ -96,6 +97,8 @@ def native_w8a8_block_int8_matmul(A, B, As, Bs, block_size, output_dtype=torch.f
def
torch_w8a8_block_int8_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
topk
,
block_shape
):
def
torch_w8a8_block_int8_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
topk
,
block_shape
):
"""This function performs fused moe with block-wise quantization using native torch."""
"""This function performs fused moe with block-wise quantization using native torch."""
set_global_server_args_for_scheduler
(
ServerArgs
(
model_path
=
"dummy"
))
B
,
D
=
a
.
shape
B
,
D
=
a
.
shape
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
out
=
torch
.
zeros
(
B
*
topk
,
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
out
=
torch
.
zeros
(
B
*
topk
,
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
...
...
test/srt/quant/test_int8_kernel.py
View file @
20bd2271
...
@@ -7,6 +7,7 @@ from sglang.srt.layers.activation import SiluAndMul
...
@@ -7,6 +7,7 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_moe
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_moe
from
sglang.srt.layers.moe.topk
import
TopKConfig
,
select_experts
from
sglang.srt.layers.moe.topk
import
TopKConfig
,
select_experts
from
sglang.srt.layers.quantization.int8_kernel
import
per_token_quant_int8
from
sglang.srt.layers.quantization.int8_kernel
import
per_token_quant_int8
from
sglang.srt.server_args
import
ServerArgs
,
set_global_server_args_for_scheduler
from
sglang.test.test_utils
import
CustomTestCase
from
sglang.test.test_utils
import
CustomTestCase
...
@@ -35,6 +36,8 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
...
@@ -35,6 +36,8 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
def
torch_w8a8_per_column_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
topk
):
def
torch_w8a8_per_column_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
topk
):
"""This function performs fused moe with per-column int8 quantization using native torch."""
"""This function performs fused moe with per-column int8 quantization using native torch."""
set_global_server_args_for_scheduler
(
ServerArgs
(
model_path
=
"dummy"
))
B
,
D
=
a
.
shape
B
,
D
=
a
.
shape
# Perform per-token quantization
# Perform per-token quantization
a_q
,
a_s
=
per_token_quant_int8
(
a
)
a_q
,
a_s
=
per_token_quant_int8
(
a
)
...
...
test/srt/test_fused_moe.py
View file @
20bd2271
...
@@ -9,6 +9,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
...
@@ -9,6 +9,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from
sglang.srt.layers.moe.topk
import
TopKConfig
,
select_experts
from
sglang.srt.layers.moe.topk
import
TopKConfig
,
select_experts
from
sglang.srt.layers.quantization.fp8_kernel
import
is_fp8_fnuz
from
sglang.srt.layers.quantization.fp8_kernel
import
is_fp8_fnuz
from
sglang.srt.layers.quantization.fp8_utils
import
normalize_e4m3fn_to_e4m3fnuz
from
sglang.srt.layers.quantization.fp8_utils
import
normalize_e4m3fn_to_e4m3fnuz
from
sglang.srt.server_args
import
ServerArgs
,
set_global_server_args_for_scheduler
from
sglang.srt.utils
import
is_hip
from
sglang.srt.utils
import
is_hip
from
sglang.test.test_utils
import
CustomTestCase
from
sglang.test.test_utils
import
CustomTestCase
...
@@ -63,6 +64,8 @@ class TestFusedMOE(CustomTestCase):
...
@@ -63,6 +64,8 @@ class TestFusedMOE(CustomTestCase):
a1_scale
=
None
,
a1_scale
=
None
,
a2_scale
=
None
,
a2_scale
=
None
,
):
):
set_global_server_args_for_scheduler
(
ServerArgs
(
model_path
=
"dummy"
))
B
,
D
=
a
.
shape
B
,
D
=
a
.
shape
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
out
=
torch
.
zeros
(
B
*
topk
,
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
out
=
torch
.
zeros
(
B
*
topk
,
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
...
...
test/srt/test_triton_fused_moe.py
View file @
20bd2271
...
@@ -9,6 +9,7 @@ from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
...
@@ -9,6 +9,7 @@ from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
from
sglang.srt.layers.moe.moe_runner.triton_kernels
import
TritonKernelsQuantInfo
from
sglang.srt.layers.moe.moe_runner.triton_kernels
import
TritonKernelsQuantInfo
from
sglang.srt.layers.moe.token_dispatcher.standard
import
StandardDispatchOutput
from
sglang.srt.layers.moe.token_dispatcher.standard
import
StandardDispatchOutput
from
sglang.srt.layers.moe.topk
import
TopK
,
TopKOutputFormat
from
sglang.srt.layers.moe.topk
import
TopK
,
TopKOutputFormat
from
sglang.srt.server_args
import
ServerArgs
,
set_global_server_args_for_scheduler
from
sglang.test.test_utils
import
CustomTestCase
from
sglang.test.test_utils
import
CustomTestCase
...
@@ -56,6 +57,8 @@ class TestFusedMOE(CustomTestCase):
...
@@ -56,6 +57,8 @@ class TestFusedMOE(CustomTestCase):
topk
,
topk
,
return_per_expert
:
bool
=
False
,
return_per_expert
:
bool
=
False
,
):
):
set_global_server_args_for_scheduler
(
ServerArgs
(
model_path
=
"dummy"
))
B
,
D
=
a
.
shape
B
,
D
=
a
.
shape
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
out
=
torch
.
zeros
(
B
*
topk
,
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
out
=
torch
.
zeros
(
B
*
topk
,
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
...
...
test/srt/test_triton_moe_channel_fp8_kernel.py
View file @
20bd2271
...
@@ -7,6 +7,7 @@ from sglang.srt.layers.activation import SiluAndMul
...
@@ -7,6 +7,7 @@ from sglang.srt.layers.activation import SiluAndMul
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_moe
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_moe
from
sglang.srt.layers.moe.topk
import
TopKConfig
,
select_experts
from
sglang.srt.layers.moe.topk
import
TopKConfig
,
select_experts
from
sglang.srt.layers.quantization.fp8_kernel
import
scaled_fp8_quant
from
sglang.srt.layers.quantization.fp8_kernel
import
scaled_fp8_quant
from
sglang.srt.server_args
import
ServerArgs
,
set_global_server_args_for_scheduler
from
sglang.test.test_utils
import
CustomTestCase
from
sglang.test.test_utils
import
CustomTestCase
...
@@ -40,6 +41,8 @@ def fp8_mask(a, mask):
...
@@ -40,6 +41,8 @@ def fp8_mask(a, mask):
def
torch_w8a8_per_column_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
topk
):
def
torch_w8a8_per_column_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
topk
):
"""This function performs fused moe with per-column int8 quantization using native torch."""
"""This function performs fused moe with per-column int8 quantization using native torch."""
set_global_server_args_for_scheduler
(
ServerArgs
(
model_path
=
"dummy"
))
B
,
D
=
a
.
shape
B
,
D
=
a
.
shape
# Perform per-token quantization
# Perform per-token quantization
a_q
,
a_s
=
scaled_fp8_quant
(
a
,
use_per_token_if_dynamic
=
True
)
a_q
,
a_s
=
scaled_fp8_quant
(
a
,
use_per_token_if_dynamic
=
True
)
...
...
test/srt/test_triton_moe_wna16.py
View file @
20bd2271
...
@@ -6,6 +6,7 @@ import torch
...
@@ -6,6 +6,7 @@ import torch
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_moe
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_moe
from
sglang.srt.layers.moe.topk
import
TopKConfig
,
select_experts
from
sglang.srt.layers.moe.topk
import
TopKConfig
,
select_experts
from
sglang.srt.server_args
import
ServerArgs
,
set_global_server_args_for_scheduler
NUM_EXPERTS
=
[
8
,
64
]
NUM_EXPERTS
=
[
8
,
64
]
TOP_KS
=
[
2
,
6
]
TOP_KS
=
[
2
,
6
]
...
@@ -116,6 +117,8 @@ def quantize_weights(
...
@@ -116,6 +117,8 @@ def quantize_weights(
def
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
):
def
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
):
set_global_server_args_for_scheduler
(
ServerArgs
(
model_path
=
"dummy"
))
B
,
D
=
a
.
shape
B
,
D
=
a
.
shape
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
out
=
torch
.
zeros
(
B
*
topk
,
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
out
=
torch
.
zeros
(
B
*
topk
,
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment