Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
25796d05
Commit
25796d05
authored
Nov 14, 2025
by
maxiao1
Browse files
适配w8a8_marlin
parent
ac7dcc2d
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
407 additions
and
5 deletions
+407
-5
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+2
-0
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+3
-1
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+45
-0
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_marlin.py
...ntization/compressed_tensors/compressed_tensors_marlin.py
+92
-0
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
...ation/compressed_tensors/compressed_tensors_moe_marlin.py
+259
-0
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
...ompressed_tensors/schemes/compressed_tensors_w8a8_int8.py
+4
-3
python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
...n/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
+1
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-0
No files found.
python/sglang/srt/configs/model_config.py
View file @
25796d05
...
@@ -616,6 +616,7 @@ class ModelConfig:
...
@@ -616,6 +616,7 @@ class ModelConfig:
"mxfp4"
,
"mxfp4"
,
"slimquant_w4a8_marlin"
,
"slimquant_w4a8_marlin"
,
"w8a8_int8"
,
"w8a8_int8"
,
"slimquant_marlin"
,
]
]
optimized_quantization_methods
=
[
optimized_quantization_methods
=
[
"fp8"
,
"fp8"
,
...
@@ -636,6 +637,7 @@ class ModelConfig:
...
@@ -636,6 +637,7 @@ class ModelConfig:
"w4afp8"
,
"w4afp8"
,
"petit_nvfp4"
,
"petit_nvfp4"
,
"slimquant_w4a8_marlin"
,
"slimquant_w4a8_marlin"
,
"slimquant_marlin"
,
]
]
compatible_quantization_methods
=
{
compatible_quantization_methods
=
{
"modelopt_fp4"
:
[
"modelopt"
],
"modelopt_fp4"
:
[
"modelopt"
],
...
...
python/sglang/srt/layers/quantization/__init__.py
View file @
25796d05
...
@@ -58,6 +58,7 @@ from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
...
@@ -58,6 +58,7 @@ from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
from
sglang.srt.layers.quantization.w8a8_fp8
import
W8A8Fp8Config
from
sglang.srt.layers.quantization.w8a8_fp8
import
W8A8Fp8Config
from
sglang.srt.layers.quantization.w8a8_int8
import
W8A8Int8Config
from
sglang.srt.layers.quantization.w8a8_int8
import
W8A8Int8Config
from
sglang.srt.layers.quantization.slimquant_w4a8_marlin
import
SlimQuantW4A8Int8MarlinConfig
from
sglang.srt.layers.quantization.slimquant_w4a8_marlin
import
SlimQuantW4A8Int8MarlinConfig
from
sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_marlin
import
SlimQuantCompressedTensorsMarlinConfig
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
mxfp_supported
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
mxfp_supported
_is_mxfp_supported
=
mxfp_supported
()
_is_mxfp_supported
=
mxfp_supported
()
...
@@ -84,7 +85,8 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
...
@@ -84,7 +85,8 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"w4afp8"
:
W4AFp8Config
,
"w4afp8"
:
W4AFp8Config
,
"petit_nvfp4"
:
PetitNvFp4Config
,
"petit_nvfp4"
:
PetitNvFp4Config
,
"fbgemm_fp8"
:
FBGEMMFp8Config
,
"fbgemm_fp8"
:
FBGEMMFp8Config
,
"slimquant_w4a8_marlin"
:
SlimQuantW4A8Int8MarlinConfig
,
"slimquant_w4a8_marlin"
:
SlimQuantW4A8Int8MarlinConfig
,
"slimquant_marlin"
:
SlimQuantCompressedTensorsMarlinConfig
,
}
}
...
...
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
25796d05
...
@@ -44,6 +44,7 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
...
@@ -44,6 +44,7 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
)
)
from
sglang.srt.layers.quantization.fp8
import
Fp8LinearMethod
from
sglang.srt.layers.quantization.fp8
import
Fp8LinearMethod
from
sglang.srt.layers.quantization.unquant
import
UnquantizedLinearMethod
from
sglang.srt.layers.quantization.unquant
import
UnquantizedLinearMethod
from
sglang.srt.layers.quantization.kv_cache
import
BaseKVCacheMethod
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -636,3 +637,47 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
...
@@ -636,3 +637,47 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
if
scheme
is
None
:
if
scheme
is
None
:
raise
ValueError
(
"A scheme must be defined for each layer"
)
raise
ValueError
(
"A scheme must be defined for each layer"
)
return
scheme
.
apply_weights
(
layer
,
x
,
bias
=
bias
)
return
scheme
.
apply_weights
(
layer
,
x
,
bias
=
bias
)
class
CompressedTensorsKVCacheMethod
(
BaseKVCacheMethod
):
"""
Supports loading kv-cache scaling factors from compressed-tensors
checkpoints.
"""
def
__init__
(
self
,
quant_config
:
CompressedTensorsConfig
):
self
.
validate_kv_cache_scheme
(
quant_config
.
kv_cache_scheme
)
super
().
__init__
(
quant_config
)
@
staticmethod
def
validate_kv_cache_scheme
(
kv_cache_scheme
:
Optional
[
dict
[
str
,
Any
]]):
"""
Validator for the kv cache scheme. Useful for controlling the
kv cache quantization schemes, that are being supported in vLLM
:param kv_cache_scheme: the compressed-tensors kv cache scheme
"""
if
kv_cache_scheme
is
None
:
return
type_
=
kv_cache_scheme
.
get
(
"type"
)
num_bits
=
kv_cache_scheme
.
get
(
"num_bits"
)
if
type_
!=
"float"
and
num_bits
!=
8
:
raise
NotImplementedError
(
"Currently supported kv cache quantization is "
"num_bits=8, type=float, however "
f
"received num_bits=
{
num_bits
}
, type=
{
type_
}
"
)
strategy
=
kv_cache_scheme
.
get
(
"strategy"
)
if
strategy
!=
"tensor"
:
raise
NotImplementedError
(
"Only support per-tensor scaling factor "
"for compressed-tensors KV cache. "
f
"Expected strategy: tensor, found strategy:
{
strategy
}
"
)
is_symmetric
=
kv_cache_scheme
.
get
(
"symmetric"
)
if
not
is_symmetric
:
raise
NotImplementedError
(
"Only support symmetric scaling factor "
"for compressed-tensors KV cache. "
f
"However found symmetric:
{
is_symmetric
}
"
)
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_marlin.py
0 → 100644
View file @
25796d05
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
Optional
,
cast
import
torch
from
compressed_tensors.config
import
SparsityCompressionConfig
from
compressed_tensors.quantization
import
QuantizationArgs
import
logging
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.quantization.unquant
import
UnquantizedEmbeddingMethod
from
sglang.srt.layers.quantization.base_config
import
(
LinearMethodBase
,
QuantizationConfig
,
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.compressed_tensors.compressed_tensors
import
CompressedTensorsConfig
,
CompressedTensorsLinearMethod
,
CompressedTensorsKVCacheMethod
from
sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe_marlin
import
CompressedTensorsMarlinMoEMethod
from
sglang.srt.layers.quantization.compressed_tensors.utils
import
(
should_ignore_layer
)
from
sglang.srt.layers.quantization.kv_cache
import
BaseKVCacheMethod
import
os
# if TYPE_CHECKING:
# from vllm.model_executor.models.utils import WeightsMapper
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
"CompressedTensorsLinearMethod"
]
SPARSITY_CONFIG_NAME
:
Literal
[
"sparsity_config"
]
=
"sparsity_config"
QUANTIZATION_SCHEME_MAP_TYPE
=
dict
[
str
,
Optional
[
dict
[
str
,
QuantizationArgs
]]]
class
SlimQuantCompressedTensorsMarlinConfig
(
CompressedTensorsConfig
):
def
__init__
(
self
,
target_scheme_map
:
dict
[
str
,
Any
],
ignore
:
list
[
str
],
quant_format
:
str
,
sparsity_scheme_map
:
dict
[
str
,
SparsityCompressionConfig
],
sparsity_ignore_list
:
list
[
str
],
kv_cache_scheme
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
packed_modules_mapping
:
Optional
[
dict
[
str
,
list
[
str
]]]
=
None
,
):
super
().
__init__
(
target_scheme_map
,
ignore
,
quant_format
,
sparsity_scheme_map
,
sparsity_ignore_list
,
kv_cache_scheme
,
config
,
packed_modules_mapping
,
)
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
if
hf_quant_cfg
.
get
(
"quant_method"
)
==
"compressed-tensors"
\
and
user_quant
==
"slimquant_marlin"
:
return
cls
.
get_name
()
return
None
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"slimquant_marlin"
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
,
)
->
Optional
[
"QuantizeMethodBase"
]:
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
# Avoid circular import
# from sglang.srt.layers.radix_attention import RadixAttention
# Check if the layer is skipped for quantization.
if
should_ignore_layer
(
prefix
,
ignore
=
self
.
ignore
,
fused_mapping
=
self
.
packed_modules_mapping
):
return
UnquantizedEmbeddingMethod
()
#UnquantizedLinearMethod()
if
isinstance
(
layer
,
LinearBase
):
scheme
=
self
.
get_scheme
(
layer
=
layer
,
layer_name
=
prefix
)
if
scheme
is
None
:
return
UnquantizedEmbeddingMethod
()
#UnquantizedLinearMethod()
layer
.
scheme
=
scheme
return
CompressedTensorsLinearMethod
(
self
)
# if isinstance(layer, RadixAttention):
# return CompressedTensorsKVCacheMethod(self)
if
isinstance
(
layer
,
FusedMoE
):
return
CompressedTensorsMarlinMoEMethod
.
get_moe_method
(
self
,
layer
)
return
None
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
0 → 100644
View file @
25796d05
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
import
enum
from
enum
import
Enum
from
typing
import
Callable
,
Optional
import
torch
from
compressed_tensors.quantization
import
(
QuantizationStrategy
)
import
logging
from
torch.nn.parameter
import
Parameter
from
sglang.srt.layers.quantization.base_config
import
FusedMoEMethodBase
from
sglang.srt.utils
import
set_weight_attrs
from
sglang.srt.layers.moe
import
MoeRunner
,
MoeRunnerBackend
,
MoeRunnerConfig
try
:
from
lmslim.layers.fused_moe.fuse_moe_int8_marlin
import
fused_experts_impl_int8_marlin
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer the quantitative model of moe.
\n
"
)
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
"CompressedTensorsW8A8Int8MarlinMoEMethod"
,
]
def
get_w8a8_int8_marlin_weights
(
weight
,
k_tile
=
64
):
# 7168, 512
weight
=
weight
.
T
size_k
,
size_n
=
weight
.
shape
assert
size_k
//
k_tile
weight
=
weight
.
reshape
(
size_k
//
k_tile
,
k_tile
,
size_n
)
weight
=
weight
.
transpose
(
1
,
2
)
weight
=
weight
.
reshape
(
size_k
//
k_tile
,
size_n
*
k_tile
)
return
weight
class
CompressedTensorsMarlinMoEMethod
(
FusedMoEMethodBase
):
@
staticmethod
def
get_moe_method
(
quant_config
:
"SlimQuantCompressedTensorsMarlinConfig"
,
# type: ignore # noqa E501
layer
:
torch
.
nn
.
Module
,
)
->
"CompressedTensorsMarlinMoEMethod"
:
# are supported + check if the layer is being ignored.
weight_quant
=
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
input_quant
=
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"input_activations"
)
if
quant_config
.
_is_dynamic_token_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Int8MarlinMoEMethod
(
quant_config
)
else
:
raise
RuntimeError
(
f
"Slimquant_marlin does not support the FusedMoe scheme:
{
weight_quant
}
,
{
input_quant
}
"
)
class
CompressedTensorsW8A8Int8MarlinMoEMethod
(
CompressedTensorsMarlinMoEMethod
):
def
__init__
(
self
,
quant_config
:
"CompressedTensorsMarlinConfig"
# type: ignore # noqa E501
):
self
.
quant_config
=
quant_config
self
.
weight_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
self
.
input_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"input_activations"
)
per_channel
=
(
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
and
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
)
if
not
per_channel
:
raise
ValueError
(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found "
f
"
{
self
.
weight_quant
}
,
{
self
.
input_quant
}
"
)
self
.
static_input_scales
=
not
self
.
input_quant
.
dynamic
if
self
.
static_input_scales
:
raise
ValueError
(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales."
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoeWeightScaleSupported
params_dtype
=
torch
.
int8
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# WEIGHT_SCALES
assert
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
intermediate_size_per_partition
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
hidden_size
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
})
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
# INPUT_SCALES
assert
not
self
.
static_input_scales
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
w1_marlin_list
=
[]
for
ii
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
w1_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w13_weight
[
ii
])
w1_marlin_list
.
append
(
w1_marlin_in
)
w1_marlin
=
torch
.
stack
(
w1_marlin_list
,
dim
=
0
)
w2_marlin_list
=
[]
for
ii
in
range
(
layer
.
w2_weight
.
shape
[
0
]):
w2_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w2_weight
[
ii
])
w2_marlin_list
.
append
(
w2_marlin_in
)
w2_marlin
=
torch
.
stack
(
w2_marlin_list
,
dim
=
0
)
layer
.
w13_weight
=
Parameter
(
w1_marlin
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
w2_marlin
,
requires_grad
=
False
)
def
create_moe_runner
(
self
,
layer
:
torch
.
nn
.
Module
,
moe_runner_config
:
MoeRunnerConfig
):
self
.
moe_runner_config
=
moe_runner_config
# def apply(
# self,
# layer: torch.nn.Module,
# x: torch.Tensor,
# router_logits: torch.Tensor,
# top_k: int,
# renormalize: bool,
# use_grouped_topk: bool = False,
# topk_group: Optional[int] = None,
# num_expert_group: Optional[int] = None,
# global_num_experts: int = -1,
# expert_map: Optional[torch.Tensor] = None,
# custom_routing_function: Optional[Callable] = None,
# scoring_func: str = "softmax",
# e_score_correction_bias: Optional[torch.Tensor] = None,
# apply_router_weight_on_input: bool = False,
# activation: str = "silu",
# enable_eplb: bool = False,
# use_nn_moe: Optional[bool] = False,
# routed_scaling_factor: Optional[float] = None,
# use_fused_gate: Optional[bool] = False,
# expert_load_view: Optional[torch.Tensor] = None,
# logical_to_physical_map: Optional[torch.Tensor] = None,
# logical_replica_count: Optional[torch.Tensor] = None,
# shared_output: Optional[torch.Tensor] = None,
# ) -> torch.Tensor:
# from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
# if enable_eplb:
# raise NotImplementedError(
# "EPLB not supported for "
# "`CompressedTensorsW8A8Int8MoEMethod` yet.")
# topk_weights, topk_ids = FusedMoE.select_experts(
# hidden_states=x,
# router_logits=router_logits,
# use_grouped_topk=use_grouped_topk,
# top_k=top_k,
# renormalize=renormalize,
# topk_group=topk_group,
# num_expert_group=num_expert_group,
# custom_routing_function=custom_routing_function,
# scoring_func=scoring_func,
# routed_scaling_factor=routed_scaling_factor,
# use_fused_gate=use_fused_gate,
# e_score_correction_bias=e_score_correction_bias)
# return fused_experts_impl_int8_marlin(
# hidden_states=x,
# w1=layer.w13_weight,
# w2=layer.w2_weight,
# topk_weights=topk_weights,
# topk_ids=topk_ids,
# inplace=True,
# activation=activation,
# apply_router_weight_on_input=apply_router_weight_on_input,
# use_int8_w8a8=True,
# per_channel_quant=True,
# global_num_experts=global_num_experts,
# expert_map=expert_map,
# w1_scale=layer.w13_weight_scale,
# w2_scale=layer.w2_weight_scale,
# a1_scale=layer.w13_input_scale,
# a2_scale=layer.w2_input_scale,
# use_nn_moe=False,
# shared_output=shared_output,
# routed_scaling_factor=routed_scaling_factor)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
dispatch_output
,
)
:
from
sglang.srt.layers.moe.token_dispatcher.standard
import
StandardCombineInput
x
=
dispatch_output
.
hidden_states
topk_output
=
dispatch_output
.
topk_output
from
sglang.srt.layers.moe.topk
import
apply_topk_weights_cpu
topk_weights
,
topk_ids
,
_
=
topk_output
x
,
topk_weights
=
apply_topk_weights_cpu
(
self
.
moe_runner_config
.
apply_router_weight_on_input
,
topk_weights
,
x
)
output
=
fused_experts_impl_int8_marlin
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
layer
.
moe_runner_config
.
activation
,
apply_router_weight_on_input
=
self
.
moe_runner_config
.
apply_router_weight_on_input
,
use_int8_w8a8
=
True
,
per_channel_quant
=
True
,
global_num_experts
=
layer
.
moe_runner_config
.
num_experts
,
w1_scale
=
(
layer
.
w13_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale
),
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
False
,
)
return
StandardCombineInput
(
hidden_states
=
output
)
\ No newline at end of file
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
View file @
25796d05
...
@@ -15,10 +15,11 @@ from sglang.srt.layers.parameter import (
...
@@ -15,10 +15,11 @@ from sglang.srt.layers.parameter import (
from
sglang.srt.layers.quantization.compressed_tensors.schemes
import
(
from
sglang.srt.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
,
CompressedTensorsScheme
,
)
)
from
sglang.srt.layers.quantization.int8_kernel
import
per_token_quant_int8
# from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from
lmslim.layers.gemm.int8_utils
import
per_token_quant_int8
from
sglang.srt.layers.quantization.utils
import
requantize_with_max_scale
from
sglang.srt.layers.quantization.utils
import
requantize_with_max_scale
from
sglang.srt.utils
import
is_cuda
from
sglang.srt.utils
import
is_cuda
from
lmslim
import
quant_ops
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
if
_is_cuda
:
if
_is_cuda
:
from
sgl_kernel
import
int8_scaled_mm
from
sgl_kernel
import
int8_scaled_mm
...
@@ -168,6 +169,6 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
...
@@ -168,6 +169,6 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
# TODO: add cutlass_scaled_mm_azp support
# TODO: add cutlass_scaled_mm_azp support
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
return
int8
_scaled_mm
(
return
quant_ops
.
triton
_scaled_mm
(
x_q
,
layer
.
weight
,
x_scale
,
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
x_q
,
layer
.
weight
,
x_scale
,
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
)
python/sglang/srt/layers/quantization/slimquant_w4a8_marlin.py
View file @
25796d05
...
@@ -95,7 +95,7 @@ class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig):
...
@@ -95,7 +95,7 @@ class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig):
def
override_quantization_method
(
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
if
hf_quant_cfg
.
get
(
"quant_method"
)
==
"slimquant_w4a8"
\
if
hf_quant_cfg
.
get
(
"quant_method"
)
==
"slimquant_w4a8"
\
and
user_quant
==
"slimquant_w4a8_marlin"
:
and
user_quant
in
(
"slimquant_w4a8_marlin"
,
"slimquant_marlin"
)
:
return
cls
.
get_name
()
return
cls
.
get_name
()
return
None
return
None
def
get_quant_method
(
def
get_quant_method
(
...
...
python/sglang/srt/server_args.py
View file @
25796d05
...
@@ -94,6 +94,7 @@ QUANTIZATION_CHOICES = [
...
@@ -94,6 +94,7 @@ QUANTIZATION_CHOICES = [
"mxfp4"
,
"mxfp4"
,
"compressed-tensors"
,
# for Ktransformers
"compressed-tensors"
,
# for Ktransformers
"slimquant_w4a8_marlin"
,
"slimquant_w4a8_marlin"
,
"slimquant_marlin"
,
]
]
ATTENTION_BACKEND_CHOICES
=
[
ATTENTION_BACKEND_CHOICES
=
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment