Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
015e6cc2
Unverified
Commit
015e6cc2
authored
Aug 26, 2024
by
Dipika Sikka
Committed by
GitHub
Aug 26, 2024
Browse files
[Misc] Update compressed tensors lifecycle to remove `prefix` from `create_weights` (#7825)
parent
760e9f71
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
17 additions
and
75 deletions
+17
-75
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+3
-6
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+14
-18
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
...ayers/quantization/compressed_tensors/schemes/__init__.py
+0
-2
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py
...pressed_tensors/schemes/compressed_tensors_unquantized.py
+0
-49
No files found.
vllm/model_executor/layers/linear.py
View file @
015e6cc2
...
@@ -208,8 +208,7 @@ class ReplicatedLinear(LinearBase):
...
@@ -208,8 +208,7 @@ class ReplicatedLinear(LinearBase):
self
.
input_size
,
self
.
input_size
,
self
.
output_size
,
self
.
output_size
,
self
.
params_dtype
,
self
.
params_dtype
,
weight_loader
=
self
.
weight_loader
,
weight_loader
=
self
.
weight_loader
)
prefix
=
prefix
)
if
bias
:
if
bias
:
self
.
bias
=
Parameter
(
self
.
bias
=
Parameter
(
...
@@ -307,8 +306,7 @@ class ColumnParallelLinear(LinearBase):
...
@@ -307,8 +306,7 @@ class ColumnParallelLinear(LinearBase):
params_dtype
=
self
.
params_dtype
,
params_dtype
=
self
.
params_dtype
,
weight_loader
=
(
weight_loader
=
(
self
.
weight_loader_v2
if
self
.
quant_method
.
__class__
.
__name__
self
.
weight_loader_v2
if
self
.
quant_method
.
__class__
.
__name__
in
WEIGHT_LOADER_V2_SUPPORTED
else
self
.
weight_loader
),
in
WEIGHT_LOADER_V2_SUPPORTED
else
self
.
weight_loader
))
prefix
=
prefix
)
if
bias
:
if
bias
:
self
.
bias
=
Parameter
(
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
torch
.
empty
(
self
.
output_size_per_partition
,
...
@@ -976,8 +974,7 @@ class RowParallelLinear(LinearBase):
...
@@ -976,8 +974,7 @@ class RowParallelLinear(LinearBase):
params_dtype
=
self
.
params_dtype
,
params_dtype
=
self
.
params_dtype
,
weight_loader
=
(
weight_loader
=
(
self
.
weight_loader_v2
if
self
.
quant_method
.
__class__
.
__name__
self
.
weight_loader_v2
if
self
.
quant_method
.
__class__
.
__name__
in
WEIGHT_LOADER_V2_SUPPORTED
else
self
.
weight_loader
),
in
WEIGHT_LOADER_V2_SUPPORTED
else
self
.
weight_loader
))
prefix
=
prefix
)
if
not
reduce_results
and
(
bias
and
not
skip_bias_add
):
if
not
reduce_results
and
(
bias
and
not
skip_bias_add
):
raise
ValueError
(
"When not reduce the results, adding bias to the "
raise
ValueError
(
"When not reduce the results, adding bias to the "
"results can lead to incorrect results"
)
"results can lead to incorrect results"
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
015e6cc2
...
@@ -3,15 +3,15 @@ from typing import Any, Dict, List, Optional
...
@@ -3,15 +3,15 @@ from typing import Any, Dict, List, Optional
import
torch
import
torch
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization.base_config
import
(
# noqa: E501
from
vllm.model_executor.layers.quantization.base_config
import
(
# noqa: E501
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
W4A16SPARSE24_SUPPORTED_BITS
,
WNA16_SUPPORTED_BITS
,
W4A16SPARSE24_SUPPORTED_BITS
,
WNA16_SUPPORTED_BITS
,
CompressedTensorsScheme
,
CompressedTensorsUnquantized
,
CompressedTensorsScheme
,
CompressedTensorsW4A16Sparse24
,
CompressedTensorsW4A16Sparse24
,
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A8Int8
,
CompressedTensorsW8A8Int8
,
CompressedTensorsW8A16Fp8
,
CompressedTensorsW8A16Fp8
,
CompressedTensorsWNA16
)
CompressedTensorsWNA16
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
CompressionFormat
,
QuantizationArgs
,
QuantizationStrategy
,
CompressionFormat
,
QuantizationArgs
,
QuantizationStrategy
,
QuantizationType
,
find_matched_target
,
is_activation_quantization_format
,
QuantizationType
,
find_matched_target
,
is_activation_quantization_format
,
...
@@ -52,15 +52,20 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -52,15 +52,20 @@ class CompressedTensorsConfig(QuantizationConfig):
def
get_name
(
self
)
->
str
:
def
get_name
(
self
)
->
str
:
return
"compressed_tensors"
return
"compressed_tensors"
# TODO (@robertgshaw2-neuralmagic): do layer skipping though here
# rather than though create_weights to match other methods
def
get_quant_method
(
def
get_quant_method
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
,
prefix
:
str
,
)
->
Optional
[
"QuantizeMethodBase"
]:
)
->
Optional
[
"QuantizeMethodBase"
]:
from
vllm.attention.layer
import
Attention
# Avoid circular import
from
vllm.attention.layer
import
Attention
# Avoid circular import
# Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names
if
should_ignore_layer
(
prefix
,
ignore
=
self
.
ignore
):
return
UnquantizedLinearMethod
()
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
scheme
=
self
.
get_scheme
(
layer
=
layer
,
layer_name
=
prefix
)
layer
.
scheme
=
scheme
return
CompressedTensorsLinearMethod
(
self
)
return
CompressedTensorsLinearMethod
(
self
)
if
isinstance
(
layer
,
Attention
):
if
isinstance
(
layer
,
Attention
):
return
CompressedTensorsKVCacheMethod
(
self
)
return
CompressedTensorsKVCacheMethod
(
self
)
...
@@ -281,15 +286,11 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -281,15 +286,11 @@ class CompressedTensorsConfig(QuantizationConfig):
to select the CompressedTensorsScheme used for infernece.
to select the CompressedTensorsScheme used for infernece.
"""
"""
# Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names
if
should_ignore_layer
(
layer_name
,
ignore
=
self
.
ignore
):
return
CompressedTensorsUnquantized
()
# Find the "target" in the compressed-tensors config
# Find the "target" in the compressed-tensors config
# that our layer conforms to.
# that our layer conforms to.
# TODO (@robertgshaw): add compressed-tensors as dep
# TODO (@robertgshaw): add compressed-tensors as dep
# so we do not have to re-write these functions
# so we do not have to re-write these functions
# need to make accelerate optional in ct to do this
matched_target
=
find_matched_target
(
matched_target
=
find_matched_target
(
layer_name
=
layer_name
,
layer_name
=
layer_name
,
module
=
layer
,
module
=
layer
,
...
@@ -327,10 +328,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
...
@@ -327,10 +328,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
details
details
"""
"""
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
layer_name
=
extra_weight_attrs
.
get
(
"prefix"
)
layer
.
scheme
.
create_weights
(
scheme
=
self
.
quantization_config
.
get_scheme
(
layer
,
layer_name
)
scheme
.
create_weights
(
layer
=
layer
,
layer
=
layer
,
input_size
=
input_size
,
input_size
=
input_size
,
input_size_per_partition
=
input_size_per_partition
,
input_size_per_partition
=
input_size_per_partition
,
...
@@ -339,8 +337,6 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
...
@@ -339,8 +337,6 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
weight_loader
=
weight_loader
)
weight_loader
=
weight_loader
)
layer
.
scheme
=
scheme
def
apply
(
self
,
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
View file @
015e6cc2
from
.compressed_tensors_scheme
import
CompressedTensorsScheme
from
.compressed_tensors_scheme
import
CompressedTensorsScheme
from
.compressed_tensors_unquantized
import
CompressedTensorsUnquantized
from
.compressed_tensors_w4a16_24
import
(
W4A16SPARSE24_SUPPORTED_BITS
,
from
.compressed_tensors_w4a16_24
import
(
W4A16SPARSE24_SUPPORTED_BITS
,
CompressedTensorsW4A16Sparse24
)
CompressedTensorsW4A16Sparse24
)
from
.compressed_tensors_w8a8_fp8
import
CompressedTensorsW8A8Fp8
from
.compressed_tensors_w8a8_fp8
import
CompressedTensorsW8A8Fp8
...
@@ -10,7 +9,6 @@ from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS,
...
@@ -10,7 +9,6 @@ from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS,
__all__
=
[
__all__
=
[
"CompressedTensorsScheme"
,
"CompressedTensorsScheme"
,
"CompressedTensorsUnquantized"
,
"CompressedTensorsWNA16"
,
"CompressedTensorsWNA16"
,
"CompressedTensorsW8A16Fp8"
,
"CompressedTensorsW8A16Fp8"
,
"CompressedTensorsW4A16Sparse24"
,
"CompressedTensorsW4A16Sparse24"
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py
deleted
100644 → 0
View file @
760e9f71
from
typing
import
Callable
,
List
,
Optional
import
torch
import
torch.nn.functional
as
F
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
from
vllm.model_executor.parameter
import
ModelWeightParameter
__all__
=
[
"CompressedTensorsUnquantized"
]
class
CompressedTensorsUnquantized
(
CompressedTensorsScheme
):
"""
Implements the scheme for all layers which are ignored
in the CompressedTensors config. The input and loaded weight are used
in a linear transformation.
"""
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# volta and up
return
70
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# required by torch.compile to be torch.nn.Parameter
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
params_dtype
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight"
,
weight
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
F
.
linear
(
x
,
layer
.
weight
,
bias
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment