Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
83f2d9d4
Unverified
Commit
83f2d9d4
authored
May 20, 2025
by
PGFLMG
Committed by
GitHub
May 19, 2025
Browse files
[QuickFix] fix gptq model initialize (#6429)
parent
6317c5c6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
303 additions
and
8 deletions
+303
-8
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+5
-2
python/sglang/srt/layers/quantization/gptq.py
python/sglang/srt/layers/quantization/gptq.py
+298
-6
No files found.
python/sglang/srt/layers/quantization/__init__.py
View file @
83f2d9d4
...
@@ -25,7 +25,6 @@ try:
...
@@ -25,7 +25,6 @@ try:
from
vllm.model_executor.layers.quantization.gptq
import
GPTQLinearMethod
from
vllm.model_executor.layers.quantization.gptq
import
GPTQLinearMethod
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinLinearMethod
,
GPTQMarlinLinearMethod
,
GPTQMarlinMoEMethod
,
)
)
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
GPTQMarlin24Config
,
GPTQMarlin24Config
,
...
@@ -58,7 +57,11 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
...
@@ -58,7 +57,11 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
CompressedTensorsConfig
,
CompressedTensorsConfig
,
)
)
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
from
sglang.srt.layers.quantization.gptq
import
GPTQConfig
,
GPTQMarlinConfig
from
sglang.srt.layers.quantization.gptq
import
(
GPTQConfig
,
GPTQMarlinConfig
,
GPTQMarlinMoEMethod
,
)
from
sglang.srt.layers.quantization.modelopt_quant
import
(
from
sglang.srt.layers.quantization.modelopt_quant
import
(
ModelOptFp4Config
,
ModelOptFp4Config
,
ModelOptFp8Config
,
ModelOptFp8Config
,
...
...
python/sglang/srt/layers/quantization/gptq.py
View file @
83f2d9d4
import
logging
import
logging
from
fractions
import
Fraction
from
fractions
import
Fraction
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
import
torch
import
torch
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.linear
import
LinearBase
,
set_weight_attrs
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.utils
import
replace_parameter
from
sglang.srt.utils
import
is_cuda
from
sglang.srt.utils
import
is_cuda
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
try
:
try
:
from
vllm
.model_executor.layers.quantization.base_config
import
QuantizeMethodBase
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.gptq
import
GPTQLinearMethod
from
vllm.model_executor.layers.quantization.gptq
import
GPTQLinearMethod
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
,
GPTQMarlinLinearMethod
,
GPTQMarlinLinearMethod
,
GPTQMarlinMoEMethod
,
marlin_moe_permute_scales
,
)
)
from
vllm.model_executor.layers.quantization.marlin
import
MarlinLinearMethod
from
vllm.model_executor.layers.quantization.marlin
import
MarlinLinearMethod
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
...
@@ -27,7 +34,9 @@ try:
...
@@ -27,7 +34,9 @@ try:
except
ImportError
:
except
ImportError
:
VLLM_AVAILABLE
=
False
VLLM_AVAILABLE
=
False
GPTQLinearMethod
=
MarlinLinearMethod
=
QuantizeMethodBase
=
Any
GPTQLinearMethod
=
MarlinLinearMethod
=
Any
FusedMoEMethodBase
=
QuantizeMethodBase
class
scalar_types
:
class
scalar_types
:
uint4b8
=
"uint4b8"
uint4b8
=
"uint4b8"
...
@@ -437,3 +446,286 @@ class MarlinConfig(QuantizationConfig):
...
@@ -437,3 +446,286 @@ class MarlinConfig(QuantizationConfig):
):
):
return
MarlinLinearMethod
(
self
)
return
MarlinLinearMethod
(
self
)
return
None
return
None
class
GPTQMarlinMoEMethod
(
FusedMoEMethodBase
):
"""MoE Marlin method with quantization."""
def
__init__
(
self
,
quant_config
:
GPTQMarlinConfig
)
->
None
:
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
intermediate_size
=
extra_weight_attrs
.
pop
(
"intermediate_size"
)
self
.
is_k_full
=
(
not
self
.
quant_config
.
desc_act
)
or
(
intermediate_size_per_partition
==
intermediate_size
)
if
self
.
quant_config
.
group_size
!=
-
1
:
scales_size13
=
hidden_size
//
self
.
quant_config
.
group_size
w2_scales_size
=
(
intermediate_size
if
self
.
quant_config
.
desc_act
else
intermediate_size_per_partition
)
scales_size2
=
w2_scales_size
//
self
.
quant_config
.
group_size
strategy
=
FusedMoeWeightScaleSupported
.
GROUP
.
value
else
:
scales_size13
=
1
scales_size2
=
1
strategy
=
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
extra_weight_attrs
.
update
({
"quant_method"
:
strategy
,
"is_transposed"
:
True
})
# Fused gate_up_proj (column parallel)
w13_qweight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
//
self
.
quant_config
.
pack_factor
,
2
*
intermediate_size_per_partition
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_qweight"
,
w13_qweight
)
set_weight_attrs
(
w13_qweight
,
extra_weight_attrs
)
# down_proj (row parallel)
w2_qweight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
intermediate_size_per_partition
//
self
.
quant_config
.
pack_factor
,
hidden_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_qweight"
,
w2_qweight
)
set_weight_attrs
(
w2_qweight
,
extra_weight_attrs
)
# up_proj scales
w13_scales
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
scales_size13
,
2
*
intermediate_size_per_partition
,
dtype
=
torch
.
half
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_scales"
,
w13_scales
)
set_weight_attrs
(
w13_scales
,
extra_weight_attrs
)
# down_proj scales
w2_scales
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
scales_size2
,
hidden_size
,
dtype
=
torch
.
half
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_scales"
,
w2_scales
)
set_weight_attrs
(
w2_scales
,
extra_weight_attrs
)
# dont shard the w2 scales when running act order
set_weight_attrs
(
w2_scales
,
{
"load_full_w2"
:
self
.
quant_config
.
desc_act
})
# up_proj scales
w13_qzeros
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
scales_size13
,
2
*
intermediate_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_qzeros"
,
w13_qzeros
)
set_weight_attrs
(
w13_qzeros
,
extra_weight_attrs
)
# down_proj scales
w2_qzeros
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
scales_size2
,
hidden_size
//
self
.
quant_config
.
pack_factor
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_qzeros"
,
w2_qzeros
)
set_weight_attrs
(
w2_qzeros
,
extra_weight_attrs
)
# dont shard the w2 scales when running act order
set_weight_attrs
(
w2_qzeros
,
{
"load_full_w2"
:
self
.
quant_config
.
desc_act
})
w13_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_g_idx"
,
w13_g_idx
)
set_weight_attrs
(
w13_g_idx
,
extra_weight_attrs
)
w2_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
intermediate_size_per_partition
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_g_idx"
,
w2_g_idx
)
set_weight_attrs
(
w2_g_idx
,
extra_weight_attrs
)
w13_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_g_idx_sort_indices"
,
w13_g_idx_sort_indices
)
set_weight_attrs
(
w13_g_idx_sort_indices
,
extra_weight_attrs
)
w2_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
intermediate_size_per_partition
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_g_idx_sort_indices"
,
w2_g_idx_sort_indices
)
set_weight_attrs
(
w2_g_idx_sort_indices
,
extra_weight_attrs
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# Process act_order
if
self
.
quant_config
.
desc_act
:
# Get sorting based on g_idx
num_experts
=
layer
.
w13_g_idx
.
shape
[
0
]
w13_g_idx_sort_indices
=
torch
.
empty_like
(
layer
.
w13_g_idx
)
w2_g_idx_sort_indices
=
torch
.
empty_like
(
layer
.
w2_g_idx
)
w13_sorted_g_idx
=
torch
.
empty_like
(
layer
.
w13_g_idx
)
w2_sorted_g_idx
=
torch
.
empty_like
(
layer
.
w2_g_idx
)
for
e
in
range
(
num_experts
):
w13_g_idx_sort_indices
[
e
]
=
torch
.
argsort
(
layer
.
w13_g_idx
[
e
]).
to
(
torch
.
int32
)
w2_g_idx_sort_indices
[
e
]
=
torch
.
argsort
(
layer
.
w2_g_idx
[
e
]).
to
(
torch
.
int32
)
w13_sorted_g_idx
[
e
]
=
layer
.
w13_g_idx
[
e
][
w13_g_idx_sort_indices
[
e
]]
w2_sorted_g_idx
[
e
]
=
layer
.
w2_g_idx
[
e
][
w2_g_idx_sort_indices
[
e
]]
replace_parameter
(
layer
,
"w13_g_idx"
,
w13_sorted_g_idx
)
replace_parameter
(
layer
,
"w2_g_idx"
,
w2_sorted_g_idx
)
replace_parameter
(
layer
,
"w13_g_idx_sort_indices"
,
w13_g_idx_sort_indices
)
replace_parameter
(
layer
,
"w2_g_idx_sort_indices"
,
w2_g_idx_sort_indices
)
else
:
# Reset g_idx related tensors
num_experts
=
layer
.
w13_g_idx
.
shape
[
0
]
device
=
layer
.
w13_g_idx
.
device
layer
.
w13_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
layer
.
w2_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
layer
.
w13_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
layer
.
w2_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
# Repack weights
marlin_w13_qweight
=
ops
.
gptq_marlin_moe_repack
(
layer
.
w13_qweight
,
layer
.
w13_g_idx_sort_indices
,
layer
.
w13_qweight
.
shape
[
1
]
*
self
.
quant_config
.
pack_factor
,
layer
.
w13_qweight
.
shape
[
2
],
self
.
quant_config
.
quant_type
.
size_bits
,
)
replace_parameter
(
layer
,
"w13_qweight"
,
marlin_w13_qweight
)
marlin_w2_qweight
=
ops
.
gptq_marlin_moe_repack
(
layer
.
w2_qweight
,
layer
.
w2_g_idx_sort_indices
,
layer
.
w2_qweight
.
shape
[
1
]
*
self
.
quant_config
.
pack_factor
,
layer
.
w2_qweight
.
shape
[
2
],
self
.
quant_config
.
quant_type
.
size_bits
,
)
replace_parameter
(
layer
,
"w2_qweight"
,
marlin_w2_qweight
)
# Repack scales
marlin_w13_scales
=
marlin_moe_permute_scales
(
s
=
layer
.
w13_scales
,
size_k
=
layer
.
intermediate_size_per_partition
,
size_n
=
layer
.
w13_scales
.
shape
[
2
],
group_size
=
self
.
quant_config
.
group_size
,
)
replace_parameter
(
layer
,
"w13_scales"
,
marlin_w13_scales
)
marlin_w2_scales
=
marlin_moe_permute_scales
(
s
=
layer
.
w2_scales
,
size_k
=
layer
.
w2_scales
.
shape
[
1
]
*
(
self
.
quant_config
.
group_size
if
self
.
quant_config
.
group_size
!=
-
1
else
self
.
quant_config
.
pack_factor
),
size_n
=
layer
.
w2_scales
.
shape
[
2
],
group_size
=
self
.
quant_config
.
group_size
,
)
replace_parameter
(
layer
,
"w2_scales"
,
marlin_w2_scales
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
# The input must currently be float16
orig_dtype
=
x
.
dtype
x
=
x
.
half
()
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
,
)
return
torch
.
ops
.
vllm
.
fused_marlin_moe
(
x
,
layer
.
w13_qweight
,
layer
.
w2_qweight
,
layer
.
w13_scales
,
layer
.
w2_scales
,
router_logits
,
topk_weights
,
topk_ids
,
g_idx1
=
layer
.
w13_g_idx
,
g_idx2
=
layer
.
w2_g_idx
,
sort_indices1
=
layer
.
w13_g_idx_sort_indices
,
sort_indices2
=
layer
.
w2_g_idx_sort_indices
,
num_bits
=
self
.
quant_config
.
quant_type
.
size_bits
,
is_k_full
=
self
.
is_k_full
,
).
to
(
orig_dtype
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment