Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
500b93c8
Commit
500b93c8
authored
Jul 25, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.3.post1' into v0.5.3.post1-dtk24.04.1
parents
99426767
38c4b7e8
Changes
282
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
852 additions
and
225 deletions
+852
-225
vllm/model_executor/layers/quantization/bitsandbytes.py
vllm/model_executor/layers/quantization/bitsandbytes.py
+2
-3
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+141
-50
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py
...n/compressed_tensors/schemes/compressed_tensors_scheme.py
+7
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py
...pressed_tensors/schemes/compressed_tensors_unquantized.py
+4
-1
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
...compressed_tensors/schemes/compressed_tensors_w4a16_24.py
+4
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+48
-28
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
...ompressed_tensors/schemes/compressed_tensors_w8a8_int8.py
+4
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
...on/compressed_tensors/schemes/compressed_tensors_wNa16.py
+14
-6
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
..._executor/layers/quantization/compressed_tensors/utils.py
+145
-21
vllm/model_executor/layers/quantization/deepspeedfp.py
vllm/model_executor/layers/quantization/deepspeedfp.py
+2
-3
vllm/model_executor/layers/quantization/fbgemm_fp8.py
vllm/model_executor/layers/quantization/fbgemm_fp8.py
+182
-0
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+9
-36
vllm/model_executor/layers/quantization/gptq.py
vllm/model_executor/layers/quantization/gptq.py
+2
-2
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+23
-15
vllm/model_executor/layers/quantization/gptq_marlin_24.py
vllm/model_executor/layers/quantization/gptq_marlin_24.py
+2
-3
vllm/model_executor/layers/quantization/kv_cache.py
vllm/model_executor/layers/quantization/kv_cache.py
+78
-0
vllm/model_executor/layers/quantization/marlin.py
vllm/model_executor/layers/quantization/marlin.py
+2
-2
vllm/model_executor/layers/quantization/squeezellm.py
vllm/model_executor/layers/quantization/squeezellm.py
+2
-2
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+173
-51
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
...el_executor/layers/quantization/utils/marlin_utils_fp8.py
+8
-2
No files found.
vllm/model_executor/layers/quantization/bitsandbytes.py
View file @
500b93c8
...
@@ -60,9 +60,8 @@ class BitsAndBytesConfig(QuantizationConfig):
...
@@ -60,9 +60,8 @@ class BitsAndBytesConfig(QuantizationConfig):
target_modules
=
cls
.
get_from_keys
(
config
,
[
"target_modules"
])
target_modules
=
cls
.
get_from_keys
(
config
,
[
"target_modules"
])
return
cls
(
adapter_name
,
target_modules
)
return
cls
(
adapter_name
,
target_modules
)
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
self
,
prefix
:
str
)
->
Optional
[
"BitsAndBytesLinearMethod"
]:
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"BitsAndBytesLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
return
BitsAndBytesLinearMethod
(
self
)
return
BitsAndBytesLinearMethod
(
self
)
return
None
return
None
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
500b93c8
...
@@ -5,25 +5,33 @@ from pydantic import BaseModel
...
@@ -5,25 +5,33 @@ from pydantic import BaseModel
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
# noqa: E501
from
vllm.model_executor.layers.quantization.base_config
import
(
# noqa: E501
QuantizationConfig
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
W4A16SPARSE24_SUPPORTED_BITS
,
WNA16_SUPPORTED_BITS
,
W4A16SPARSE24_SUPPORTED_BITS
,
WNA16_SUPPORTED_BITS
,
CompressedTensorsScheme
,
CompressedTensors
W4A16Sparse24
,
CompressedTensorsScheme
,
CompressedTensors
Unquantized
,
CompressedTensorsW
8A8Fp8
,
CompressedTensorsW8A8
Int
8
,
CompressedTensorsW
4A16Sparse24
,
CompressedTensorsW8A8
Fp
8
,
CompressedTensorsWNA16
)
CompressedTensorsW8A8Int8
,
CompressedTensorsWNA16
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
CompressionFormat
,
QuantizationArgs
,
QuantizationStrategy
,
CompressionFormat
,
QuantizationArgs
,
QuantizationStrategy
,
QuantizationType
,
find_first_name_or_class_match
)
QuantizationType
,
find_matched_target
,
is_activation_quantization_format
,
should_ignore_layer
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
class
CompressedTensorsConfig
(
QuantizationConfig
):
class
CompressedTensorsConfig
(
QuantizationConfig
):
def
__init__
(
self
,
layer_quant_details
:
Dict
[
str
,
Any
],
ignore
:
List
[
str
],
def
__init__
(
self
,
quant_format
:
str
):
target_scheme_map
:
Dict
[
str
,
Any
],
ignore
:
List
[
str
],
quant_format
:
str
,
kv_cache_scheme
:
Optional
[
Dict
[
str
,
Any
]]
=
None
):
self
.
ignore
=
ignore
self
.
ignore
=
ignore
self
.
layer_quant_details
=
layer_quant_details
self
.
quant_format
=
quant_format
self
.
quant_format
=
quant_format
# Map from [target -> scheme]
self
.
target_scheme_map
=
target_scheme_map
self
.
kv_cache_scheme
=
kv_cache_scheme
def
get_linear_method
(
self
)
->
"CompressedTensorsLinearMethod"
:
def
get_linear_method
(
self
)
->
"CompressedTensorsLinearMethod"
:
return
CompressedTensorsLinearMethod
(
self
)
return
CompressedTensorsLinearMethod
(
self
)
...
@@ -36,21 +44,28 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -36,21 +44,28 @@ class CompressedTensorsConfig(QuantizationConfig):
@
classmethod
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
def
get_min_capability
(
cls
)
->
int
:
return
7
5
return
7
0
def
get_name
(
self
)
->
str
:
def
get_name
(
self
)
->
str
:
return
"compressed_tensors"
return
"compressed_tensors"
# TODO (@robertgshaw2-neuralmagic): do layer skipping though here
# rather than though create_weights to match other methods
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
self
,
)
->
Optional
[
"CompressedTensorsLinearMethod"
]:
layer
:
torch
.
nn
.
Module
,
prefix
:
str
,
)
->
Optional
[
"QuantizeMethodBase"
]:
from
vllm.attention.layer
import
Attention
# Avoid circular import
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
return
CompressedTensorsLinearMethod
(
self
)
return
CompressedTensorsLinearMethod
(
self
)
if
isinstance
(
layer
,
Attention
):
return
CompressedTensorsKVCacheMethod
(
self
)
return
None
return
None
@
classmethod
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"CompressedTensorsConfig"
:
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"CompressedTensorsConfig"
:
layer_quant_details
:
Dict
[
str
,
Any
]
=
dict
()
target_scheme_map
:
Dict
[
str
,
Any
]
=
dict
()
ignore
:
List
[
str
]
=
config
.
get
(
"ignore"
,
None
)
ignore
:
List
[
str
]
=
config
.
get
(
"ignore"
,
None
)
quant_format
:
str
=
config
.
get
(
"format"
,
None
)
quant_format
:
str
=
config
.
get
(
"format"
,
None
)
...
@@ -62,35 +77,37 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -62,35 +77,37 @@ class CompressedTensorsConfig(QuantizationConfig):
# details follow the structure defined by the QuantizationArgs
# details follow the structure defined by the QuantizationArgs
# pydantic model, which is used to verify the structure of the
# pydantic model, which is used to verify the structure of the
# quant_config and also store the details for later use.
# quant_config and also store the details for later use.
for
key
,
quant_config
in
config
[
"config_groups"
].
items
():
for
_
,
quant_config
in
config
[
"config_groups"
].
items
():
targets
=
quant_config
.
get
(
"targets"
)
targets
=
quant_config
.
get
(
"targets"
)
for
target
in
targets
:
for
target
in
targets
:
layer_quant_details
[
target
]
=
{}
target_scheme_map
[
target
]
=
{}
layer_quant_details
[
target
][
target_scheme_map
[
target
][
"weights"
]
=
QuantizationArgs
.
parse_obj
(
"weights"
]
=
QuantizationArgs
.
parse_obj
(
quant_config
.
get
(
"weights"
))
quant_config
.
get
(
"weights"
))
try
:
try
:
layer_quant_details
[
target
][
target_scheme_map
[
target
][
"input_activations"
]
=
QuantizationArgs
.
parse_obj
(
"input_activations"
]
=
QuantizationArgs
.
parse_obj
(
quant_config
.
get
(
"input_activations"
))
quant_config
.
get
(
"input_activations"
))
except
Exception
:
except
Exception
:
layer_quant_details
[
target
][
"input_activations"
]
=
None
target_scheme_map
[
target
][
"input_activations"
]
=
None
return
cls
(
layer_quant_details
=
layer_quant_details
,
return
cls
(
target_scheme_map
=
target_scheme_map
,
ignore
=
ignore
,
ignore
=
ignore
,
quant_format
=
quant_format
)
quant_format
=
quant_format
,
kv_cache_scheme
=
config
.
get
(
"kv_cache_scheme"
))
@
classmethod
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[]
return
[]
def
_check_
gptq_and_marlin_can_run
(
self
):
def
_check_
scheme_supported
(
self
,
min_capability
:
int
):
capability
=
current_platform
.
get_device_capability
()
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
if
capability
<
80
:
if
capability
<
min_capability
:
raise
RuntimeError
(
"The quantization config is not supported for "
,
raise
RuntimeError
(
"the current GPU. Minimum capability: 80. "
,
"Quantization scheme is not supported for "
,
f
"Current capability:
{
capability
}
."
)
f
"the current GPU. Min capability:
{
min_capability
}
. "
,
f
"Current capability:
{
capability
}
."
)
def
_is_static_tensor_w8a8
(
self
,
weight_quant
:
BaseModel
,
def
_is_static_tensor_w8a8
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
input_quant
:
BaseModel
)
->
bool
:
...
@@ -132,10 +149,11 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -132,10 +149,11 @@ class CompressedTensorsConfig(QuantizationConfig):
# Confirm weight scheme is supported.
# Confirm weight scheme is supported.
is_symmetric_weight
=
weight_quant
.
symmetric
is_symmetric_weight
=
weight_quant
.
symmetric
is_static_weight
=
not
weight_quant
.
dynamic
is_static_weight
=
not
weight_quant
.
dynamic
is_per_tensor_weight
=
(
is_per_tensor_or_channel_weight
=
(
weight_quant
.
strategy
in
[
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
)
QuantizationStrategy
.
TENSOR
,
QuantizationStrategy
.
CHANNEL
])
if
not
(
is_symmetric_weight
and
is_static_weight
if
not
(
is_symmetric_weight
and
is_static_weight
and
is_per_tensor_weight
):
and
is_per_tensor_
or_channel_
weight
):
return
False
return
False
# Dynamic quantization is always supported if weights supported.
# Dynamic quantization is always supported if weights supported.
...
@@ -164,11 +182,12 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -164,11 +182,12 @@ class CompressedTensorsConfig(QuantizationConfig):
return
(
is_channel_group
and
input_quant_none
and
is_symmetric
return
(
is_channel_group
and
input_quant_none
and
is_symmetric
and
is_static
)
and
is_static
)
def
_get_schema
(
self
,
weight_quant
:
BaseModel
,
def
_get_scheme_from_parts
(
input_quant
:
BaseModel
)
->
"CompressedTensorsScheme"
:
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
"CompressedTensorsScheme"
:
# Detect If Mixed Precision
if
self
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
if
self
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
self
.
_check_gptq_and_marlin_can_run
()
if
(
self
.
quant_format
==
CompressionFormat
.
marlin_24
.
value
if
(
self
.
quant_format
==
CompressionFormat
.
marlin_24
.
value
and
weight_quant
.
num_bits
in
W4A16SPARSE24_SUPPORTED_BITS
):
and
weight_quant
.
num_bits
in
W4A16SPARSE24_SUPPORTED_BITS
):
return
CompressedTensorsW4A16Sparse24
(
return
CompressedTensorsW4A16Sparse24
(
...
@@ -182,11 +201,12 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -182,11 +201,12 @@ class CompressedTensorsConfig(QuantizationConfig):
strategy
=
weight_quant
.
strategy
,
strategy
=
weight_quant
.
strategy
,
group_size
=
weight_quant
.
group_size
)
group_size
=
weight_quant
.
group_size
)
if
(
self
.
quant_format
==
CompressionFormat
.
int_quantized
.
value
or
# Detect If Activation Quantization.
self
.
quant_format
==
Compress
ion
F
ormat
.
float_quantized
.
value
):
if
is_activation_quantizat
ion
_f
ormat
(
self
.
quant_format
):
if
self
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
if
self
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Fp8
(
return
CompressedTensorsW8A8Fp8
(
input_dynamic
=
input_quant
.
dynamic
)
strategy
=
weight_quant
.
strategy
,
is_static_input_scheme
=
(
not
input_quant
.
dynamic
))
if
self
.
_is_static_tensor_w8a8
(
weight_quant
,
input_quant
):
if
self
.
_is_static_tensor_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Int8
(
return
CompressedTensorsW8A8Int8
(
...
@@ -201,26 +221,53 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -201,26 +221,53 @@ class CompressedTensorsConfig(QuantizationConfig):
raise
NotImplementedError
(
raise
NotImplementedError
(
"No compressed-tensors compatible scheme was found."
)
"No compressed-tensors compatible scheme was found."
)
def
get_scheme
(
self
,
layer
:
torch
.
nn
.
Module
)
->
"CompressedTensorsScheme"
:
def
get_scheme
(
self
,
layer
:
torch
.
nn
.
Module
,
layer_name
:
Optional
[
str
]
=
None
)
->
"CompressedTensorsScheme"
:
"""
compressed-tensors supports non uniform in the following way:
ignore: List of layer_names or nn.Module names to be ignored.
targets of config_groups: There can be N config_groups which each
have a quantization scheme. Each config_group has a list of targets
which can be a full layer_name, a regex for a layer_name, or
an nn.Module name.
layer_type_name
=
find_first_name_or_class_match
(
We first check whether a layer is in the ignore group and use
name
=
""
,
CompressedTensorsUnquantized (i.e. fp16/bf16) scheme for the layer
We then detect whether a layer_name is found in any target and
use the quantization scheme corresponding to the matched target
to select the CompressedTensorsScheme used for infernece.
"""
# Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names
if
should_ignore_layer
(
layer_name
,
ignore
=
self
.
ignore
):
return
CompressedTensorsUnquantized
()
# Find the "target" in the compressed-tensors config
# that our layer conforms to.
# TODO (@robertgshaw): add compressed-tensors as dep
# so we do not have to re-write these functions
matched_target
=
find_matched_target
(
layer_name
=
layer_name
,
module
=
layer
,
module
=
layer
,
targets
=
self
.
layer_quant_details
.
keys
(),
targets
=
self
.
target_scheme_map
.
keys
())
check_contains
=
True
)
# Find the quant_scheme
scheme
=
self
.
target_scheme_map
[
matched_target
]
if
layer_type_name
is
None
:
return
self
.
_get_scheme_from_parts
(
raise
ValueError
(
f
"Could not matching target for layer
{
layer
}
"
)
weight_quant
=
scheme
[
"weights"
],
input_quant
=
scheme
[
"input_activations"
])
layer_quant_details
:
Dict
[
str
,
Any
]
=
self
.
layer_quant_details
.
get
(
# Raise error if device does not support the scheme
layer_type_name
,
None
)
# (e.g. fp8 needs ada lovelace)
if
layer_quant_details
is
None
:
self
.
_check_scheme_supported
(
scheme
.
get_min_capability
())
raise
ValueError
(
f
"Could not find quantization details for
{
layer
}
."
)
return
self
.
_get_schema
(
return
scheme
weight_quant
=
layer_quant_details
[
"weights"
],
input_quant
=
layer_quant_details
[
"input_activations"
])
class
CompressedTensorsLinearMethod
(
LinearMethodBase
):
class
CompressedTensorsLinearMethod
(
LinearMethodBase
):
...
@@ -240,11 +287,11 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
...
@@ -240,11 +287,11 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
Use the CompressedTensorsScheme associated with each layer to create
Use the CompressedTensorsScheme associated with each layer to create
the necessary parameters for the layer. See LinearMethodBase for param
the necessary parameters for the layer. See LinearMethodBase for param
details
details
"""
"""
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
layer_name
=
extra_weight_attrs
.
get
(
"prefix"
)
scheme
=
self
.
quantization_config
.
get_scheme
(
layer
=
layer
)
scheme
=
self
.
quantization_config
.
get_scheme
(
layer
,
layer
_name
)
scheme
.
create_weights
(
scheme
.
create_weights
(
layer
=
layer
,
layer
=
layer
,
input_size
=
input_size
,
input_size
=
input_size
,
...
@@ -271,3 +318,47 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
...
@@ -271,3 +318,47 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
if
scheme
is
None
:
if
scheme
is
None
:
raise
ValueError
(
"A scheme must be defined for each layer"
)
raise
ValueError
(
"A scheme must be defined for each layer"
)
return
scheme
.
apply_weights
(
layer
,
x
,
bias
=
bias
)
return
scheme
.
apply_weights
(
layer
,
x
,
bias
=
bias
)
class
CompressedTensorsKVCacheMethod
(
BaseKVCacheMethod
):
"""
Supports loading kv-cache scaling factors from compressed-tensors
checkpoints.
"""
def
__init__
(
self
,
quant_config
:
CompressedTensorsConfig
):
self
.
validate_kv_cache_scheme
(
quant_config
.
kv_cache_scheme
)
super
().
__init__
(
quant_config
)
@
staticmethod
def
validate_kv_cache_scheme
(
kv_cache_scheme
:
Optional
[
Dict
[
str
,
Any
]]):
"""
Validator for the kv cache scheme. Useful for controlling the
kv cache quantization schemes, that are being supported in vLLM
:param kv_cache_scheme: the compressed-tensors kv cache scheme
"""
if
kv_cache_scheme
is
None
:
return
type_
=
kv_cache_scheme
.
get
(
"type"
)
num_bits
=
kv_cache_scheme
.
get
(
"num_bits"
)
if
type_
!=
"float"
and
num_bits
!=
8
:
raise
NotImplementedError
(
"Currently supported kv cache quantization is "
"num_bits=8, type=float, however "
f
"received num_bits=
{
num_bits
}
, type=
{
type_
}
"
)
strategy
=
kv_cache_scheme
.
get
(
"strategy"
)
if
strategy
!=
"tensor"
:
raise
NotImplementedError
(
"Only support per-tensor scaling factor "
"for compressed-tensors KV cache. "
f
"Expected strategy: tensor, found strategy:
{
strategy
}
"
)
is_symmetric
=
kv_cache_scheme
.
get
(
"symmetric"
)
if
not
is_symmetric
:
raise
NotImplementedError
(
"Only support symmetric scaling factor "
"for compressed-tensors KV cache. "
f
"However found symmetric:
{
is_symmetric
}
"
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py
View file @
500b93c8
...
@@ -12,6 +12,13 @@ class CompressedTensorsScheme(ABC):
...
@@ -12,6 +12,13 @@ class CompressedTensorsScheme(ABC):
of different quantization schemes supported by CompressedTensors.
of different quantization schemes supported by CompressedTensors.
"""
"""
@
abstractmethod
def
get_min_capability
(
self
)
->
int
:
"""
Get minimum device capability.
"""
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
create_weights
(
self
,
*
args
,
**
kwargs
):
def
create_weights
(
self
,
*
args
,
**
kwargs
):
"""
"""
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py
View file @
500b93c8
...
@@ -18,6 +18,10 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
...
@@ -18,6 +18,10 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
in a linear transformation.
in a linear transformation.
"""
"""
def
get_min_capability
(
self
)
->
int
:
# volta and up
return
70
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
pass
pass
...
@@ -29,7 +33,6 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
...
@@ -29,7 +33,6 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
weight
=
Parameter
(
torch
.
empty
(
sum
(
output_partition_sizes
),
weight
=
Parameter
(
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
input_size_per_partition
,
device
=
"cuda"
,
dtype
=
params_dtype
),
dtype
=
params_dtype
),
requires_grad
=
False
)
requires_grad
=
False
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
View file @
500b93c8
...
@@ -29,6 +29,10 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
...
@@ -29,6 +29,10 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
raise
ValueError
(
raise
ValueError
(
"group_size must be given when using strategy group"
)
"group_size must be given when using strategy group"
)
def
get_min_capability
(
self
)
->
int
:
# ampere + up
return
80
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
pass
pass
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
500b93c8
from
typing
import
Callable
,
List
,
Optional
from
typing
import
Callable
,
List
,
Optional
import
torch
import
torch
from
torch.nn
import
Parameter
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
QuantizationStrategy
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_fp8_linear
,
create_per_tensor_scale_param
,
cutlass_fp8_supported
,
apply_fp8_linear
,
create_per_channel_scale_param
,
create_per_tensor_scale_param
,
cutlass_fp8_supported
,
requantize_with_max_scale
)
requantize_with_max_scale
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
...
@@ -14,39 +18,49 @@ __all__ = ["CompressedTensorsW8A8Fp8"]
...
@@ -14,39 +18,49 @@ __all__ = ["CompressedTensorsW8A8Fp8"]
class
CompressedTensorsW8A8Fp8
(
CompressedTensorsScheme
):
class
CompressedTensorsW8A8Fp8
(
CompressedTensorsScheme
):
def
__init__
(
self
,
input_dynamic
:
bool
):
def
__init__
(
self
,
strategy
:
str
,
is_static_input_scheme
:
bool
):
self
.
input_dynamic
=
input_dynamic
self
.
strategy
=
strategy
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
# W8A8-Fp8 kernels support only per-tensor and per-channel cases.
def
get_min_capability
(
self
)
->
int
:
# So if we have a fused module (QKV, MLP) with per tensor scales (thus N
# lovelace and up
# scales being passed to the kernel), we requantize with a single scale.
return
89
def
process_weights_after_loading
(
self
,
layer
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
)
->
None
:
# Dequant -> Quant with max scale.
# If per tensor, when we have a fused module (e.g. QKV) with per
max_w_scale
,
weight
=
requantize_with_max_scale
(
# tensor scales (thus N scales being passed to the kernel),
weight
=
layer
.
weight
,
# requantize so we can always run per tensor
weight_scale
=
layer
.
weight_scale
,
if
self
.
strategy
==
QuantizationStrategy
.
TENSOR
:
logical_widths
=
layer
.
logical_widths
,
max_w_scale
,
weight
=
requantize_with_max_scale
(
)
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
# Update layer with new values.
logical_widths
=
layer
.
logical_widths
,
layer
.
weight
=
torch
.
nn
.
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
max_w_scale
,
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
if
self
.
input_dynamic
:
layer
.
weight_scale
=
Parameter
(
max_w_scale
,
requires_grad
=
False
)
layer
.
input_scale
=
None
# If channelwise, scales are already lined up, so just transpose.
elif
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
weight
=
layer
.
weight
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
else
:
raise
ValueError
(
f
"Unknown quantization strategy
{
self
.
strategy
}
"
)
# INPUT SCALE
if
self
.
is_static_input_scheme
:
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
else
:
else
:
layer
.
input_scale
=
torch
.
nn
.
Parameter
(
layer
.
input_scale
.
max
(),
layer
.
input_scale
=
None
requires_grad
=
False
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_partition_sizes
:
List
[
int
],
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
**
kwargs
):
del
params_dtype
output_size_per_partition
=
sum
(
output_partition_sizes
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
logical_widths
=
output_partition_sizes
...
@@ -63,12 +77,17 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
...
@@ -63,12 +77,17 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
})
})
# WEIGHT SCALE
# WEIGHT SCALE
weight_scale
=
create_per_tensor_scale_param
(
if
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
output_partition_sizes
,
weight_loader
=
weight_loader
)
weight_scale
=
create_per_channel_scale_param
(
output_partition_sizes
,
weight_loader
=
weight_loader
)
else
:
assert
self
.
strategy
==
QuantizationStrategy
.
TENSOR
weight_scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
# INPUT SCALE
# INPUT SCALE
if
not
self
.
i
nput_dynamic
:
if
self
.
i
s_static_input_scheme
:
input_scale
=
create_per_tensor_scale_param
(
input_scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
weight_loader
=
weight_loader
)
output_partition_sizes
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
...
@@ -84,4 +103,5 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
...
@@ -84,4 +103,5 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
weight_scale
=
layer
.
weight_scale
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
,
bias
=
bias
,
cutlass_fp8_supported
=
self
.
cutlass_fp8_supported
)
cutlass_fp8_supported
=
self
.
cutlass_fp8_supported
,
use_per_token_if_dynamic
=
True
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
View file @
500b93c8
...
@@ -19,6 +19,10 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
...
@@ -19,6 +19,10 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
self
.
strategy
=
strategy
self
.
strategy
=
strategy
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
is_static_input_scheme
=
is_static_input_scheme
def
get_min_capability
(
self
)
->
int
:
# turing and up
return
75
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# WEIGHT
# WEIGHT
# Cutlass kernels need transposed weight.
# Cutlass kernels need transposed weight.
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
View file @
500b93c8
...
@@ -7,8 +7,8 @@ from vllm import _custom_ops as ops
...
@@ -7,8 +7,8 @@ from vllm import _custom_ops as ops
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
apply_marlin_linear
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
apply_
gptq_
marlin_linear
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
replace_tensor
,
verify_marlin_supported
,
marlin_permute_scales
,
replace_tensor
,
verify_
gptq_
marlin_supported
,
verify_marlin_supports_shape
)
verify_marlin_supports_shape
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
...
@@ -38,9 +38,13 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
...
@@ -38,9 +38,13 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
self
.
group_size
=
group_size
self
.
group_size
=
group_size
# Verify supported on platform.
# Verify supported on platform.
verify_marlin_supported
(
num_bits
=
self
.
num_bits
,
verify_gptq_marlin_supported
(
num_bits
=
self
.
num_bits
,
group_size
=
self
.
group_size
,
group_size
=
self
.
group_size
,
is_sym
=
True
)
is_sym
=
True
)
def
get_min_capability
(
self
)
->
int
:
# ampere and up
return
80
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size
:
int
,
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size
:
int
,
output_partition_sizes
:
List
[
int
],
output_partition_sizes
:
List
[
int
],
...
@@ -131,6 +135,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
...
@@ -131,6 +135,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
layer
.
g_idx
=
marlin_make_empty_g_idx
(
device
)
layer
.
g_idx
=
marlin_make_empty_g_idx
(
device
)
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
device
)
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
device
)
# No zero-point
layer
.
weight_zp
=
marlin_make_empty_g_idx
(
device
)
# Repack weights from compressed-tensors format to marlin format.
# Repack weights from compressed-tensors format to marlin format.
marlin_qweight
=
ops
.
gptq_marlin_repack
(
marlin_qweight
=
ops
.
gptq_marlin_repack
(
layer
.
weight_packed
.
t
().
contiguous
(),
layer
.
weight_packed
.
t
().
contiguous
(),
...
@@ -151,10 +158,11 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
...
@@ -151,10 +158,11 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
apply_marlin_linear
(
return
apply_
gptq_
marlin_linear
(
input
=
x
,
input
=
x
,
weight
=
layer
.
weight_packed
,
weight
=
layer
.
weight_packed
,
weight_scale
=
layer
.
weight_scale
,
weight_scale
=
layer
.
weight_scale
,
weight_zp
=
layer
.
weight_zp
,
g_idx
=
layer
.
g_idx
,
g_idx
=
layer
.
g_idx
,
g_idx_sort_indices
=
layer
.
g_idx_sort_indices
,
g_idx_sort_indices
=
layer
.
g_idx_sort_indices
,
workspace
=
layer
.
workspace
,
workspace
=
layer
.
workspace
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
View file @
500b93c8
...
@@ -9,6 +9,7 @@ from torch.nn import Module
...
@@ -9,6 +9,7 @@ from torch.nn import Module
class
CompressionFormat
(
Enum
):
class
CompressionFormat
(
Enum
):
dense
=
"dense"
dense
=
"dense"
sparse_bitmask
=
"sparse-bitmask"
sparse_bitmask
=
"sparse-bitmask"
naive_quantized
=
"naive-quantized"
float_quantized
=
"float-quantized"
float_quantized
=
"float-quantized"
int_quantized
=
"int-quantized"
int_quantized
=
"int-quantized"
pack_quantized
=
"pack-quantized"
pack_quantized
=
"pack-quantized"
...
@@ -76,25 +77,115 @@ class QuantizationArgs(BaseModel):
...
@@ -76,25 +77,115 @@ class QuantizationArgs(BaseModel):
)
)
def
find_first_name_or_class_match
(
def
is_activation_quantization_format
(
format
:
str
)
->
bool
:
name
:
str
,
_ACTIVATION_QUANTIZATION_FORMATS
=
[
module
:
Module
,
CompressionFormat
.
naive_quantized
.
value
,
targets
:
Iterable
[
str
],
CompressionFormat
.
int_quantized
.
value
,
check_contains
:
bool
=
False
)
->
Optional
[
str
]:
CompressionFormat
.
float_quantized
.
value
]
return
format
in
_ACTIVATION_QUANTIZATION_FORMATS
# fused_name: List[shard_name]
_FUSED_LAYER_NAME_MAPPING
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
]
}
def
should_ignore_layer
(
layer_name
:
Optional
[
str
],
ignore
:
Iterable
[
str
])
->
bool
:
if
layer_name
is
None
:
return
False
# layer_name = model.layers.0.self_attn.qkv_proj
# proj_name = qkv_proj
proj_name
=
layer_name
.
split
(
"."
)[
-
1
]
# Fused layers like gate_up_proj or qkv_proj will not be fused
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if
proj_name
in
_FUSED_LAYER_NAME_MAPPING
:
shard_proj_names
=
_FUSED_LAYER_NAME_MAPPING
[
proj_name
]
# Convert fused_name --> [shard_names]
shard_names
=
[
layer_name
.
replace
(
proj_name
,
shard_proj_name
)
for
shard_proj_name
in
shard_proj_names
]
# Layer should be ignored if shards are ignored.
should_ignore_layer
=
None
for
shard_name
in
shard_names
:
should_ignore_shard
=
check_equal_or_regex_match
(
layer_name
=
shard_name
,
targets
=
ignore
)
# If shard_idx=0, set layer ignore to match shard.
if
should_ignore_layer
is
None
:
should_ignore_layer
=
should_ignore_shard
# If shard_idx=1+ confirm scheme matches prior shards.
elif
should_ignore_shard
!=
should_ignore_layer
:
raise
ValueError
(
f
"Found a different quantization schemes for "
f
"
{
shard_proj_names
}
in
{
layer_name
}
. vLLM "
"requires all to use the same scheme."
)
# Unfused layers like down_proj and o_proj will match
# the safetensors checkpoint already.
else
:
should_ignore_layer
=
check_equal_or_regex_match
(
layer_name
=
layer_name
,
targets
=
ignore
)
assert
should_ignore_layer
is
not
None
return
should_ignore_layer
def
check_equal_or_regex_match
(
layer_name
:
str
,
targets
:
Iterable
[
str
])
->
bool
:
"""
"""
Helper function to map the quantization details listed in the config
Checks whether a layer_name is exactly equal or a regex match for
for a given list of targets against each model layer. First uses the
if target starts with 're:' to any target in list.
layer name to try and find a match. If no name match is found, uses
"""
the layer class name. Returns None otherwise.
for
target
in
targets
:
if
_is_equal_or_regex_match
(
layer_name
,
target
):
return
True
return
False
def
find_matched_target
(
layer_name
:
Optional
[
str
],
module
:
Module
,
targets
:
Iterable
[
str
])
->
str
:
"""
Helper function to look up which "target" in the compressed-tensors
config that a layer corresponds to.
Recall that a compressed-tensors configs has a concept of
config_groups, where each layer can be quantized with with a different
scheme.
:param name: layer name
targets in each config_group will be a list of either layer names
(or regexes corresponding to layer names) or names of torch Modules.
First, we try to match the layer_name with a target
Second, we try to match the module's name with a target
:param layer_name: layer name
:param module: torch.nn.Module
:param module: torch.nn.Module
:param targets: list of targets to match the layer against
:param targets: list of targets to match the layer against
:param check_contains: whether or not to do a substring match
"""
"""
return
_find_first_match
(
name
,
targets
)
or
_find_first_match
(
if
layer_name
is
None
:
module
.
__class__
.
__name__
,
targets
,
check_contains
)
layer_name
=
""
matched_target
=
(
_find_first_match
(
layer_name
,
targets
)
or
_find_first_match
(
module
.
__class__
.
__name__
,
targets
,
True
))
if
matched_target
is
None
:
raise
ValueError
(
f
"Unable to find matching target for
{
module
}
in the "
"compressed-tensors config."
)
return
matched_target
def
_find_first_match
(
value
:
str
,
def
_find_first_match
(
value
:
str
,
...
@@ -111,13 +202,46 @@ def _find_first_match(value: str,
...
@@ -111,13 +202,46 @@ def _find_first_match(value: str,
"""
"""
for
target
in
targets
:
for
target
in
targets
:
if
target
.
startswith
(
"re:"
):
if
_is_equal_or_regex_match
(
value
,
pattern
=
target
[
3
:]
target
,
if
re
.
match
(
pattern
,
value
):
check_contains
=
check_contains
):
return
target
elif
check_contains
:
if
target
.
lower
()
in
value
.
lower
():
return
target
elif
target
==
value
:
return
target
return
target
return
None
return
None
def
get_compressed_tensors_cache_scale
(
name
:
str
)
->
Optional
[
str
]:
"""
Check whether the param name matches the format for k/v cache scales
in compressed-tensors. If this is the case, return its equivalent
param name expected by vLLM
:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if
name
.
endswith
(
".output_scale"
)
and
".k_proj"
in
name
:
return
name
.
replace
(
".k_proj.output_scale"
,
".attn.k_scale"
)
if
name
.
endswith
(
".output_scale"
)
and
".v_proj"
in
name
:
return
name
.
replace
(
".v_proj.output_scale"
,
".attn.v_scale"
)
# If no matches, return None
return
None
def
_is_equal_or_regex_match
(
value
:
str
,
target
:
str
,
check_contains
:
bool
=
False
)
->
bool
:
"""
Checks whether a value is exactly equal or a regex match for target
if target starts with 're:'. If check_contains is set to True,
additionally checks if the target string is contained within the value.
"""
if
target
.
startswith
(
"re:"
):
pattern
=
target
[
3
:]
if
re
.
match
(
pattern
,
value
):
return
True
elif
check_contains
:
if
target
.
lower
()
in
value
.
lower
():
return
True
elif
target
==
value
:
return
True
return
False
vllm/model_executor/layers/quantization/deepspeedfp.py
View file @
500b93c8
...
@@ -69,9 +69,8 @@ class DeepSpeedFPConfig(QuantizationConfig):
...
@@ -69,9 +69,8 @@ class DeepSpeedFPConfig(QuantizationConfig):
"quantize_config.json"
,
"quantize_config.json"
,
]
]
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
self
,
prefix
:
str
)
->
Optional
[
"DeepSpeedFPLinearMethod"
]:
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"DeepSpeedFPLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
return
DeepSpeedFPLinearMethod
(
self
)
return
DeepSpeedFPLinearMethod
(
self
)
return
None
return
None
...
...
vllm/model_executor/layers/quantization/fbgemm_fp8.py
0 → 100644
View file @
500b93c8
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_fp8_linear
,
create_per_channel_scale_param
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
# Note: this is a hack. We should update each model to register the
# stacked params and get it from there instead in a future PR.
# fused_name: List[shard_name]
_FUSED_LAYER_NAME_MAPPING
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
]
}
class
FBGEMMFp8Config
(
QuantizationConfig
):
"""Config class for FBGEMM Fp8."""
def
__init__
(
self
,
ignore_list
:
List
[
str
],
input_scale_ub
:
float
):
self
.
ignore_list
=
ignore_list
if
ignore_list
else
[]
self
.
input_scale_ub
=
input_scale_ub
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
self
.
use_marlin
=
capability
<
89
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"fbgemm_fp8"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
bfloat16
,
torch
.
float16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"FBGEMMFp8Config"
:
ignore_list
=
cls
.
get_from_keys
(
config
,
[
"modules_to_not_convert"
])
input_scale_ub
=
cls
.
get_from_keys
(
config
,
[
"activation_scale_ub"
])
return
cls
(
ignore_list
=
ignore_list
,
input_scale_ub
=
input_scale_ub
)
def
_is_layer_skipped
(
self
,
prefix
:
str
)
->
bool
:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
proj_name
=
prefix
.
split
(
"."
)[
-
1
]
if
proj_name
in
_FUSED_LAYER_NAME_MAPPING
:
shard_prefixes
=
[
prefix
.
replace
(
proj_name
,
shard_proj_name
)
for
shard_proj_name
in
_FUSED_LAYER_NAME_MAPPING
[
proj_name
]
]
is_skipped
=
None
for
shard_prefix
in
shard_prefixes
:
is_shard_skipped
=
shard_prefix
in
self
.
ignore_list
if
is_skipped
is
None
:
is_skipped
=
is_shard_skipped
elif
is_shard_skipped
!=
is_skipped
:
raise
ValueError
(
f
"Detected some but not all shards of
{
prefix
}
"
"are quantized. All shards of fused layers "
"to have the same precision."
)
else
:
is_skipped
=
prefix
in
self
.
ignore_list
assert
is_skipped
is
not
None
return
is_skipped
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
if
isinstance
(
layer
,
LinearBase
):
if
self
.
_is_layer_skipped
(
prefix
):
return
UnquantizedLinearMethod
()
return
FBGEMMFp8LinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
FBGEMMFp8LinearMethod
(
LinearMethodBase
):
def
__init__
(
self
,
quant_config
:
FBGEMMFp8Config
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
del
input_size
,
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
orig_dtype
=
params_dtype
# WEIGHT
weight
=
Parameter
(
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
torch
.
float8_e4m3fn
),
requires_grad
=
False
)
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
,
**
extra_weight_attrs
,
})
# WEIGHT SCALE
weight_scale
=
create_per_channel_scale_param
(
output_partition_sizes
,
**
extra_weight_attrs
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
# INPUT SCALE UPPER BOUND
input_scale_ub
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
(
self
.
quant_config
.
input_scale_ub
),
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
input_scale_ub
=
input_scale_ub
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
weight
=
layer
.
weight
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
if
self
.
quant_config
.
use_marlin
:
prepare_fp8_layer_for_marlin
(
layer
)
# Activations not quantized for marlin.
del
layer
.
input_scale_ub
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
if
self
.
quant_config
.
use_marlin
:
return
apply_fp8_marlin_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
workspace
=
layer
.
workspace
,
size_n
=
layer
.
output_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
bias
=
bias
)
return
apply_fp8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
None
,
input_scale_ub
=
layer
.
input_scale_ub
,
bias
=
bias
,
cutlass_fp8_supported
=
True
,
use_per_token_if_dynamic
=
True
)
vllm/model_executor/layers/quantization/fp8.py
View file @
500b93c8
...
@@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
...
@@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
)
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
...
@@ -66,8 +67,8 @@ class Fp8Config(QuantizationConfig):
...
@@ -66,8 +67,8 @@ class Fp8Config(QuantizationConfig):
return
cls
(
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
,
return
cls
(
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
,
activation_scheme
=
activation_scheme
)
activation_scheme
=
activation_scheme
)
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"QuantizeMethodBase"
]:
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
from
vllm.attention.layer
import
Attention
# Avoid circular import
from
vllm.attention.layer
import
Attention
# Avoid circular import
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
...
@@ -214,7 +215,8 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -214,7 +215,8 @@ class Fp8LinearMethod(LinearMethodBase):
weight_scale
=
layer
.
weight_scale
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
,
bias
=
bias
,
cutlass_fp8_supported
=
self
.
cutlass_fp8_supported
)
cutlass_fp8_supported
=
self
.
cutlass_fp8_supported
,
use_per_token_if_dynamic
=
False
)
class
Fp8MoEMethod
(
FusedMoEMethodBase
):
class
Fp8MoEMethod
(
FusedMoEMethodBase
):
...
@@ -399,39 +401,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -399,39 +401,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_group
=
topk_group
)
topk_group
=
topk_group
)
class
Fp8KVCacheMethod
(
QuantizeMethodBase
):
class
Fp8KVCacheMethod
(
BaseKVCacheMethod
):
"""Supports loading kv-cache scaling factors from FP8 checkpoints.
"""
Supports loading kv-cache scaling factors from FP8 checkpoints.
"""
"""
def
__init__
(
self
,
quant_config
:
Fp8Config
):
def
__init__
(
self
,
quant_config
:
Fp8Config
):
self
.
quant_config
=
quant_config
super
().
__init__
(
quant_config
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
):
"""Create "weight" (aka kv_scale) for an attention layer.
Args:
layer: The layer that is using the QuantizeMethodBase factory.
"""
# Initialize the KV cache scale to 1.0 as the default value.
# If the kv_scale appears in the checkpoint, it will be
# overwritten when loading weights.
layer
.
kv_scale
=
Parameter
(
torch
.
tensor
(
1.0
),
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
)
->
torch
.
Tensor
:
raise
RuntimeError
(
"Fp8KVCacheMethod.apply should not be called."
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
if
layer
.
kv_cache_dtype
!=
"auto"
:
kv_scale
=
layer
.
kv_scale
.
to
(
"cpu"
).
tolist
()
if
not
isinstance
(
kv_scale
,
float
):
raise
ValueError
(
"Only support per-tensor scaling factor "
"for fp8 KV cache"
)
layer
.
_kv_scale
=
kv_scale
if
layer
.
_kv_scale
==
1.0
and
"e5m2"
not
in
layer
.
kv_cache_dtype
:
print_warning_once
(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This may "
"cause accuracy issues. Please make sure kv-cache scaling "
"factor is available in the fp8 checkpoint."
)
del
layer
.
kv_scale
vllm/model_executor/layers/quantization/gptq.py
View file @
500b93c8
...
@@ -69,8 +69,8 @@ class GPTQConfig(QuantizationConfig):
...
@@ -69,8 +69,8 @@ class GPTQConfig(QuantizationConfig):
default
=
False
)
default
=
False
)
return
cls
(
weight_bits
,
group_size
,
desc_act
,
lm_head_quantized
)
return
cls
(
weight_bits
,
group_size
,
desc_act
,
lm_head_quantized
)
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"GPTQLinearMethod"
]:
prefix
:
str
)
->
Optional
[
"GPTQLinearMethod"
]:
if
(
isinstance
(
layer
,
LinearBase
)
or
if
(
isinstance
(
layer
,
LinearBase
)
or
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
)):
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
)):
return
GPTQLinearMethod
(
self
)
return
GPTQLinearMethod
(
self
)
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
500b93c8
...
@@ -10,10 +10,10 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
...
@@ -10,10 +10,10 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
apply_marlin_linear
,
check_marlin_supported
,
marlin_is_k_full
,
apply_
gptq_
marlin_linear
,
check_
gptq_
marlin_supported
,
marlin_is_k_full
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
marlin_repeat_scales_on_all_ranks
,
marlin_sort_g_idx
,
replace_tensor
,
marlin_repeat_scales_on_all_ranks
,
marlin_sort_g_idx
,
replace_tensor
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
verify_
gptq_
marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -37,9 +37,9 @@ class GPTQMarlinConfig(QuantizationConfig):
...
@@ -37,9 +37,9 @@ class GPTQMarlinConfig(QuantizationConfig):
self
.
lm_head_quantized
=
lm_head_quantized
self
.
lm_head_quantized
=
lm_head_quantized
# Verify supported on platform.
# Verify supported on platform.
verify_marlin_supported
(
num_bits
=
self
.
weight_bits
,
verify_
gptq_
marlin_supported
(
num_bits
=
self
.
weight_bits
,
group_size
=
self
.
group_size
,
group_size
=
self
.
group_size
,
is_sym
=
self
.
is_sym
)
is_sym
=
self
.
is_sym
)
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"GPTQMarlinConfig(weight_bits=
{
self
.
weight_bits
}
, "
return
(
f
"GPTQMarlinConfig(weight_bits=
{
self
.
weight_bits
}
, "
...
@@ -77,7 +77,7 @@ class GPTQMarlinConfig(QuantizationConfig):
...
@@ -77,7 +77,7 @@ class GPTQMarlinConfig(QuantizationConfig):
@
classmethod
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
user_quant
)
->
Optional
[
str
]:
can_convert
=
cls
.
is_marlin_compatible
(
hf_quant_cfg
)
can_convert
=
cls
.
is_
gptq_
marlin_compatible
(
hf_quant_cfg
)
is_valid_user_quant
=
(
user_quant
is
None
or
user_quant
==
"marlin"
)
is_valid_user_quant
=
(
user_quant
is
None
or
user_quant
==
"marlin"
)
...
@@ -94,9 +94,8 @@ class GPTQMarlinConfig(QuantizationConfig):
...
@@ -94,9 +94,8 @@ class GPTQMarlinConfig(QuantizationConfig):
" faster inference"
)
" faster inference"
)
return
None
return
None
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
self
,
prefix
:
str
)
->
Optional
[
"GPTQMarlinLinearMethod"
]:
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"GPTQMarlinLinearMethod"
]:
if
(
isinstance
(
layer
,
LinearBase
)
or
if
(
isinstance
(
layer
,
LinearBase
)
or
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
)):
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
)):
return
GPTQMarlinLinearMethod
(
self
)
return
GPTQMarlinLinearMethod
(
self
)
...
@@ -106,22 +105,27 @@ class GPTQMarlinConfig(QuantizationConfig):
...
@@ -106,22 +105,27 @@ class GPTQMarlinConfig(QuantizationConfig):
return
[]
return
[]
@
classmethod
@
classmethod
def
is_marlin_compatible
(
cls
,
quant_config
:
Dict
[
str
,
Any
]):
def
is_
gptq_
marlin_compatible
(
cls
,
quant_config
:
Dict
[
str
,
Any
]):
# Extract data from quant config.
# Extract data from quant config.
quant_method
=
quant_config
.
get
(
"quant_method"
,
""
).
lower
()
num_bits
=
quant_config
.
get
(
"bits"
,
None
)
num_bits
=
quant_config
.
get
(
"bits"
,
None
)
group_size
=
quant_config
.
get
(
"group_size"
,
None
)
group_size
=
quant_config
.
get
(
"group_size"
,
None
)
sym
=
quant_config
.
get
(
"sym"
,
None
)
sym
=
quant_config
.
get
(
"sym"
,
None
)
desc_act
=
quant_config
.
get
(
"desc_act"
,
None
)
desc_act
=
quant_config
.
get
(
"desc_act"
,
None
)
if
quant_method
!=
"gptq"
:
return
False
# If we cannot find the info needed in the config, cannot convert.
# If we cannot find the info needed in the config, cannot convert.
if
(
num_bits
is
None
or
group_size
is
None
or
sym
is
None
if
(
num_bits
is
None
or
group_size
is
None
or
sym
is
None
or
desc_act
is
None
):
or
desc_act
is
None
):
return
False
return
False
return
check_marlin_supported
(
num_bits
=
num_bits
,
return
check_gptq_marlin_supported
(
group_size
=
group_size
,
num_bits
=
num_bits
,
is_sym
=
sym
,
group_size
=
group_size
,
min_capability
=
cls
.
get_min_capability
())
is_sym
=
sym
,
min_capability
=
cls
.
get_min_capability
())
class
GPTQMarlinLinearMethod
(
LinearMethodBase
):
class
GPTQMarlinLinearMethod
(
LinearMethodBase
):
...
@@ -279,6 +283,9 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -279,6 +283,9 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
layer
.
g_idx
=
marlin_make_empty_g_idx
(
device
)
layer
.
g_idx
=
marlin_make_empty_g_idx
(
device
)
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
device
)
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
device
)
# No zero-point
layer
.
zp
=
marlin_make_empty_g_idx
(
device
)
# Repack weights from autogptq format to marlin format.
# Repack weights from autogptq format to marlin format.
marlin_qweight
=
ops
.
gptq_marlin_repack
(
marlin_qweight
=
ops
.
gptq_marlin_repack
(
layer
.
qweight
,
layer
.
qweight
,
...
@@ -303,10 +310,11 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -303,10 +310,11 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
apply_marlin_linear
(
return
apply_
gptq_
marlin_linear
(
input
=
x
,
input
=
x
,
weight
=
layer
.
qweight
,
weight
=
layer
.
qweight
,
weight_scale
=
layer
.
scales
,
weight_scale
=
layer
.
scales
,
weight_zp
=
layer
.
zp
,
g_idx
=
layer
.
g_idx
,
g_idx
=
layer
.
g_idx
,
g_idx_sort_indices
=
layer
.
g_idx_sort_indices
,
g_idx_sort_indices
=
layer
.
g_idx_sort_indices
,
workspace
=
layer
.
workspace
,
workspace
=
layer
.
workspace
,
...
...
vllm/model_executor/layers/quantization/gptq_marlin_24.py
View file @
500b93c8
...
@@ -109,9 +109,8 @@ class GPTQMarlin24Config(QuantizationConfig):
...
@@ -109,9 +109,8 @@ class GPTQMarlin24Config(QuantizationConfig):
return
None
return
None
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
self
,
prefix
:
str
)
->
Optional
[
"GPTQMarlin24LinearMethod"
]:
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"GPTQMarlin24LinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
return
GPTQMarlin24LinearMethod
(
self
)
return
GPTQMarlin24LinearMethod
(
self
)
return
None
return
None
...
...
vllm/model_executor/layers/quantization/kv_cache.py
0 → 100644
View file @
500b93c8
import
torch
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.utils
import
print_warning_once
class
BaseKVCacheMethod
(
QuantizeMethodBase
):
"""
Quant method that adds `_k_scale` and `_v_scale` attributes to the
Attention layer to support loading those scaling factors from checkpoints.
The k/v_scale will be used to:
- quantize k/v_cache entries before saving them to the cache
- dequantize k/v_cache entries before fetching them from the cache
:param quant_config: the appropriate QuantizationConfig
"""
def
__init__
(
self
,
quant_config
:
QuantizationConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
):
"""
Create "weight" (aka k_scale and v_scale) for an attention layer.
"""
# Initialize the KV cache scales to -1.0, which is an invalid value.
# If the k/v_scale appears in the checkpoint, it will be
# overwritten when loading weights.
layer
.
k_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
-
1.0
),
requires_grad
=
False
)
layer
.
v_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
-
1.0
),
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
)
->
torch
.
Tensor
:
raise
RuntimeError
(
f
"
{
self
.
__class__
.
__name__
}
.apply should not be called."
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
if
layer
.
kv_cache_dtype
!=
"auto"
:
if
layer
.
k_scale
>
0.0
and
layer
.
v_scale
>
0.0
:
# We prefer to use separate k_scale and v_scale if present
k_scale
=
layer
.
k_scale
.
to
(
"cpu"
).
tolist
()
v_scale
=
layer
.
v_scale
.
to
(
"cpu"
).
tolist
()
elif
layer
.
k_scale
<
0.0
and
layer
.
v_scale
<
0.0
:
# If no scales were loaded (both scales are invalid negative
# values), use the default value of 1.0
k_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
1.0
),
requires_grad
=
False
)
v_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
1.0
),
requires_grad
=
False
)
else
:
# If we find a single kv_scale in the checkpoint, we remap
# kv_scale to k_scale during weight loading, and duplicate
# k_scale to v_scale here
assert
layer
.
k_scale
>
0.0
scale_to_duplicate
=
max
(
layer
.
k_scale
,
layer
.
v_scale
)
k_scale
=
scale_to_duplicate
.
to
(
"cpu"
).
tolist
()
v_scale
=
scale_to_duplicate
.
to
(
"cpu"
).
tolist
()
if
not
isinstance
(
k_scale
,
float
)
or
not
isinstance
(
v_scale
,
float
):
raise
ValueError
(
"Only support per-tensor scaling factor "
"for fp8 KV cache"
)
# These are used in the final Attention.forward()
layer
.
_k_scale
=
k_scale
layer
.
_v_scale
=
v_scale
if
(
layer
.
_k_scale
==
1.0
and
layer
.
_v_scale
==
1.0
and
"e5m2"
not
in
layer
.
kv_cache_dtype
):
print_warning_once
(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
"may cause accuracy issues. Please make sure k/v_scale "
"scaling factors are available in the fp8 checkpoint."
)
del
layer
.
k_scale
del
layer
.
v_scale
vllm/model_executor/layers/quantization/marlin.py
View file @
500b93c8
...
@@ -100,8 +100,8 @@ class MarlinConfig(QuantizationConfig):
...
@@ -100,8 +100,8 @@ class MarlinConfig(QuantizationConfig):
return
None
return
None
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"MarlinLinearMethod"
]:
prefix
:
str
)
->
Optional
[
"MarlinLinearMethod"
]:
if
(
isinstance
(
layer
,
LinearBase
)
or
if
(
isinstance
(
layer
,
LinearBase
)
or
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
)):
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
)):
return
MarlinLinearMethod
(
self
)
return
MarlinLinearMethod
(
self
)
...
...
vllm/model_executor/layers/quantization/squeezellm.py
View file @
500b93c8
...
@@ -52,8 +52,8 @@ class SqueezeLLMConfig(QuantizationConfig):
...
@@ -52,8 +52,8 @@ class SqueezeLLMConfig(QuantizationConfig):
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"wbits"
])
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"wbits"
])
return
cls
(
weight_bits
)
return
cls
(
weight_bits
)
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
QuantizeMethodBase
]:
prefix
:
str
)
->
Optional
[
QuantizeMethodBase
]:
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
return
SqueezeLLMLinearMethod
(
self
)
return
SqueezeLLMLinearMethod
(
self
)
return
None
return
None
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
500b93c8
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
numpy
import
torch
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
.quant_utils
import
pack_cols
,
unpack_cols
GPTQ_MARLIN_TILE
=
16
GPTQ_MARLIN_TILE
=
16
GPTQ_MARLIN_MIN_THREAD_N
=
64
GPTQ_MARLIN_MIN_THREAD_N
=
64
GPTQ_MARLIN_MIN_THREAD_K
=
128
GPTQ_MARLIN_MIN_THREAD_K
=
128
GPTQ_MARLIN_MAX_PARALLEL
=
16
GPTQ_MARLIN_MAX_PARALLEL
=
16
GPTQ_MARLIN_SUPPORTED_NUM_BITS
=
[
4
,
8
]
MARLIN_SUPPORTED_NUM_BITS
=
[
4
,
8
]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
MARLIN_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
GPTQ_MARLIN_SUPPORTED_SYM
=
[
True
]
GTPQ_MARLIN_UNSUPPORTED_GROUP_SIZE_ACT_ORDER
=
[
-
1
]
def
_check_marlin_supported
(
num_bits
:
int
,
group_size
:
int
,
is_sym
:
bool
,
min_capability
:
Optional
[
int
],
def
check_marlin_supported
(
num_bits
:
int
,
group_size
:
int
,
is_sym
:
bool
,
has_zp
:
bool
)
->
Tuple
[
bool
,
Optional
[
str
]]:
min_capability
:
int
)
->
bool
:
if
min_capability
is
not
None
:
major
,
minor
=
current_platform
.
get_device_capability
()
# If the capability of the device is too low, cannot convert.
device_capability
=
major
*
10
+
minor
major
,
minor
=
current_platform
.
get_device_capability
()
if
device_capability
<
min_capability
:
device_capability
=
major
*
10
+
minor
return
(
False
,
"Marlin does not support device_capability = {}"
if
device_capability
<
min_capability
:
", the min_capability required is {}"
.
format
(
return
False
device_capability
,
min_capability
))
return
(
device_capability
>=
min_capability
if
num_bits
not
in
MARLIN_SUPPORTED_NUM_BITS
:
and
num_bits
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
return
(
False
,
"Marlin does not support weight_bits = {}. "
and
group_size
in
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
"Only weight_bits = {} are supported."
.
format
(
and
is_sym
in
GPTQ_MARLIN_SUPPORTED_SYM
)
num_bits
,
MARLIN_SUPPORTED_NUM_BITS
))
if
group_size
not
in
MARLIN_SUPPORTED_GROUP_SIZES
:
def
verify_marlin_supported
(
num_bits
:
int
,
group_size
:
Optional
[
int
],
return
(
False
,
"Marlin does not support group_size = {}. Only "
is_sym
:
bool
)
->
None
:
"group_sizes = {} are supported."
.
format
(
group_size
,
MARLIN_SUPPORTED_GROUP_SIZES
))
if
num_bits
not
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
:
raise
ValueError
(
if
not
has_zp
and
not
is_sym
:
f
"Marlin does not support weight_bits =
{
num_bits
}
. "
return
(
False
,
f
"Only weight_bits =
{
GPTQ_MARLIN_SUPPORTED_NUM_BITS
}
"
"Marlin without zero_points must have symmetric quantization"
)
"are supported."
)
if
(
group_size
is
None
return
True
,
None
or
group_size
not
in
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
):
raise
ValueError
(
f
"Marlin does not support group_size =
{
group_size
}
. "
def
check_gptq_marlin_supported
(
num_bits
:
int
,
group_size
:
int
,
is_sym
:
bool
,
f
"Only group_sizes =
{
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
}
"
min_capability
:
int
)
->
bool
:
"are supported."
)
cond
,
_
=
_check_marlin_supported
(
num_bits
,
if
is_sym
not
in
GPTQ_MARLIN_SUPPORTED_SYM
:
group_size
,
raise
ValueError
(
is_sym
,
f
"Marlin does not support is_sym = is_sym. "
min_capability
,
f
"Only sym =
{
GPTQ_MARLIN_SUPPORTED_SYM
}
are supported."
)
has_zp
=
False
)
return
cond
def
check_awq_marlin_supported
(
num_bits
:
int
,
group_size
:
int
,
has_zp
:
bool
,
min_capability
:
int
)
->
bool
:
cond
,
_
=
_check_marlin_supported
(
num_bits
,
group_size
,
False
,
min_capability
,
has_zp
=
has_zp
)
return
cond
def
verify_gptq_marlin_supported
(
num_bits
:
int
,
group_size
:
int
,
is_sym
:
bool
)
->
None
:
cond
,
err_msg
=
_check_marlin_supported
(
num_bits
,
group_size
,
is_sym
,
min_capability
=
None
,
has_zp
=
False
)
if
not
cond
:
assert
err_msg
is
not
None
raise
ValueError
(
"GPTQ"
+
err_msg
)
def
verify_awq_marlin_supported
(
num_bits
:
int
,
group_size
:
int
,
has_zp
:
bool
)
->
None
:
cond
,
err_msg
=
_check_marlin_supported
(
num_bits
,
group_size
,
False
,
min_capability
=
None
,
has_zp
=
has_zp
)
if
not
cond
:
assert
err_msg
is
not
None
raise
ValueError
(
"AWQ"
+
err_msg
)
def
verify_marlin_supports_shape
(
output_size_per_partition
:
int
,
def
verify_marlin_supports_shape
(
output_size_per_partition
:
int
,
...
@@ -138,6 +176,51 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
...
@@ -138,6 +176,51 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
return
s
return
s
def
marlin_zero_points
(
zp
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
num_bits
:
int
)
->
torch
.
Tensor
:
# Permute zero-points in a similar way to scales, but do not use the
# "single" permutation, since zero-points are applied on every MMA
scale_perm
,
_
=
get_scale_perms
()
zp
=
zp
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
# Interleave column dim (for the dequantize code) and pack it to int32
if
num_bits
==
4
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
interleave
=
numpy
.
array
([
0
,
2
,
1
,
3
])
else
:
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
zp
=
zp
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
zp
=
zp
.
reshape
((
-
1
,
size_n
)).
contiguous
()
zp
=
pack_cols
(
zp
,
num_bits
,
size_k
,
size_n
)
return
zp
def
awq_to_marlin_zero_points
(
q_zp_packed
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
num_bits
:
int
)
->
torch
.
Tensor
:
# AWQ zero-points are quantized and packed on the column dim.
# In addition, the values are permuted based on dequantizer.
# Here we undo both of these, and then apply marlin permutation
# and pack it back.
q_zp
=
unpack_cols
(
q_zp_packed
,
num_bits
,
size_k
,
size_n
)
# Undo interleaving (use argsort(..) to get inverse perm)
if
num_bits
==
4
:
undo_interleave
=
numpy
.
argsort
(
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
]))
elif
num_bits
==
8
:
undo_interleave
=
numpy
.
argsort
(
numpy
.
array
([
0
,
2
,
1
,
3
]))
else
:
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
q_zp
=
q_zp
.
reshape
((
-
1
,
len
(
undo_interleave
)))[:,
undo_interleave
].
ravel
()
q_zp
=
q_zp
.
reshape
((
-
1
,
size_n
)).
contiguous
()
marlin_zp
=
marlin_zero_points
(
q_zp
,
size_k
,
size_n
,
num_bits
)
return
marlin_zp
# Newly generated tensors need to replace existing tensors that are
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
# already registered as parameters by vLLM (and won't be freed)
def
replace_tensor
(
layer
:
torch
.
nn
.
Module
,
name
:
str
,
def
replace_tensor
(
layer
:
torch
.
nn
.
Module
,
name
:
str
,
...
@@ -149,23 +232,61 @@ def replace_tensor(layer: torch.nn.Module, name: str,
...
@@ -149,23 +232,61 @@ def replace_tensor(layer: torch.nn.Module, name: str,
del
new_t
del
new_t
def
apply_marlin_linear
(
input
:
torch
.
Tensor
,
def
apply_gptq_marlin_linear
(
weight
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
g_idx_sort_indices
:
torch
.
Tensor
,
weight_zp
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
num_bits
:
int
,
g_idx_sort_indices
:
torch
.
Tensor
,
output_size_per_partition
:
int
,
workspace
:
torch
.
Tensor
,
input_size_per_partition
:
int
,
num_bits
:
int
,
is_k_full
:
bool
,
output_size_per_partition
:
int
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
input_size_per_partition
:
int
,
is_k_full
:
bool
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
reshaped_x
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
out_shape
=
input
.
shape
[:
-
1
]
+
(
output_size_per_partition
,
)
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
weight
,
weight_scale
,
weight_zp
,
g_idx
,
g_idx_sort_indices
,
workspace
,
num_bits
,
size_m
=
reshaped_x
.
shape
[
0
],
size_n
=
output_size_per_partition
,
size_k
=
input_size_per_partition
,
is_k_full
=
is_k_full
,
has_zp
=
False
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
.
reshape
(
out_shape
)
def
apply_awq_marlin_linear
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
weight_zp
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
g_idx_sort_indices
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_bits
:
int
,
output_size_per_partition
:
int
,
input_size_per_partition
:
int
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
reshaped_x
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
reshaped_x
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
out_shape
=
input
.
shape
[:
-
1
]
+
(
output_size_per_partition
,
)
out_shape
=
input
.
shape
[:
-
1
]
+
(
output_size_per_partition
,
)
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
weight
,
weight
,
weight_scale
,
weight_scale
,
weight_zp
,
g_idx
,
g_idx
,
g_idx_sort_indices
,
g_idx_sort_indices
,
workspace
,
workspace
,
...
@@ -173,7 +294,8 @@ def apply_marlin_linear(input: torch.Tensor,
...
@@ -173,7 +294,8 @@ def apply_marlin_linear(input: torch.Tensor,
size_m
=
reshaped_x
.
shape
[
0
],
size_m
=
reshaped_x
.
shape
[
0
],
size_n
=
output_size_per_partition
,
size_n
=
output_size_per_partition
,
size_k
=
input_size_per_partition
,
size_k
=
input_size_per_partition
,
is_k_full
=
is_k_full
)
is_k_full
=
True
,
has_zp
=
True
)
if
bias
is
not
None
:
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
output
.
add_
(
bias
)
# In-place add
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
View file @
500b93c8
...
@@ -76,8 +76,14 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
...
@@ -76,8 +76,14 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
# WEIGHT SCALES
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
# expand it to channelwise
scales
=
layer
.
weight_scale
.
repeat
(
1
,
part_size_n
).
to
(
is_channelwise
=
(
len
(
layer
.
weight_scale
.
shape
)
>
0
layer
.
orig_dtype
).
to
(
device
)
and
layer
.
weight_scale
.
shape
[
0
]
==
part_size_n
)
if
is_channelwise
:
scales
=
layer
.
weight_scale
else
:
scales
=
layer
.
weight_scale
.
repeat
(
1
,
part_size_n
)
scales
=
scales
.
to
(
layer
.
orig_dtype
).
to
(
device
)
# Permute scales
# Permute scales
marlin_scales
=
marlin_permute_scales
(
s
=
scales
,
marlin_scales
=
marlin_permute_scales
(
s
=
scales
,
size_k
=
part_size_k
,
size_k
=
part_size_k
,
...
...
Prev
1
…
7
8
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment