Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8da47f19
"Dockerfile.examples" did not exist on "eab72f04573e8fb094d951f1b25d03692e38173e"
Commit
8da47f19
authored
Nov 19, 2025
by
renzhc
Browse files
支持w8a8 compile,注册自定义算子解决部分断图问题
parent
bd63af06
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
167 additions
and
28 deletions
+167
-28
python/sglang/srt/layers/attention/dcu_mla_backend.py
python/sglang/srt/layers/attention/dcu_mla_backend.py
+2
-2
python/sglang/srt/layers/attention/lightop_concat.py
python/sglang/srt/layers/attention/lightop_concat.py
+43
-8
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+114
-15
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
...ation/compressed_tensors/compressed_tensors_moe_marlin.py
+3
-1
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
...ompressed_tensors/schemes/compressed_tensors_w8a8_int8.py
+2
-0
python/sglang/srt/layers/quantization/slimquant_w4a8.py
python/sglang/srt/layers/quantization/slimquant_w4a8.py
+1
-1
python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
...n/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
+2
-1
No files found.
python/sglang/srt/layers/attention/dcu_mla_backend.py
View file @
8da47f19
...
@@ -362,7 +362,7 @@ class DCUMLABackend(AttentionBackend):
...
@@ -362,7 +362,7 @@ class DCUMLABackend(AttentionBackend):
)
)
return
o
return
o
@
torch
.
_dynamo
.
disable
()
@
torch
.
_dynamo
.
disable
()
# NOTE: FP8 cache decode不支持compile
def
forward_decode
(
def
forward_decode
(
self
,
self
,
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
...
@@ -417,7 +417,7 @@ class DCUMLABackend(AttentionBackend):
...
@@ -417,7 +417,7 @@ class DCUMLABackend(AttentionBackend):
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
@
torch
.
_dynamo
.
disable
()
# NOTE: untested
@
torch
.
_dynamo
.
disable
()
def
forward_extend
(
def
forward_extend
(
self
,
self
,
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
...
...
python/sglang/srt/layers/attention/lightop_concat.py
View file @
8da47f19
...
@@ -4,7 +4,7 @@ import warnings
...
@@ -4,7 +4,7 @@ import warnings
import
torch
import
torch
from
sglang.srt.utils
import
get_bool_env_var
from
sglang.srt.utils
import
get_bool_env_var
,
direct_register_custom_op
_USE_OPT_CAT
=
get_bool_env_var
(
"SGLANG_USE_OPT_CAT"
)
_USE_OPT_CAT
=
get_bool_env_var
(
"SGLANG_USE_OPT_CAT"
)
...
@@ -20,13 +20,48 @@ else:
...
@@ -20,13 +20,48 @@ else:
ds_cat
=
None
ds_cat
=
None
def
concat_decode_opt
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
dim
:
int
):
assert
dim
==
2
,
"tensor dim must be 3 and concat dim must be 2"
# TODO: 单独注册有些问题
def
ds_cat_wrapper
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
dim
:
int
,
mode
:
int
)
->
torch
.
Tensor
:
output_shape
=
list
(
A
.
shape
)
output_shape
=
list
(
A
.
shape
)
output_shape
[
dim
]
=
A
.
shape
[
dim
]
+
B
.
shape
[
dim
]
output_shape
[
dim
]
=
A
.
shape
[
dim
]
+
B
.
shape
[
dim
]
C
=
torch
.
empty
(
output_shape
,
device
=
A
.
device
,
dtype
=
A
.
dtype
)
C
=
torch
.
empty
(
output_shape
,
device
=
A
.
device
,
dtype
=
A
.
dtype
)
mode
=
0
ds_cat
(
A
,
B
,
C
,
mode
)
if
dim
!=
0
:
ds_cat
(
A
,
B
,
C
,
mode
)
return
C
return
C
def
ds_cat_fake
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
dim
:
int
,
mode
:
int
)
->
torch
.
Tensor
:
# 使用标准cat作为fake实现
return
torch
.
cat
([
A
,
B
],
dim
=
dim
)
direct_register_custom_op
(
op_name
=
"ds_cat"
,
op_func
=
ds_cat_wrapper
,
mutates_args
=
[],
# 没有修改参数,只有返回值
fake_impl
=
ds_cat_fake
)
def
concat_decode_opt
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
dim
:
int
):
assert
dim
==
2
,
"tensor dim must be 3 and concat dim must be 2"
mode
=
0
if
dim
!=
0
:
return
torch
.
ops
.
sglang
.
ds_cat
(
A
,
B
,
dim
,
mode
)
assert
False
,
"not support"
assert
False
,
"not support"
# def concat_decode_opt(A:torch.Tensor, B:torch.Tensor, dim:int):
# assert dim==2 , "tensor dim must be 3 and concat dim must be 2"
# output_shape = list(A.shape)
# output_shape[dim] = A.shape[dim] + B.shape[dim]
# C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
# mode=0
# if dim!=0 :
# ds_cat(A, B, C, mode)
# return C
# assert False, "not support"
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
8da47f19
...
@@ -31,7 +31,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
...
@@ -31,7 +31,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
)
)
from
sglang.srt.layers.quantization.w4afp8
import
W4AFp8Config
,
W4AFp8MoEMethod
from
sglang.srt.layers.quantization.w4afp8
import
W4AFp8Config
,
W4AFp8MoEMethod
from
sglang.srt.single_batch_overlap
import
DownGemmOverlapArgs
from
sglang.srt.single_batch_overlap
import
DownGemmOverlapArgs
from
sglang.srt.utils
import
ceil_div
,
dispose_tensor
,
get_bool_env_var
,
is_hip
,
is_npu
from
sglang.srt.utils
import
ceil_div
,
dispose_tensor
,
get_bool_env_var
,
is_hip
,
is_npu
,
direct_register_custom_op
from
sglang.srt.utils.offloader
import
get_offloader
from
sglang.srt.utils.offloader
import
get_offloader
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -57,6 +57,105 @@ if _use_aiter:
...
@@ -57,6 +57,105 @@ if _use_aiter:
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
#------ custom op for lightop
def
m_grouped_w4a8_gemm_nt_masked_wrapper
(
a0
:
torch
.
Tensor
,
a1
:
torch
.
Tensor
,
b0
:
torch
.
Tensor
,
b1
:
torch
.
Tensor
,
d
:
torch
.
Tensor
,
masked_m
:
torch
.
Tensor
,
expected_m_per_group
:
int
)
->
torch
.
Tensor
:
return
m_grouped_w4a8_gemm_nt_masked
(
(
a0
,
a1
),
(
b0
,
b1
),
d
,
masked_m
,
expected_m_per_group
,
config
=
{
"MODE"
:
1000
,}
)
def
m_grouped_w4a8_gemm_nt_masked_fake
(
a0
:
torch
.
Tensor
,
a1
:
torch
.
Tensor
,
b0
:
torch
.
Tensor
,
b1
:
torch
.
Tensor
,
d
:
torch
.
Tensor
,
masked_m
:
torch
.
Tensor
,
expected_m_per_group
:
int
)
->
torch
.
Tensor
:
return
d
def
m_grouped_w8a8_gemm_nt_masked_wrapper
(
a0
:
torch
.
Tensor
,
a1
:
torch
.
Tensor
,
b0
:
torch
.
Tensor
,
b1
:
torch
.
Tensor
,
d
:
torch
.
Tensor
,
masked_m
:
torch
.
Tensor
,
expected_m_per_group
:
int
)
->
torch
.
Tensor
:
return
m_grouped_w8a8_gemm_nt_masked
(
(
a0
,
a1
),
(
b0
,
b1
),
d
,
masked_m
,
expected_m_per_group
,
config
=
{
"MODE"
:
1000
,}
)
def
m_grouped_w8a8_gemm_nt_masked_fake
(
a0
:
torch
.
Tensor
,
a1
:
torch
.
Tensor
,
b0
:
torch
.
Tensor
,
b1
:
torch
.
Tensor
,
d
:
torch
.
Tensor
,
masked_m
:
torch
.
Tensor
,
expected_m_per_group
:
int
)
->
torch
.
Tensor
:
return
d
def
fuse_silu_mul_quant_ep_wrapper
(
input
:
torch
.
Tensor
,
tokens_per_expert
:
Optional
[
torch
.
Tensor
]
=
None
,
num_local_tokens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
,
topk
:
int
=
1
,
expect_m
:
int
=-
1
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
fuse_silu_mul_quant_ep
(
input
,
tokens_per_expert
,
num_local_tokens_tensor
,
topk
,
expect_m
)
def
fuse_silu_mul_quant_ep_fake
(
input
:
torch
.
Tensor
,
tokens_per_expert
:
Optional
[
torch
.
Tensor
]
=
None
,
num_local_tokens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
,
topk
:
int
=
1
,
expect_m
:
int
=-
1
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
E
,
T
,
H
=
input
.
shape
d
=
H
//
2
output
=
torch
.
empty
(
E
,
T
,
d
,
dtype
=
torch
.
int8
,
device
=
input
.
device
)
scales
=
torch
.
empty
((
E
,
T
,
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
return
output
,
scales
direct_register_custom_op
(
op_name
=
"m_grouped_w4a8_gemm_nt_masked"
,
op_func
=
m_grouped_w4a8_gemm_nt_masked_wrapper
,
mutates_args
=
[],
fake_impl
=
m_grouped_w4a8_gemm_nt_masked_fake
)
direct_register_custom_op
(
op_name
=
"m_grouped_w8a8_gemm_nt_masked"
,
op_func
=
m_grouped_w8a8_gemm_nt_masked_wrapper
,
mutates_args
=
[],
fake_impl
=
m_grouped_w8a8_gemm_nt_masked_fake
)
direct_register_custom_op
(
op_name
=
"fuse_silu_mul_quant_ep"
,
op_func
=
fuse_silu_mul_quant_ep_wrapper
,
mutates_args
=
[],
fake_impl
=
fuse_silu_mul_quant_ep_fake
)
#------
# TODO(kaixih@nvidia): ideally we should merge this logic into
# TODO(kaixih@nvidia): ideally we should merge this logic into
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
...
@@ -815,23 +914,23 @@ class DeepEPMoE(EPMoE):
...
@@ -815,23 +914,23 @@ class DeepEPMoE(EPMoE):
gateup_output
=
torch
.
empty
((
num_groups
,
m
,
n1
),
device
=
hidden_states
.
device
,
dtype
=
torch
.
bfloat16
)
gateup_output
=
torch
.
empty
((
num_groups
,
m
,
n1
),
device
=
hidden_states
.
device
,
dtype
=
torch
.
bfloat16
)
# ---- first GEMM ----
# ---- first GEMM ----
m_grouped_w4a8_gemm_nt_masked
(
torch
.
ops
.
sglang
.
m_grouped_w4a8_gemm_nt_masked
(
(
q_a1_all
,
q_a1_scale
)
,
q_a1_all
,
q_a1_scale
,
(
w13_weight
,
w13_scales
)
,
w13_weight
,
w13_scales
,
gateup_output
,
gateup_output
,
masked_m
,
masked_m
,
expected_m
,
expected_m
,
)
)
q_a2_all
,
q_a2_scale
=
fuse_silu_mul_quant_ep
(
gateup_output
,
masked_m
)
q_a2_all
,
q_a2_scale
=
torch
.
ops
.
sglang
.
fuse_silu_mul_quant_ep
(
gateup_output
,
masked_m
)
# ---- second GEMM ----
# ---- second GEMM ----
n2
=
w2_scales
.
size
(
1
)
n2
=
w2_scales
.
size
(
1
)
down_output
=
torch
.
empty
((
num_groups
,
m
,
n2
),
device
=
q_a2_all
.
device
,
dtype
=
torch
.
bfloat16
)
down_output
=
torch
.
empty
((
num_groups
,
m
,
n2
),
device
=
q_a2_all
.
device
,
dtype
=
torch
.
bfloat16
)
m_grouped_w4a8_gemm_nt_masked
(
torch
.
ops
.
sglang
.
m_grouped_w4a8_gemm_nt_masked
(
(
q_a2_all
,
q_a2_scale
)
,
q_a2_all
,
q_a2_scale
,
(
w2_weight
,
w2_scales
)
,
w2_weight
,
w2_scales
,
down_output
,
down_output
,
masked_m
,
masked_m
,
expected_m
,
expected_m
,
...
@@ -865,23 +964,23 @@ class DeepEPMoE(EPMoE):
...
@@ -865,23 +964,23 @@ class DeepEPMoE(EPMoE):
gateup_output
=
torch
.
empty
((
num_groups
,
m
,
n1
),
device
=
hidden_states
.
device
,
dtype
=
torch
.
bfloat16
)
gateup_output
=
torch
.
empty
((
num_groups
,
m
,
n1
),
device
=
hidden_states
.
device
,
dtype
=
torch
.
bfloat16
)
# ---- first GEMM ----
# ---- first GEMM ----
m_grouped_w8a8_gemm_nt_masked
(
torch
.
ops
.
sglang
.
m_grouped_w8a8_gemm_nt_masked
(
(
q_a1_all
,
q_a1_scale
)
,
q_a1_all
,
q_a1_scale
,
(
w13_weight
,
w13_scales
)
,
w13_weight
,
w13_scales
,
gateup_output
,
gateup_output
,
masked_m
,
masked_m
,
expected_m
,
expected_m
,
)
)
q_a2_all
,
q_a2_scale
=
fuse_silu_mul_quant_ep
(
gateup_output
,
masked_m
)
q_a2_all
,
q_a2_scale
=
torch
.
ops
.
sglang
.
fuse_silu_mul_quant_ep
(
gateup_output
,
masked_m
)
# ---- second GEMM ----
# ---- second GEMM ----
n2
=
w2_scales
.
size
(
1
)
n2
=
w2_scales
.
size
(
1
)
down_output
=
torch
.
empty
((
num_groups
,
m
,
n2
),
device
=
q_a2_all
.
device
,
dtype
=
torch
.
bfloat16
)
down_output
=
torch
.
empty
((
num_groups
,
m
,
n2
),
device
=
q_a2_all
.
device
,
dtype
=
torch
.
bfloat16
)
m_grouped_w8a8_gemm_nt_masked
(
torch
.
ops
.
sglang
.
m_grouped_w8a8_gemm_nt_masked
(
(
q_a2_all
,
q_a2_scale
)
,
q_a2_all
,
q_a2_scale
,
(
w2_weight
,
w2_scales
)
,
w2_weight
,
w2_scales
,
down_output
,
down_output
,
masked_m
,
masked_m
,
expected_m
,
expected_m
,
...
...
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
View file @
8da47f19
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
__future__
import
annotations
...
@@ -15,6 +16,7 @@ from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
...
@@ -15,6 +16,7 @@ from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
from
sglang.srt.utils
import
set_weight_attrs
from
sglang.srt.utils
import
set_weight_attrs
from
sglang.srt.layers.moe
import
MoeRunner
,
MoeRunnerBackend
,
MoeRunnerConfig
from
sglang.srt.layers.moe
import
MoeRunner
,
MoeRunnerBackend
,
MoeRunnerConfig
from
sglang.srt.layers.moe.utils
import
get_moe_a2a_backend
try
:
try
:
from
lmslim.layers.fused_moe.fuse_moe_int8_marlin
import
fused_experts_impl_int8_marlin
from
lmslim.layers.fused_moe.fuse_moe_int8_marlin
import
fused_experts_impl_int8_marlin
except
Exception
:
except
Exception
:
...
@@ -77,7 +79,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
...
@@ -77,7 +79,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
"weights"
)
"weights"
)
self
.
input_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
self
.
input_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"input_activations"
)
"input_activations"
)
self
.
use_deepep
=
True
self
.
use_deepep
=
get_moe_a2a_backend
().
is_deepep
()
per_channel
=
(
per_channel
=
(
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
and
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
)
and
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
)
...
...
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
View file @
8da47f19
...
@@ -163,12 +163,14 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
...
@@ -163,12 +163,14 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
)
)
layer
.
register_parameter
(
"input_zero_point"
,
input_zero_point
)
layer
.
register_parameter
(
"input_zero_point"
,
input_zero_point
)
@
torch
.
_dynamo
.
disable
()
def
apply_weights
(
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# TODO: add cutlass_scaled_mm_azp support
# TODO: add cutlass_scaled_mm_azp support
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
# TODO: fix with lmslim/lightop
return
quant_ops
.
triton_scaled_mm
(
return
quant_ops
.
triton_scaled_mm
(
x_q
,
layer
.
weight
,
x_scale
,
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
x_q
,
layer
.
weight
,
x_scale
,
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
)
python/sglang/srt/layers/quantization/slimquant_w4a8.py
View file @
8da47f19
...
@@ -157,7 +157,7 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
...
@@ -157,7 +157,7 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
)
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
@
torch
.
_dynamo
.
disable
()
@
torch
.
_dynamo
.
disable
()
# TODO: 性能优化需要lmslim/lightop配合
def
apply
(
def
apply
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
...
python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
View file @
8da47f19
...
@@ -214,7 +214,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
...
@@ -214,7 +214,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
self
.
moe_runner_config
=
moe_runner_config
self
.
moe_runner_config
=
moe_runner_config
self
.
runner
=
MoeRunner
(
MoeRunnerBackend
.
TRITON
,
moe_runner_config
)
self
.
runner
=
MoeRunner
(
MoeRunnerBackend
.
TRITON
,
moe_runner_config
)
@
torch
.
_dynamo
.
disable
()
@
torch
.
_dynamo
.
disable
()
# TODO: 性能优化需lmslim/lightop配合
def
apply
(
def
apply
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
@@ -253,6 +253,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
...
@@ -253,6 +253,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
)
)
return
StandardCombineInput
(
hidden_states
=
output
)
return
StandardCombineInput
(
hidden_states
=
output
)
@
torch
.
_dynamo
.
disable
()
# TODO: 性能优化需lmslim/lightop配合
def
apply_with_shared_output
(
def
apply_with_shared_output
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment