Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c03b2816
Commit
c03b2816
authored
Nov 13, 2025
by
renzhc
Browse files
以尽量小的改动支持torch compile,补全部分接口
parent
eb4b015f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
99 additions
and
4 deletions
+99
-4
python/sglang/srt/_custom_ops.py
python/sglang/srt/_custom_ops.py
+58
-0
python/sglang/srt/layers/attention/dcu_mla_backend.py
python/sglang/srt/layers/attention/dcu_mla_backend.py
+2
-0
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+35
-1
python/sglang/srt/layers/quantization/slimquant_w4a8.py
python/sglang/srt/layers/quantization/slimquant_w4a8.py
+1
-0
python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
...n/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
+3
-3
No files found.
python/sglang/srt/_custom_ops.py
View file @
c03b2816
...
@@ -274,6 +274,64 @@ def triton_scaled_mm(a: torch.Tensor,
...
@@ -274,6 +274,64 @@ def triton_scaled_mm(a: torch.Tensor,
return
quant_ops
.
triton_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
,
best_config
)
return
quant_ops
.
triton_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
,
best_config
)
def
cutlass_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""
`cutlass_scaled_mm` implements a fused version of
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
where scale_a * a and scale_b * b are implemented using numpy-style
broadcasting.
In order to support blockwise scaling like found in DeepSeek V3 we also
support extended "group" broadcast rules. We extend the numpy-style
broadcasting rules with the following rule:
"if the extent of a dimension in the source shape is between 1 and
corresponding extent in the target shape we repeat each element along
that dimension src_shape[dim] // target_shape[dim] times consecutively"
example if we have:
a = [[1, 2], and target_shape = (2, 4)
[3, 4]]
then we would expand a to:
a = [[1, 1, 2, 2],
[3, 3, 4, 4]]
currently we only support the case:
scale_a.shape * [1, 128] == a.shape
scale_b.shape * [128, 128] == b.shape
"""
assert
(
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float16
)
assert
bias
is
None
or
bias
.
shape
[
0
]
==
b
.
shape
[
1
]
and
bias
.
dtype
==
out_dtype
# m = a.shape[0]
# n = b.shape[1]
# cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
# if current_platform.is_rocm() or not cutlass_compatible_b:
# from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa
# triton_scaled_mm)
# return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
# out = torch.empty((m, n), dtype=out_dtype, device=a.device)
# torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
# return out
#return quant_ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
return
quant_ops
.
rocblas_scaled_mm_nn
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
def
rocblas_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
quant_ops
.
rocblas_scaled_mm_nn
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
def
triton_int8_gemm_helper
(
m
:
int
,
def
triton_int8_gemm_helper
(
m
:
int
,
n
:
int
,
n
:
int
,
k
:
int
,
k
:
int
,
...
...
python/sglang/srt/layers/attention/dcu_mla_backend.py
View file @
c03b2816
...
@@ -362,6 +362,7 @@ class DCUMLABackend(AttentionBackend):
...
@@ -362,6 +362,7 @@ class DCUMLABackend(AttentionBackend):
)
)
return
o
return
o
@
torch
.
_dynamo
.
disable
()
def
forward_decode
(
def
forward_decode
(
self
,
self
,
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
...
@@ -416,6 +417,7 @@ class DCUMLABackend(AttentionBackend):
...
@@ -416,6 +417,7 @@ class DCUMLABackend(AttentionBackend):
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
@
torch
.
_dynamo
.
disable
()
# NOTE: untested
def
forward_extend
(
def
forward_extend
(
self
,
self
,
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
...
...
python/sglang/srt/layers/moe/topk.py
View file @
c03b2816
...
@@ -45,6 +45,7 @@ from sglang.srt.layers.moe import (
...
@@ -45,6 +45,7 @@ from sglang.srt.layers.moe import (
should_use_flashinfer_trtllm_moe
,
should_use_flashinfer_trtllm_moe
,
)
)
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
direct_register_custom_op
,
cpu_has_amx_support
,
cpu_has_amx_support
,
get_bool_env_var
,
get_bool_env_var
,
get_compiler_backend
,
get_compiler_backend
,
...
@@ -87,6 +88,39 @@ if _use_lightop:
...
@@ -87,6 +88,39 @@ if _use_lightop:
if
_is_npu
:
if
_is_npu
:
import
torch_npu
import
torch_npu
# ------- custom op for moe_fused_gate
def
moe_fused_gate_dcu
(
gating_output
:
torch
.
Tensor
,
correction_bias
:
torch
.
Tensor
,
num_expert_group
:
int
,
topk_group
:
int
,
topk
:
int
,
num_fused_shared_experts
:
int
,
routed_scaling_factor
:
float
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
topk_weights
,
topk_ids
=
op
.
moe_fused_gate
(
gating_output
.
to
(
dtype
=
torch
.
float32
),
# or bfloat16
correction_bias
,
num_expert_group
,
topk_group
,
topk
,
num_fused_shared_experts
,
# 0 in vllm
routed_scaling_factor
,
)
return
topk_weights
,
topk_ids
def
moe_fused_gate_fake
(
gating_output
:
torch
.
Tensor
,
correction_bias
:
torch
.
Tensor
,
num_expert_group
:
int
,
topk_group
:
int
,
topk
:
int
,
num_fused_shared_experts
:
int
,
routed_scaling_factor
:
float
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
torch
.
empty
((
gating_output
.
size
(
0
),
topk
),
dtype
=
gating_output
.
dtype
,
device
=
gating_output
.
device
),
\
torch
.
empty
((
gating_output
.
size
(
0
),
topk
),
dtype
=
gating_output
.
dtype
,
device
=
gating_output
.
device
)
direct_register_custom_op
(
op_name
=
"moe_fused_gate_dcu"
,
op_func
=
moe_fused_gate_dcu
,
mutates_args
=
[],
fake_impl
=
moe_fused_gate_fake
,
)
# -------
# -------------------------------- TopKConfig ---------------------------------------
# -------------------------------- TopKConfig ---------------------------------------
...
@@ -732,7 +766,7 @@ def biased_grouped_topk_gpu(
...
@@ -732,7 +766,7 @@ def biased_grouped_topk_gpu(
return
topk_weights
,
topk_ids
return
topk_weights
,
topk_ids
elif
_use_lightop
:
elif
_use_lightop
:
assert
not
apply_routed_scaling_factor_on_output
,
"Not implemented"
assert
not
apply_routed_scaling_factor_on_output
,
"Not implemented"
topk_weights
,
topk_ids
=
op
.
moe_fused_gate
(
topk_weights
,
topk_ids
=
torch
.
ops
.
sglang
.
moe_fused_gate
_dcu
(
gating_output
.
to
(
dtype
=
torch
.
float32
),
# or bfloat16
gating_output
.
to
(
dtype
=
torch
.
float32
),
# or bfloat16
correction_bias
,
correction_bias
,
num_expert_group
,
num_expert_group
,
...
...
python/sglang/srt/layers/quantization/slimquant_w4a8.py
View file @
c03b2816
...
@@ -154,6 +154,7 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
...
@@ -154,6 +154,7 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
)
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
@
torch
.
_dynamo
.
disable
()
def
apply
(
def
apply
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
...
python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
View file @
c03b2816
...
@@ -3,7 +3,7 @@ from typing import Any, Callable, Dict, List, Optional
...
@@ -3,7 +3,7 @@ from typing import Any, Callable, Dict, List, Optional
import
torch
import
torch
from
sglang.srt
import
_custom_ops
as
ops
from
sglang.srt
import
_custom_ops
as
ops
from
sglang.srt.utils
import
set_weight_attrs
from
sglang.srt.utils
import
set_weight_attrs
,
get_bool_env_var
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.linear
import
LinearBase
...
@@ -213,8 +213,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
...
@@ -213,8 +213,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
):
):
self
.
moe_runner_config
=
moe_runner_config
self
.
moe_runner_config
=
moe_runner_config
self
.
runner
=
MoeRunner
(
MoeRunnerBackend
.
TRITON
,
moe_runner_config
)
self
.
runner
=
MoeRunner
(
MoeRunnerBackend
.
TRITON
,
moe_runner_config
)
@
torch
.
_dynamo
.
disable
()
def
apply
(
def
apply
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment