Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c7cb5c33
Unverified
Commit
c7cb5c33
authored
Sep 09, 2024
by
Kyle Sayers
Committed by
GitHub
Sep 09, 2024
Browse files
[Misc] GPTQ Activation Ordering (#8135)
parent
f9b4a2d4
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
64 additions
and
15 deletions
+64
-15
tests/weight_loading/models.txt
tests/weight_loading/models.txt
+1
-0
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+2
-1
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
...on/compressed_tensors/schemes/compressed_tensors_wNa16.py
+33
-12
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
..._executor/layers/quantization/compressed_tensors/utils.py
+28
-2
No files found.
tests/weight_loading/models.txt
View file @
c7cb5c33
...
...
@@ -21,6 +21,7 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main
awq, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
c7cb5c33
...
...
@@ -232,7 +232,8 @@ class CompressedTensorsConfig(QuantizationConfig):
return
CompressedTensorsWNA16
(
num_bits
=
weight_quant
.
num_bits
,
strategy
=
weight_quant
.
strategy
,
group_size
=
weight_quant
.
group_size
)
group_size
=
weight_quant
.
group_size
,
actorder
=
weight_quant
.
actorder
)
# Detect If Activation Quantization.
# TODO @dsikka: clean-up conditions
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
View file @
c7cb5c33
...
...
@@ -5,14 +5,18 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
ActivationOrdering
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
apply_gptq_marlin_linear
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
replace_tensor
,
verify_marlin_supported
,
marlin_permute_scales
,
marlin_repeat_scales_on_all_ranks
,
marlin_sort_g_idx
,
replace_tensor
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
PackedvLLMParameter
)
PackedvLLMParameter
,
RowvLLMParameter
)
from
vllm.scalar_type
import
scalar_types
__all__
=
[
"CompressedTensorsWNA16"
]
...
...
@@ -28,11 +32,13 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
def
__init__
(
self
,
strategy
:
str
,
num_bits
:
int
,
group_size
:
Optional
[
int
]
=
None
):
group_size
:
Optional
[
int
]
=
None
,
actorder
:
Optional
[
ActivationOrdering
]
=
None
):
self
.
pack_factor
=
32
//
num_bits
self
.
strategy
=
strategy
self
.
group_size
=
-
1
if
group_size
is
None
else
group_size
self
.
has_g_idx
=
actorder
==
ActivationOrdering
.
GROUP
if
self
.
group_size
==
-
1
and
self
.
strategy
!=
"channel"
:
raise
ValueError
(
"Marlin kernels require group quantization or "
...
...
@@ -64,12 +70,10 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
output_size_per_partition
=
sum
(
output_partition_sizes
)
# If group_size is -1, we are in channelwise case.
channelwise
=
(
self
.
group_size
==
-
1
)
group_size
=
self
.
group_size
if
self
.
group_size
!=
-
1
else
input_size
row_parallel
=
(
input_size
!=
input_size_per_partition
)
# In the case of channelwise quantization, we need to replicate the
# scales across all gpus.
partition_scales
=
(
row_parallel
and
not
channelwise
)
partition_scales
=
not
marlin_repeat_scales_on_all_ranks
(
self
.
has_g_idx
,
self
.
group_size
,
row_parallel
)
verify_marlin_supports_shape
(
output_size_per_partition
=
output_size_per_partition
,
...
...
@@ -123,6 +127,16 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_shape"
,
weight_shape
)
# group index (for activation reordering)
if
self
.
has_g_idx
:
weight_g_idx
=
RowvLLMParameter
(
data
=
torch
.
empty
(
input_size_per_partition
,
dtype
=
torch
.
int32
,
),
input_dim
=
0
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight_g_idx"
,
weight_g_idx
)
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
input_size
=
input_size
...
...
@@ -137,8 +151,13 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
layer
.
workspace
=
marlin_make_workspace
(
layer
.
output_size_per_partition
,
device
)
# Act-order not supported in compressed-tensors yet, so set to empty.
layer
.
g_idx
=
marlin_make_empty_g_idx
(
device
)
# Handle sorting for activation reordering if needed.
if
self
.
has_g_idx
:
g_idx
,
g_idx_sort_indices
=
marlin_sort_g_idx
(
layer
.
weight_g_idx
)
layer
.
g_idx_sort_indices
=
g_idx_sort_indices
replace_tensor
(
layer
,
"weight_g_idx"
,
g_idx
)
else
:
layer
.
weight_g_idx
=
marlin_make_empty_g_idx
(
device
)
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
device
)
# No zero-point
...
...
@@ -159,9 +178,11 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
replace_tensor
(
layer
,
"weight_packed"
,
marlin_qweight
)
# Permute scales from compressed-tensors format to marlin format.
# scale is required on all partitions if activation reordering
marlin_scales
=
marlin_permute_scales
(
layer
.
weight_scale
,
size_k
=
layer
.
input_size_per_partition
,
size_k
=
(
layer
.
input_size
if
self
.
has_g_idx
else
layer
.
input_size_per_partition
),
size_n
=
layer
.
output_size_per_partition
,
group_size
=
layer
.
group_size
)
replace_tensor
(
layer
,
"weight_scale"
,
marlin_scales
)
...
...
@@ -174,7 +195,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
weight
=
layer
.
weight_packed
,
weight_scale
=
layer
.
weight_scale
,
weight_zp
=
layer
.
weight_zp
,
g_idx
=
layer
.
g_idx
,
g_idx
=
layer
.
weight_
g_idx
,
g_idx_sort_indices
=
layer
.
g_idx_sort_indices
,
workspace
=
layer
.
workspace
,
wtype
=
self
.
quant_type
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
View file @
c7cb5c33
import
re
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Union
from
pydantic
import
BaseModel
,
Field
from
pydantic
import
BaseModel
,
Field
,
field_validator
from
torch.nn
import
Module
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
...
...
@@ -40,6 +40,19 @@ class QuantizationStrategy(str, Enum):
TOKEN
=
"token"
class
ActivationOrdering
(
str
,
Enum
):
"""
Enum storing strategies for activation ordering
Group: reorder groups and weight
\n
Weight: only reorder weight, not groups. Slightly lower latency and
accuracy compared to group actorder
\n
"""
GROUP
=
"group"
WEIGHT
=
"weight"
class
QuantizationArgs
(
BaseModel
):
"""
User facing arguments used to define a quantization config
...
...
@@ -58,6 +71,8 @@ class QuantizationArgs(BaseModel):
observed with every sample. Defaults to False for static
quantization. Note that enabling dynamic quantization
will change the default observer to a memoryless one
:param actorder: whether to apply group quantization in decreasing order of
activation. Defaults to None for arbitrary ordering
"""
num_bits
:
int
=
8
...
...
@@ -67,6 +82,7 @@ class QuantizationArgs(BaseModel):
strategy
:
Optional
[
QuantizationStrategy
]
=
None
block_structure
:
Optional
[
str
]
=
None
dynamic
:
bool
=
False
actorder
:
Union
[
ActivationOrdering
,
bool
,
None
]
=
None
observer
:
str
=
Field
(
default
=
"minmax"
,
description
=
(
"The class to use to compute the quantization param - "
...
...
@@ -79,6 +95,16 @@ class QuantizationArgs(BaseModel):
"Observers constructor excluding quantization range or symmetry"
),
)
@
field_validator
(
"actorder"
,
mode
=
"before"
)
def
validate_actorder
(
cls
,
value
)
->
Optional
[
ActivationOrdering
]:
if
isinstance
(
value
,
bool
):
return
ActivationOrdering
.
GROUP
if
value
else
None
if
isinstance
(
value
,
str
):
return
ActivationOrdering
(
value
.
lower
())
return
value
def
is_activation_quantization_format
(
format
:
str
)
->
bool
:
_ACTIVATION_QUANTIZATION_FORMATS
=
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment