Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7ca9934f
Unverified
Commit
7ca9934f
authored
Feb 06, 2025
by
Dipika Sikka
Committed by
GitHub
Feb 06, 2025
Browse files
[Misc] Update w2 scale loading for GPTQMarlinMoE (#12757)
parent
0408efc6
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
8 deletions
+21
-8
tests/weight_loading/models-large.txt
tests/weight_loading/models-large.txt
+2
-0
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+2
-2
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+17
-6
No files found.
tests/weight_loading/models-large.txt
View file @
7ca9934f
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main
compressed-tensors, nm-testing/test-w4a16-mixtral-actorder-group, main
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, gptq-8bit-128g-actorder_True
awq_marlin, casperhansen/deepseek-coder-v2-instruct-awq, main
awq_marlin, casperhansen/deepseek-coder-v2-instruct-awq, main
\ No newline at end of file
vllm/model_executor/layers/fused_moe/layer.py
View file @
7ca9934f
...
@@ -302,8 +302,8 @@ class FusedMoE(torch.nn.Module):
...
@@ -302,8 +302,8 @@ class FusedMoE(torch.nn.Module):
"weight_loader"
:
self
.
weight_loader
,
"weight_loader"
:
self
.
weight_loader
,
}
}
# need full intermediate size pre-sharding for WNA16 act order
# need full intermediate size pre-sharding for WNA16 act order
if
(
self
.
quant_method
.
__class__
.
__name__
==
if
(
self
.
quant_method
.
__class__
.
__name__
"CompressedTensorsWNA16MoEMethod"
):
in
(
"GPTQMarlinMoEMethod"
,
"CompressedTensorsWNA16MoEMethod"
)
)
:
moe_quant_params
[
"intermediate_size_full"
]
=
intermediate_size
moe_quant_params
[
"intermediate_size_full"
]
=
intermediate_size
self
.
quant_method
.
create_weights
(
layer
=
self
,
**
moe_quant_params
)
self
.
quant_method
.
create_weights
(
layer
=
self
,
**
moe_quant_params
)
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
7ca9934f
...
@@ -323,13 +323,18 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -323,13 +323,18 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
params_dtype
:
torch
.
dtype
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
**
extra_weight_attrs
,
):
):
# Currently assuming is_k_full is always True
intermediate_size_full
=
extra_weight_attrs
.
pop
(
# (input size per partition is the same as full input size)
"intermediate_size_full"
)
# Supports only sym for now (no zp)
self
.
is_k_full
=
(
not
self
.
quant_config
.
desc_act
)
or
(
intermediate_size_per_partition
==
intermediate_size_full
)
if
self
.
quant_config
.
group_size
!=
-
1
:
if
self
.
quant_config
.
group_size
!=
-
1
:
scales_size13
=
hidden_size
//
self
.
quant_config
.
group_size
scales_size13
=
hidden_size
//
self
.
quant_config
.
group_size
scales_size2
=
(
intermediate_size_per_partition
//
w2_scales_size
=
(
intermediate_size_full
self
.
quant_config
.
group_size
)
if
self
.
quant_config
.
desc_act
else
intermediate_size_per_partition
)
scales_size2
=
(
w2_scales_size
//
self
.
quant_config
.
group_size
)
strategy
=
FusedMoeWeightScaleSupported
.
GROUP
.
value
strategy
=
FusedMoeWeightScaleSupported
.
GROUP
.
value
else
:
else
:
scales_size13
=
1
scales_size13
=
1
...
@@ -385,6 +390,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -385,6 +390,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
)
)
layer
.
register_parameter
(
"w2_scales"
,
w2_scales
)
layer
.
register_parameter
(
"w2_scales"
,
w2_scales
)
set_weight_attrs
(
w2_scales
,
extra_weight_attrs
)
set_weight_attrs
(
w2_scales
,
extra_weight_attrs
)
# dont shard the w2 scales when running act order
set_weight_attrs
(
w2_scales
,
{
"load_full_w2"
:
self
.
quant_config
.
desc_act
})
# up_proj scales
# up_proj scales
w13_qzeros
=
torch
.
nn
.
Parameter
(
w13_qzeros
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
torch
.
empty
(
num_experts
,
...
@@ -406,6 +414,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -406,6 +414,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
)
)
layer
.
register_parameter
(
"w2_qzeros"
,
w2_qzeros
)
layer
.
register_parameter
(
"w2_qzeros"
,
w2_qzeros
)
set_weight_attrs
(
w2_qzeros
,
extra_weight_attrs
)
set_weight_attrs
(
w2_qzeros
,
extra_weight_attrs
)
# dont shard the w2 scales when running act order
set_weight_attrs
(
w2_qzeros
,
{
"load_full_w2"
:
self
.
quant_config
.
desc_act
})
w13_g_idx
=
torch
.
nn
.
Parameter
(
w13_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
torch
.
empty
(
num_experts
,
num_experts
,
...
@@ -575,4 +586,4 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...
@@ -575,4 +586,4 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
sort_indices1
=
layer
.
w13_g_idx_sort_indices
,
sort_indices1
=
layer
.
w13_g_idx_sort_indices
,
sort_indices2
=
layer
.
w2_g_idx_sort_indices
,
sort_indices2
=
layer
.
w2_g_idx_sort_indices
,
num_bits
=
self
.
quant_config
.
quant_type
.
size_bits
,
num_bits
=
self
.
quant_config
.
quant_type
.
size_bits
,
).
to
(
orig_dtype
)
is_k_full
=
self
.
is_k_full
).
to
(
orig_dtype
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment