Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
890d8d96
Unverified
Commit
890d8d96
authored
Jun 17, 2024
by
Dipika Sikka
Committed by
GitHub
Jun 17, 2024
Browse files
[Kernel] `compressed-tensors` marlin 24 support (#5435)
parent
9e74d9d0
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
196 additions
and
19 deletions
+196
-19
tests/quantization/test_compressed_tensors.py
tests/quantization/test_compressed_tensors.py
+20
-3
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+32
-16
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
...ayers/quantization/compressed_tensors/schemes/__init__.py
+2
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
...compressed_tensors/schemes/compressed_tensors_w4a16_24.py
+134
-0
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
..._executor/layers/quantization/compressed_tensors/utils.py
+8
-0
No files found.
tests/quantization/test_compressed_tensors.py
View file @
890d8d96
...
...
@@ -9,7 +9,8 @@ import torch
from
vllm
import
SamplingParams
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors
import
(
# noqa: E501
CompressedTensorsLinearMethod
,
CompressedTensorsW4A16
,
CompressedTensorsW8A8DynamicToken
,
CompressedTensorsW8A8StaticTensor
)
CompressedTensorsW4A16Sparse24
,
CompressedTensorsW8A8DynamicToken
,
CompressedTensorsW8A8StaticTensor
)
def
test_compressed_tensors_w8a8_static_setup
(
vllm_runner
):
...
...
@@ -51,8 +52,7 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):
def
test_compressed_tensors_w8a8_dynanmic_per_token
(
vllm_runner
):
model_path
=
"nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2"
with
vllm_runner
(
model_path
,
enforce_eager
=
True
,
dtype
=
torch
.
float16
)
as
llm
:
with
vllm_runner
(
model_path
,
dtype
=
torch
.
float16
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
...
...
@@ -83,3 +83,20 @@ def test_compressed_tensors_w4a16(vllm_runner, w4a16_args):
assert
qkv_proj
.
weight_packed
.
dtype
is
torch
.
int32
assert
qkv_proj
.
weight_scale
.
dtype
is
torch
.
float16
assert
qkv_proj
.
weight_packed
.
pack_factor
==
8
def
test_compressed_tensors_w4a16_marlin24
(
vllm_runner
):
model_path
=
"nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
with
vllm_runner
(
model_path
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW4A16Sparse24
)
assert
qkv_proj
.
weight_packed
.
dtype
is
torch
.
int32
sampling_params
=
SamplingParams
()
output
=
llm
.
generate
(
"Hello world!"
,
sampling_params
=
sampling_params
)
assert
output
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
890d8d96
...
...
@@ -8,16 +8,20 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
,
CompressedTensorsW4A16
,
CompressedTensorsW8A8DynamicToken
,
CompressedTensorsW8A8StaticTensor
)
CompressedTensorsW4A16Sparse24
,
CompressedTensorsW8A8DynamicToken
,
CompressedTensorsW8A8StaticTensor
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
QuantizationArgs
,
QuantizationStrategy
,
find_first_name_or_class_match
)
CompressionFormat
,
QuantizationArgs
,
QuantizationStrategy
,
find_first_name_or_class_match
)
class
CompressedTensorsConfig
(
QuantizationConfig
):
def
__init__
(
self
,
layer_quant_details
:
Dict
[
str
,
Any
],
ignore
:
List
[
str
]):
def
__init__
(
self
,
layer_quant_details
:
Dict
[
str
,
Any
],
ignore
:
List
[
str
],
quant_format
:
str
):
self
.
ignore
=
ignore
self
.
layer_quant_details
=
layer_quant_details
self
.
quant_format
=
quant_format
def
get_linear_method
(
self
)
->
"CompressedTensorsLinearMethod"
:
return
CompressedTensorsLinearMethod
(
self
)
...
...
@@ -46,6 +50,7 @@ class CompressedTensorsConfig(QuantizationConfig):
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"CompressedTensorsConfig"
:
layer_quant_details
:
Dict
[
str
,
Any
]
=
dict
()
ignore
:
List
[
str
]
=
config
.
get
(
"ignore"
,
None
)
quant_format
:
str
=
config
.
get
(
"format"
,
None
)
# The quant_config has multiple config_groups, each containing
# an input_activations key with details about how the activations are
...
...
@@ -69,7 +74,9 @@ class CompressedTensorsConfig(QuantizationConfig):
except
Exception
:
layer_quant_details
[
target
][
"input_activations"
]
=
None
return
cls
(
layer_quant_details
=
layer_quant_details
,
ignore
=
ignore
)
return
cls
(
layer_quant_details
=
layer_quant_details
,
ignore
=
ignore
,
quant_format
=
quant_format
)
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
...
...
@@ -110,17 +117,26 @@ class CompressedTensorsConfig(QuantizationConfig):
input_quant
:
BaseModel
)
->
"CompressedTensorsScheme"
:
if
self
.
_is_w4a16
(
weight_quant
,
input_quant
):
return
CompressedTensorsW4A16
(
num_bits
=
weight_quant
.
num_bits
,
if
self
.
quant_format
==
CompressionFormat
.
marlin_24
.
value
:
return
CompressedTensorsW4A16Sparse24
(
strategy
=
weight_quant
.
strategy
,
num_bits
=
weight_quant
.
num_bits
,
group_size
=
weight_quant
.
group_size
)
if
self
.
quant_format
==
CompressionFormat
.
pack_quantized
.
value
:
return
CompressedTensorsW4A16
(
num_bits
=
weight_quant
.
num_bits
,
strategy
=
weight_quant
.
strategy
,
group_size
=
weight_quant
.
group_size
)
if
self
.
quant_format
==
CompressionFormat
.
int_quantized
.
value
:
if
self
.
_is_static_tensor_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8StaticTensor
()
if
self
.
_is_dynamic_token_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8DynamicToken
()
raise
NotImplementedError
(
"Scheme not supported."
)
raise
NotImplementedError
(
"No compressed-tensors compatible scheme was found."
)
def
get_scheme
(
self
,
layer
:
torch
.
nn
.
Module
)
->
"CompressedTensorsScheme"
:
...
...
@@ -165,9 +181,9 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
scheme
=
self
.
quantization_config
.
get_scheme
(
layer
=
layer
)
scheme
.
create_weights
(
layer
=
layer
,
input_size
=
input_size
,
input_size_per_partition
=
input_size_per_partition
,
output_partition_sizes
=
output_partition_sizes
,
input_size
=
input_size
,
output_size
=
output_size
,
params_dtype
=
params_dtype
,
weight_loader
=
weight_loader
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
View file @
890d8d96
...
...
@@ -2,6 +2,8 @@ from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401
from
.compressed_tensors_unquantized
import
(
# noqa: F401
CompressedTensorsUnquantized
)
from
.compressed_tensors_w4a16
import
CompressedTensorsW4A16
# noqa: F401
from
.compressed_tensors_w4a16_24
import
(
# noqa: F401
CompressedTensorsW4A16Sparse24
)
from
.compressed_tensors_w8a8_dynamictoken
import
(
# noqa: F401, E501
CompressedTensorsW8A8DynamicToken
)
from
.compressed_tensors_w8a8_statictensor
import
(
# noqa: F401, E501
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
0 → 100644
View file @
890d8d96
from
typing
import
Callable
,
List
,
Optional
import
torch
from
torch.nn
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
GPTQ_MARLIN_24_MAX_PARALLEL
,
GPTQ_MARLIN_24_MIN_THREAD_N
)
from
vllm.model_executor.utils
import
set_weight_attrs
__all__
=
[
"CompressedTensorsW4A16Sparse24"
]
class
CompressedTensorsW4A16Sparse24
(
CompressedTensorsScheme
):
def
__init__
(
self
,
strategy
:
str
,
num_bits
:
int
,
group_size
:
Optional
[
int
]
=
None
):
self
.
strategy
=
strategy
self
.
group_size
=
group_size
self
.
num_bits
=
num_bits
self
.
tile_size
=
16
if
self
.
strategy
==
"group"
and
self
.
group_size
is
None
:
raise
ValueError
(
"group_size must be given when using strategy group"
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size
:
int
,
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
pack_factor
=
32
//
self
.
num_bits
output_size_per_partition
=
sum
(
output_partition_sizes
)
qweight
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
self
.
tile_size
//
2
,
output_size_per_partition
*
self
.
tile_size
//
pack_factor
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
pack_factor
,
"marlin_tile_size"
:
self
.
tile_size
,
"weight_loader"
:
weight_loader
},
)
layer
.
register_parameter
(
"weight_packed"
,
qweight
)
input_groups
=
(
1
if
self
.
group_size
is
None
else
input_size_per_partition
//
self
.
group_size
)
scales
=
Parameter
(
torch
.
empty
(
input_groups
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
scales
,
{
"output_dim"
:
1
,
"input_dim"
:
None
if
input_groups
==
1
else
0
,
"weight_loader"
:
weight_loader
},
)
layer
.
register_parameter
(
"scale_packed"
,
scales
)
weight_shape
=
Parameter
(
torch
.
empty
(
2
,
dtype
=
torch
.
int64
),
requires_grad
=
False
)
layer
.
register_parameter
(
"weight_shape"
,
weight_shape
)
set_weight_attrs
(
weight_shape
,
{
"weight_loader"
:
weight_loader
})
meta
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
8
//
2
//
2
,
output_size_per_partition
*
2
,
dtype
=
torch
.
int16
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
meta
,
{
"input_dim"
:
0
,
"packed_dim"
:
1
,
"pack_factor"
:
1
,
"output_dim"
:
1
,
"marlin_tile_size"
:
2
,
"weight_loader"
:
weight_loader
},
)
layer
.
register_parameter
(
"meta"
,
meta
)
max_workspace_size
=
(
output_size_per_partition
//
GPTQ_MARLIN_24_MIN_THREAD_N
)
*
GPTQ_MARLIN_24_MAX_PARALLEL
workspace
=
Parameter
(
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
),
requires_grad
=
False
)
layer
.
workspace
=
workspace
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
):
qweight
=
layer
.
weight_packed
meta
=
layer
.
meta
scales
=
layer
.
scale_packed
workspace
=
layer
.
workspace
x_2d
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
size_m
=
x_2d
.
shape
[
0
]
size_k
=
x_2d
.
shape
[
1
]
size_n
=
scales
.
shape
[
1
]
output_2d
=
ops
.
gptq_marlin_24_gemm
(
x_2d
,
qweight
,
meta
,
scales
,
workspace
,
self
.
num_bits
,
size_m
,
size_n
,
size_k
)
output
=
output_2d
.
view
(
x
.
shape
[:
-
1
]
+
(
output_2d
.
shape
[
1
],
))
return
output
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
View file @
890d8d96
...
...
@@ -6,6 +6,14 @@ from pydantic import BaseModel, Field
from
torch.nn
import
Module
class
CompressionFormat
(
Enum
):
dense
=
"dense"
sparse_bitmask
=
"sparse-bitmask"
int_quantized
=
"int-quantized"
pack_quantized
=
"pack-quantized"
marlin_24
=
"marlin-24"
class
QuantizationType
(
str
,
Enum
):
"""
Enum storing quantization type options
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment