Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b1e5afc3
Unverified
Commit
b1e5afc3
authored
Aug 13, 2024
by
Dipika Sikka
Committed by
GitHub
Aug 13, 2024
Browse files
[Misc] Update `awq` and `awq_marlin` to use `vLLMParameters` (#7422)
parent
d3bdfd3a
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
74 additions
and
83 deletions
+74
-83
tests/weight_loading/models.txt
tests/weight_loading/models.txt
+3
-1
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+2
-1
vllm/model_executor/layers/quantization/awq.py
vllm/model_executor/layers/quantization/awq.py
+35
-39
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+34
-42
No files found.
tests/weight_loading/models.txt
View file @
b1e5afc3
...
@@ -13,3 +13,5 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main
...
@@ -13,3 +13,5 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main
compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
awq, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main
vllm/model_executor/layers/linear.py
View file @
b1e5afc3
...
@@ -21,7 +21,8 @@ from vllm.model_executor.utils import set_weight_attrs
...
@@ -21,7 +21,8 @@ from vllm.model_executor.utils import set_weight_attrs
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
WEIGHT_LOADER_V2_SUPPORTED
=
[
WEIGHT_LOADER_V2_SUPPORTED
=
[
"CompressedTensorsLinearMethod"
,
"GPTQMarlinLinearMethod"
"CompressedTensorsLinearMethod"
,
"AWQMarlinLinearMethod"
,
"AWQLinearMethod"
,
"GPTQMarlinLinearMethod"
]
]
...
...
vllm/model_executor/layers/quantization/awq.py
View file @
b1e5afc3
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.parameter
import
(
GroupQuantScaleParameter
,
PackedvLLMParameter
)
class
AWQConfig
(
QuantizationConfig
):
class
AWQConfig
(
QuantizationConfig
):
...
@@ -101,55 +101,51 @@ class AWQLinearMethod(LinearMethodBase):
...
@@ -101,55 +101,51 @@ class AWQLinearMethod(LinearMethodBase):
"weight shape. This can be caused by too large "
"weight shape. This can be caused by too large "
"tensor parallel size."
)
"tensor parallel size."
)
qweight
=
Parameter
(
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
torch
.
empty
(
qweight
=
PackedvLLMParameter
(
data
=
torch
.
empty
(
input_size_per_partition
,
input_size_per_partition
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
),
),
requires_grad
=
False
,
input_dim
=
0
,
)
output_dim
=
1
,
set_weight_attrs
(
packed_dim
=
1
,
qweight
,
{
packed_factor
=
self
.
quant_config
.
pack_factor
,
"input_dim"
:
0
,
weight_loader
=
weight_loader
)
"output_dim"
:
1
,
"packed_dim"
:
1
,
qzeros
=
PackedvLLMParameter
(
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
data
=
torch
.
empty
(
})
qzeros
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
group_size
,
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
),
),
requires_grad
=
False
,
input_dim
=
0
,
)
output_dim
=
1
,
set_weight_attrs
(
packed_dim
=
1
,
qzeros
,
{
packed_factor
=
self
.
quant_config
.
pack_factor
,
"input_dim"
:
0
,
weight_loader
=
weight_loader
)
"output_dim"
:
1
,
"packed_dim"
:
1
,
scales
=
GroupQuantScaleParameter
(
data
=
torch
.
empty
(
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
scales
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
group_size
,
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
,
output_size_per_partition
,
dtype
=
params_dtype
,
dtype
=
params_dtype
,
),
),
requires_grad
=
False
,
input_dim
=
0
,
)
output_dim
=
1
,
set_weight_attrs
(
scales
,
{
weight_loader
=
weight_loader
)
"input_dim"
:
0
,
"output_dim"
:
1
,
})
layer
.
register_parameter
(
"qweight"
,
qweight
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
set_weight_attrs
(
qzeros
,
extra_weight_attrs
)
layer
.
register_parameter
(
"scales"
,
scales
)
layer
.
register_parameter
(
"scales"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
layer
.
qweight
=
torch
.
nn
.
Parameter
(
layer
.
qweight
.
data
,
requires_grad
=
False
)
layer
.
qzeros
=
torch
.
nn
.
Parameter
(
layer
.
qzeros
.
data
,
requires_grad
=
False
)
layer
.
scales
=
torch
.
nn
.
Parameter
(
layer
.
scales
.
data
,
requires_grad
=
False
)
def
apply
(
self
,
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
...
vllm/model_executor/layers/quantization/awq_marlin.py
View file @
b1e5afc3
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
...
@@ -14,6 +12,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
...
@@ -14,6 +12,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
replace_tensor
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
replace_tensor
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.parameter
import
(
GroupQuantScaleParameter
,
PackedvLLMParameter
)
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
scalar_types
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -151,6 +151,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
...
@@ -151,6 +151,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
)
->
None
:
)
->
None
:
del
output_size
del
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
# Normalize group_size
# Normalize group_size
if
self
.
quant_config
.
group_size
!=
-
1
:
if
self
.
quant_config
.
group_size
!=
-
1
:
...
@@ -164,59 +165,44 @@ class AWQMarlinLinearMethod(LinearMethodBase):
...
@@ -164,59 +165,44 @@ class AWQMarlinLinearMethod(LinearMethodBase):
input_size
=
input_size
,
input_size
=
input_size
,
group_size
=
group_size
)
group_size
=
group_size
)
qweight
=
Parameter
(
qweight
=
PackedvLLM
Parameter
(
torch
.
empty
(
data
=
torch
.
empty
(
input_size_per_partition
,
input_size_per_partition
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
),
),
requires_grad
=
False
,
input_dim
=
0
,
)
output_dim
=
1
,
set_weight_attrs
(
packed_dim
=
1
,
qweight
,
{
packed_factor
=
self
.
quant_config
.
pack_factor
,
"input_dim"
:
0
,
weight_loader
=
weight_loader
)
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
num_groups
=
input_size_per_partition
//
group_size
num_groups
=
input_size_per_partition
//
group_size
qzeros
=
Parameter
(
qzeros
=
PackedvLLM
Parameter
(
torch
.
empty
(
data
=
torch
.
empty
(
num_groups
,
num_groups
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
),
),
requires_grad
=
False
,
input_dim
=
0
,
)
output_dim
=
1
,
set_weight_attrs
(
packed_dim
=
1
,
qzeros
,
{
packed_factor
=
self
.
quant_config
.
pack_factor
,
"input_dim"
:
0
,
weight_loader
=
weight_loader
)
"output_dim"
:
1
,
"packed_dim"
:
1
,
scales
=
GroupQuantScaleParameter
(
data
=
torch
.
empty
(
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
scales
=
Parameter
(
torch
.
empty
(
num_groups
,
num_groups
,
output_size_per_partition
,
output_size_per_partition
,
dtype
=
params_dtype
,
dtype
=
params_dtype
,
),
),
requires_grad
=
False
,
input_dim
=
0
,
)
output_dim
=
1
,
set_weight_attrs
(
scales
,
{
weight_loader
=
weight_loader
)
"input_dim"
:
0
,
"output_dim"
:
1
,
})
layer
.
register_parameter
(
"qweight"
,
qweight
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
set_weight_attrs
(
qzeros
,
extra_weight_attrs
)
layer
.
register_parameter
(
"scales"
,
scales
)
layer
.
register_parameter
(
"scales"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
...
@@ -228,6 +214,12 @@ class AWQMarlinLinearMethod(LinearMethodBase):
...
@@ -228,6 +214,12 @@ class AWQMarlinLinearMethod(LinearMethodBase):
# Here, we handle the repacking
# Here, we handle the repacking
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
device
=
layer
.
qweight
.
device
device
=
layer
.
qweight
.
device
layer
.
qweight
=
torch
.
nn
.
Parameter
(
layer
.
qweight
.
data
,
requires_grad
=
False
)
layer
.
qzeros
=
torch
.
nn
.
Parameter
(
layer
.
qzeros
.
data
,
requires_grad
=
False
)
layer
.
scales
=
torch
.
nn
.
Parameter
(
layer
.
scales
.
data
,
requires_grad
=
False
)
# Allocate marlin workspace
# Allocate marlin workspace
layer
.
workspace
=
marlin_make_workspace
(
layer
.
workspace
=
marlin_make_workspace
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment