Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
62963d12
Unverified
Commit
62963d12
authored
Jul 03, 2024
by
Robert Shaw
Committed by
GitHub
Jul 03, 2024
Browse files
[ Misc ] Clean Up `CompressedTensorsW8A8` (#6113)
parent
d9e98f42
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
44 additions
and
95 deletions
+44
-95
tests/quantization/test_compressed_tensors.py
tests/quantization/test_compressed_tensors.py
+5
-4
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+5
-6
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
...ayers/quantization/compressed_tensors/schemes/__init__.py
+1
-4
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8.py
...ion/compressed_tensors/schemes/compressed_tensors_w8a8.py
+33
-1
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py
...d_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py
+0
-33
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py
...d_tensors/schemes/compressed_tensors_w8a8_statictensor.py
+0
-47
No files found.
tests/quantization/test_compressed_tensors.py
View file @
62963d12
...
@@ -9,8 +9,7 @@ import torch
...
@@ -9,8 +9,7 @@ import torch
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors
import
(
# noqa: E501
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors
import
(
# noqa: E501
CompressedTensorsLinearMethod
,
CompressedTensorsW4A16Sparse24
,
CompressedTensorsLinearMethod
,
CompressedTensorsW4A16Sparse24
,
CompressedTensorsW8A8DynamicToken
,
CompressedTensorsW8A8StaticTensor
,
CompressedTensorsW8A8
,
CompressedTensorsWNA16
)
CompressedTensorsWNA16
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
QuantizationType
)
QuantizationType
)
...
@@ -38,9 +37,10 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
...
@@ -38,9 +37,10 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
CompressedTensorsLinearMethod
)
CompressedTensorsLinearMethod
)
assert
isinstance
(
down_proj
.
quant_method
,
assert
isinstance
(
down_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW8A8
StaticTensor
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW8A8
)
assert
qkv_proj
.
scheme
.
strategy
==
strategy
assert
qkv_proj
.
scheme
.
strategy
==
strategy
assert
qkv_proj
.
scheme
.
is_static_input_scheme
expected_type
=
(
torch
.
int8
if
quant_type
==
QuantizationType
.
INT
else
expected_type
=
(
torch
.
int8
if
quant_type
==
QuantizationType
.
INT
else
torch
.
float8_e4m3fn
)
torch
.
float8_e4m3fn
)
...
@@ -79,7 +79,8 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args):
...
@@ -79,7 +79,8 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args):
qkv_proj
=
layer
.
self_attn
.
qkv_proj
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW8A8DynamicToken
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW8A8
)
assert
not
qkv_proj
.
scheme
.
is_static_input_scheme
assert
qkv_proj
.
scheme
.
strategy
==
strategy
assert
qkv_proj
.
scheme
.
strategy
==
strategy
assert
qkv_proj
.
weight
.
dtype
is
torch
.
int8
assert
qkv_proj
.
weight
.
dtype
is
torch
.
int8
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
62963d12
...
@@ -9,8 +9,7 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
...
@@ -9,8 +9,7 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
W4A16SPARSE24_SUPPORTED_BITS
,
WNA16_SUPPORTED_BITS
,
W4A16SPARSE24_SUPPORTED_BITS
,
WNA16_SUPPORTED_BITS
,
CompressedTensorsScheme
,
CompressedTensorsW4A16Sparse24
,
CompressedTensorsScheme
,
CompressedTensorsW4A16Sparse24
,
CompressedTensorsW8A8DynamicToken
,
CompressedTensorsW8A8StaticTensor
,
CompressedTensorsW8A8
,
CompressedTensorsWNA16
)
CompressedTensorsWNA16
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
CompressionFormat
,
QuantizationArgs
,
QuantizationStrategy
,
CompressionFormat
,
QuantizationArgs
,
QuantizationStrategy
,
find_first_name_or_class_match
)
find_first_name_or_class_match
)
...
@@ -150,12 +149,12 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -150,12 +149,12 @@ class CompressedTensorsConfig(QuantizationConfig):
if
self
.
quant_format
==
CompressionFormat
.
int_quantized
.
value
:
if
self
.
quant_format
==
CompressionFormat
.
int_quantized
.
value
:
if
self
.
_is_static_tensor_w8a8
(
weight_quant
,
input_quant
):
if
self
.
_is_static_tensor_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8
StaticTensor
(
return
CompressedTensorsW8A8
(
strategy
=
weight_quant
.
strategy
,
st
r
at
egy
=
weight_quant
.
strategy
)
is_
stat
ic_input_scheme
=
True
)
if
self
.
_is_dynamic_token_w8a8
(
weight_quant
,
input_quant
):
if
self
.
_is_dynamic_token_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8
DynamicToken
(
return
CompressedTensorsW8A8
(
strategy
=
weight_quant
.
strategy
,
st
r
at
egy
=
weight_quant
.
strategy
)
is_
stat
ic_input_scheme
=
False
)
raise
NotImplementedError
(
raise
NotImplementedError
(
"No compressed-tensors compatible scheme was found."
)
"No compressed-tensors compatible scheme was found."
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
View file @
62963d12
...
@@ -3,9 +3,6 @@ from .compressed_tensors_unquantized import ( # noqa: F401
...
@@ -3,9 +3,6 @@ from .compressed_tensors_unquantized import ( # noqa: F401
CompressedTensorsUnquantized
)
CompressedTensorsUnquantized
)
from
.compressed_tensors_w4a16_24
import
(
# noqa: F401
from
.compressed_tensors_w4a16_24
import
(
# noqa: F401
W4A16SPARSE24_SUPPORTED_BITS
,
CompressedTensorsW4A16Sparse24
)
W4A16SPARSE24_SUPPORTED_BITS
,
CompressedTensorsW4A16Sparse24
)
from
.compressed_tensors_w8a8_dynamictoken
import
(
# noqa: F401, E501
from
.compressed_tensors_w8a8
import
CompressedTensorsW8A8
# noqa: F401
CompressedTensorsW8A8DynamicToken
)
from
.compressed_tensors_w8a8_statictensor
import
(
# noqa: F401, E501
CompressedTensorsW8A8StaticTensor
)
from
.compressed_tensors_wNa16
import
WNA16_SUPPORTED_BITS
# noqa: F401
from
.compressed_tensors_wNa16
import
WNA16_SUPPORTED_BITS
# noqa: F401
from
.compressed_tensors_wNa16
import
CompressedTensorsWNA16
# noqa: F401
from
.compressed_tensors_wNa16
import
CompressedTensorsWNA16
# noqa: F401
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8.py
View file @
62963d12
...
@@ -3,6 +3,7 @@ from typing import Callable, List, Tuple, Union
...
@@ -3,6 +3,7 @@ from typing import Callable, List, Tuple, Union
import
torch
import
torch
from
torch.nn
import
Parameter
from
torch.nn
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
...
@@ -12,8 +13,9 @@ from vllm.model_executor.utils import set_weight_attrs
...
@@ -12,8 +13,9 @@ from vllm.model_executor.utils import set_weight_attrs
class
CompressedTensorsW8A8
(
CompressedTensorsScheme
):
class
CompressedTensorsW8A8
(
CompressedTensorsScheme
):
def
__init__
(
self
,
strategy
:
str
):
def
__init__
(
self
,
strategy
:
str
,
is_static_input_scheme
:
bool
):
self
.
strategy
=
strategy
self
.
strategy
=
strategy
self
.
is_static_input_scheme
=
is_static_input_scheme
# Cutlass kernels support only per-tensor and per-channel cases.
# Cutlass kernels support only per-tensor and per-channel cases.
# So if we have a fused module (QKV, MLP) with per tensor scales (thus N
# So if we have a fused module (QKV, MLP) with per tensor scales (thus N
...
@@ -36,6 +38,10 @@ class CompressedTensorsW8A8(CompressedTensorsScheme):
...
@@ -36,6 +38,10 @@ class CompressedTensorsW8A8(CompressedTensorsScheme):
layer
.
weight_scale
=
Parameter
(
weight_scale_channel
,
layer
.
weight_scale
=
Parameter
(
weight_scale_channel
,
requires_grad
=
False
)
requires_grad
=
False
)
# transpose weights for cutlass.
weight
=
layer
.
weight
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_partition_sizes
:
List
[
int
],
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
input_size_per_partition
:
int
,
...
@@ -75,3 +81,29 @@ class CompressedTensorsW8A8(CompressedTensorsScheme):
...
@@ -75,3 +81,29 @@ class CompressedTensorsW8A8(CompressedTensorsScheme):
"output_dim"
:
0
,
"output_dim"
:
0
,
"weight_loader"
:
weight_loader
,
"weight_loader"
:
weight_loader
,
})
})
# INPUT SCALE
# Static quantization: load from disk.
if
self
.
is_static_input_scheme
:
input_scale
=
Parameter
(
torch
.
empty
(
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
set_weight_attrs
(
input_scale
,
{
"weight_loader"
:
weight_loader
,
"ignore_warning"
:
True
,
})
# Dynamic quantization: set to None.
else
:
layer
.
input_scale
=
None
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
):
# ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale.
x_q
,
x_scale
=
ops
.
scaled_int8_quant
(
x
,
layer
.
input_scale
)
return
ops
.
cutlass_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py
deleted
100644 → 0
View file @
d9e98f42
from
typing
import
Callable
,
List
import
torch
from
vllm
import
_custom_ops
as
custom_ops
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8
import
(
# noqa: E501
CompressedTensorsW8A8
)
__all__
=
[
"CompressedTensorsW8A8DynamicToken"
]
class
CompressedTensorsW8A8DynamicToken
(
CompressedTensorsW8A8
):
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
super
().
create_weights
(
layer
=
layer
,
output_partition_sizes
=
output_partition_sizes
,
input_size_per_partition
=
input_size_per_partition
,
params_dtype
=
params_dtype
,
weight_loader
=
weight_loader
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
):
weight
=
layer
.
weight
weight_scale
=
layer
.
weight_scale
x_q
,
input_scales
=
custom_ops
.
scaled_int8_quant
(
x
)
return
custom_ops
.
cutlass_scaled_mm
(
x_q
,
weight
.
t
(),
input_scales
,
weight_scale
,
x
.
dtype
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py
deleted
100644 → 0
View file @
d9e98f42
from
typing
import
Callable
,
List
import
torch
from
torch.nn
import
Parameter
from
vllm
import
_custom_ops
as
custom_ops
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8
import
(
# noqa: E501
CompressedTensorsW8A8
)
from
vllm.model_executor.utils
import
set_weight_attrs
__all__
=
[
"CompressedTensorsW8A8StaticTensor"
]
class
CompressedTensorsW8A8StaticTensor
(
CompressedTensorsW8A8
):
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
super
().
create_weights
(
layer
=
layer
,
output_partition_sizes
=
output_partition_sizes
,
input_size_per_partition
=
input_size_per_partition
,
params_dtype
=
params_dtype
,
weight_loader
=
weight_loader
)
input_scale
=
Parameter
(
torch
.
empty
(
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
set_weight_attrs
(
input_scale
,
{
"weight_loader"
:
weight_loader
,
"ignore_warning"
:
True
,
})
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
):
weight
=
layer
.
weight
weight_scale
=
layer
.
weight_scale
act_scale
=
layer
.
input_scale
# Input quantize
x_q
,
_
=
custom_ops
.
scaled_int8_quant
(
x
,
act_scale
)
return
custom_ops
.
cutlass_scaled_mm
(
x_q
,
weight
.
t
(),
act_scale
,
weight_scale
,
x
.
dtype
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment