Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
769353e6
Commit
769353e6
authored
Nov 19, 2025
by
lizhigong
Browse files
Merge branch 'v0.5.4_rzc' into 'v0.5.4_dev'
支持w8a8 compile,注册自定义算子解决部分断图问题 See merge request OpenDAS/sglang!31
parents
0dc51b09
8da47f19
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
165 additions
and
27 deletions
+165
-27
python/sglang/srt/layers/attention/dcu_mla_backend.py
python/sglang/srt/layers/attention/dcu_mla_backend.py
+2
-2
python/sglang/srt/layers/attention/lightop_concat.py
python/sglang/srt/layers/attention/lightop_concat.py
+43
-8
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+114
-15
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
...ation/compressed_tensors/compressed_tensors_moe_marlin.py
+1
-0
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
...ompressed_tensors/schemes/compressed_tensors_w8a8_int8.py
+2
-0
python/sglang/srt/layers/quantization/slimquant_w4a8.py
python/sglang/srt/layers/quantization/slimquant_w4a8.py
+1
-1
python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
...n/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
+2
-1
No files found.
python/sglang/srt/layers/attention/dcu_mla_backend.py
View file @
769353e6
...
...
@@ -420,7 +420,7 @@ class DCUMLABackend(AttentionBackend):
)
return
o
@
torch
.
_dynamo
.
disable
()
@
torch
.
_dynamo
.
disable
()
# NOTE: FP8 cache decode不支持compile
def
forward_decode
(
self
,
q
:
torch
.
Tensor
,
...
...
@@ -475,7 +475,7 @@ class DCUMLABackend(AttentionBackend):
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
@
torch
.
_dynamo
.
disable
()
# NOTE: untested
@
torch
.
_dynamo
.
disable
()
def
forward_extend
(
self
,
q
:
torch
.
Tensor
,
...
...
python/sglang/srt/layers/attention/lightop_concat.py
View file @
769353e6
...
...
@@ -4,7 +4,7 @@ import warnings
import
torch
from
sglang.srt.utils
import
get_bool_env_var
from
sglang.srt.utils
import
get_bool_env_var
,
direct_register_custom_op
_USE_OPT_CAT
=
get_bool_env_var
(
"SGLANG_USE_OPT_CAT"
)
...
...
@@ -18,15 +18,50 @@ if _USE_OPT_CAT:
)
else
:
ds_cat
=
None
def
concat_decode_opt
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
dim
:
int
):
assert
dim
==
2
,
"tensor dim must be 3 and concat dim must be 2"
# TODO: 单独注册有些问题
def
ds_cat_wrapper
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
dim
:
int
,
mode
:
int
)
->
torch
.
Tensor
:
output_shape
=
list
(
A
.
shape
)
output_shape
[
dim
]
=
A
.
shape
[
dim
]
+
B
.
shape
[
dim
]
output_shape
[
dim
]
=
A
.
shape
[
dim
]
+
B
.
shape
[
dim
]
C
=
torch
.
empty
(
output_shape
,
device
=
A
.
device
,
dtype
=
A
.
dtype
)
ds_cat
(
A
,
B
,
C
,
mode
)
return
C
def
ds_cat_fake
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
dim
:
int
,
mode
:
int
)
->
torch
.
Tensor
:
# 使用标准cat作为fake实现
return
torch
.
cat
([
A
,
B
],
dim
=
dim
)
direct_register_custom_op
(
op_name
=
"ds_cat"
,
op_func
=
ds_cat_wrapper
,
mutates_args
=
[],
# 没有修改参数,只有返回值
fake_impl
=
ds_cat_fake
)
def
concat_decode_opt
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
dim
:
int
):
assert
dim
==
2
,
"tensor dim must be 3 and concat dim must be 2"
mode
=
0
if
dim
!=
0
:
ds_cat
(
A
,
B
,
C
,
mode
)
return
C
assert
False
,
"not support"
\ No newline at end of file
if
dim
!=
0
:
return
torch
.
ops
.
sglang
.
ds_cat
(
A
,
B
,
dim
,
mode
)
assert
False
,
"not support"
# def concat_decode_opt(A:torch.Tensor, B:torch.Tensor, dim:int):
# assert dim==2 , "tensor dim must be 3 and concat dim must be 2"
# output_shape = list(A.shape)
# output_shape[dim] = A.shape[dim] + B.shape[dim]
# C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
# mode=0
# if dim!=0 :
# ds_cat(A, B, C, mode)
# return C
# assert False, "not support"
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
769353e6
...
...
@@ -31,7 +31,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
)
from
sglang.srt.layers.quantization.w4afp8
import
W4AFp8Config
,
W4AFp8MoEMethod
from
sglang.srt.single_batch_overlap
import
DownGemmOverlapArgs
from
sglang.srt.utils
import
ceil_div
,
dispose_tensor
,
get_bool_env_var
,
is_hip
,
is_npu
from
sglang.srt.utils
import
ceil_div
,
dispose_tensor
,
get_bool_env_var
,
is_hip
,
is_npu
,
direct_register_custom_op
from
sglang.srt.utils.offloader
import
get_offloader
if
TYPE_CHECKING
:
...
...
@@ -57,6 +57,105 @@ if _use_aiter:
logger
=
logging
.
getLogger
(
__name__
)
#------ custom op for lightop
def
m_grouped_w4a8_gemm_nt_masked_wrapper
(
a0
:
torch
.
Tensor
,
a1
:
torch
.
Tensor
,
b0
:
torch
.
Tensor
,
b1
:
torch
.
Tensor
,
d
:
torch
.
Tensor
,
masked_m
:
torch
.
Tensor
,
expected_m_per_group
:
int
)
->
torch
.
Tensor
:
return
m_grouped_w4a8_gemm_nt_masked
(
(
a0
,
a1
),
(
b0
,
b1
),
d
,
masked_m
,
expected_m_per_group
,
config
=
{
"MODE"
:
1000
,}
)
def
m_grouped_w4a8_gemm_nt_masked_fake
(
a0
:
torch
.
Tensor
,
a1
:
torch
.
Tensor
,
b0
:
torch
.
Tensor
,
b1
:
torch
.
Tensor
,
d
:
torch
.
Tensor
,
masked_m
:
torch
.
Tensor
,
expected_m_per_group
:
int
)
->
torch
.
Tensor
:
return
d
def
m_grouped_w8a8_gemm_nt_masked_wrapper
(
a0
:
torch
.
Tensor
,
a1
:
torch
.
Tensor
,
b0
:
torch
.
Tensor
,
b1
:
torch
.
Tensor
,
d
:
torch
.
Tensor
,
masked_m
:
torch
.
Tensor
,
expected_m_per_group
:
int
)
->
torch
.
Tensor
:
return
m_grouped_w8a8_gemm_nt_masked
(
(
a0
,
a1
),
(
b0
,
b1
),
d
,
masked_m
,
expected_m_per_group
,
config
=
{
"MODE"
:
1000
,}
)
def
m_grouped_w8a8_gemm_nt_masked_fake
(
a0
:
torch
.
Tensor
,
a1
:
torch
.
Tensor
,
b0
:
torch
.
Tensor
,
b1
:
torch
.
Tensor
,
d
:
torch
.
Tensor
,
masked_m
:
torch
.
Tensor
,
expected_m_per_group
:
int
)
->
torch
.
Tensor
:
return
d
def
fuse_silu_mul_quant_ep_wrapper
(
input
:
torch
.
Tensor
,
tokens_per_expert
:
Optional
[
torch
.
Tensor
]
=
None
,
num_local_tokens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
,
topk
:
int
=
1
,
expect_m
:
int
=-
1
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
fuse_silu_mul_quant_ep
(
input
,
tokens_per_expert
,
num_local_tokens_tensor
,
topk
,
expect_m
)
def
fuse_silu_mul_quant_ep_fake
(
input
:
torch
.
Tensor
,
tokens_per_expert
:
Optional
[
torch
.
Tensor
]
=
None
,
num_local_tokens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
,
topk
:
int
=
1
,
expect_m
:
int
=-
1
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
E
,
T
,
H
=
input
.
shape
d
=
H
//
2
output
=
torch
.
empty
(
E
,
T
,
d
,
dtype
=
torch
.
int8
,
device
=
input
.
device
)
scales
=
torch
.
empty
((
E
,
T
,
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
return
output
,
scales
direct_register_custom_op
(
op_name
=
"m_grouped_w4a8_gemm_nt_masked"
,
op_func
=
m_grouped_w4a8_gemm_nt_masked_wrapper
,
mutates_args
=
[],
fake_impl
=
m_grouped_w4a8_gemm_nt_masked_fake
)
direct_register_custom_op
(
op_name
=
"m_grouped_w8a8_gemm_nt_masked"
,
op_func
=
m_grouped_w8a8_gemm_nt_masked_wrapper
,
mutates_args
=
[],
fake_impl
=
m_grouped_w8a8_gemm_nt_masked_fake
)
direct_register_custom_op
(
op_name
=
"fuse_silu_mul_quant_ep"
,
op_func
=
fuse_silu_mul_quant_ep_wrapper
,
mutates_args
=
[],
fake_impl
=
fuse_silu_mul_quant_ep_fake
)
#------
# TODO(kaixih@nvidia): ideally we should merge this logic into
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
...
...
@@ -815,23 +914,23 @@ class DeepEPMoE(EPMoE):
gateup_output
=
torch
.
empty
((
num_groups
,
m
,
n1
),
device
=
hidden_states
.
device
,
dtype
=
torch
.
bfloat16
)
# ---- first GEMM ----
m_grouped_w4a8_gemm_nt_masked
(
(
q_a1_all
,
q_a1_scale
)
,
(
w13_weight
,
w13_scales
)
,
torch
.
ops
.
sglang
.
m_grouped_w4a8_gemm_nt_masked
(
q_a1_all
,
q_a1_scale
,
w13_weight
,
w13_scales
,
gateup_output
,
masked_m
,
expected_m
,
)
q_a2_all
,
q_a2_scale
=
fuse_silu_mul_quant_ep
(
gateup_output
,
masked_m
)
q_a2_all
,
q_a2_scale
=
torch
.
ops
.
sglang
.
fuse_silu_mul_quant_ep
(
gateup_output
,
masked_m
)
# ---- second GEMM ----
n2
=
w2_scales
.
size
(
1
)
down_output
=
torch
.
empty
((
num_groups
,
m
,
n2
),
device
=
q_a2_all
.
device
,
dtype
=
torch
.
bfloat16
)
m_grouped_w4a8_gemm_nt_masked
(
(
q_a2_all
,
q_a2_scale
)
,
(
w2_weight
,
w2_scales
)
,
torch
.
ops
.
sglang
.
m_grouped_w4a8_gemm_nt_masked
(
q_a2_all
,
q_a2_scale
,
w2_weight
,
w2_scales
,
down_output
,
masked_m
,
expected_m
,
...
...
@@ -865,23 +964,23 @@ class DeepEPMoE(EPMoE):
gateup_output
=
torch
.
empty
((
num_groups
,
m
,
n1
),
device
=
hidden_states
.
device
,
dtype
=
torch
.
bfloat16
)
# ---- first GEMM ----
m_grouped_w8a8_gemm_nt_masked
(
(
q_a1_all
,
q_a1_scale
)
,
(
w13_weight
,
w13_scales
)
,
torch
.
ops
.
sglang
.
m_grouped_w8a8_gemm_nt_masked
(
q_a1_all
,
q_a1_scale
,
w13_weight
,
w13_scales
,
gateup_output
,
masked_m
,
expected_m
,
)
q_a2_all
,
q_a2_scale
=
fuse_silu_mul_quant_ep
(
gateup_output
,
masked_m
)
q_a2_all
,
q_a2_scale
=
torch
.
ops
.
sglang
.
fuse_silu_mul_quant_ep
(
gateup_output
,
masked_m
)
# ---- second GEMM ----
n2
=
w2_scales
.
size
(
1
)
down_output
=
torch
.
empty
((
num_groups
,
m
,
n2
),
device
=
q_a2_all
.
device
,
dtype
=
torch
.
bfloat16
)
m_grouped_w8a8_gemm_nt_masked
(
(
q_a2_all
,
q_a2_scale
)
,
(
w2_weight
,
w2_scales
)
,
torch
.
ops
.
sglang
.
m_grouped_w8a8_gemm_nt_masked
(
q_a2_all
,
q_a2_scale
,
w2_weight
,
w2_scales
,
down_output
,
masked_m
,
expected_m
,
...
...
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
View file @
769353e6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
...
...
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
View file @
769353e6
...
...
@@ -163,12 +163,14 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
)
layer
.
register_parameter
(
"input_zero_point"
,
input_zero_point
)
@
torch
.
_dynamo
.
disable
()
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
)
->
torch
.
Tensor
:
# TODO: add cutlass_scaled_mm_azp support
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
# TODO: fix with lmslim/lightop
return
quant_ops
.
triton_scaled_mm
(
x_q
,
layer
.
weight
,
x_scale
,
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
python/sglang/srt/layers/quantization/slimquant_w4a8.py
View file @
769353e6
...
...
@@ -157,7 +157,7 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
@
torch
.
_dynamo
.
disable
()
@
torch
.
_dynamo
.
disable
()
# TODO: 性能优化需要lmslim/lightop配合
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
View file @
769353e6
...
...
@@ -214,7 +214,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
self
.
moe_runner_config
=
moe_runner_config
self
.
runner
=
MoeRunner
(
MoeRunnerBackend
.
TRITON
,
moe_runner_config
)
@
torch
.
_dynamo
.
disable
()
@
torch
.
_dynamo
.
disable
()
# TODO: 性能优化需lmslim/lightop配合
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -253,6 +253,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
)
return
StandardCombineInput
(
hidden_states
=
output
)
@
torch
.
_dynamo
.
disable
()
# TODO: 性能优化需lmslim/lightop配合
def
apply_with_shared_output
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment