Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
143ec5f3
Commit
143ec5f3
authored
Oct 21, 2025
by
lizhigong
Committed by
maxiao1
Oct 25, 2025
Browse files
adaptation w4A8 quantization
(cherry picked from commit
848c5b82
)
parent
67510e01
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
131 additions
and
47 deletions
+131
-47
python/sglang/srt/_custom_ops.py
python/sglang/srt/_custom_ops.py
+31
-0
python/sglang/srt/layers/quantization/slimquant_w4a8.py
python/sglang/srt/layers/quantization/slimquant_w4a8.py
+7
-0
python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
...n/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
+93
-47
No files found.
python/sglang/srt/_custom_ops.py
View file @
143ec5f3
...
@@ -5,6 +5,15 @@ from typing import List, Optional, Tuple
...
@@ -5,6 +5,15 @@ from typing import List, Optional, Tuple
import
torch
import
torch
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
,
is_hpu
,
is_npu
from
sglang.srt.utils
import
get_bool_env_var
,
is_hip
,
is_hpu
,
is_npu
try
:
from
lmslim
import
quant_ops
from
lmslim
import
quant_tools
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.
\n
"
)
try
:
import
lightop
except
Exception
:
print
(
"INFO: Please install lightop if you want to infer awq of marlin.
\n
"
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
use_vllm_custom_allreduce
=
get_bool_env_var
(
use_vllm_custom_allreduce
=
get_bool_env_var
(
...
@@ -175,3 +184,25 @@ def mscclpp_allreduce(
...
@@ -175,3 +184,25 @@ def mscclpp_allreduce(
context
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
nthreads
:
int
,
nblocks
:
int
context
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
nthreads
:
int
,
nblocks
:
int
)
->
None
:
)
->
None
:
return
sgl_kernel
.
allreduce
.
mscclpp_allreduce
(
context
,
inp
,
out
,
nthreads
,
nblocks
)
return
sgl_kernel
.
allreduce
.
mscclpp_allreduce
(
context
,
inp
,
out
,
nthreads
,
nblocks
)
def
triton_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
best_config
:
Optional
[
list
]
=
None
)
->
torch
.
Tensor
:
return
quant_ops
.
triton_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
,
best_config
)
def
triton_int8_gemm_helper
(
m
:
int
,
n
:
int
,
k
:
int
,
per_token_act_quant
:
bool
,
per_out_channel_weight_quant
:
bool
,
use_bias
:
bool
,
out_dtype
:
type
[
torch
.
dtype
]
=
torch
.
float16
,
device
:
str
=
"cuda:0"
,
best_config
:
Optional
[
list
]
=
None
,
repeat
:
Optional
[
int
]
=
2
):
return
quant_tools
.
triton_int8_gemm_helper
(
m
,
n
,
k
,
per_token_act_quant
,
per_out_channel_weight_quant
,
use_bias
,
out_dtype
,
device
,
best_config
,
repeat
)
\ No newline at end of file
python/sglang/srt/layers/quantization/slimquant_w4a8.py
View file @
143ec5f3
...
@@ -16,6 +16,7 @@ from lmslim.layers.gemm.int8_utils import (
...
@@ -16,6 +16,7 @@ from lmslim.layers.gemm.int8_utils import (
per_token_quant_int8
)
per_token_quant_int8
)
from
sglang.srt
import
_custom_ops
as
ops
from
sglang.srt
import
_custom_ops
as
ops
from
vllm.utils
import
W8a8GetCacheJSON
from
vllm.utils
import
W8a8GetCacheJSON
from
sglang.srt.layers.moe
import
MoeRunner
,
MoeRunnerBackend
,
MoeRunnerConfig
import
os
import
os
...
@@ -343,6 +344,12 @@ class SlimQuantW4A8Int8MoEMethod:
...
@@ -343,6 +344,12 @@ class SlimQuantW4A8Int8MoEMethod:
layer
.
w2_weight_scale
.
data
,
requires_grad
=
False
layer
.
w2_weight_scale
.
data
,
requires_grad
=
False
)
)
def
create_moe_runner
(
self
,
layer
:
torch
.
nn
.
Module
,
moe_runner_config
:
MoeRunnerConfig
):
self
.
moe_runner_config
=
moe_runner_config
self
.
runner
=
MoeRunner
(
MoeRunnerBackend
.
TRITON
,
moe_runner_config
)
def
apply
(
def
apply
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
...
python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
View file @
143ec5f3
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
from
sglang.srt.layers.moe.token_dispatcher.base
import
CombineInput
from
sglang.srt.layers.moe.token_dispatcher.standard
import
StandardCombineInput
,
StandardDispatchOutput
import
torch
import
torch
from
sglang.srt
import
_custom_ops
as
ops
from
sglang.srt
import
_custom_ops
as
ops
from
sglang.srt.utils
import
set_weight_attrs
from
sglang.srt.utils
import
set_weight_attrs
...
@@ -9,6 +11,7 @@ from sglang.srt.layers.quantization import QuantizationConfig
...
@@ -9,6 +11,7 @@ from sglang.srt.layers.quantization import QuantizationConfig
from
sglang.srt.layers.quantization.w4a8_utils
import
w4a8_weight_repack_impl
from
sglang.srt.layers.quantization.w4a8_utils
import
w4a8_weight_repack_impl
from
sglang.srt.layers.quantization.base_config
import
(
FusedMoEMethodBase
,
QuantizeMethodBase
)
from
sglang.srt.layers.quantization.base_config
import
(
FusedMoEMethodBase
,
QuantizeMethodBase
)
from
sglang.srt.layers.quantization.slimquant_w4a8
import
SlimQuantW4A8Int8LinearMethod
from
sglang.srt.layers.quantization.slimquant_w4a8
import
SlimQuantW4A8Int8LinearMethod
from
sglang.srt.layers.moe
import
MoeRunner
,
MoeRunnerBackend
,
MoeRunnerConfig
try
:
try
:
from
lmslim.layers.fused_moe.fuse_moe_w4a8_marlin
import
fused_experts_impl_w4a8_marlin
from
lmslim.layers.fused_moe.fuse_moe_w4a8_marlin
import
fused_experts_impl_w4a8_marlin
...
@@ -146,13 +149,13 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
...
@@ -146,13 +149,13 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
num_experts
:
int
,
hidden_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
intermediate_size
_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
**
extra_weight_attrs
,
):
):
from
sglang.srt.layers.moe.fused_moe_triton
import
(
FusedMoE
,
FusedMoeWeightScaleSupported
)
from
sglang.srt.layers.moe.fused_moe_triton
import
(
FusedMoE
,
FusedMoeWeightScaleSupported
)
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
intermediate_size
=
intermediate_size_per_partition
# WEIGHTS
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
torch
.
empty
(
...
@@ -205,51 +208,28 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
...
@@ -205,51 +208,28 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
layer
.
w13_weight
=
Parameter
(
w4a8_weight_repack_impl
(
layer
.
w13_weight
),
requires_grad
=
False
)
layer
.
w13_weight
=
Parameter
(
w4a8_weight_repack_impl
(
layer
.
w13_weight
),
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
w4a8_weight_repack_impl
(
layer
.
w2_weight
),
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
w4a8_weight_repack_impl
(
layer
.
w2_weight
),
requires_grad
=
False
)
def
create_moe_runner
(
self
,
layer
:
torch
.
nn
.
Module
,
moe_runner_config
:
MoeRunnerConfig
):
self
.
moe_runner_config
=
moe_runner_config
self
.
runner
=
MoeRunner
(
MoeRunnerBackend
.
TRITON
,
moe_runner_config
)
def
apply
(
def
apply
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
dispatch_output
:
StandardDispatchOutput
,
router_logits
:
torch
.
Tensor
,
)
->
CombineInput
:
top_k
:
int
,
x
=
dispatch_output
.
hidden_states
renormalize
:
bool
,
topk_output
=
dispatch_output
.
topk_output
use_grouped_topk
:
bool
=
False
,
from
sglang.srt.layers.moe.topk
import
apply_topk_weights_cpu
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_weights
,
topk_ids
,
_
=
topk_output
global_num_experts
:
int
=
-
1
,
x
,
topk_weights
=
apply_topk_weights_cpu
(
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
self
.
moe_runner_config
.
apply_router_weight_on_input
,
topk_weights
,
x
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
**
_
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.fused_moe_triton
import
(
FusedMoE
,
FusedMoeWeightScaleSupported
)
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet."
)
# Expert selection
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
use_fused_gate
=
use_fused_gate
)
)
workspace
,
global_reduce_buffer
=
MarlinMoeWorkspace
(
x
.
device
).
get_buffers
()
workspace
,
global_reduce_buffer
=
MarlinMoeWorkspace
(
x
.
device
).
get_buffers
()
return
fused_experts_impl_w4a8_marlin
(
output
=
fused_experts_impl_w4a8_marlin
(
x
,
x
,
layer
.
w13_weight
,
layer
.
w13_weight
,
layer
.
w2_weight
,
layer
.
w2_weight
,
...
@@ -260,13 +240,79 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
...
@@ -260,13 +240,79 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
inplace
=
True
,
inplace
=
True
,
use_int4_w4a8
=
True
,
use_int4_w4a8
=
True
,
per_channel_quant
=
True
,
per_channel_quant
=
True
,
activation
=
activation
,
activation
=
layer
.
moe_runner_config
.
activation
,
expert_map
=
expert_map
,
expert_map
=
layer
.
expert_map
_gpu
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
apply_router_weight_on_input
=
self
.
moe_runner_config
.
apply_router_weight_on_input
,
global_num_experts
=
global_
num_experts
,
global_num_experts
=
layer
.
moe_runner_config
.
num_experts
,
w1_scale
=
(
layer
.
w13_weight_scale
),
w1_scale
=
(
layer
.
w13_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale
),
a1_scale
=
layer
.
w13_input_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
use_nn_mo
e
,
use_nn_moe
=
Fals
e
,
)
)
return
StandardCombineInput
(
hidden_states
=
output
)
# def _apply(
# self,
# layer: torch.nn.Module,
# x: torch.Tensor,
# router_logits: torch.Tensor,
# top_k: int,
# #renormalize: bool,
# #use_grouped_topk: bool = False,
# topk_group: Optional[int] = None,
# num_expert_group: Optional[int] = None,
# global_num_experts: int = -1,
# expert_map: Optional[torch.Tensor] = None,
# custom_routing_function: Optional[Callable] = None,
# scoring_func: str = "softmax",
# e_score_correction_bias: Optional[torch.Tensor] = None,
# apply_router_weight_on_input: bool = False,
# activation: str = "silu",
# enable_eplb: bool = False,
# use_nn_moe: Optional[bool] = False,
# routed_scaling_factor: Optional[float] = None,
# use_fused_gate: Optional[bool] = False,
# **_
# ) -> torch.Tensor:
# from sglang.srt.layers.moe.fused_moe_triton import (FusedMoE, FusedMoeWeightScaleSupported)
# from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
# if enable_eplb:
# raise NotImplementedError(
# "EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet.")
# # Expert selection
# topk_weights, topk_ids = FusedMoE.select_experts(
# hidden_states=x,
# router_logits=router_logits,
# #use_grouped_topk=use_grouped_topk,
# top_k=top_k,
# #renormalize=renormalize,
# topk_group=topk_group,
# num_expert_group=num_expert_group,
# custom_routing_function=custom_routing_function,
# scoring_func=scoring_func,
# e_score_correction_bias=e_score_correction_bias,
# routed_scaling_factor=routed_scaling_factor,
# use_fused_gate=use_fused_gate
# )
# workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
# return fused_experts_impl_w4a8_marlin(
# x,
# layer.w13_weight,
# layer.w2_weight,
# topk_weights=topk_weights,
# topk_ids=topk_ids,
# workspace=workspace,
# global_reduce_buffer=global_reduce_buffer,
# inplace=True,
# use_int4_w4a8=True,
# per_channel_quant=True,
# activation=activation,
# expert_map=expert_map,
# apply_router_weight_on_input=apply_router_weight_on_input,
# global_num_experts=global_num_experts,
# w1_scale=(layer.w13_weight_scale),
# w2_scale=(layer.w2_weight_scale),
# a1_scale=layer.w13_input_scale,
# a2_scale=layer.w2_input_scale,
# use_nn_moe=use_nn_moe,
# )
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment