Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d7e834d6
"docs/vscode:/vscode.git/clone" did not exist on "690f37bbe96a301cec8709a55a9d5e716f515683"
Unverified
Commit
d7e834d6
authored
Oct 23, 2025
by
Hongbo Xu
Committed by
GitHub
Oct 23, 2025
Browse files
[6/n]decouple quantization implementation from vLLM dependency (#10750)
parent
200a3c0b
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
389 additions
and
192 deletions
+389
-192
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+0
-52
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+8
-53
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+23
-65
python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py
...ayers/quantization/compressed_tensors/schemes/__init__.py
+3
-0
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
...ompressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
+4
-22
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
...on/compressed_tensors/schemes/compressed_tensors_wNa16.py
+339
-0
python/sglang/srt/layers/quantization/marlin_utils.py
python/sglang/srt/layers/quantization/marlin_utils.py
+12
-0
No files found.
python/sglang/srt/layers/quantization/__init__.py
View file @
d7e834d6
...
@@ -10,10 +10,6 @@ import torch
...
@@ -10,10 +10,6 @@ import torch
try
:
try
:
from
vllm.model_executor.layers.quantization.aqlm
import
AQLMConfig
from
vllm.model_executor.layers.quantization.aqlm
import
AQLMConfig
from
vllm.model_executor.layers.quantization.bitsandbytes
import
BitsAndBytesConfig
from
vllm.model_executor.layers.quantization.bitsandbytes
import
BitsAndBytesConfig
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe
import
(
CompressedTensorsW8A8Fp8MoEMethod
,
CompressedTensorsWNA16MoEMethod
,
)
from
vllm.model_executor.layers.quantization.deepspeedfp
import
DeepSpeedFPConfig
from
vllm.model_executor.layers.quantization.deepspeedfp
import
DeepSpeedFPConfig
from
vllm.model_executor.layers.quantization.experts_int8
import
ExpertsInt8Config
from
vllm.model_executor.layers.quantization.experts_int8
import
ExpertsInt8Config
from
vllm.model_executor.layers.quantization.gguf
import
GGUFConfig
from
vllm.model_executor.layers.quantization.gguf
import
GGUFConfig
...
@@ -175,51 +171,3 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
...
@@ -175,51 +171,3 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
return
original_isinstance
(
obj
,
classinfo
)
return
original_isinstance
(
obj
,
classinfo
)
builtins
.
isinstance
=
patched_isinstance
builtins
.
isinstance
=
patched_isinstance
def
monkey_patch_moe_apply
(
class_obj
:
"FusedMoEMethodBase"
):
"""
Monkey patch the apply function of vllm's FusedMoEMethodBase.
Convert sglang arguments to vllm arguments.
"""
original_apply
=
class_obj
.
apply
sig
=
inspect
.
signature
(
original_apply
)
param_names
=
list
(
sig
.
parameters
.
keys
())
has_correction_bias
=
"e_score_correction_bias"
in
param_names
def
new_apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
*
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
):
assert
activation
==
"silu"
assert
inplace
and
not
no_combine
kwargs
=
{
"self"
:
self
,
"layer"
:
layer
,
"x"
:
x
,
"topk_output"
:
topk_output
,
}
return
original_apply
(
**
kwargs
)
setattr
(
class_obj
,
"apply"
,
new_apply
)
def
monkey_patch_quant_configs
():
"""Apply all monkey patches in one place."""
monkey_patch_moe_apply
(
CompressedTensorsW8A8Fp8MoEMethod
)
monkey_patch_moe_apply
(
CompressedTensorsWNA16MoEMethod
)
# Only apply monkey patches if vllm is available
if
VLLM_AVAILABLE
:
monkey_patch_quant_configs
()
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
d7e834d6
...
@@ -30,10 +30,12 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe im
...
@@ -30,10 +30,12 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe im
CompressedTensorsMoEMethod
,
CompressedTensorsMoEMethod
,
)
)
from
sglang.srt.layers.quantization.compressed_tensors.schemes
import
(
from
sglang.srt.layers.quantization.compressed_tensors.schemes
import
(
WNA16_SUPPORTED_BITS
,
CompressedTensorsScheme
,
CompressedTensorsScheme
,
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A8Int8
,
CompressedTensorsW8A8Int8
,
CompressedTensorsW8A16Fp8
,
CompressedTensorsW8A16Fp8
,
CompressedTensorsWNA16
,
)
)
from
sglang.srt.layers.quantization.compressed_tensors.utils
import
(
from
sglang.srt.layers.quantization.compressed_tensors.utils
import
(
find_matched_target
,
find_matched_target
,
...
@@ -43,23 +45,6 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
...
@@ -43,23 +45,6 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
from
sglang.srt.layers.quantization.fp8
import
Fp8LinearMethod
from
sglang.srt.layers.quantization.fp8
import
Fp8LinearMethod
from
sglang.srt.layers.quantization.unquant
import
UnquantizedLinearMethod
from
sglang.srt.layers.quantization.unquant
import
UnquantizedLinearMethod
try
:
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_24
import
(
CompressedTensors24
,
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w4a16_sparse24
import
(
W4A16SPARSE24_SUPPORTED_BITS
,
CompressedTensorsW4A16Sparse24
,
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16
import
(
WNA16_SUPPORTED_BITS
,
CompressedTensorsWNA16
,
)
VLLM_AVAILABLE
=
True
except
ImportError
:
VLLM_AVAILABLE
=
False
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
"CompressedTensorsLinearMethod"
]
__all__
=
[
"CompressedTensorsLinearMethod"
]
...
@@ -380,19 +365,6 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -380,19 +365,6 @@ class CompressedTensorsConfig(QuantizationConfig):
# Detect If Mixed Precision
# Detect If Mixed Precision
if
self
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
if
self
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
if
not
VLLM_AVAILABLE
:
raise
ImportError
(
"vllm is not installed, to use CompressedTensorsW4A16Sparse24 and CompressedTensorsWNA16, please install vllm"
)
if
(
self
.
quant_format
==
CompressionFormat
.
marlin_24
.
value
and
weight_quant
.
num_bits
in
W4A16SPARSE24_SUPPORTED_BITS
):
return
CompressedTensorsW4A16Sparse24
(
strategy
=
weight_quant
.
strategy
,
num_bits
=
weight_quant
.
num_bits
,
group_size
=
weight_quant
.
group_size
,
)
if
(
if
(
self
.
quant_format
==
CompressionFormat
.
pack_quantized
.
value
self
.
quant_format
==
CompressionFormat
.
pack_quantized
.
value
and
weight_quant
.
num_bits
in
WNA16_SUPPORTED_BITS
and
weight_quant
.
num_bits
in
WNA16_SUPPORTED_BITS
...
@@ -403,6 +375,10 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -403,6 +375,10 @@ class CompressedTensorsConfig(QuantizationConfig):
group_size
=
weight_quant
.
group_size
,
group_size
=
weight_quant
.
group_size
,
actorder
=
weight_quant
.
actorder
,
actorder
=
weight_quant
.
actorder
,
)
)
else
:
raise
ImportError
(
"Other method (CompressedTensorsW4A16Sparse24) is not supported now"
)
if
is_activation_quantization_format
(
self
.
quant_format
):
if
is_activation_quantization_format
(
self
.
quant_format
):
if
self
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
if
self
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
...
@@ -426,10 +402,6 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -426,10 +402,6 @@ class CompressedTensorsConfig(QuantizationConfig):
# note: input_quant can be None
# note: input_quant can be None
if
self
.
_is_fp8_w8a16
(
weight_quant
,
input_quant
):
if
self
.
_is_fp8_w8a16
(
weight_quant
,
input_quant
):
if
not
VLLM_AVAILABLE
:
raise
ImportError
(
"vllm is not installed, to use CompressedTensorsW8A16Fp8, please install vllm"
)
is_static_input_scheme
=
input_quant
and
not
input_quant
.
dynamic
is_static_input_scheme
=
input_quant
and
not
input_quant
.
dynamic
return
CompressedTensorsW8A16Fp8
(
return
CompressedTensorsW8A16Fp8
(
strategy
=
weight_quant
.
strategy
,
strategy
=
weight_quant
.
strategy
,
...
@@ -470,7 +442,7 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -470,7 +442,7 @@ class CompressedTensorsConfig(QuantizationConfig):
# Find the "target" in the compressed-tensors config
# Find the "target" in the compressed-tensors config
# that our layer conforms to.
# that our layer conforms to.
# TODO
(@robertgshaw)
: add compressed-tensors as dep
# TODO : add compressed-tensors as dep
# so we do not have to re-write these functions
# so we do not have to re-write these functions
# need to make accelerate optional in ct to do this
# need to make accelerate optional in ct to do this
...
@@ -508,24 +480,7 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -508,24 +480,7 @@ class CompressedTensorsConfig(QuantizationConfig):
input_quant
=
input_quant
,
input_quant
=
input_quant
,
sparsity_scheme
=
sparsity_scheme
,
sparsity_scheme
=
sparsity_scheme
,
):
):
if
not
VLLM_AVAILABLE
:
raise
ImportError
(
"CompressedTensors24 is not supported now"
)
raise
ImportError
(
"vllm is not installed, to use CompressedTensors24, please install vllm"
)
# Have a valid sparsity scheme
# Validate layer is supported by Cutlass 2:4 Kernel
model_compression_config
=
(
None
if
sparsity_scheme
is
None
or
sparsity_scheme
.
format
==
"dense"
else
self
.
config
)
scheme
=
CompressedTensors24
(
quantized
=
weight_quant
is
not
None
or
input_quant
is
not
None
,
weight_quant
=
weight_quant
,
input_quant
=
input_quant
,
model_compression_config
=
model_compression_config
,
)
elif
weight_quant
is
None
:
elif
weight_quant
is
None
:
logger
.
warning_once
(
logger
.
warning_once
(
"Acceleration for non-quantized schemes is "
"Acceleration for non-quantized schemes is "
...
...
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
d7e834d6
...
@@ -6,7 +6,7 @@ import enum
...
@@ -6,7 +6,7 @@ import enum
import
logging
import
logging
import
re
import
re
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
TYPE_CHECKING
,
List
from
typing
import
TYPE_CHECKING
try
:
try
:
from
sgl_kernel
import
fused_marlin_moe
from
sgl_kernel
import
fused_marlin_moe
...
@@ -31,9 +31,13 @@ from sglang.srt.environ import envs
...
@@ -31,9 +31,13 @@ from sglang.srt.environ import envs
from
sglang.srt.layers.moe
import
MoeRunner
,
MoeRunnerBackend
,
MoeRunnerConfig
from
sglang.srt.layers.moe
import
MoeRunner
,
MoeRunnerBackend
,
MoeRunnerConfig
from
sglang.srt.layers.moe.moe_runner.triton
import
TritonMoeQuantInfo
from
sglang.srt.layers.moe.moe_runner.triton
import
TritonMoeQuantInfo
from
sglang.srt.layers.quantization.base_config
import
FusedMoEMethodBase
from
sglang.srt.layers.quantization.base_config
import
FusedMoEMethodBase
from
sglang.srt.layers.quantization.compressed_tensors
import
WNA16_SUPPORTED_BITS
from
sglang.srt.layers.quantization.compressed_tensors.schemes
import
(
WNA16_SUPPORTED_BITS
,
)
from
sglang.srt.layers.quantization.fp8_kernel
import
is_fp8_fnuz
,
scaled_fp8_quant
from
sglang.srt.layers.quantization.fp8_kernel
import
is_fp8_fnuz
,
scaled_fp8_quant
from
sglang.srt.layers.quantization.fp8_utils
import
normalize_e4m3fn_to_e4m3fnuz
from
sglang.srt.layers.quantization.fp8_utils
import
normalize_e4m3fn_to_e4m3fnuz
from
sglang.srt.layers.quantization.gptq
import
gptq_marlin_moe_repack
from
sglang.srt.layers.quantization.marlin_utils
import
marlin_moe_permute_scales
from
sglang.srt.layers.quantization.utils
import
(
from
sglang.srt.layers.quantization.utils
import
(
all_close_1d
,
all_close_1d
,
per_tensor_dequantize
,
per_tensor_dequantize
,
...
@@ -42,6 +46,7 @@ from sglang.srt.layers.quantization.utils import (
...
@@ -42,6 +46,7 @@ from sglang.srt.layers.quantization.utils import (
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
get_bool_env_var
,
get_bool_env_var
,
get_compiler_backend
,
get_compiler_backend
,
is_cuda
,
is_hip
,
is_hip
,
set_weight_attrs
,
set_weight_attrs
,
)
)
...
@@ -57,6 +62,8 @@ if TYPE_CHECKING:
...
@@ -57,6 +62,8 @@ if TYPE_CHECKING:
)
)
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
_is_cuda
=
is_cuda
()
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
if
_use_aiter
:
if
_use_aiter
:
...
@@ -64,12 +71,9 @@ if _use_aiter:
...
@@ -64,12 +71,9 @@ if _use_aiter:
from
sglang.srt.layers.moe.rocm_moe_utils
import
rocm_fused_experts_tkw1
from
sglang.srt.layers.moe.rocm_moe_utils
import
rocm_fused_experts_tkw1
try
:
import
vllm
# noqa: F401
VLLM_AVAILABLE
=
True
if
_is_cuda
:
except
ImportError
:
from
sgl_kernel
import
fused_marlin_moe
VLLM_AVAILABLE
=
False
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -127,10 +131,8 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
...
@@ -127,10 +131,8 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
weight_quant
=
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
weight_quant
=
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
input_quant
=
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"input_activations"
)
input_quant
=
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"input_activations"
)
if
quant_config
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
if
quant_config
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
if
not
VLLM_AVAILABLE
:
raise
ImportError
(
logger
.
info_once
(
"Using CompressedTensorsWNA16MarlinMoEMethod"
)
"vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm."
)
return
CompressedTensorsWNA16MoEMethod
(
quant_config
)
return
CompressedTensorsWNA16MoEMethod
(
quant_config
)
elif
quant_config
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
elif
quant_config
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Fp8MoEMethod
(
quant_config
)
return
CompressedTensorsW8A8Fp8MoEMethod
(
quant_config
)
...
@@ -432,9 +434,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
...
@@ -432,9 +434,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
):
):
if
self
.
num_gpu_experts
!=
-
1
:
if
self
.
num_gpu_experts
!=
-
1
:
num_experts
=
self
.
num_gpu_experts
num_experts
=
self
.
num_gpu_experts
# assert (
# params_dtype == torch.float16
# ), "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501
# Will transpose the loaded weight along the
# Will transpose the loaded weight along the
# intermediate and hidden dim sizes. Will
# intermediate and hidden dim sizes. Will
...
@@ -573,44 +572,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
...
@@ -573,44 +572,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
getattr
(
layer
,
name
).
copy_
(
new_t
)
getattr
(
layer
,
name
).
copy_
(
new_t
)
del
new_t
del
new_t
def
get_scale_perms
(
num_bits
:
int
):
scale_perm
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
:
List
[
int
]
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]]
)
return
scale_perm
,
scale_perm_single
def
marlin_permute_scales
(
s
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
group_size
:
int
,
num_bits
:
int
):
scale_perm
,
scale_perm_single
=
get_scale_perms
(
num_bits
)
if
group_size
<
size_k
and
group_size
!=
-
1
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
else
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
s
def
marlin_moe_permute_scales
(
s
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
group_size
:
int
,
num_bits
:
int
):
num_experts
=
s
.
shape
[
0
]
output
=
torch
.
empty
(
(
num_experts
,
s
.
shape
[
1
],
s
.
shape
[
2
]),
device
=
s
.
device
,
dtype
=
s
.
dtype
)
for
e
in
range
(
num_experts
):
output
[
e
]
=
marlin_permute_scales
(
s
[
e
],
size_k
,
size_n
,
group_size
,
num_bits
)
return
output
size_k2
=
layer
.
w2_weight_packed
.
shape
[
2
]
size_k13
=
layer
.
w13_weight_packed
.
shape
[
2
]
num_experts
=
layer
.
w13_weight_g_idx
.
shape
[
0
]
num_experts
=
layer
.
w13_weight_g_idx
.
shape
[
0
]
device
=
layer
.
w13_weight_g_idx
.
device
device
=
layer
.
w13_weight_g_idx
.
device
...
@@ -657,42 +618,39 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
...
@@ -657,42 +618,39 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
requires_grad
=
False
,
requires_grad
=
False
,
)
)
from
vllm
import
_custom_ops
as
vllm_ops
marlin_w13_qweight
=
gptq_marlin_moe_repack
(
marlin_w13_qweight
=
vllm_ops
.
gptq_marlin_moe_repack
(
layer
.
w13_weight_packed
,
layer
.
w13_weight_packed
,
layer
.
w13_g_idx_sort_indices
,
layer
.
w13_g_idx_sort_indices
,
layer
.
w13_weight_packed
.
shape
[
1
]
*
self
.
packed_factor
,
layer
.
w13_weight_packed
.
shape
[
1
]
*
self
.
packed_factor
,
layer
.
w13_weight_packed
.
shape
[
2
],
layer
.
w13_weight_packed
.
shape
[
2
],
self
.
num_bits
,
self
.
num_bits
,
)
)
replace_
tensor
(
"w13_weight_packed"
,
marlin_w13_qweight
)
replace_
parameter
(
layer
,
"w13_weight_packed"
,
marlin_w13_qweight
)
marlin_w2_qweight
=
vllm_ops
.
gptq_marlin_moe_repack
(
marlin_w2_qweight
=
gptq_marlin_moe_repack
(
layer
.
w2_weight_packed
,
layer
.
w2_weight_packed
,
layer
.
w2_g_idx_sort_indices
,
layer
.
w2_g_idx_sort_indices
,
layer
.
w2_weight_packed
.
shape
[
1
]
*
self
.
packed_factor
,
layer
.
w2_weight_packed
.
shape
[
1
]
*
self
.
packed_factor
,
layer
.
w2_weight_packed
.
shape
[
2
],
layer
.
w2_weight_packed
.
shape
[
2
],
self
.
num_bits
,
self
.
num_bits
,
)
)
replace_
tensor
(
"w2_weight_packed"
,
marlin_w2_qweight
)
replace_
parameter
(
layer
,
"w2_weight_packed"
,
marlin_w2_qweight
)
# Repack scales
# Repack scales
marlin_w13_scales
=
marlin_moe_permute_scales
(
marlin_w13_scales
=
marlin_moe_permute_scales
(
layer
.
w13_weight_scale
,
layer
.
w13_weight_scale
,
size_k13
,
layer
.
w13_weight_packed
.
shape
[
2
]
,
layer
.
w13_weight_scale
.
shape
[
2
],
layer
.
w13_weight_scale
.
shape
[
2
],
self
.
group_size
,
self
.
group_size
,
self
.
num_bits
,
)
)
replace_tensor
(
"w13_weight_scale"
,
marlin_w13_scales
)
replace_parameter
(
layer
,
"w13_weight_scale"
,
marlin_w13_scales
)
marlin_w2_scales
=
marlin_moe_permute_scales
(
marlin_w2_scales
=
marlin_moe_permute_scales
(
layer
.
w2_weight_scale
,
layer
.
w2_weight_scale
,
layer
.
w2_weight_scale
.
shape
[
1
]
layer
.
w2_weight_scale
.
shape
[
1
]
*
(
self
.
group_size
if
self
.
group_size
!=
-
1
else
self
.
packed_factor
),
*
(
self
.
group_size
if
self
.
group_size
!=
-
1
else
self
.
packed_factor
),
size_k2
,
layer
.
w2_weight_scale
.
shape
[
2
]
,
self
.
group_size
,
self
.
group_size
,
self
.
num_bits
,
)
)
replace_
tensor
(
"w2_weight_scale"
,
marlin_w2_scales
)
replace_
parameter
(
layer
,
"w2_weight_scale"
,
marlin_w2_scales
)
def
create_moe_runner
(
def
create_moe_runner
(
self
,
layer
:
torch
.
nn
.
Module
,
moe_runner_config
:
MoeRunnerConfig
self
,
layer
:
torch
.
nn
.
Module
,
moe_runner_config
:
MoeRunnerConfig
...
@@ -716,7 +674,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
...
@@ -716,7 +674,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
topk_weights
,
topk_ids
,
router_logits
=
topk_output
topk_weights
,
topk_ids
,
router_logits
=
topk_output
output
=
torch
.
ops
.
vllm
.
fused_marlin_moe
(
output
=
fused_marlin_moe
(
x
,
x
,
layer
.
w13_weight_packed
,
layer
.
w13_weight_packed
,
layer
.
w2_weight_packed
,
layer
.
w2_weight_packed
,
...
...
python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py
View file @
d7e834d6
...
@@ -4,10 +4,13 @@ from .compressed_tensors_scheme import CompressedTensorsScheme
...
@@ -4,10 +4,13 @@ from .compressed_tensors_scheme import CompressedTensorsScheme
from
.compressed_tensors_w8a8_fp8
import
CompressedTensorsW8A8Fp8
from
.compressed_tensors_w8a8_fp8
import
CompressedTensorsW8A8Fp8
from
.compressed_tensors_w8a8_int8
import
CompressedTensorsW8A8Int8
from
.compressed_tensors_w8a8_int8
import
CompressedTensorsW8A8Int8
from
.compressed_tensors_w8a16_fp8
import
CompressedTensorsW8A16Fp8
from
.compressed_tensors_w8a16_fp8
import
CompressedTensorsW8A16Fp8
from
.compressed_tensors_wNa16
import
WNA16_SUPPORTED_BITS
,
CompressedTensorsWNA16
__all__
=
[
__all__
=
[
"CompressedTensorsScheme"
,
"CompressedTensorsScheme"
,
"CompressedTensorsW8A8Fp8"
,
"CompressedTensorsW8A8Fp8"
,
"CompressedTensorsW8A16Fp8"
,
"CompressedTensorsW8A16Fp8"
,
"CompressedTensorsW8A8Int8"
,
"CompressedTensorsW8A8Int8"
,
"CompressedTensorsWNA16"
,
"WNA16_SUPPORTED_BITS"
,
]
]
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
View file @
d7e834d6
...
@@ -14,25 +14,12 @@ from sglang.srt.layers.parameter import (
...
@@ -14,25 +14,12 @@ from sglang.srt.layers.parameter import (
from
sglang.srt.layers.quantization.compressed_tensors.schemes
import
(
from
sglang.srt.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
,
CompressedTensorsScheme
,
)
)
from
sglang.srt.layers.quantization.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
,
)
from
sglang.srt.layers.quantization.utils
import
convert_to_channelwise
from
sglang.srt.layers.quantization.utils
import
convert_to_channelwise
try
:
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
,
)
MARLIN_FP8_AVAILABLE
=
True
except
ImportError
:
MARLIN_FP8_AVAILABLE
=
False
def
apply_fp8_marlin_linear
(
*
args
,
**
kwargs
):
raise
ImportError
(
"vllm is not installed"
)
def
prepare_fp8_layer_for_marlin
(
*
args
,
**
kwargs
):
raise
ImportError
(
"vllm is not installed"
)
__all__
=
[
"CompressedTensorsW8A16Fp8"
]
__all__
=
[
"CompressedTensorsW8A16Fp8"
]
SUPPORTED_STRATEGIES
=
[
QuantizationStrategy
.
CHANNEL
,
QuantizationStrategy
.
TENSOR
]
SUPPORTED_STRATEGIES
=
[
QuantizationStrategy
.
CHANNEL
,
QuantizationStrategy
.
TENSOR
]
...
@@ -43,11 +30,6 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
...
@@ -43,11 +30,6 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
self
.
strategy
=
strategy
self
.
strategy
=
strategy
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
is_static_input_scheme
=
is_static_input_scheme
if
not
MARLIN_FP8_AVAILABLE
:
raise
ImportError
(
"vllm is not installed. To use CompressedTensorsW8A16Fp8, please install vllm"
)
@
classmethod
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
def
get_min_capability
(
cls
)
->
int
:
# ampere and up
# ampere and up
...
...
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
0 → 100644
View file @
d7e834d6
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
import
logging
from
typing
import
Callable
,
Optional
import
torch
from
compressed_tensors.quantization
import
ActivationOrdering
# yapf conflicts with isort for this block
# yapf: disable
from
sglang.srt.layers.parameter
import
(
BasevLLMParameter
,
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
PackedColumnParameter
,
PackedvLLMParameter
,
RowvLLMParameter
,
permute_param_layout_
,
)
from
sglang.srt.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
,
)
from
sglang.srt.layers.quantization.marlin_utils
import
(
MarlinLinearLayerConfig
,
apply_gptq_marlin_linear
,
check_marlin_supports_shape
,
marlin_is_k_full
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
marlin_repeat_scales_on_all_ranks
,
marlin_sort_g_idx
,
marlin_zero_points
,
)
from
sglang.srt.layers.quantization.utils
import
(
get_scalar_types
,
replace_parameter
,
unpack_cols
,
)
from
sglang.srt.utils
import
is_cuda
_is_cuda
=
is_cuda
()
if
_is_cuda
:
from
sgl_kernel
import
gptq_marlin_repack
ScalarType
,
scalar_types
=
get_scalar_types
()
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
"CompressedTensorsWNA16"
]
WNA16_SUPPORTED_TYPES_MAP
=
{
4
:
scalar_types
.
uint4b8
,
8
:
scalar_types
.
uint8b128
}
WNA16_ZP_SUPPORTED_TYPES_MAP
=
{
4
:
scalar_types
.
uint4
,
8
:
scalar_types
.
uint8
}
WNA16_SUPPORTED_BITS
=
list
(
WNA16_SUPPORTED_TYPES_MAP
.
keys
())
class
CompressedTensorsWNA16
(
CompressedTensorsScheme
):
_kernel_backends_being_used
:
set
[
str
]
=
set
()
def
__init__
(
self
,
strategy
:
str
,
num_bits
:
int
,
group_size
:
Optional
[
int
]
=
None
,
symmetric
:
Optional
[
bool
]
=
True
,
actorder
:
Optional
[
ActivationOrdering
]
=
None
):
self
.
pack_factor
=
32
//
num_bits
self
.
strategy
=
strategy
self
.
symmetric
=
symmetric
self
.
group_size
=
-
1
if
group_size
is
None
else
group_size
self
.
has_g_idx
=
actorder
==
ActivationOrdering
.
GROUP
if
self
.
group_size
==
-
1
and
self
.
strategy
!=
"channel"
:
raise
ValueError
(
"Marlin kernels require group quantization or "
"channelwise quantization, but found no group "
"size and strategy is not channelwise."
)
if
num_bits
not
in
WNA16_SUPPORTED_TYPES_MAP
:
raise
ValueError
(
f
"Unsupported num_bits =
{
num_bits
}
. "
f
"Supported num_bits =
{
WNA16_SUPPORTED_TYPES_MAP
.
keys
()
}
"
)
self
.
quant_type
=
(
WNA16_ZP_SUPPORTED_TYPES_MAP
[
num_bits
]
if
not
self
.
symmetric
else
WNA16_SUPPORTED_TYPES_MAP
[
num_bits
])
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# ampere and up
return
80
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_size
:
int
,
input_size
:
int
,
output_partition_sizes
:
list
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
output_size_per_partition
=
sum
(
output_partition_sizes
)
self
.
kernel_config
=
MarlinLinearLayerConfig
(
full_weight_shape
=
(
input_size
,
output_size
),
partition_weight_shape
=
(
input_size_per_partition
,
output_size_per_partition
,
),
weight_type
=
self
.
quant_type
,
act_type
=
params_dtype
,
group_size
=
self
.
group_size
,
zero_points
=
not
self
.
symmetric
,
has_g_idx
=
self
.
has_g_idx
)
# If group_size is -1, we are in channelwise case.
group_size
=
self
.
group_size
if
self
.
group_size
!=
-
1
else
input_size
row_parallel
=
(
input_size
!=
input_size_per_partition
)
partition_scales
=
not
marlin_repeat_scales_on_all_ranks
(
self
.
has_g_idx
,
self
.
group_size
,
row_parallel
)
scales_and_zp_size
=
input_size
//
group_size
if
partition_scales
:
assert
input_size_per_partition
%
group_size
==
0
scales_and_zp_size
=
input_size_per_partition
//
group_size
weight
=
PackedvLLMParameter
(
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
packed_factor
=
self
.
pack_factor
,
packed_dim
=
1
,
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
//
self
.
pack_factor
,
dtype
=
torch
.
int32
,
))
weight_scale_args
=
{
"weight_loader"
:
weight_loader
,
"data"
:
torch
.
empty
(
output_size_per_partition
,
scales_and_zp_size
,
dtype
=
params_dtype
,
)
}
zeros_args
=
{
"weight_loader"
:
weight_loader
,
"data"
:
torch
.
zeros
(
output_size_per_partition
//
self
.
pack_factor
,
scales_and_zp_size
,
dtype
=
torch
.
int32
,
)
}
if
not
partition_scales
:
weight_scale
=
ChannelQuantScaleParameter
(
output_dim
=
0
,
**
weight_scale_args
)
if
not
self
.
symmetric
:
qzeros
=
PackedColumnParameter
(
output_dim
=
0
,
packed_dim
=
0
,
packed_factor
=
self
.
pack_factor
,
**
zeros_args
)
else
:
weight_scale
=
GroupQuantScaleParameter
(
output_dim
=
0
,
input_dim
=
1
,
**
weight_scale_args
)
if
not
self
.
symmetric
:
qzeros
=
PackedvLLMParameter
(
input_dim
=
1
,
output_dim
=
0
,
packed_dim
=
0
,
packed_factor
=
self
.
pack_factor
,
**
zeros_args
)
# A 2D array defining the original shape of the weights
# before packing
weight_shape
=
BasevLLMParameter
(
data
=
torch
.
empty
(
2
,
dtype
=
torch
.
int64
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight_packed"
,
weight
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_shape"
,
weight_shape
)
if
not
self
.
symmetric
:
layer
.
register_parameter
(
"weight_zero_point"
,
qzeros
)
# group index (for activation reordering)
if
self
.
has_g_idx
:
weight_g_idx
=
RowvLLMParameter
(
data
=
torch
.
empty
(
input_size_per_partition
,
dtype
=
torch
.
int32
,
),
input_dim
=
0
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight_g_idx"
,
weight_g_idx
)
# Checkpoints are serialized in compressed-tensors format, which is
# different from the format the kernel may want. Handle repacking here.
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# Default names since marlin requires empty parameters for these,
# TODO: remove this requirement from marlin (allow optional tensors)
self
.
w_q_name
=
"weight_packed"
self
.
w_s_name
=
"weight_scale"
self
.
w_zp_name
=
"weight_zero_point"
self
.
w_gidx_name
=
"weight_g_idx"
device
=
getattr
(
layer
,
self
.
w_q_name
).
device
c
=
self
.
kernel_config
check_marlin_supports_shape
(
c
.
partition_weight_shape
[
1
],
# out_features
c
.
partition_weight_shape
[
0
],
# in_features
c
.
full_weight_shape
[
0
],
# in_features
c
.
group_size
,
)
row_parallel
=
c
.
partition_weight_shape
[
0
]
!=
c
.
full_weight_shape
[
0
]
self
.
is_k_full
=
marlin_is_k_full
(
c
.
has_g_idx
,
row_parallel
)
# Allocate marlin workspace.
self
.
workspace
=
marlin_make_workspace
(
device
)
def
_transform_param
(
layer
:
torch
.
nn
.
Module
,
name
:
Optional
[
str
],
fn
:
Callable
)
->
None
:
if
name
is
not
None
and
getattr
(
layer
,
name
,
None
)
is
not
None
:
old_param
=
getattr
(
layer
,
name
)
new_param
=
fn
(
old_param
)
# replace the parameter with torch.nn.Parameter for TorchDynamo
# compatibility
replace_parameter
(
layer
,
name
,
torch
.
nn
.
Parameter
(
new_param
.
data
,
requires_grad
=
False
)
)
def
transform_w_q
(
x
):
assert
isinstance
(
x
,
BasevLLMParameter
)
permute_param_layout_
(
x
,
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
0
)
x
.
data
=
gptq_marlin_repack
(
x
.
data
.
contiguous
(),
perm
=
layer
.
g_idx_sort_indices
,
size_k
=
c
.
partition_weight_shape
[
0
],
size_n
=
c
.
partition_weight_shape
[
1
],
num_bits
=
c
.
weight_type
.
size_bits
,
)
return
x
def
transform_w_s
(
x
):
assert
isinstance
(
x
,
BasevLLMParameter
)
permute_param_layout_
(
x
,
input_dim
=
0
,
output_dim
=
1
)
x
.
data
=
marlin_permute_scales
(
x
.
data
.
contiguous
(),
size_k
=
c
.
partition_weight_shape
[
0
],
size_n
=
c
.
partition_weight_shape
[
1
],
group_size
=
c
.
group_size
,
)
return
x
if
c
.
has_g_idx
:
g_idx
,
g_idx_sort_indices
=
marlin_sort_g_idx
(
getattr
(
layer
,
self
.
w_gidx_name
)
)
_transform_param
(
layer
,
self
.
w_gidx_name
,
lambda
_
:
g_idx
)
layer
.
g_idx_sort_indices
=
g_idx_sort_indices
else
:
setattr
(
layer
,
self
.
w_gidx_name
,
marlin_make_empty_g_idx
(
device
))
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
device
)
if
c
.
zero_points
:
grouped_k
=
(
c
.
partition_weight_shape
[
0
]
//
c
.
group_size
if
c
.
group_size
!=
-
1
else
1
)
_transform_param
(
layer
,
self
.
w_zp_name
,
lambda
x
:
marlin_zero_points
(
unpack_cols
(
x
.
t
(),
c
.
weight_type
.
size_bits
,
grouped_k
,
c
.
partition_weight_shape
[
1
],
),
size_k
=
grouped_k
,
size_n
=
c
.
partition_weight_shape
[
1
],
num_bits
=
c
.
weight_type
.
size_bits
,
),
)
else
:
setattr
(
layer
,
self
.
w_zp_name
,
marlin_make_empty_g_idx
(
device
))
_transform_param
(
layer
,
self
.
w_q_name
,
transform_w_q
)
_transform_param
(
layer
,
self
.
w_s_name
,
transform_w_s
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
c
=
self
.
kernel_config
def
_get_weight_params
(
layer
:
torch
.
nn
.
Module
,
)
->
tuple
[
torch
.
Tensor
,
# w_q
torch
.
Tensor
,
# w_s
Optional
[
torch
.
Tensor
],
# w_zp,
Optional
[
torch
.
Tensor
],
# w_gidx
]:
return
(
getattr
(
layer
,
self
.
w_q_name
),
getattr
(
layer
,
self
.
w_s_name
),
getattr
(
layer
,
self
.
w_zp_name
or
""
,
None
),
getattr
(
layer
,
self
.
w_gidx_name
or
""
,
None
),
)
w_q
,
w_s
,
w_zp
,
w_gidx
=
_get_weight_params
(
layer
)
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
# None for marlin
return
apply_gptq_marlin_linear
(
input
=
x
,
weight
=
w_q
,
weight_scale
=
w_s
,
weight_zp
=
w_zp
,
# type: ignore
g_idx
=
w_gidx
,
# type: ignore
g_idx_sort_indices
=
layer
.
g_idx_sort_indices
,
workspace
=
self
.
workspace
,
wtype
=
c
.
weight_type
,
input_size_per_partition
=
c
.
partition_weight_shape
[
0
],
output_size_per_partition
=
c
.
partition_weight_shape
[
1
],
is_k_full
=
self
.
is_k_full
,
bias
=
bias
,
)
python/sglang/srt/layers/quantization/marlin_utils.py
View file @
d7e834d6
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
from
__future__
import
annotations
from
__future__
import
annotations
import
logging
import
logging
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
numpy
import
numpy
...
@@ -57,6 +58,17 @@ MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
...
@@ -57,6 +58,17 @@ MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
USE_FP32_REDUCE_DEFAULT
=
True
USE_FP32_REDUCE_DEFAULT
=
True
@
dataclass
class
MarlinLinearLayerConfig
:
full_weight_shape
:
tuple
[
int
,
int
]
# [in, out]
partition_weight_shape
:
tuple
[
int
,
int
]
weight_type
:
ScalarType
act_type
:
torch
.
dtype
group_size
:
int
zero_points
:
bool
has_g_idx
:
bool
# For binary size and compile time, we don't support the same types for with and
# For binary size and compile time, we don't support the same types for with and
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
# TODO: we may want to move this into the C++ so its closer to the actual impl
# TODO: we may want to move this into the C++ so its closer to the actual impl
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment