Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1f76fc87
Unverified
Commit
1f76fc87
authored
Jul 19, 2025
by
Hongbo Xu
Committed by
GitHub
Jul 18, 2025
Browse files
[3/n] chore: decouple AWQ implementation from vLLM dependency (#8113)
Co-authored-by:
AniZpZ
<
zhuangsen.zp@antgroup.com
>
parent
6737671c
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1143 additions
and
20 deletions
+1143
-20
benchmark/deepseek_v3/README.md
benchmark/deepseek_v3/README.md
+9
-0
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+7
-15
python/sglang/srt/layers/quantization/awq.py
python/sglang/srt/layers/quantization/awq.py
+582
-2
python/sglang/srt/layers/quantization/utils.py
python/sglang/srt/layers/quantization/utils.py
+84
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+3
-1
python/sglang/test/test_marlin_moe.py
python/sglang/test/test_marlin_moe.py
+286
-0
python/sglang/test/test_marlin_utils.py
python/sglang/test/test_marlin_utils.py
+171
-0
test/srt/test_gptqmodel_dynamic.py
test/srt/test_gptqmodel_dynamic.py
+1
-1
No files found.
benchmark/deepseek_v3/README.md
View file @
1f76fc87
...
@@ -178,6 +178,8 @@ python3 -m sglang.bench_one_batch_server --model None --base-url http://10.0.0.1
...
@@ -178,6 +178,8 @@ python3 -m sglang.bench_one_batch_server --model None --base-url http://10.0.0.1
### Example: Serving with 8 A100/A800 with AWQ Quantization
### Example: Serving with 8 A100/A800 with AWQ Quantization
**Recommended Usage**
Add
`--quantization moe_wna16`
flag to enable moe wna16 kernel for better performance.
Add
`--quantization moe_wna16`
flag to enable moe wna16 kernel for better performance.
One example is as follows:
One example is as follows:
...
@@ -185,6 +187,13 @@ One example is as follows:
...
@@ -185,6 +187,13 @@ One example is as follows:
python3
-m
sglang.launch_server
--model
cognitivecomputations/DeepSeek-R1-AWQ
--tp
8
--trust-remote-code
--quantization
moe_wna16
python3
-m
sglang.launch_server
--model
cognitivecomputations/DeepSeek-R1-AWQ
--tp
8
--trust-remote-code
--quantization
moe_wna16
```
```
Alternatively, you can use
`--quantization awq_marlin`
as follows:
```
bash
python3
-m
sglang.launch_server
--model
cognitivecomputations/DeepSeek-R1-AWQ
--tp
8
--trust-remote-code
--quantization
awq_marlin
--dtype
float16
```
Note that
`awq_marlin`
only supports
`float16`
now, which may lead to some precision loss.
### Example: Serving with 16 A100/A800 with int8 Quantization
### Example: Serving with 16 A100/A800 with int8 Quantization
...
...
python/sglang/srt/layers/quantization/__init__.py
View file @
1f76fc87
...
@@ -7,10 +7,6 @@ import torch
...
@@ -7,10 +7,6 @@ import torch
try
:
try
:
from
vllm.model_executor.layers.quantization.aqlm
import
AQLMConfig
from
vllm.model_executor.layers.quantization.aqlm
import
AQLMConfig
from
vllm.model_executor.layers.quantization.awq_marlin
import
(
AWQMarlinConfig
,
AWQMoEMethod
,
)
from
vllm.model_executor.layers.quantization.bitsandbytes
import
BitsAndBytesConfig
from
vllm.model_executor.layers.quantization.bitsandbytes
import
BitsAndBytesConfig
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe
import
(
CompressedTensorsW8A8Fp8MoEMethod
,
CompressedTensorsW8A8Fp8MoEMethod
,
...
@@ -36,14 +32,14 @@ except ImportError:
...
@@ -36,14 +32,14 @@ except ImportError:
def
override_quantization_method
(
self
,
*
args
,
**
kwargs
):
def
override_quantization_method
(
self
,
*
args
,
**
kwargs
):
return
None
return
None
AQLMConfig
=
AWQMarlinConfig
=
BitsAndBytesConfig
=
CompressedTensorsConfig
=
(
AQLMConfig
=
BitsAndBytesConfig
=
CompressedTensorsConfig
=
DeepSpeedFPConfig
=
(
DeepSpeedFP
Config
ExpertsInt8
Config
)
=
ExpertsInt8Config
=
FBGEMMFp8Config
=
GGUFConfig
=
GPTQMarlin24Config
=
(
)
=
FBGEMMFp8Config
=
GGUFConfig
=
GPTQMarlin24Config
=
MarlinConfig
=
QQQConfig
=
(
Marlin
Config
Int8Tpu
Config
)
=
QQQConfig
=
Int8TpuConfig
=
DummyConfig
)
=
DummyConfig
from
sglang.srt.layers.quantization.awq
import
AWQConfig
from
sglang.srt.layers.quantization.awq
import
AWQConfig
,
AWQMarlinConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.blockwise_int8
import
BlockInt8Config
from
sglang.srt.layers.quantization.blockwise_int8
import
BlockInt8Config
from
sglang.srt.layers.quantization.compressed_tensors.compressed_tensors
import
(
from
sglang.srt.layers.quantization.compressed_tensors.compressed_tensors
import
(
...
@@ -63,10 +59,7 @@ from sglang.srt.layers.quantization.modelopt_quant import (
...
@@ -63,10 +59,7 @@ from sglang.srt.layers.quantization.modelopt_quant import (
)
)
from
sglang.srt.layers.quantization.moe_wna16
import
MoeWNA16Config
from
sglang.srt.layers.quantization.moe_wna16
import
MoeWNA16Config
from
sglang.srt.layers.quantization.qoq
import
QoQConfig
from
sglang.srt.layers.quantization.qoq
import
QoQConfig
from
sglang.srt.layers.quantization.utils
import
(
from
sglang.srt.layers.quantization.utils
import
get_linear_quant_method
get_dynamic_override
,
get_linear_quant_method
,
)
from
sglang.srt.layers.quantization.w4afp8
import
W4AFp8Config
from
sglang.srt.layers.quantization.w4afp8
import
W4AFp8Config
from
sglang.srt.layers.quantization.w8a8_fp8
import
W8A8Fp8Config
from
sglang.srt.layers.quantization.w8a8_fp8
import
W8A8Fp8Config
from
sglang.srt.layers.quantization.w8a8_int8
import
W8A8Int8Config
from
sglang.srt.layers.quantization.w8a8_int8
import
W8A8Int8Config
...
@@ -237,7 +230,6 @@ def monkey_patch_quant_configs():
...
@@ -237,7 +230,6 @@ def monkey_patch_quant_configs():
setattr
(
GPTQMarlinConfig
,
"get_quant_method"
,
gptq_get_quant_method
)
setattr
(
GPTQMarlinConfig
,
"get_quant_method"
,
gptq_get_quant_method
)
setattr
(
GPTQConfig
,
"get_quant_method"
,
gptq_get_quant_method
)
setattr
(
GPTQConfig
,
"get_quant_method"
,
gptq_get_quant_method
)
monkey_patch_moe_apply
(
AWQMoEMethod
)
monkey_patch_moe_apply
(
GPTQMarlinMoEMethod
)
monkey_patch_moe_apply
(
GPTQMarlinMoEMethod
)
monkey_patch_moe_apply
(
CompressedTensorsW8A8Fp8MoEMethod
)
monkey_patch_moe_apply
(
CompressedTensorsW8A8Fp8MoEMethod
)
monkey_patch_moe_apply
(
CompressedTensorsWNA16MoEMethod
)
monkey_patch_moe_apply
(
CompressedTensorsWNA16MoEMethod
)
...
...
python/sglang/srt/layers/quantization/awq.py
View file @
1f76fc87
...
@@ -2,21 +2,52 @@
...
@@ -2,21 +2,52 @@
from
__future__
import
annotations
from
__future__
import
annotations
import
logging
import
logging
from
typing
import
Any
,
Dict
,
List
,
Optional
import
warnings
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
import
torch
from
sglang.srt.layers.linear
import
LinearBase
,
set_weight_attrs
from
sglang.srt.layers.parameter
import
GroupQuantScaleParameter
,
PackedvLLMParameter
from
sglang.srt.layers.parameter
import
GroupQuantScaleParameter
,
PackedvLLMParameter
from
sglang.srt.layers.quantization.base_config
import
(
from
sglang.srt.layers.quantization.base_config
import
(
FusedMoEMethodBase
,
LinearMethodBase
,
LinearMethodBase
,
QuantizationConfig
,
QuantizationConfig
,
QuantizeMethodBase
,
)
)
from
sglang.srt.layers.quantization.marlin_utils
import
(
apply_awq_marlin_linear
,
awq_to_marlin_zero_points
,
check_marlin_supported
,
check_marlin_supports_layer
,
check_moe_marlin_supports_layer
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_moe_permute_scales
,
marlin_permute_scales
,
moe_awq_to_marlin_zero_points
,
verify_marlin_supported
,
verify_marlin_supports_shape
,
)
from
sglang.srt.layers.quantization.scalar_type
import
scalar_types
from
sglang.srt.layers.quantization.unquant
import
UnquantizedLinearMethod
from
sglang.srt.layers.quantization.unquant
import
UnquantizedLinearMethod
from
sglang.srt.layers.quantization.utils
import
replace_parameter
try
:
from
vllm
import
_custom_ops
as
ops
warnings
.
warn
(
f
"Using kernels directly from vllm. This might lead to performance degradation or "
f
"missing functionalities as certain kernels may not be optimized. "
)
except
ImportError
:
ops
=
None
from
sglang.srt.utils
import
is_cuda
from
sglang.srt.utils
import
is_cuda
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
if
_is_cuda
:
if
_is_cuda
:
from
sgl_kernel
import
awq_dequantize
from
sgl_kernel
import
awq_dequantize
,
fused_marlin_moe
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -103,6 +134,176 @@ class AWQConfig(QuantizationConfig):
...
@@ -103,6 +134,176 @@ class AWQConfig(QuantizationConfig):
return
None
return
None
class
AWQMarlinConfig
(
QuantizationConfig
):
"""Config class for AWQ Marlin"""
# num_bits -> type
TYPE_MAP
=
{
4
:
scalar_types
.
uint4
,
8
:
scalar_types
.
uint8
,
}
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
zero_point
:
bool
,
lm_head_quantized
:
bool
,
modules_to_not_convert
:
Optional
[
list
[
str
]],
full_config
:
dict
[
str
,
Any
],
)
->
None
:
super
().
__init__
()
self
.
pack_factor
=
32
//
weight_bits
# packed into int32
self
.
group_size
=
group_size
self
.
zero_point
=
zero_point
self
.
lm_head_quantized
=
lm_head_quantized
self
.
weight_bits
=
weight_bits
self
.
modules_to_not_convert
=
modules_to_not_convert
or
[]
self
.
full_config
=
full_config
if
self
.
weight_bits
not
in
self
.
TYPE_MAP
:
raise
ValueError
(
f
"Unsupported num_bits =
{
self
.
weight_bits
}
. "
f
"Supported num_bits =
{
self
.
TYPE_MAP
.
keys
()
}
"
)
self
.
quant_type
=
self
.
TYPE_MAP
[
self
.
weight_bits
]
verify_marlin_supported
(
self
.
quant_type
,
group_size
=
self
.
group_size
,
has_zp
=
self
.
zero_point
)
def
__repr__
(
self
)
->
str
:
return
(
f
"AWQMarlinConfig(quant_type=
{
self
.
quant_type
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"zero_point=
{
self
.
zero_point
}
, "
f
"lm_head_quantized=
{
self
.
lm_head_quantized
}
, "
f
"modules_to_not_convert=
{
self
.
modules_to_not_convert
}
)"
)
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"awq_marlin"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
list
[
torch
.
dtype
]:
return
[
torch
.
half
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
list
[
str
]:
return
[
"quantize_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
dict
[
str
,
Any
])
->
AWQMarlinConfig
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"bits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
zero_point
=
cls
.
get_from_keys
(
config
,
[
"zero_point"
])
lm_head_quantized
=
cls
.
get_from_keys_or
(
config
,
[
"lm_head"
],
default
=
False
)
modules_to_not_convert
=
cls
.
get_from_keys_or
(
config
,
[
"modules_to_not_convert"
],
None
)
return
cls
(
weight_bits
,
group_size
,
zero_point
,
lm_head_quantized
,
modules_to_not_convert
,
config
,
)
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
can_convert
=
cls
.
is_awq_marlin_compatible
(
hf_quant_cfg
)
is_valid_user_quant
=
(
user_quant
is
None
or
user_quant
==
"marlin"
or
user_quant
==
"awq_marlin"
)
if
can_convert
and
is_valid_user_quant
:
msg
=
(
"The model is convertible to {} during runtime."
" Using {} kernel."
.
format
(
cls
.
get_name
(),
cls
.
get_name
())
)
logger
.
info
(
msg
)
return
cls
.
get_name
()
if
can_convert
and
user_quant
==
"awq"
:
logger
.
info
(
"Detected that the model can run with awq_marlin"
", however you specified quantization=awq explicitly,"
" so forcing awq. Use quantization=awq_marlin for"
" faster inference"
)
return
None
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
QuantizeMethodBase
]:
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
if
isinstance
(
layer
,
LinearBase
)
or
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
):
if
is_layer_skipped_awq
(
prefix
,
self
.
modules_to_not_convert
):
return
UnquantizedLinearMethod
()
# Check if the layer is supported by AWQMarlin.
if
not
check_marlin_supports_layer
(
layer
,
self
.
group_size
):
logger
.
warning_once
(
"Layer '%s' is not supported by AWQMarlin. Falling back to unoptimized AWQ kernels."
,
# noqa: E501
prefix
,
)
return
AWQConfig
.
from_config
(
self
.
full_config
).
get_quant_method
(
layer
,
prefix
)
return
AWQMarlinLinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
from
sglang.srt.layers.quantization.moe_wna16
import
MoeWNA16Config
if
not
check_moe_marlin_supports_layer
(
layer
,
self
.
group_size
):
logger
.
warning_once
(
f
"Layer '
{
prefix
}
' is not supported by AWQMoeMarlin. "
"Falling back to Moe WNA16 kernels."
)
return
MoeWNA16Config
.
from_config
(
self
.
full_config
).
get_quant_method
(
layer
,
prefix
)
return
AWQMoEMethod
(
self
)
return
None
@
classmethod
def
is_awq_marlin_compatible
(
cls
,
quant_config
:
dict
[
str
,
Any
]):
# Extract data from quant config.
quant_method
=
quant_config
.
get
(
"quant_method"
,
""
).
lower
()
num_bits
=
quant_config
.
get
(
"bits"
)
group_size
=
quant_config
.
get
(
"group_size"
)
zero_point
=
quant_config
.
get
(
"zero_point"
)
if
not
_is_cuda
:
return
False
if
quant_method
!=
"awq"
:
return
False
# If we cannot find the info needed in the config, cannot convert.
if
num_bits
is
None
or
group_size
is
None
or
zero_point
is
None
:
return
False
if
num_bits
not
in
cls
.
TYPE_MAP
:
return
False
return
check_marlin_supported
(
quant_type
=
cls
.
TYPE_MAP
[
num_bits
],
group_size
=
group_size
,
has_zp
=
zero_point
)
class
AWQLinearMethod
(
LinearMethodBase
):
class
AWQLinearMethod
(
LinearMethodBase
):
"""Linear method for AWQ.
"""Linear method for AWQ.
...
@@ -204,3 +405,382 @@ class AWQLinearMethod(LinearMethodBase):
...
@@ -204,3 +405,382 @@ class AWQLinearMethod(LinearMethodBase):
if
bias
is
not
None
:
if
bias
is
not
None
:
out
.
add_
(
bias
)
out
.
add_
(
bias
)
return
out
.
reshape
(
out_shape
)
return
out
.
reshape
(
out_shape
)
class
AWQMarlinLinearMethod
(
LinearMethodBase
):
"""Linear method for AWQ Marlin.
Args:
quant_config: The AWQ Marlin quantization config.
"""
def
__init__
(
self
,
quant_config
:
AWQMarlinConfig
)
->
None
:
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
list
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
)
->
None
:
del
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
# Normalize group_size
if
self
.
quant_config
.
group_size
!=
-
1
:
group_size
=
self
.
quant_config
.
group_size
else
:
group_size
=
input_size
verify_marlin_supports_shape
(
output_size_per_partition
=
output_size_per_partition
,
input_size_per_partition
=
input_size_per_partition
,
input_size
=
input_size
,
group_size
=
group_size
,
)
qweight
=
PackedvLLMParameter
(
data
=
torch
.
empty
(
input_size_per_partition
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
),
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
weight_loader
=
weight_loader
,
)
num_groups
=
input_size_per_partition
//
group_size
qzeros
=
PackedvLLMParameter
(
data
=
torch
.
empty
(
num_groups
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
),
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
weight_loader
=
weight_loader
,
)
scales
=
GroupQuantScaleParameter
(
data
=
torch
.
empty
(
num_groups
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
input_dim
=
0
,
output_dim
=
1
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
layer
.
register_parameter
(
"scales"
,
scales
)
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
num_groups
=
num_groups
# TODO: Update this docs
# Checkpoints are serialized in AutoAWQ format, which is different from the
# marlin format. This function is called after the weights are loaded.
# Here, we handle the repacking
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
device
=
layer
.
qweight
.
device
layer
.
qweight
=
torch
.
nn
.
Parameter
(
layer
.
qweight
.
data
,
requires_grad
=
False
)
layer
.
qzeros
=
torch
.
nn
.
Parameter
(
layer
.
qzeros
.
data
,
requires_grad
=
False
)
layer
.
scales
=
torch
.
nn
.
Parameter
(
layer
.
scales
.
data
,
requires_grad
=
False
)
# Allocate marlin workspace
layer
.
workspace
=
marlin_make_workspace
(
device
)
# Repack weights from AWQ format to marlin format.
marlin_qweight
=
ops
.
awq_marlin_repack
(
layer
.
qweight
,
size_k
=
layer
.
input_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
num_bits
=
self
.
quant_config
.
quant_type
.
size_bits
,
)
replace_parameter
(
layer
,
"qweight"
,
marlin_qweight
)
# Permute scales from AWQ format to marlin format.
marlin_scales
=
marlin_permute_scales
(
layer
.
scales
,
size_k
=
layer
.
input_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
group_size
=
self
.
quant_config
.
group_size
,
)
replace_parameter
(
layer
,
"scales"
,
marlin_scales
)
# Permute zero-points from AWQ format to marlin format.
marlin_zp
=
awq_to_marlin_zero_points
(
layer
.
qzeros
,
size_k
=
layer
.
num_groups
,
size_n
=
layer
.
output_size_per_partition
,
num_bits
=
self
.
quant_config
.
quant_type
.
size_bits
,
)
replace_parameter
(
layer
,
"qzeros"
,
marlin_zp
)
# Not-used
layer
.
g_idx
=
marlin_make_empty_g_idx
(
device
)
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
device
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
apply_awq_marlin_linear
(
input
=
x
,
weight
=
layer
.
qweight
,
weight_scale
=
layer
.
scales
,
weight_zp
=
layer
.
qzeros
,
g_idx
=
layer
.
g_idx
,
g_idx_sort_indices
=
layer
.
g_idx_sort_indices
,
workspace
=
layer
.
workspace
,
quant_type
=
self
.
quant_config
.
quant_type
,
output_size_per_partition
=
layer
.
output_size_per_partition
,
input_size_per_partition
=
layer
.
input_size_per_partition
,
bias
=
bias
,
)
class
AWQMoEMethod
(
FusedMoEMethodBase
):
def
__init__
(
self
,
quant_config
:
AWQMarlinConfig
):
self
.
quant_config
=
quant_config
if
self
.
quant_config
.
weight_bits
!=
4
:
raise
ValueError
(
"AWQMoEMethod only supports 4bit now."
)
self
.
quant_type
=
scalar_types
.
uint4
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
# Delay the import to avoid circular dependency
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoeWeightScaleSupported
extra_weight_attrs
.
update
(
{
"is_transposed"
:
True
,
"quant_method"
:
FusedMoeWeightScaleSupported
.
GROUP
.
value
,
}
)
w13_qweight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
2
*
intermediate_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_qweight"
,
w13_qweight
)
set_weight_attrs
(
w13_qweight
,
extra_weight_attrs
)
w2_qweight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
intermediate_size_per_partition
,
hidden_size
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_qweight"
,
w2_qweight
)
set_weight_attrs
(
w2_qweight
,
extra_weight_attrs
)
num_groups_w13
=
hidden_size
//
self
.
quant_config
.
group_size
num_groups_w2
=
intermediate_size_per_partition
//
self
.
quant_config
.
group_size
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
w13_scales
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
num_groups_w13
,
intermediate_size_per_partition
*
2
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_scales"
,
w13_scales
)
set_weight_attrs
(
w13_scales
,
extra_weight_attrs
)
w2_scales
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
num_groups_w2
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_scales"
,
w2_scales
)
set_weight_attrs
(
w2_scales
,
extra_weight_attrs
)
# WEIGHT_ZERO_POINT
# Allocate 2 zero points for w1 and w3 respectively.
w13_qzeros
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
num_groups_w13
,
2
*
intermediate_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_qzeros"
,
w13_qzeros
)
set_weight_attrs
(
w13_qzeros
,
extra_weight_attrs
)
w2_qzeros
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
num_groups_w2
,
hidden_size
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_qzeros"
,
w2_qzeros
)
set_weight_attrs
(
w2_qzeros
,
extra_weight_attrs
)
device
=
layer
.
w13_qweight
.
device
layer
.
workspace
=
marlin_make_workspace
(
device
,
4
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
num_experts
=
layer
.
w13_qweight
.
shape
[
0
]
device
=
layer
.
w13_qweight
.
device
layer
.
w13_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
layer
.
w2_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
marlin_w13_qweight
=
ops
.
awq_marlin_moe_repack
(
layer
.
w13_qweight
,
layer
.
w13_g_idx_sort_indices
,
size_k
=
layer
.
w13_qweight
.
shape
[
1
],
size_n
=
layer
.
w13_qweight
.
shape
[
2
]
*
self
.
quant_config
.
pack_factor
,
num_bits
=
self
.
quant_config
.
weight_bits
,
)
replace_parameter
(
layer
,
"w13_qweight"
,
marlin_w13_qweight
)
marlin_w2_qweight
=
ops
.
awq_marlin_moe_repack
(
layer
.
w2_qweight
,
layer
.
w2_g_idx_sort_indices
,
size_k
=
layer
.
w2_qweight
.
shape
[
1
],
size_n
=
layer
.
w2_qweight
.
shape
[
2
]
*
self
.
quant_config
.
pack_factor
,
num_bits
=
self
.
quant_config
.
weight_bits
,
)
replace_parameter
(
layer
,
"w2_qweight"
,
marlin_w2_qweight
)
# hidden_size->intermediate_size
marlin_w13_scales
=
marlin_moe_permute_scales
(
s
=
layer
.
w13_scales
,
size_k
=
layer
.
intermediate_size_per_partition
,
size_n
=
layer
.
w13_scales
.
shape
[
2
],
group_size
=
self
.
quant_config
.
group_size
,
)
replace_parameter
(
layer
,
"w13_scales"
,
marlin_w13_scales
)
marlin_w2_scales
=
marlin_moe_permute_scales
(
s
=
layer
.
w2_scales
,
size_k
=
layer
.
intermediate_size_per_partition
,
size_n
=
layer
.
w2_scales
.
shape
[
2
],
group_size
=
self
.
quant_config
.
group_size
,
)
replace_parameter
(
layer
,
"w2_scales"
,
marlin_w2_scales
)
marlin_w13_zp
=
moe_awq_to_marlin_zero_points
(
layer
.
w13_qzeros
,
size_k
=
layer
.
w13_qzeros
.
shape
[
1
],
size_n
=
layer
.
w13_qzeros
.
shape
[
2
]
*
self
.
quant_config
.
pack_factor
,
num_bits
=
self
.
quant_config
.
weight_bits
,
)
replace_parameter
(
layer
,
"w13_qzeros"
,
marlin_w13_zp
)
marlin_w2_zp
=
moe_awq_to_marlin_zero_points
(
layer
.
w2_qzeros
,
size_k
=
layer
.
w2_qzeros
.
shape
[
1
],
size_n
=
layer
.
w2_qzeros
.
shape
[
2
]
*
self
.
quant_config
.
pack_factor
,
num_bits
=
self
.
quant_config
.
weight_bits
,
)
replace_parameter
(
layer
,
"w2_qzeros"
,
marlin_w2_zp
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
int
=
0
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
# Delay the import to avoid circular dependency
from
sglang.srt.layers.moe.topk
import
select_experts
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
assert
(
scoring_func
==
"softmax"
),
"Only softmax score func is supported for now."
# The input must currently be float16
orig_dtype
=
x
.
dtype
x
=
x
.
half
()
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
top_k
=
top_k
,
use_grouped_topk
=
use_grouped_topk
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_fused_shared_experts
=
num_fused_shared_experts
,
custom_routing_function
=
custom_routing_function
,
correction_bias
=
correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
)
return
fused_marlin_moe
(
x
,
layer
.
w13_qweight
,
layer
.
w2_qweight
,
layer
.
w13_scales
,
layer
.
w2_scales
,
router_logits
,
topk_weights
,
topk_ids
,
sort_indices1
=
layer
.
w13_g_idx_sort_indices
,
sort_indices2
=
layer
.
w2_g_idx_sort_indices
,
w1_zeros
=
layer
.
w13_qzeros
,
w2_zeros
=
layer
.
w2_qzeros
,
num_bits
=
self
.
quant_config
.
weight_bits
,
).
to
(
orig_dtype
)
python/sglang/srt/layers/quantization/utils.py
View file @
1f76fc87
...
@@ -11,7 +11,7 @@ import numpy
...
@@ -11,7 +11,7 @@ import numpy
import
torch
import
torch
from
sglang.srt.layers.quantization.fp8_kernel
import
scaled_fp8_quant
from
sglang.srt.layers.quantization.fp8_kernel
import
scaled_fp8_quant
from
sglang.srt.layers.quantization.scalar_type
import
ScalarType
from
sglang.srt.layers.quantization.scalar_type
import
ScalarType
,
scalar_types
from
sglang.srt.utils
import
cpu_has_amx_support
,
is_cpu
,
is_cuda
,
is_npu
from
sglang.srt.utils
import
cpu_has_amx_support
,
is_cpu
,
is_cuda
,
is_npu
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -247,6 +247,36 @@ def get_pack_factor(num_bits):
...
@@ -247,6 +247,36 @@ def get_pack_factor(num_bits):
return
32
//
num_bits
return
32
//
num_bits
def
permute_rows
(
q_w
:
torch
.
Tensor
,
w_ref
:
torch
.
Tensor
,
group_size
:
int
,
test_perm
:
Optional
[
torch
.
Tensor
]
=
None
,
):
assert
q_w
.
shape
==
w_ref
.
shape
orig_device
=
q_w
.
device
k_size
,
_
=
q_w
.
shape
g_idx
=
torch
.
zeros
((
k_size
,),
dtype
=
torch
.
int32
)
for
i
in
range
(
k_size
):
g_idx
[
i
]
=
i
//
group_size
# Simulate act_order by doing a random permutation on K
rand_perm
=
test_perm
if
test_perm
is
not
None
else
torch
.
randperm
(
k_size
)
g_idx
=
g_idx
[
rand_perm
].
contiguous
()
q_w
=
q_w
[
rand_perm
,
:].
contiguous
()
w_ref
=
w_ref
[
rand_perm
,
:].
contiguous
()
return
(
w_ref
.
to
(
device
=
orig_device
),
q_w
.
to
(
device
=
orig_device
),
g_idx
.
to
(
device
=
orig_device
),
rand_perm
.
to
(
device
=
orig_device
),
)
def
pack_cols
(
def
pack_cols
(
q_w
:
torch
.
Tensor
,
q_w
:
torch
.
Tensor
,
num_bits
:
int
,
num_bits
:
int
,
...
@@ -399,3 +429,56 @@ def quantize_weights(
...
@@ -399,3 +429,56 @@ def quantize_weights(
w_s
if
group_size
is
not
None
else
None
,
w_s
if
group_size
is
not
None
else
None
,
maybe_w_zp
,
maybe_w_zp
,
)
)
SUPPORTED_GPTQ_QUANT_TYPES
=
[
scalar_types
.
uint4b8
,
scalar_types
.
uint8b128
]
SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
def
gptq_quantize_weights
(
w
:
torch
.
Tensor
,
quant_type
:
ScalarType
,
group_size
:
int
,
act_order
:
bool
,
test_perm
:
Optional
[
torch
.
Tensor
]
=
None
,
):
size_k
,
_
=
w
.
shape
assert
w
.
is_floating_point
(),
"w must be float"
assert
(
quant_type
in
SUPPORTED_GPTQ_QUANT_TYPES
),
f
"Unsupported gptq type =
{
quant_type
}
"
assert
group_size
in
SUPPORTED_GROUP_SIZES
+
[
size_k
],
f
"Unsupported groupsize =
{
group_size
}
"
w_ref
,
w_q
,
w_s
,
_
=
quantize_weights
(
w
,
quant_type
,
group_size
)
# Apply act_order
g_idx
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w
.
device
)
rand_perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w
.
device
)
if
act_order
:
assert
(
group_size
<
size_k
),
"For act_order, groupsize = {} must be less than size_k = {}"
.
format
(
group_size
,
size_k
)
w_ref
,
w_q
,
g_idx
,
rand_perm
=
permute_rows
(
w_q
,
w_ref
,
group_size
,
test_perm
)
return
w_ref
,
w_q
,
w_s
,
g_idx
,
rand_perm
def
sort_weights
(
q_w
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
):
orig_device
=
q_w
.
device
sort_indices
=
torch
.
argsort
(
g_idx
).
to
(
dtype
=
torch
.
int32
)
# Sort based on g_idx
g_idx
=
g_idx
[
sort_indices
].
contiguous
()
q_w
=
q_w
[
sort_indices
,
:].
contiguous
()
return
(
q_w
.
to
(
device
=
orig_device
),
g_idx
.
to
(
device
=
orig_device
),
sort_indices
.
to
(
device
=
orig_device
),
)
python/sglang/srt/models/deepseek_v2.py
View file @
1f76fc87
...
@@ -355,6 +355,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -355,6 +355,7 @@ class DeepseekV2MoE(nn.Module):
self
.
shared_experts
.
gate_up_proj
.
quant_method
,
"quant_config"
self
.
shared_experts
.
gate_up_proj
.
quant_method
,
"quant_config"
)
and
self
.
shared_experts
.
gate_up_proj
.
quant_method
.
quant_config
.
get_name
()
in
{
)
and
self
.
shared_experts
.
gate_up_proj
.
quant_method
.
quant_config
.
get_name
()
in
{
"awq"
,
"awq"
,
"awq_marlin"
,
"moe_wna16"
,
"moe_wna16"
,
}
}
self
.
shared_experts_is_int8
=
(
self
.
shared_experts_is_int8
=
(
...
@@ -929,7 +930,7 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -929,7 +930,7 @@ class DeepseekV2AttentionMLA(nn.Module):
has_fused_proj
has_fused_proj
and
hasattr
(
self
.
fused_qkv_a_proj_with_mqa
.
quant_method
,
"quant_config"
)
and
hasattr
(
self
.
fused_qkv_a_proj_with_mqa
.
quant_method
,
"quant_config"
)
and
self
.
fused_qkv_a_proj_with_mqa
.
quant_method
.
quant_config
.
get_name
()
and
self
.
fused_qkv_a_proj_with_mqa
.
quant_method
.
quant_config
.
get_name
()
in
{
"awq"
,
"moe_wna16"
}
in
{
"awq"
,
"awq_marlin"
,
"moe_wna16"
}
)
)
self
.
use_min_latency_fused_a_gemm
=
(
self
.
use_min_latency_fused_a_gemm
=
(
has_fused_proj
has_fused_proj
...
@@ -2551,6 +2552,7 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -2551,6 +2552,7 @@ class DeepseekV2ForCausalLM(nn.Module):
cat_dim
=
0
cat_dim
=
0
if
self
.
quant_config
is
not
None
and
(
if
self
.
quant_config
is
not
None
and
(
self
.
quant_config
.
get_name
()
==
"awq"
self
.
quant_config
.
get_name
()
==
"awq"
or
self
.
quant_config
.
get_name
()
==
"awq_marlin"
or
self
.
quant_config
.
get_name
()
==
"moe_wna16"
or
self
.
quant_config
.
get_name
()
==
"moe_wna16"
):
):
cat_dim
=
1
cat_dim
=
1
...
...
python/sglang/test/test_marlin_moe.py
0 → 100644
View file @
1f76fc87
import
types
from
typing
import
Optional
import
pytest
import
torch
from
sgl_kernel
import
fused_marlin_moe
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.quantization.scalar_type
import
ScalarType
,
scalar_types
from
sglang.test.test_marlin_utils
import
awq_marlin_quantize
,
marlin_quantize
def
stack_and_dev
(
tensors
:
list
[
torch
.
Tensor
]):
dev
=
tensors
[
0
].
device
return
torch
.
stack
(
tensors
,
dim
=
0
).
to
(
dev
)
def
torch_experts
(
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
quant_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
apply_router_weights_on_input
:
bool
=
False
,
)
->
torch
.
Tensor
:
assert
(
global_num_experts
==
-
1
or
(
global_num_experts
==
w1
.
shape
[
0
]
and
expert_map
is
None
)
or
(
expert_map
is
not
None
and
global_num_experts
==
expert_map
.
shape
[
0
])
)
M
,
K
=
a
.
shape
topk
=
topk_ids
.
shape
[
1
]
print
(
"quant_dtype"
,
quant_dtype
)
# exit(0)
if
apply_router_weights_on_input
:
assert
topk
==
1
a
=
a
*
topk_weight
.
to
(
a
.
dtype
)
a
=
a
.
view
(
M
,
-
1
,
K
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
K
)
out
=
torch
.
zeros
(
M
*
topk
,
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
num_experts
=
w1
.
shape
[
0
]
topk_ids
=
topk_ids
.
view
(
-
1
)
if
expert_map
is
not
None
:
topk_ids
=
expert_map
[
topk_ids
]
f32
=
torch
.
float32
for
i
in
range
(
num_experts
):
mask
=
topk_ids
==
i
if
mask
.
sum
():
if
quant_dtype
is
None
:
tmp1
=
a
[
mask
]
@
w1
[
i
].
transpose
(
0
,
1
)
tmp2
=
SiluAndMul
()(
tmp1
)
out
[
mask
]
=
tmp2
@
w2
[
i
].
transpose
(
0
,
1
)
if
apply_router_weights_on_input
:
return
out
else
:
return
(
(
out
.
view
(
M
,
-
1
,
w2
.
shape
[
1
]).
to
(
f32
)
*
topk_weight
.
view
(
M
,
-
1
,
1
))
.
sum
(
dim
=
1
)
.
to
(
out
.
dtype
)
)
def
torch_moe
(
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
score
:
torch
.
Tensor
,
topk
:
int
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
score
=
torch
.
softmax
(
score
,
dim
=-
1
,
dtype
=
torch
.
float32
)
topk_weight
,
topk_ids
=
torch
.
topk
(
score
,
topk
)
return
torch_experts
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
,
global_num_experts
,
expert_map
)
def
marlin_moe_generate_valid_test_cases
():
import
itertools
m_list
=
[
1
,
123
,
666
]
n_list
=
[
128
,
1024
]
k_list
=
[
256
,
2048
]
e_list
=
[
4
,
12
]
topk_list
=
[
2
,
3
]
dtype_list
=
[
torch
.
half
,
torch
.
bfloat16
]
group_size_list
=
[
128
]
act_order_list
=
[
True
,
False
]
quant_type_list
=
[
scalar_types
.
uint4
,
scalar_types
.
uint4b8
,
]
is_k_full_list
=
[
True
,
False
]
all_combinations
=
itertools
.
product
(
m_list
,
n_list
,
k_list
,
e_list
,
topk_list
,
dtype_list
,
group_size_list
,
act_order_list
,
quant_type_list
,
is_k_full_list
,
)
def
is_invalid
(
m
,
n
,
k
,
e
,
topk
,
dtype
,
group_size
,
act_order
,
quant_type
,
is_k_full
):
# Filter act_order
if
act_order
:
if
group_size
in
(
-
1
,
k
,
n
):
return
False
if
quant_type
not
in
[
scalar_types
.
uint4b8
]:
return
False
elif
not
is_k_full
:
return
False
return
True
cases
=
[]
for
case
in
all_combinations
:
if
is_invalid
(
*
case
):
cases
.
append
(
case
)
return
cases
@
pytest
.
mark
.
flaky
(
reruns
=
2
)
@
pytest
.
mark
.
parametrize
(
(
"m, n, k, e, topk, dtype, group_size,"
"act_order, quant_type, is_k_full"
),
marlin_moe_generate_valid_test_cases
(),
)
def
test_fused_marlin_moe
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
group_size
:
int
,
act_order
:
bool
,
quant_type
:
ScalarType
,
is_k_full
:
bool
,
):
if
not
torch
.
cuda
.
is_available
():
pytest
.
skip
(
"CUDA device not available"
)
torch
.
manual_seed
(
0
)
has_zp
=
quant_type
in
[
scalar_types
.
uint4
,
scalar_types
.
uint8
]
# Filter act_order
if
act_order
:
if
group_size
==
-
1
:
return
if
group_size
in
(
k
,
n
):
return
if
has_zp
:
return
else
:
if
not
is_k_full
:
return
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
20
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
20
e_map
=
None
w_ref1_l
=
[]
qweight1_l
=
[]
scales1_l
=
[]
zeros1_l
=
[]
g_idx1_l
=
[]
sort_indices1_l
=
[]
for
i
in
range
(
w1
.
shape
[
0
]):
if
has_zp
:
w_ref1
,
qweight1
,
scales1
,
zeros1
=
awq_marlin_quantize
(
w1
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
)
w_ref1_l
.
append
(
w_ref1
.
T
)
qweight1_l
.
append
(
qweight1
)
scales1_l
.
append
(
scales1
)
zeros1_l
.
append
(
zeros1
)
else
:
test_perm
=
torch
.
randperm
(
k
)
w_ref1
,
qweight1
,
scales1
,
g_idx1
,
sort_indices1
,
_
=
marlin_quantize
(
w1
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
,
act_order
,
test_perm
)
w_ref1_l
.
append
(
w_ref1
.
T
)
qweight1_l
.
append
(
qweight1
)
scales1_l
.
append
(
scales1
)
g_idx1_l
.
append
(
g_idx1
)
sort_indices1_l
.
append
(
sort_indices1
)
w_ref1
=
stack_and_dev
(
w_ref1_l
)
qweight1
=
stack_and_dev
(
qweight1_l
).
contiguous
()
scales1
=
stack_and_dev
(
scales1_l
)
g_idx1
=
stack_and_dev
(
g_idx1_l
)
if
g_idx1_l
else
None
zeros1
=
stack_and_dev
(
zeros1_l
)
if
zeros1_l
else
None
sort_indices1
=
stack_and_dev
(
sort_indices1_l
)
if
sort_indices1_l
else
None
w_ref2_l
=
[]
qweight2_l
=
[]
scales2_l
=
[]
zeros2_l
=
[]
g_idx2_l
=
[]
sort_indices2_l
=
[]
for
i
in
range
(
w2
.
shape
[
0
]):
if
has_zp
:
w_ref2
,
qweight2
,
scales2
,
zeros2
=
awq_marlin_quantize
(
w2
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
)
w_ref2_l
.
append
(
w_ref2
.
T
)
qweight2_l
.
append
(
qweight2
)
scales2_l
.
append
(
scales2
)
zeros2_l
.
append
(
zeros2
)
else
:
test_perm
=
torch
.
randperm
(
n
)
w_ref2
,
qweight2
,
scales2
,
g_idx2
,
sort_indices2
,
_
=
marlin_quantize
(
w2
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
,
act_order
,
test_perm
)
w_ref2_l
.
append
(
w_ref2
.
T
)
qweight2_l
.
append
(
qweight2
)
scales2_l
.
append
(
scales2
)
g_idx2_l
.
append
(
g_idx2
)
sort_indices2_l
.
append
(
sort_indices2
)
w_ref2
=
stack_and_dev
(
w_ref2_l
)
qweight2
=
stack_and_dev
(
qweight2_l
).
contiguous
()
scales2
=
stack_and_dev
(
scales2_l
)
g_idx2
=
stack_and_dev
(
g_idx2_l
)
if
g_idx2_l
else
None
zeros2
=
stack_and_dev
(
zeros2_l
)
if
zeros2_l
else
None
sort_indices2
=
stack_and_dev
(
sort_indices2_l
)
if
sort_indices2_l
else
None
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
from
sglang.srt.layers.moe.topk
import
fused_topk_torch_native
topk_weights
,
topk_ids
=
fused_topk_torch_native
(
a
,
score
,
topk
,
False
)
torch_output
=
torch_moe
(
a
,
w_ref1
,
w_ref2
,
score
,
topk
,
expert_map
=
e_map
)
marlin_output
=
fused_marlin_moe
(
a
,
qweight1
,
qweight2
,
scales1
,
scales2
,
score
,
topk_weights
,
topk_ids
,
g_idx1
=
g_idx1
,
g_idx2
=
g_idx2
,
sort_indices1
=
sort_indices1
,
sort_indices2
=
sort_indices2
,
w1_zeros
=
zeros1
,
w2_zeros
=
zeros2
,
num_bits
=
4
,
is_k_full
=
is_k_full
,
)
torch
.
testing
.
assert_close
(
marlin_output
,
torch_output
,
atol
=
5e-2
,
rtol
=
0
)
if
__name__
==
"__main__"
:
# Run the specific test function directly
pytest
.
main
([
__file__
])
python/sglang/test/test_marlin_utils.py
0 → 100644
View file @
1f76fc87
"""
Adapted from
https://github.com/vllm-project/vllm/blob/020f58abcdea65302225663130d08fd8f4dd755a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py
"""
# SPDX-License-Identifier: Apache-2.0
"""Utility functions used for tests and benchmarks"""
from
typing
import
Optional
import
numpy
as
np
import
torch
from
sglang.srt.layers.quantization.marlin_utils
import
(
GPTQ_MARLIN_TILE
,
marlin_permute_scales
,
marlin_zero_points
,
)
from
sglang.srt.layers.quantization.scalar_type
import
ScalarType
from
sglang.srt.layers.quantization.utils
import
(
get_pack_factor
,
gptq_quantize_weights
,
quantize_weights
,
sort_weights
,
)
class
MarlinWorkspace
:
def
__init__
(
self
,
out_features
,
min_thread_n
,
max_parallel
):
assert
(
out_features
%
min_thread_n
==
0
),
"out_features = {} is undivisible by min_thread_n = {}"
.
format
(
out_features
,
min_thread_n
)
max_workspace_size
=
(
out_features
//
min_thread_n
)
*
max_parallel
self
.
scratch
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
def
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
perm
,
tile
=
GPTQ_MARLIN_TILE
):
assert
q_w
.
shape
==
(
size_k
,
size_n
)
assert
size_k
%
tile
==
0
,
f
"size_k =
{
size_k
}
, tile =
{
tile
}
"
assert
size_n
%
tile
==
0
,
f
"size_k =
{
size_n
}
, tile =
{
tile
}
"
# Permute weights to 16x64 marlin tiles
q_w
=
q_w
.
reshape
((
size_k
//
tile
,
tile
,
size_n
//
tile
,
tile
))
q_w
=
q_w
.
permute
((
0
,
2
,
1
,
3
))
q_w
=
q_w
.
reshape
((
size_k
//
tile
,
size_n
*
tile
))
q_w
=
q_w
.
reshape
((
-
1
,
perm
.
numel
()))[:,
perm
].
reshape
(
q_w
.
shape
)
return
q_w
def
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
perm
):
# Permute
q_w
=
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
perm
)
# Pack
pack_factor
=
get_pack_factor
(
num_bits
)
orig_device
=
q_w
.
device
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
np
.
uint32
)
q_packed
=
np
.
zeros
((
q_w
.
shape
[
0
],
q_w
.
shape
[
1
]
//
pack_factor
),
dtype
=
np
.
uint32
)
for
i
in
range
(
pack_factor
):
q_packed
|=
q_w
[:,
i
::
pack_factor
]
<<
num_bits
*
i
q_packed
=
torch
.
from_numpy
(
q_packed
.
astype
(
np
.
int32
)).
to
(
orig_device
)
return
q_packed
def
get_weight_perm
(
num_bits
:
int
):
perm_list
:
list
[
int
]
=
[]
for
i
in
range
(
32
):
perm1
:
list
[
int
]
=
[]
col
=
i
//
4
for
block
in
[
0
,
1
]:
for
row
in
[
2
*
(
i
%
4
),
2
*
(
i
%
4
)
+
1
,
2
*
(
i
%
4
+
4
),
2
*
(
i
%
4
+
4
)
+
1
,
]:
perm1
.
append
(
16
*
row
+
col
+
8
*
block
)
for
j
in
range
(
4
):
perm_list
.
extend
([
p
+
256
*
j
for
p
in
perm1
])
perm
=
np
.
array
(
perm_list
)
if
num_bits
==
4
:
interleave
=
np
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
interleave
=
np
.
array
([
0
,
2
,
1
,
3
])
else
:
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
perm
=
perm
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
perm
=
torch
.
from_numpy
(
perm
)
return
perm
def
marlin_quantize
(
w
:
torch
.
Tensor
,
quant_type
:
ScalarType
,
group_size
:
int
,
act_order
:
bool
,
test_perm
:
Optional
[
torch
.
Tensor
]
=
None
,
):
size_k
,
size_n
=
w
.
shape
num_bits
=
quant_type
.
size_bits
# Normalize group_size
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
# Quantize (and apply act_order if provided)
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
=
gptq_quantize_weights
(
w
,
quant_type
,
group_size
,
act_order
,
test_perm
)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing
sort_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w
.
device
)
if
act_order
:
q_w
,
g_idx
,
sort_indices
=
sort_weights
(
q_w
,
g_idx
)
# Reformat to marlin
weight_perm
=
get_weight_perm
(
num_bits
)
marlin_q_w
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
weight_perm
)
marlin_s
=
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
)
# Create result
res_list
=
[
w_ref
,
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
rand_perm
]
for
i
in
range
(
len
(
res_list
)):
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
return
res_list
def
awq_marlin_quantize
(
w
:
torch
.
Tensor
,
quant_type
:
ScalarType
,
group_size
:
int
):
size_k
,
size_n
=
w
.
shape
# Normalize group_size
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
# Detect num groups
assert
size_k
%
group_size
==
0
num_groups
=
size_k
//
group_size
# Quantize with zp
w_ref
,
q_w
,
s
,
zp
=
quantize_weights
(
w
,
quant_type
,
group_size
,
zero_points
=
True
)
# Reformat to marlin
weight_perm
=
get_weight_perm
(
quant_type
.
size_bits
)
marlin_q_w
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
quant_type
.
size_bits
,
weight_perm
)
marlin_s
=
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
)
marlin_zp
=
marlin_zero_points
(
zp
,
num_groups
,
size_n
,
quant_type
.
size_bits
)
# Create result
res_list
=
[
w_ref
,
marlin_q_w
,
marlin_s
,
marlin_zp
]
for
i
in
range
(
len
(
res_list
)):
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
return
res_list
test/srt/test_gptqmodel_dynamic.py
View file @
1f76fc87
...
@@ -24,7 +24,7 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool):
...
@@ -24,7 +24,7 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool):
set_custom_all_reduce
,
set_custom_all_reduce
,
)
)
from
sglang.srt.distributed.parallel_state
import
monkey_patch_vllm_parallel_state
from
sglang.srt.distributed.parallel_state
import
monkey_patch_vllm_parallel_state
from
sglang.srt.layers.quantization
import
get_dynamic_override
from
sglang.srt.layers.quantization
.utils
import
get_dynamic_override
from
sglang.srt.model_loader
import
get_model
from
sglang.srt.model_loader
import
get_model
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment