Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b9e12416
Commit
b9e12416
authored
May 31, 2024
by
zhuwenwen
Browse files
merge v0.4.3
parents
e5d707db
e9d3aa04
Changes
345
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1679 additions
and
21 deletions
+1679
-21
vllm/model_executor/layers/pooler.py
vllm/model_executor/layers/pooler.py
+56
-0
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+13
-2
vllm/model_executor/layers/quantization/aqlm.py
vllm/model_executor/layers/quantization/aqlm.py
+1
-1
vllm/model_executor/layers/quantization/base_config.py
vllm/model_executor/layers/quantization/base_config.py
+11
-0
vllm/model_executor/layers/quantization/compressed_tensors/__init__.py
...ecutor/layers/quantization/compressed_tensors/__init__.py
+0
-0
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+151
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
...ayers/quantization/compressed_tensors/schemes/__init__.py
+5
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py
...n/compressed_tensors/schemes/compressed_tensors_scheme.py
+33
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py
...pressed_tensors/schemes/compressed_tensors_unquantized.py
+39
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py
...d_tensors/schemes/compressed_tensors_w8a8_statictensor.py
+130
-0
vllm/model_executor/layers/quantization/deepspeedfp.py
vllm/model_executor/layers/quantization/deepspeedfp.py
+194
-0
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+54
-6
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+31
-12
vllm/model_executor/layers/quantization/gptq_marlin_24.py
vllm/model_executor/layers/quantization/gptq_marlin_24.py
+291
-0
vllm/model_executor/layers/quantization/marlin.py
vllm/model_executor/layers/quantization/marlin.py
+22
-0
vllm/model_executor/layers/quantization/utils/__init__.py
vllm/model_executor/layers/quantization/utils/__init__.py
+0
-0
vllm/model_executor/layers/quantization/utils/format_24.py
vllm/model_executor/layers/quantization/utils/format_24.py
+308
-0
vllm/model_executor/layers/quantization/utils/marlin_24_perms.py
...del_executor/layers/quantization/utils/marlin_24_perms.py
+58
-0
vllm/model_executor/layers/quantization/utils/marlin_perms.py
.../model_executor/layers/quantization/utils/marlin_perms.py
+58
-0
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+224
-0
No files found.
Too many changes to show.
To preserve performance only
345 of 345+
files are displayed.
Plain diff
Email patch
vllm/model_executor/layers/pooler.py
0 → 100644
View file @
b9e12416
from
enum
import
IntEnum
import
torch
import
torch.nn
as
nn
from
vllm.model_executor.pooling_metadata
import
(
PoolingMetadata
,
PoolingTensors
)
from
vllm.sequence
import
EmbeddingSequenceGroupOutput
,
PoolerOutput
class
PoolingType
(
IntEnum
):
"""Enumeration for different types of pooling methods."""
LAST
=
0
class
Pooler
(
nn
.
Module
):
"""A layer that pools specific information from hidden states.
This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Normalizes output if specified.
3. Returns structured results as `PoolerOutput`.
Attributes:
pooling_type: The type of pooling to use (LAST, AVERAGE, MAX).
normalize: Whether to normalize the pooled data.
"""
def
__init__
(
self
,
pooling_type
:
PoolingType
,
normalize
:
bool
):
super
().
__init__
()
self
.
pooling_type
=
pooling_type
self
.
normalize
=
normalize
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
"""Pools specific information from hidden states based on metadata."""
prompt_lens
=
PoolingTensors
.
from_pooling_metadata
(
pooling_metadata
,
hidden_states
.
device
).
prompt_lens
if
self
.
pooling_type
==
PoolingType
.
LAST
:
last_token_flat_indices
=
torch
.
cumsum
(
prompt_lens
,
dim
=
0
)
-
1
pooled_data
=
hidden_states
[
last_token_flat_indices
]
else
:
raise
ValueError
(
f
"Invalid pooling type:
{
self
.
pooling_type
}
"
)
if
self
.
normalize
:
pooled_data
=
nn
.
functional
.
normalize
(
pooled_data
,
p
=
2
,
dim
=
1
)
pooled_outputs
=
[
EmbeddingSequenceGroupOutput
(
data
.
tolist
())
for
data
in
pooled_data
]
return
PoolerOutput
(
outputs
=
pooled_outputs
)
vllm/model_executor/layers/quantization/__init__.py
View file @
b9e12416
...
...
@@ -4,21 +4,32 @@ from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors
import
(
# noqa: E501
CompressedTensorsConfig
)
from
vllm.model_executor.layers.quantization.deepspeedfp
import
(
DeepSpeedFPConfig
)
from
vllm.model_executor.layers.quantization.fp8
import
Fp8Config
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinConfig
)
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
GPTQMarlin24Config
)
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
vllm.model_executor.layers.quantization.squeezellm
import
SqueezeLLMConfig
QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
"aqlm"
:
AQLMConfig
,
"awq"
:
AWQConfig
,
"deepspeedfp"
:
DeepSpeedFPConfig
,
"fp8"
:
Fp8Config
,
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin"
:
MarlinConfig
,
"gptq_marlin_24"
:
GPTQMarlin24Config
,
"gptq_marlin"
:
GPTQMarlinConfig
,
"gptq"
:
GPTQConfig
,
"squeezellm"
:
SqueezeLLMConfig
,
"gptq_marlin"
:
GPTQMarlinConfig
,
"marlin"
:
MarlinConfig
,
"sparseml"
:
CompressedTensorsConfig
,
}
...
...
vllm/model_executor/layers/quantization/aqlm.py
View file @
b9e12416
...
...
@@ -192,7 +192,7 @@ class AQLMConfig(QuantizationConfig):
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
7
0
return
6
0
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
...
...
vllm/model_executor/layers/quantization/base_config.py
View file @
b9e12416
...
...
@@ -66,6 +66,17 @@ class QuantizationConfig(ABC):
"""Create a config class from the model's quantization config."""
raise
NotImplementedError
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
"""
Detects if this quantization method can support a given checkpoint
format by overriding the user specified quantization method --
this method should only be overwritten by subclasses in exceptional
circumstances
"""
return
None
@
staticmethod
def
get_from_keys
(
config
:
Dict
[
str
,
Any
],
keys
:
List
[
str
])
->
Any
:
"""Get a value from the model's quantization config."""
...
...
vllm/model_executor/layers/quantization/compressed_tensors/__init__.py
0 → 100644
View file @
b9e12416
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
0 → 100644
View file @
b9e12416
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
# noqa: E501
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
,
CompressedTensorsW8A8StaticTensor
)
class
CompressedTensorsConfig
(
QuantizationConfig
):
def
__init__
(
self
,
layer_quant_details
:
Dict
[
str
,
Any
],
ignore
:
List
[
str
]):
self
.
ignore
=
ignore
self
.
layer_quant_details
=
layer_quant_details
def
get_linear_method
(
self
)
->
"CompressedTensorsLinearMethod"
:
return
CompressedTensorsLinearMethod
(
self
)
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
float16
]
# Need to figure it out
def
get_min_capability
(
self
)
->
int
:
return
60
def
get_name
(
self
)
->
str
:
return
"compressed_tensors"
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"CompressedTensorsLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
CompressedTensorsLinearMethod
(
self
)
return
None
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"CompressedTensorsConfig"
:
layer_quant_details
:
Dict
[
str
,
Any
]
=
dict
()
ignore
:
List
[
str
]
=
config
.
get
(
"ignore"
,
None
)
for
key
,
quant_config
in
config
[
"config_groups"
].
items
():
targets
=
quant_config
.
get
(
"targets"
)
for
target
in
targets
:
layer_quant_details
[
target
]
=
{}
layer_quant_details
[
target
][
"weight"
]
=
quant_config
.
get
(
"weights"
)
layer_quant_details
[
target
][
"input"
]
=
quant_config
.
get
(
"input_activations"
)
return
cls
(
layer_quant_details
=
layer_quant_details
,
ignore
=
ignore
)
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[]
def
_get_schema
(
self
,
weight_quant
:
Dict
,
input_quant
:
Dict
):
# TODO: Refactor as additional cases are supported
weight_bit
=
weight_quant
.
get
(
"num_bits"
)
input_bit
=
input_quant
.
get
(
"num_bits"
)
weight_strategy
=
weight_quant
.
get
(
"strategy"
)
input_strategy
=
input_quant
.
get
(
"strategy"
)
weight_symmetric
=
weight_quant
.
get
(
"symmetric"
)
input_symmetric
=
input_quant
.
get
(
"symmetric"
)
is_8_bits
=
weight_bit
==
input_bit
==
8
is_tensor
=
weight_strategy
==
input_strategy
==
"tensor"
is_symmetric
=
weight_symmetric
and
input_symmetric
if
is_8_bits
and
is_tensor
and
is_symmetric
and
\
torch
.
cuda
.
is_available
():
# CompressedTensorsW8A8StaticTensor only supports CUDA path for
# now.
return
CompressedTensorsW8A8StaticTensor
()
raise
NotImplementedError
(
"Scheme not supported. Only CUDA, 8-bit static symmtetric "
"per tensor quantization is currently supported"
)
def
get_scheme
(
self
,
layer
:
torch
.
nn
.
Module
)
->
"CompressedTensorsScheme"
:
# TODO: update with matching function from `compressed_tensors`
layer_type_name
=
None
layer_name_class
=
type
(
layer
).
__name__
.
lower
()
for
target
in
self
.
layer_quant_details
:
if
target
.
lower
()
in
layer_name_class
:
layer_type_name
=
target
break
if
layer_type_name
is
None
:
raise
ValueError
(
f
"Could not matching target for layer
{
layer
}
"
)
layer_quant_details
:
Dict
[
str
,
Any
]
=
self
.
layer_quant_details
.
get
(
layer_type_name
,
None
)
if
layer_quant_details
is
None
:
raise
ValueError
(
f
"Could not find quantization details for
{
layer
}
."
)
return
self
.
_get_schema
(
weight_quant
=
layer_quant_details
[
"weight"
],
input_quant
=
layer_quant_details
[
"input"
])
class
CompressedTensorsLinearMethod
(
LinearMethodBase
):
def
__init__
(
self
,
quantization_config
:
CompressedTensorsConfig
):
self
.
quantization_config
=
quantization_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
"""
Use the CompressedTensorsScheme associated with each layer to create
the necessary parameters for the layer.
"""
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
scheme
=
self
.
quantization_config
.
get_scheme
(
layer
=
layer
)
scheme
.
create_weights
(
layer
=
layer
,
input_size_per_partition
=
input_size_per_partition
,
output_partition_sizes
=
output_partition_sizes
,
output_size
=
output_size
,
params_dtype
=
params_dtype
,
weight_loader
=
weight_loader
)
layer
.
scheme
=
scheme
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
):
"""
Use the output of create_weights and the CompressedTensorsScheme
associated with the layer to apply the forward pass with the
layer input.
"""
if
bias
is
not
None
:
raise
ValueError
(
"bias is not supported for this linear method"
)
scheme
=
layer
.
scheme
if
scheme
is
None
:
raise
ValueError
(
"A scheme must be defined for each layer"
)
return
scheme
.
apply_weights
(
layer
,
x
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
0 → 100644
View file @
b9e12416
from
.compressed_tensors_scheme
import
CompressedTensorsScheme
# noqa: F401
from
.compressed_tensors_unquantized
import
(
# noqa: F401
CompressedTensorsUnquantized
)
from
.compressed_tensors_w8a8_statictensor
import
(
# noqa: F401, E501
CompressedTensorsW8A8StaticTensor
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py
0 → 100644
View file @
b9e12416
from
abc
import
ABC
,
abstractmethod
import
torch
__all__
=
[
"CompressedTensorsScheme"
]
class
CompressedTensorsScheme
(
ABC
):
"""
Abstract class used to describe the weight creation and forward pass
of different quantization schemes supported by CompressedTensors.
"""
@
abstractmethod
def
create_weights
(
self
,
*
args
,
**
kwargs
):
"""
Weight creation for the particular scheme. Inputs to this function
"""
raise
NotImplementedError
@
abstractmethod
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
):
"""
Run the forward pass for the particular scheme. This is where
scheme-specific dequant/quant steps/kernels should be applied.
:param layer: toch.nn.Module with the registered weights and
other parameters relevant to the particular scheme.
:param x: input to the layer
"""
raise
NotImplementedError
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py
0 → 100644
View file @
b9e12416
from
typing
import
Callable
,
List
import
torch
import
torch.nn.functional
as
F
from
torch.nn
import
Parameter
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
from
vllm.model_executor.utils
import
set_weight_attrs
__all__
=
[
"CompressedTensorsUnquantized"
]
class
CompressedTensorsUnquantized
(
CompressedTensorsScheme
):
"""
Implements the scheme for all layers which are ignored
in the CompressedTensors config. The input and loaded weight are used
in a linear transformation.
"""
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
weight
=
Parameter
(
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
device
=
"cuda"
,
dtype
=
params_dtype
),
requires_grad
=
False
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
})
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
{
"weight_loader"
:
weight_loader
})
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
):
weight
=
layer
.
weight
return
F
.
linear
(
x
,
weight
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py
0 → 100644
View file @
b9e12416
from
typing
import
Callable
,
List
,
Tuple
,
Union
import
torch
from
torch.nn
import
Parameter
from
vllm
import
_custom_ops
as
custom_ops
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
from
vllm.model_executor.utils
import
set_weight_attrs
__all__
=
[
"CompressedTensorsW8A8StaticTensor"
]
class
CompressedTensorsW8A8StaticTensor
(
CompressedTensorsScheme
):
def
_shard_id_as_int
(
self
,
shard_id
:
Union
[
str
,
int
])
->
int
:
if
isinstance
(
shard_id
,
int
):
return
shard_id
assert
isinstance
(
shard_id
,
str
)
qkv_idxs
=
{
"q"
:
0
,
"k"
:
1
,
"v"
:
2
}
assert
shard_id
in
qkv_idxs
return
qkv_idxs
[
shard_id
]
def
scales_shard_splitter
(
self
,
param
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
shard_id
:
Union
[
str
,
int
],
logical_widths
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
shard_id
=
self
.
_shard_id_as_int
(
shard_id
)
offset
=
sum
(
logical_widths
[:
shard_id
])
size
=
logical_widths
[
shard_id
]
# update loaded weight with copies for broadcast.
loaded_weight
=
loaded_weight
.
repeat
(
size
)
return
param
[
offset
:
offset
+
size
],
loaded_weight
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
# TODO: remove zero_point parameters once the configs given remove them
# Note on input/weight scales and zero_points
#
# When the scales have a single value, it is required that they be
# on the CPU for 2 reasons,
# 1. Performance:
# When the scales (input_scale/weight_scales) have only a single
# value, we perform a scalar broadcast of that value during the
# quant/dequant operations. The "quant" and the "gemm+dequant"
# kernels accept the Scalar by-value. These tensors are allocated
# on the CPU in order to avoid the GPU-to-CPU copy when passing
# by-value.
#
# 2. CUDA Graphs:
# CUDA Graphs don't support GPU-to-CPU copy operations during
# stream capture.
#
# TODO: zero-points are not supported yet. But we expect a similar
# pattern.
is_tensor_partitioned
=
len
(
output_partition_sizes
)
!=
1
weight_scale_dim
=
sum
(
output_partition_sizes
)
if
is_tensor_partitioned
else
1
weight_scale_device
=
"cpu"
if
weight_scale_dim
==
1
else
"cuda"
input_scale
=
Parameter
(
torch
.
empty
(
1
,
device
=
"cpu"
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
input_zero_point
=
Parameter
(
torch
.
empty
(
1
,
device
=
"cpu"
,
dtype
=
torch
.
int8
),
requires_grad
=
False
)
weight_scale
=
Parameter
(
torch
.
empty
(
weight_scale_dim
,
device
=
weight_scale_device
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
weight_zero_point
=
Parameter
(
torch
.
empty
(
1
,
device
=
"cpu"
,
dtype
=
torch
.
int8
),
requires_grad
=
False
)
weight
=
Parameter
(
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
torch
.
int8
),
requires_grad
=
False
)
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
{
"weight_loader"
:
weight_loader
,
"input_dim"
:
1
,
"output_dim"
:
0
,
})
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
set_weight_attrs
(
input_scale
,
{
"weight_loader"
:
weight_loader
,
"ignore_warning"
:
True
,
})
layer
.
register_parameter
(
"input_zero_point"
,
input_zero_point
)
set_weight_attrs
(
input_zero_point
,
{
"weight_loader"
:
weight_loader
,
"ignore_warning"
:
True
,
})
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
set_weight_attrs
(
weight_scale
,
{
"weight_loader"
:
weight_loader
,
"shard_splitter"
:
self
.
scales_shard_splitter
,
"logical_widths"
:
output_partition_sizes
,
"ignore_warning"
:
True
,
})
layer
.
register_parameter
(
"weight_zero_point"
,
weight_zero_point
)
set_weight_attrs
(
weight_zero_point
,
{
"weight_loader"
:
weight_loader
,
"ignore_warning"
:
True
})
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
):
weight
=
layer
.
weight
weight_scale
=
layer
.
weight_scale
act_scale
=
layer
.
input_scale
# Input quantize
x_q
=
custom_ops
.
static_scaled_int8_quant
(
x
,
act_scale
[
0
].
item
())
return
custom_ops
.
cutlass_scaled_mm_dq
(
x_q
,
weight
.
t
(),
act_scale
,
weight_scale
,
x
.
dtype
)
vllm/model_executor/layers/quantization/deepspeedfp.py
0 → 100644
View file @
b9e12416
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
class
DeepSpeedFPConfig
(
QuantizationConfig
):
"""Config for DeepSpeed FP quantizer. It supports fp6 and fp8.
Args:
weight_bits: the target quantization bits, 6 or 8.
group_size: group size for quantizaiton, default to 128.
"""
def
__init__
(
self
,
weight_bits
:
int
=
8
,
group_size
:
int
=
512
,
)
->
None
:
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
valid_types
=
[
torch
.
bfloat16
,
torch
.
float16
]
if
self
.
weight_bits
not
in
(
6
,
8
):
raise
ValueError
(
"Currently, only 6-bit or 8-bit weight quantization are "
f
"supported for DeepSpeed FP quantizaiton, but got "
f
"
{
self
.
weight_bits
}
bits."
)
def
__repr__
(
self
)
->
str
:
return
(
f
"DeepSpeedFPConfig(weight_bits=
{
self
.
weight_bits
}
), "
f
"group_size=
{
self
.
group_size
}
"
)
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"DeepSpeedFP"
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"DeepSpeedFPConfig"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"bits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
return
cls
(
weight_bits
=
weight_bits
,
group_size
=
group_size
)
def
get_linear_method
(
self
)
->
"DeepSpeedFPLinearMethod"
:
return
DeepSpeedFPLinearMethod
(
self
)
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
,
torch
.
bfloat16
]
@
classmethod
# Need to figure it out
def
get_min_capability
(
cls
)
->
int
:
return
60
@
staticmethod
def
get_config_filenames
()
->
List
[
str
]:
return
[
"quant_config.json"
,
"quantize_config.json"
,
]
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"DeepSpeedFPLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
DeepSpeedFPLinearMethod
(
self
)
return
None
class
DeepSpeedFPLinearMethod
(
LinearMethodBase
):
"""Linear method for DeepSpeedFP quantizer.
Args:
quant_config: the DeepSpeedFP quantization config.
"""
def
__init__
(
self
,
quant_config
:
DeepSpeedFPConfig
):
self
.
quant_config
=
quant_config
self
.
weight
=
None
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
=
None
,
**
extra_weight_attrs
):
del
output_size
del
input_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
weight
=
DeepSpeedFPParameter
(
torch
.
Size
((
output_size_per_partition
,
input_size_per_partition
)),
params_dtype
=
params_dtype
,
quant_config
=
self
.
quant_config
,
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
,
})
layer
.
register_parameter
(
"weight"
,
weight
)
def
quant_weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
):
# Calls the original weight loader (if any), quantizes the result,
# and then loads the quantized parameter.
if
weight_loader
is
not
None
:
orig_param_data
=
param
.
data
param
.
data
=
param
.
ds_dequantize
()
weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
)
param
.
data
,
loaded_weight
=
orig_param_data
,
param
.
data
param
.
ds_quantize_
(
loaded_weight
.
cuda
())
extra_weight_attrs
[
"weight_loader"
]
=
quant_weight_loader
set_weight_attrs
(
weight
,
extra_weight_attrs
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
weight
=
layer
.
weight
y
=
weight
.
ds_dequantize
()
return
F
.
linear
(
x
,
y
,
bias
)
class
DeepSpeedFPParameter
(
nn
.
Parameter
):
"""
DeepSpeedFP quantized parameter class that implements fp8/fp6
quantization deepspeed. Weights are stored in quantized form on
GPUs, and can be dequantized on-the-fly when needed by the model.
"""
def
__new__
(
cls
,
orig_shape
:
torch
.
Size
,
params_dtype
:
torch
.
dtype
,
quant_config
:
DeepSpeedFPConfig
):
try
:
import
deepspeed
if
deepspeed
.
__version__
<
"0.14.2"
:
raise
ImportError
(
"deepspeed version is wrong. Please "
"install deepspeed>=0.14.2."
)
from
deepspeed.ops.fp_quantizer
import
FP_Quantize
except
ImportError
as
err
:
raise
ImportError
(
"Please install deepspeed>=0.14.2 via "
"`pip install deepspeed>=0.14.2` to use "
"deepspeedfp quantizer."
)
from
err
data
=
torch
.
empty
((
orig_shape
.
numel
()
//
quant_config
.
group_size
,
quant_config
.
group_size
*
quant_config
.
weight_bits
//
8
+
4
,
),
dtype
=
torch
.
int8
)
self
=
torch
.
Tensor
.
_make_subclass
(
cls
,
data
,
data
.
requires_grad
)
self
.
orig_shape
=
orig_shape
self
.
quant_config
=
quant_config
self
.
fp_quantizer
=
FP_Quantize
(
group_size
=
quant_config
.
group_size
)
self
.
fp_quantizer
.
orig_shape
=
orig_shape
self
.
fp_quantizer
.
orig_dtype
=
params_dtype
return
self
def
ds_quantize_
(
self
,
tensor
:
torch
.
Tensor
):
assert
tensor
.
device
.
type
==
"cuda"
and
tensor
.
dtype
!=
torch
.
int8
return
self
.
data
.
copy_
(
self
.
fp_quantizer
.
quantize
(
tensor
.
data
,
q_bits
=
self
.
quant_config
.
weight_bits
,
))
def
ds_dequantize
(
self
,
fp_out
=
None
)
->
torch
.
Tensor
:
"""
Return a tensor containing the dequantized weights of this parameter.
"""
assert
self
.
data
.
device
.
type
==
"cuda"
and
self
.
data
.
dtype
==
torch
.
int8
return
self
.
fp_quantizer
.
dequantize
(
self
.
data
,
fp_out
=
fp_out
,
q_bits
=
self
.
quant_config
.
weight_bits
)
def
ds_selective_dequantize
(
self
,
indices
,
fp_out
=
None
)
->
torch
.
Tensor
:
"""
Return a tensor where only the weights at `indices` are dequantized
(to save HBM -> SRAM bandwidth).
"""
assert
self
.
data
.
device
.
type
==
"cuda"
and
self
.
data
.
dtype
==
torch
.
int8
return
self
.
fp_quantizer
.
selective_dequantize
(
self
.
data
,
indices
,
fp_out
=
fp_out
,
q_bits
=
self
.
quant_config
.
weight_bits
)
vllm/model_executor/layers/quantization/fp8.py
View file @
b9e12416
...
...
@@ -8,8 +8,9 @@ from vllm import _custom_ops as ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils
import
print_warning_once
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
...
...
@@ -58,9 +59,13 @@ class Fp8Config(QuantizationConfig):
activation_scheme
=
activation_scheme
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"Fp8LinearMethod"
]:
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"QuantizeMethodBase"
]:
from
vllm.attention.layer
import
Attention
# Avoid circular import
if
isinstance
(
layer
,
LinearBase
):
return
Fp8LinearMethod
(
self
)
if
isinstance
(
layer
,
Attention
):
return
Fp8KVCacheMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
...
...
@@ -231,9 +236,14 @@ class Fp8LinearMethod(LinearMethodBase):
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.act_scale is None and x_scale computed from x.
# If static, layer.act_scale is scalar and x_scale set to act_scale.
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
x
,
layer
.
act_scale
)
# Fused GEMM_DQ
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
x
,
layer
.
act_scale
,
batch_dim_padding
=
17
)
# Fused GEMM_DQ -- note we padded the input above because
# torch._scaled_mm is more performant for matrices with
# batch dimension > 16. Note that this could change
# in the future.
output
,
_
=
torch
.
_scaled_mm
(
qinput
,
layer
.
weight
,
...
...
@@ -243,7 +253,45 @@ class Fp8LinearMethod(LinearMethodBase):
bias
=
bias
,
)
return
output
return
torch
.
narrow
(
output
,
0
,
0
,
x
.
shape
[
0
])
class
Fp8KVCacheMethod
(
QuantizeMethodBase
):
"""Supports loading kv-cache scaling factors from FP8 checkpoints.
"""
def
__init__
(
self
,
quant_config
:
Fp8Config
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
):
"""Create "weight" (aka kv_scale) for an attention layer.
Args:
layer: The layer that is using the QuantizeMethodBase factory.
"""
# Initialize the KV cache scale to 1.0 as the default value.
# If the kv_scale appears in the checkpoint, it will be
# overwritten when loading weights.
layer
.
kv_scale
=
Parameter
(
torch
.
tensor
(
1.0
),
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
)
->
torch
.
Tensor
:
raise
RuntimeError
(
"Fp8KVCacheMethod.apply should not be called."
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
if
layer
.
kv_cache_dtype
!=
"auto"
:
kv_scale
=
layer
.
kv_scale
.
to
(
"cpu"
).
tolist
()
if
not
isinstance
(
kv_scale
,
float
):
raise
ValueError
(
"Only support per-tensor scaling factor "
"for fp8 KV cache"
)
layer
.
_kv_scale
=
kv_scale
if
layer
.
_kv_scale
==
1.0
and
"e5m2"
not
in
layer
.
kv_cache_dtype
:
print_warning_once
(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This may "
"cause accuracy issues. Please make sure kv-cache scaling "
"factor is available in the fp8 checkpoint."
)
del
layer
.
kv_scale
def
all_close_1d
(
x
:
torch
.
Tensor
)
->
bool
:
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
b9e12416
...
...
@@ -6,11 +6,14 @@ import torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
logger
=
init_logger
(
__name__
)
GPTQ_MARLIN_TILE
=
16
GPTQ_MARLIN_MIN_THREAD_N
=
64
GPTQ_MARLIN_MIN_THREAD_K
=
128
...
...
@@ -99,7 +102,7 @@ class GPTQMarlinConfig(QuantizationConfig):
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
]
return
[
torch
.
half
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
...
...
@@ -117,6 +120,26 @@ class GPTQMarlinConfig(QuantizationConfig):
is_sym
=
cls
.
get_from_keys
(
config
,
[
"sym"
])
return
cls
(
weight_bits
,
group_size
,
desc_act
,
is_sym
)
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
can_convert
=
cls
.
is_marlin_compatible
(
hf_quant_cfg
)
is_valid_user_quant
=
(
user_quant
is
None
or
user_quant
==
"marlin"
)
if
can_convert
and
is_valid_user_quant
:
msg
=
(
"The model is convertible to {} during runtime."
" Using {} kernel."
.
format
(
cls
.
get_name
(),
cls
.
get_name
()))
logger
.
info
(
msg
)
return
cls
.
get_name
()
if
can_convert
and
user_quant
==
"gptq"
:
logger
.
info
(
"Detected that the model can run with gptq_marlin"
", however you specified quantization=gptq explicitly,"
" so forcing gptq. Use quantization=gptq_marlin for"
" faster inference"
)
return
None
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"GPTQMarlinLinearMethod"
]:
...
...
@@ -186,9 +209,9 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
group_size
=
input_size
# Validate dtype
if
params_dtype
!=
torch
.
float16
:
raise
ValueError
(
f
"The params dtype must be
float16, but got
{
params_dtype
}
"
)
if
params_dtype
not
in
[
torch
.
float16
,
torch
.
bfloat16
]
:
raise
ValueError
(
f
"The params dtype must be float16 "
f
"or b
float16, but got
{
params_dtype
}
"
)
# Validate output_size_per_partition
output_size_per_partition
=
sum
(
output_partition_sizes
)
...
...
@@ -275,14 +298,10 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
},
)
g_idx_sort_indices
=
Parameter
(
torch
.
empty
(
g_idx
.
shape
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
g_idx_sort_indices
=
torch
.
empty
(
g_idx
.
shape
,
dtype
=
torch
.
int32
,
)
set_weight_attrs
(
g_idx_sort_indices
,
extra_weight_attrs
)
# Scales
scales
=
Parameter
(
...
...
@@ -333,9 +352,9 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
layer
.
register_parameter
(
"qweight"
,
qweight
)
layer
.
register_parameter
(
"g_idx"
,
g_idx
)
layer
.
register_parameter
(
"g_idx_sort_indices"
,
g_idx_sort_indices
)
layer
.
register_parameter
(
"scales"
,
scales
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
layer
.
g_idx_sort_indices
=
g_idx_sort_indices
layer
.
workspace
=
workspace
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
...
...
vllm/model_executor/layers/quantization/gptq_marlin_24.py
0 → 100644
View file @
b9e12416
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
logger
=
init_logger
(
__name__
)
GPTQ_MARLIN_24_TILE
=
16
GPTQ_MARLIN_24_MIN_THREAD_N
=
128
GPTQ_MARLIN_24_MIN_THREAD_K
=
128
GPTQ_MARLIN_24_MAX_PARALLEL
=
64
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
=
[
4
,
8
]
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
=
[
-
1
,
128
]
GPTQ_MARLIN_24_SUPPORTED_SYM
=
[
True
]
class
GPTQMarlin24Config
(
QuantizationConfig
):
"""Config class for Marlin24.
"""
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
)
->
None
:
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
# Verify
if
self
.
weight_bits
not
in
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
:
raise
ValueError
(
f
"Marlin_24 does not support weight_bits =
{
self
.
weight_bits
}
. "
f
"Only weight_bits =
{
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
}
"
"are supported."
)
if
self
.
group_size
not
in
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
:
raise
ValueError
(
f
"Marlin_24 does not support group_size =
{
self
.
group_size
}
. "
f
"Only group_sizes =
{
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
}
"
"are supported."
)
# 4 Bits packed into 32 bit datatype.
self
.
pack_factor
=
32
//
self
.
weight_bits
# Tile size used by marlin kernels.
self
.
tile_size
=
16
# Min out_features dim
self
.
min_n_threads
=
GPTQ_MARLIN_24_MIN_THREAD_N
# Min in_features dim
self
.
min_k_threads
=
GPTQ_MARLIN_24_MIN_THREAD_K
# Max parallel problems to solve at once (improves large
# batch performance)
self
.
max_parallel
=
GPTQ_MARLIN_24_MAX_PARALLEL
# Permutation length used by the marlin kernels.
self
.
perm_len
=
1024
def
__repr__
(
self
)
->
str
:
return
"Marlin24Config(weight_bits={}, group_size={})"
.
format
(
self
.
weight_bits
,
self
.
group_size
)
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"gptq_marlin_24"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
]
@
classmethod
# Need to figure it out
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"quantize_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"GPTQMarlin24Config"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"bits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
return
cls
(
weight_bits
,
group_size
)
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
is_marlin_24_format
=
(
hf_quant_cfg
.
get
(
"checkpoint_format"
)
==
"marlin_24"
)
is_valid_user_quant
=
(
user_quant
is
None
or
user_quant
==
"gptq"
or
user_quant
==
"gptq_marlin_24"
)
if
is_marlin_24_format
and
is_valid_user_quant
:
msg
=
(
"The model is serialized in {} format. "
"Using {} kernel."
.
format
(
cls
.
get_name
(),
cls
.
get_name
()))
logger
.
info
(
msg
)
return
cls
.
get_name
()
return
None
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"GPTQMarlin24LinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
GPTQMarlin24LinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
GPTQMarlin24LinearMethod
(
LinearMethodBase
):
"""Linear method for Marlin24.
Args:
quant_config: The Marlin24 quantization config.
"""
def
__init__
(
self
,
quant_config
:
GPTQMarlin24Config
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
del
output_size
# Unused.
if
params_dtype
!=
torch
.
float16
:
raise
ValueError
(
f
"The params dtype must be float16, but got
{
params_dtype
}
"
)
# Validate output_size_per_partition
output_size_per_partition
=
sum
(
output_partition_sizes
)
if
output_size_per_partition
%
self
.
quant_config
.
min_n_threads
!=
0
:
raise
ValueError
(
f
"Weight output_size_per_partition = "
f
"
{
output_size_per_partition
}
is not divisible by "
f
"min_n_threads =
{
self
.
quant_config
.
min_n_threads
}
."
)
if
output_size_per_partition
%
self
.
quant_config
.
pack_factor
!=
0
:
raise
ValueError
(
f
"Weight output_size_per_partition = "
f
"
{
output_size_per_partition
}
is not divisible by "
f
"pack_factor =
{
self
.
quant_config
.
pack_factor
}
."
)
# Validate input_size_per_partition
if
input_size_per_partition
%
self
.
quant_config
.
min_k_threads
!=
0
:
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible by "
f
"min_k_threads =
{
self
.
quant_config
.
min_k_threads
}
."
)
if
(
self
.
quant_config
.
group_size
!=
-
1
and
input_size_per_partition
%
self
.
quant_config
.
group_size
!=
0
):
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible by "
f
"group_size =
{
self
.
quant_config
.
group_size
}
."
)
# Check that we have at least 4 tiles horizontally in the shard
num_tiles_per_perm
=
self
.
quant_config
.
perm_len
//
(
self
.
quant_config
.
tile_size
**
2
)
if
output_size_per_partition
%
num_tiles_per_perm
!=
0
:
raise
ValueError
(
"Each permutation group must reside on the same gpu"
)
# Quantized 4Bit weights packed into Int32.
qweight
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
tile_size
//
2
,
output_size_per_partition
*
self
.
quant_config
.
tile_size
//
self
.
quant_config
.
pack_factor
,
device
=
"cuda"
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
"marlin_tile_size"
:
self
.
quant_config
.
tile_size
,
},
)
# Meta
meta
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
8
//
2
//
2
,
output_size_per_partition
*
2
,
device
=
"cuda"
,
dtype
=
torch
.
int16
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
meta
,
{
"input_dim"
:
0
,
"packed_dim"
:
1
,
"pack_factor"
:
1
,
"output_dim"
:
1
,
"marlin_tile_size"
:
2
,
},
)
# Determine if channelwise or not
input_groups
=
(
1
if
self
.
quant_config
.
group_size
==
-
1
else
input_size_per_partition
//
self
.
quant_config
.
group_size
)
scales
=
Parameter
(
torch
.
empty
(
input_groups
,
output_size_per_partition
,
device
=
"cuda"
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
scales
,
{
"input_dim"
:
None
if
input_groups
==
1
else
0
,
"output_dim"
:
1
,
},
)
# Allocate workspace (Used for internal locking mechanism)
max_workspace_size
=
(
output_size_per_partition
//
self
.
quant_config
.
min_n_threads
)
*
self
.
quant_config
.
max_parallel
workspace
=
Parameter
(
torch
.
zeros
(
max_workspace_size
,
device
=
"cuda"
,
dtype
=
torch
.
int
),
requires_grad
=
False
)
layer
.
register_parameter
(
"B_24"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"B_meta"
,
meta
)
set_weight_attrs
(
meta
,
extra_weight_attrs
)
layer
.
register_parameter
(
"s"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
layer
.
register_parameter
(
"workspace"
,
workspace
)
set_weight_attrs
(
workspace
,
extra_weight_attrs
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
qweight
=
layer
.
B_24
meta
=
layer
.
B_meta
scales
=
layer
.
s
workspace
=
layer
.
workspace
x_2d
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
size_m
=
x_2d
.
shape
[
0
]
size_k
=
x_2d
.
shape
[
1
]
size_n
=
scales
.
shape
[
1
]
output_2d
=
ops
.
gptq_marlin_24_gemm
(
x_2d
,
qweight
,
meta
,
scales
,
workspace
,
self
.
quant_config
.
weight_bits
,
size_m
,
size_n
,
size_k
)
output
=
output_2d
.
view
(
x
.
shape
[:
-
1
]
+
(
output_2d
.
shape
[
1
],
))
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
vllm/model_executor/layers/quantization/marlin.py
View file @
b9e12416
...
...
@@ -4,11 +4,14 @@ import torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
logger
=
init_logger
(
__name__
)
class
MarlinConfig
(
QuantizationConfig
):
"""Config class for Marlin.
...
...
@@ -72,6 +75,25 @@ class MarlinConfig(QuantizationConfig):
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
return
cls
(
group_size
)
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
# compat: autogptq >=0.8.0 use checkpoint_format: str
# compat: autogptq <=0.7.1 is_marlin_format: bool
is_marlin_format
=
(
hf_quant_cfg
.
get
(
"checkpoint_format"
)
==
"marlin"
or
hf_quant_cfg
.
get
(
"is_marlin_format"
,
False
))
is_valid_user_quant
=
(
user_quant
is
None
or
user_quant
==
"gptq"
or
user_quant
==
"marlin"
)
if
is_marlin_format
and
is_valid_user_quant
:
msg
=
(
"The model is serialized in {} format. Using {} kernel."
.
format
(
cls
.
get_name
(),
cls
.
get_name
()))
logger
.
info
(
msg
)
return
cls
.
get_name
()
return
None
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"MarlinLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
...
...
vllm/model_executor/layers/quantization/utils/__init__.py
0 → 100644
View file @
b9e12416
vllm/model_executor/layers/quantization/utils/format_24.py
0 → 100644
View file @
b9e12416
#
# Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es).
#
import
torch
# This is PyTorch implementation of main part of reorder_meta()
# function, from tools/util/include/cutlass/util/host_reorder.h file
# of CUTLASS source tree. Furthermore, CUTLASS template for sparse
# GEMM decides upon layout of this matrix, and at the moment for the
# sparse GEMM executed on tensor cores, this is layout described by
# ColumnMajorInterleaved<2> data structure, in
# include/cutlass/layout/matrix.h of CUTLASS source tree. The
# reordering of meta matrix into meta_reordered matrix calculated
# according to these segments of CUTLASS code is re-implemented here.
# Note that this calculation produces offsets for scattering metadata
# matrix elements into reordered metadata matrix elements (or,
# equivalently, for gathering reordered metadata matrix element back
# into metadata matrix elements).
def
_calculate_meta_reordering_scatter_offsets
(
m
,
meta_ncols
,
meta_dtype
,
device
):
dst_rows
=
torch
.
arange
(
0
,
m
,
device
=
device
)[:,
None
].
repeat
(
1
,
meta_ncols
)
dst_cols
=
torch
.
arange
(
0
,
meta_ncols
,
device
=
device
).
repeat
(
m
,
1
)
# Reorder the rows, then swizzle the 2x2 blocks.
group_x
=
64
group_y
=
32
if
meta_dtype
.
itemsize
==
2
else
16
dst_rows
=
(
dst_rows
//
group_x
*
group_x
+
(
dst_rows
%
2
)
*
2
+
(
dst_rows
%
8
)
//
4
+
((
dst_rows
%
group_y
)
%
4
)
//
2
*
32
+
((
dst_rows
%
group_x
)
//
8
)
*
4
)
topright
=
((
dst_rows
%
2
==
0
)
&
(
dst_cols
%
2
==
1
)).
to
(
torch
.
int8
)
bottomleft
=
((
dst_rows
%
2
==
1
)
&
(
dst_cols
%
2
==
0
)).
to
(
torch
.
int8
)
dst_rows
+=
topright
-
bottomleft
dst_cols
-=
topright
-
bottomleft
# Assumed that meta tensor is to be stored in CUTLASS
# InterleavedColumnMajor layout, and reverse engineered
# corresponding code to store values into this tensor.
interleave
=
2
cols_maj
=
dst_cols
//
interleave
cols_min
=
dst_cols
%
interleave
return
(
cols_maj
*
m
*
interleave
+
dst_rows
*
interleave
+
cols_min
).
view
(
-
1
)
# This function converts dense matrix into sparse semi-structured
# representation, producing "compressed" matrix, in the layout used by
# CUTLASS backend, and corresponding metadata matrix.
def
sparse_semi_structured_from_dense_cutlass
(
dense
):
if
dense
.
dim
()
!=
2
:
raise
RuntimeError
(
f
"Expected 2-dimensional dense tensor, got
{
dense
.
dim
()
}
-dimensional tensor"
# noqa: E501
)
m
,
k
=
dense
.
shape
device
=
dense
.
device
meta_dtype
=
torch
.
int8
if
dense
.
dtype
==
torch
.
int8
:
meta_dtype
=
torch
.
int32
elif
dense
.
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
,
torch
.
int32
]:
meta_dtype
=
torch
.
int16
else
:
raise
RuntimeError
(
f
"Invalid datatype
{
dense
.
dtype
}
of dense matrix"
)
quadbits_per_meta_elem
=
meta_dtype
.
itemsize
*
8
//
4
if
quadbits_per_meta_elem
not
in
(
4
,
8
):
raise
RuntimeError
(
"Invalid number of elements per meta element calculated"
)
if
meta_dtype
==
torch
.
int32
:
if
m
%
16
!=
0
:
raise
RuntimeError
(
f
"Number of rows of dense matrix
{
m
}
must be divisible by 16"
)
else
:
if
m
%
32
!=
0
:
raise
RuntimeError
(
f
"Number of rows of dense matrix
{
m
}
must be divisible by 32"
)
if
k
%
(
4
*
quadbits_per_meta_elem
)
!=
0
:
raise
RuntimeError
(
f
"Number of columns of dense matrix
{
k
}
must be divisible by
{
4
*
quadbits_per_meta_elem
}
"
# noqa: E501
)
if
dense
.
dtype
!=
torch
.
float
:
ksparse
=
4
dense_4
=
dense
.
view
(
-
1
,
k
//
ksparse
,
ksparse
)
m0
,
m1
,
m2
,
m3
=
(
dense_4
!=
0
).
unbind
(
-
1
)
else
:
ksparse
=
2
dense_2
=
dense
.
view
(
-
1
,
k
//
ksparse
,
ksparse
)
m0
,
m2
=
m1
,
m3
=
(
dense_2
!=
0
).
unbind
(
-
1
)
meta_ncols
=
k
//
(
ksparse
*
quadbits_per_meta_elem
)
# Encoding quadruples of True/False values as follows:
# [True, True, False, False] -> 0b0100
# [True, False, True, False] -> 0b1000
# [False, True, True, False] -> 0b1001
# [True, False, False, True ] -> 0b1100
# [False, True, False, True ] -> 0b1101
# [False, False, True, True ] -> 0b1110
# Thus, lower two bits in the encoding are index of the True value
# at the lowest index in the quadruple, and the higher two bits in
# the encoding are index of the other True value in the quadruple.
# In case there are less than two True values, than False value or
# values at some index or indices are considered True for the
# encoding. In case there are more than two True values, then the
# excess True value(s) at some indices are considered False for
# the encoding. The exact encodings used for these cases are as
# follows:
# [False, False, False, False] -> 0b1110
# [False, False, False, True ] -> 0b1110
# [False, False, True, False] -> 0b1110
# [False, True, False, False] -> 0b1001
# [False, True, True, True ] -> 0b1101
# [True, False, False, False] -> 0b1000
# [True, False, True, True ] -> 0b1100
# [True, True, False, True ] -> 0b0100
# [True, True, True, False] -> 0b0100
# [True, True, True, True ] -> 0b0100
# These particular encodings are chosen, with the help of Espresso
# logic minimizer software, for the purpose of minimization of
# corresponding Boolean functions, that translate non-zero flags
# into encoding bits. Note also possible choices for the first
# and last of these encodings were limited only to (0b0100,
# 0b1110), in order to produce valid encodings for 1:2 sparsity
# case.
expr0
=
m0
&
m1
expr1
=
~
m0
&
m1
expr2
=
~
m0
&
~
m1
bit0
=
expr1
bit1
=
expr2
bit2
=
expr0
|
expr2
|
m3
bit3
=
expr1
|
~
m1
idxs0
=
bit0
|
(
bit1
.
to
(
torch
.
int64
)
<<
1
)
idxs1
=
bit2
|
(
bit3
.
to
(
torch
.
int64
)
<<
1
)
if
dense
.
dtype
!=
torch
.
float
:
sparse0
=
dense_4
.
gather
(
-
1
,
idxs0
.
unsqueeze
(
-
1
))
# type: ignore[possibly-undefined]
sparse1
=
dense_4
.
gather
(
-
1
,
idxs1
.
unsqueeze
(
-
1
))
sparse
=
torch
.
stack
((
sparse0
,
sparse1
),
dim
=-
1
).
view
(
m
,
k
//
2
)
else
:
sparse
=
dense_2
.
gather
(
-
1
,
idxs0
.
unsqueeze
(
-
1
)
//
2
).
view
(
m
,
k
//
2
)
# type: ignore[possibly-undefined]
meta_4
=
idxs0
|
(
idxs1
<<
2
)
meta_n
=
meta_4
.
view
(
(
-
1
,
meta_ncols
,
quadbits_per_meta_elem
)).
to
(
meta_dtype
)
if
quadbits_per_meta_elem
==
4
:
meta
=
(
meta_n
[:,
:,
0
]
|
(
meta_n
[:,
:,
1
]
<<
4
)
|
(
meta_n
[:,
:,
2
]
<<
8
)
|
(
meta_n
[:,
:,
3
]
<<
12
))
elif
quadbits_per_meta_elem
==
8
:
meta
=
(
meta_n
[:,
:,
0
]
|
(
meta_n
[:,
:,
1
]
<<
4
)
|
(
meta_n
[:,
:,
2
]
<<
8
)
|
(
meta_n
[:,
:,
3
]
<<
12
)
|
(
meta_n
[:,
:,
4
]
<<
16
)
|
(
meta_n
[:,
:,
5
]
<<
20
)
|
(
meta_n
[:,
:,
6
]
<<
24
)
|
(
meta_n
[:,
:,
7
]
<<
28
))
# Reorder meta tensor elements.
meta_reordered
=
meta
.
new_empty
(
(
m
*
meta_ncols
,
))
# type: ignore[possibly-undefined]
meta_offsets
=
_calculate_meta_reordering_scatter_offsets
(
m
,
meta_ncols
,
meta_dtype
,
device
)
meta_reordered
.
scatter_
(
0
,
meta_offsets
,
meta
.
view
(
-
1
))
return
(
sparse
,
meta_reordered
.
view
(
m
,
meta_ncols
))
# This function performs reverse of the function above - it
# reconstructs dense matrix from a pair of "compressed" matrix, given
# in the layout used by CUTLASS backend, and accompanying metadata
# matrix.
def
sparse_semi_structured_to_dense_cutlass
(
sparse
,
meta_reordered
):
if
sparse
.
dim
()
!=
2
:
raise
RuntimeError
(
f
"Expected 2-dimensional sparse tensor, got
{
sparse
.
dim
()
}
-dimensional tensor"
# noqa: E501
)
m
,
k
=
sparse
.
shape
device
=
sparse
.
device
if
meta_reordered
.
dim
()
!=
2
:
raise
RuntimeError
(
f
"Expected 2-dimensional meta tensor, got
{
meta_reordered
.
dim
()
}
-dimensional tensor"
# noqa: E501
)
if
meta_reordered
.
device
!=
device
:
raise
RuntimeError
(
f
"Expected meta matrix to be on
{
device
}
device, got matrix on
{
meta_reordered
.
device
}
device"
# noqa: E501
)
meta_dtype
=
meta_reordered
.
dtype
if
meta_dtype
not
in
(
torch
.
int16
,
torch
.
int32
):
raise
RuntimeError
(
f
"Invalid datatype
{
meta_dtype
}
of meta matrix"
)
quadbits_per_meta_elem
=
meta_dtype
.
itemsize
*
8
//
4
ksparse
=
4
if
sparse
.
dtype
!=
torch
.
float
else
2
meta_nrows
,
meta_ncols
=
meta_reordered
.
shape
if
meta_nrows
!=
m
:
raise
RuntimeError
(
f
"Number of rows of meta matrix
{
meta_nrows
}
must be equal to number of columns of spase matrix
{
m
}
"
# noqa: E501
)
if
meta_ncols
*
ksparse
*
quadbits_per_meta_elem
!=
2
*
k
:
raise
RuntimeError
(
f
"Number of columns of sparse matrix
{
k
}
different from the
{
meta_ncols
*
ksparse
*
quadbits_per_meta_elem
//
2
}
, "
# noqa: E501
"expected according to the number of columns of meta matrix"
)
# Undo meta tensor elements reordering.
meta_offsets
=
_calculate_meta_reordering_scatter_offsets
(
m
,
meta_ncols
,
meta_dtype
,
device
)
meta
=
torch
.
gather
(
meta_reordered
.
view
(
-
1
),
0
,
meta_offsets
).
view
(
m
,
meta_ncols
)
# Unpack sparse tensor back to original dense tensor, using
# information provided by meta tensor. Note that torch.float
# datatype is handled pretty much the same as
# torch.half/torch.bfloat16, as metadata for a pair of torch.float
# value is encoded as if underlying 8 bytes contain four
# torch.half/torch.bfloat16 values, where either first two or last
# two are zeros.
meta_2
=
torch
.
empty
(
(
m
,
meta_ncols
,
2
*
quadbits_per_meta_elem
),
dtype
=
meta_dtype
,
device
=
device
,
)
if
quadbits_per_meta_elem
==
4
:
meta_2
[:,
:,
0
]
=
meta
&
0b11
meta_2
[:,
:,
1
]
=
(
meta
>>
2
)
&
0b11
meta_2
[:,
:,
2
]
=
(
meta
>>
4
)
&
0b11
meta_2
[:,
:,
3
]
=
(
meta
>>
6
)
&
0b11
meta_2
[:,
:,
4
]
=
(
meta
>>
8
)
&
0b11
meta_2
[:,
:,
5
]
=
(
meta
>>
10
)
&
0b11
meta_2
[:,
:,
6
]
=
(
meta
>>
12
)
&
0b11
meta_2
[:,
:,
7
]
=
(
meta
>>
14
)
&
0b11
elif
quadbits_per_meta_elem
==
8
:
meta_2
[:,
:,
0
]
=
meta
&
0b11
meta_2
[:,
:,
1
]
=
(
meta
>>
2
)
&
0b11
meta_2
[:,
:,
2
]
=
(
meta
>>
4
)
&
0b11
meta_2
[:,
:,
3
]
=
(
meta
>>
6
)
&
0b11
meta_2
[:,
:,
4
]
=
(
meta
>>
8
)
&
0b11
meta_2
[:,
:,
5
]
=
(
meta
>>
10
)
&
0b11
meta_2
[:,
:,
6
]
=
(
meta
>>
12
)
&
0b11
meta_2
[:,
:,
7
]
=
(
meta
>>
14
)
&
0b11
meta_2
[:,
:,
8
]
=
(
meta
>>
16
)
&
0b11
meta_2
[:,
:,
9
]
=
(
meta
>>
18
)
&
0b11
meta_2
[:,
:,
10
]
=
(
meta
>>
20
)
&
0b11
meta_2
[:,
:,
11
]
=
(
meta
>>
22
)
&
0b11
meta_2
[:,
:,
12
]
=
(
meta
>>
24
)
&
0b11
meta_2
[:,
:,
13
]
=
(
meta
>>
26
)
&
0b11
meta_2
[:,
:,
14
]
=
(
meta
>>
28
)
&
0b11
meta_2
[:,
:,
15
]
=
(
meta
>>
30
)
&
0b11
dense_offsets
=
meta_2
.
view
(
-
1
)
+
(
torch
.
arange
(
0
,
2
*
m
*
k
//
ksparse
,
device
=
device
)
*
4
).
view
(
-
1
,
1
).
repeat
(
1
,
2
).
view
(
-
1
)
dense
=
torch
.
zeros
((
m
*
2
*
k
,
),
dtype
=
sparse
.
dtype
,
device
=
device
)
if
sparse
.
dtype
!=
torch
.
float
:
# dense.scatter_(0, dense_offsets, sparse.view(-1))
dense
.
scatter_
(
0
,
dense_offsets
,
sparse
.
reshape
(
-
1
))
else
:
dense
.
view
(
torch
.
half
).
scatter_
(
0
,
dense_offsets
,
sparse
.
view
(
torch
.
half
).
view
(
-
1
))
return
dense
.
view
(
m
,
2
*
k
)
def
mask_creator
(
tensor
):
"""
Class for creating N:M sparsity masks.
Masks will be created using the N:M ratio, where for every block of
M weights, N will be pruned based on ranked weight value. Each mask
will correspond to the given tensor.
:param N: The number of weights in a group to keep
:param M: The size of a weight group
"""
N
=
2
M
=
4
mask
=
None
# for i, tensor in enumerate(tensors):
if
tensor
.
numel
()
%
M
!=
0
:
raise
ValueError
(
f
"Tensor of size
{
tensor
.
shape
}
can't be evenly divided into "
f
"
{
M
}
groups"
)
num_groups
=
tensor
.
numel
()
//
M
# N:M sparsity for linear layers
tensor_temp
=
tensor
.
detach
().
abs
().
reshape
(
num_groups
,
M
)
index
=
torch
.
argsort
(
tensor_temp
,
dim
=
1
)[:,
:
int
(
M
-
N
)]
w_b
=
torch
.
ones
(
tensor_temp
.
shape
,
device
=
tensor_temp
.
device
)
mask
=
w_b
.
scatter_
(
dim
=
1
,
index
=
index
,
value
=
0
).
reshape
(
tensor
.
shape
)
return
mask
vllm/model_executor/layers/quantization/utils/marlin_24_perms.py
0 → 100644
View file @
b9e12416
"""This file is used for /tests and /benchmarks"""
import
numpy
import
torch
# Precompute permutations for Marlin24 weight and scale shuffling # noqa: E501
#
# Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501
# with the tensor-core format that is described here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
#
# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
# (without the need to use ldmatrix instructions) # noqa: E501
def
get_perms_24
(
num_bits
):
perm_list
=
[]
for
i
in
range
(
32
):
perm1
=
[]
col
=
i
//
4
col_o
=
col
//
2
for
block
in
[
0
,
1
]:
for
row
in
[
2
*
(
i
%
4
),
2
*
(
i
%
4
)
+
1
,
2
*
(
i
%
4
+
4
),
2
*
(
i
%
4
+
4
)
+
1
,
]:
perm1
.
append
(
16
*
row
+
col_o
*
256
+
8
*
(
col
%
2
)
+
4
*
block
)
for
j
in
range
(
4
):
perm_list
.
extend
([
p
+
1
*
j
for
p
in
perm1
])
perm
=
numpy
.
array
(
perm_list
)
if
num_bits
==
4
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
interleave
=
numpy
.
array
([
0
,
2
,
1
,
3
])
else
:
raise
ValueError
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
perm
=
perm
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
perm
=
torch
.
from_numpy
(
perm
)
scale_perm
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
*
8
+
j
for
j
in
[
0
,
4
,
1
,
5
,
2
,
6
,
3
,
7
]])
scale_perm_single
=
[]
for
i
in
range
(
8
):
scale_perm_single
.
extend
([
8
*
i
+
j
for
j
in
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
]])
return
perm
,
scale_perm
,
scale_perm_single
marlin_24_perm
=
{}
marlin_24_scale_perm
=
{}
marlin_24_scale_perm_single
=
{}
for
num_bits
in
[
4
,
8
]:
perm_24
,
scale_perm_24
,
scale_perm_single_24
=
get_perms_24
(
num_bits
)
marlin_24_perm
[
num_bits
]
=
perm_24
marlin_24_scale_perm
[
num_bits
]
=
scale_perm_24
marlin_24_scale_perm_single
[
num_bits
]
=
scale_perm_single_24
vllm/model_executor/layers/quantization/utils/marlin_perms.py
0 → 100644
View file @
b9e12416
"""This file is used for /tests and /benchmarks"""
import
numpy
import
torch
# Precompute permutations for Marlin weight and scale shuffling # noqa: E501
#
# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501
# with the tensor-core format that is described here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
#
# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
# (without the need to use ldmatrix instructions) # noqa: E501
def
get_perms
(
num_bits
):
perm_list
=
[]
for
i
in
range
(
32
):
perm1
=
[]
col
=
i
//
4
for
block
in
[
0
,
1
]:
for
row
in
[
2
*
(
i
%
4
),
2
*
(
i
%
4
)
+
1
,
2
*
(
i
%
4
+
4
),
2
*
(
i
%
4
+
4
)
+
1
,
]:
perm1
.
append
(
16
*
row
+
col
+
8
*
block
)
for
j
in
range
(
4
):
perm_list
.
extend
([
p
+
256
*
j
for
p
in
perm1
])
perm
=
numpy
.
array
(
perm_list
)
if
num_bits
==
4
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
interleave
=
numpy
.
array
([
0
,
2
,
1
,
3
])
else
:
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
perm
=
perm
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
perm
=
torch
.
from_numpy
(
perm
)
scale_perm
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
perm
,
scale_perm
,
scale_perm_single
marlin_perm
=
{}
marlin_scale_perm
=
{}
marlin_scale_perm_single
=
{}
for
num_bits
in
[
4
,
8
]:
perm
,
scale_perm
,
scale_perm_single
=
get_perms
(
num_bits
)
marlin_perm
[
num_bits
]
=
perm
marlin_scale_perm
[
num_bits
]
=
scale_perm
marlin_scale_perm_single
[
num_bits
]
=
scale_perm_single
vllm/model_executor/layers/quantization/utils/marlin_utils.py
0 → 100644
View file @
b9e12416
"""This file is used for /tests and /benchmarks"""
import
random
import
numpy
import
torch
from
vllm.model_executor.layers.quantization.utils.format_24
import
(
mask_creator
,
sparse_semi_structured_from_dense_cutlass
)
from
vllm.model_executor.layers.quantization.utils.marlin_24_perms
import
(
marlin_24_perm
,
marlin_24_scale_perm
,
marlin_24_scale_perm_single
)
from
vllm.model_executor.layers.quantization.utils.marlin_perms
import
(
marlin_perm
,
marlin_scale_perm
,
marlin_scale_perm_single
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
get_pack_factor
,
quantize_weights
,
sort_weights
)
__cuda_arch
=
torch
.
cuda
.
get_device_capability
()
MARLIN_TILE
=
16
def
is_marlin_supported
():
return
__cuda_arch
[
0
]
>=
8
def
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
perm
,
tile
=
MARLIN_TILE
):
assert
q_w
.
shape
==
(
size_k
,
size_n
)
assert
size_k
%
tile
==
0
,
f
"size_k =
{
size_k
}
, tile =
{
tile
}
"
assert
size_n
%
tile
==
0
,
f
"size_k =
{
size_n
}
, tile =
{
tile
}
"
# Permute weights to 16x64 marlin tiles
q_w
=
q_w
.
reshape
((
size_k
//
tile
,
tile
,
size_n
//
tile
,
tile
))
q_w
=
q_w
.
permute
((
0
,
2
,
1
,
3
))
q_w
=
q_w
.
reshape
((
size_k
//
tile
,
size_n
*
tile
))
q_w
=
q_w
.
reshape
((
-
1
,
perm
.
numel
()))[:,
perm
].
reshape
(
q_w
.
shape
)
return
q_w
def
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
perm
):
# Permute
q_w
=
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
perm
)
# Pack
pack_factor
=
get_pack_factor
(
num_bits
)
orig_device
=
q_w
.
device
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_packed
=
numpy
.
zeros
((
q_w
.
shape
[
0
],
q_w
.
shape
[
1
]
//
pack_factor
),
dtype
=
numpy
.
uint32
)
for
i
in
range
(
pack_factor
):
q_packed
|=
q_w
[:,
i
::
pack_factor
]
<<
num_bits
*
i
q_packed
=
torch
.
from_numpy
(
q_packed
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
return
q_packed
def
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
,
scale_perm
,
scale_perm_single
):
if
group_size
<
size_k
and
group_size
!=
-
1
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
else
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
s
def
marlin_quantize
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
act_order
:
bool
,
):
size_k
,
size_n
=
w
.
shape
# Normalize group_size
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
# Quantize (and apply act_order if provided)
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
=
quantize_weights
(
w
,
num_bits
,
group_size
,
act_order
)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing
sort_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w
.
device
)
if
act_order
:
q_w
,
g_idx
,
sort_indices
=
sort_weights
(
q_w
,
g_idx
)
# Reformat to marlin
marlin_q_w
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
marlin_perm
[
num_bits
])
marlin_s
=
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
,
marlin_scale_perm
[
num_bits
],
marlin_scale_perm_single
[
num_bits
])
# Create result
res_list
=
[
w_ref
,
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
rand_perm
]
for
i
in
range
(
len
(
res_list
)):
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
return
res_list
def
inject_24
(
w
,
size_k
,
size_n
):
assert
w
.
shape
==
(
size_k
,
size_n
)
mask
=
mask_creator
(
w
.
t
()).
t
().
cuda
().
bool
()
return
(
mask
*
w
).
contiguous
(),
mask
.
contiguous
()
def
check_24
(
w
,
num_rows_to_sample
=
50
,
_verbose
=
False
):
BLOCK_SIZE
=
4
MAX_NON_ZEROS
=
2
w
=
w
.
t
().
contiguous
()
print
(
"check_24: w.shape = {}"
.
format
(
w
.
shape
))
num_rows
,
num_cols
=
w
.
shape
sampled_row_idxs
=
random
.
choices
(
range
(
num_rows
),
k
=
num_rows_to_sample
)
if
_verbose
:
print
(
f
"Sampled row idxs =
{
sampled_row_idxs
}
"
)
total_segments
=
0
non_24_segments
=
0
for
i
in
sampled_row_idxs
:
for
j
in
range
(
0
,
num_cols
-
BLOCK_SIZE
,
BLOCK_SIZE
):
total_segments
+=
1
block
=
w
[
i
,
j
:
j
+
BLOCK_SIZE
]
num_nonzero
=
torch
.
count_nonzero
(
block
)
if
num_nonzero
>
MAX_NON_ZEROS
:
print
(
"i = {} j = {} block = {}"
.
format
(
i
,
j
,
block
))
non_24_segments
+=
1
print
(
f
"
{
non_24_segments
}
/
{
total_segments
}
do not have 2:4 structure."
)
def
compress_quantized_24_weight
(
q_24
,
size_k
,
size_n
,
num_bits
):
assert
q_24
.
shape
==
(
size_k
,
size_n
)
# Remove zp to normalize over 0
max_q_val
=
(
1
<<
num_bits
)
-
1
zp
=
(
max_q_val
+
1
)
//
2
q_24_no_zp
=
q_24
-
zp
# Compress
q_24_no_zp
=
q_24_no_zp
.
t
().
contiguous
()
q_24_no_zp_comp
,
meta
=
sparse_semi_structured_from_dense_cutlass
(
q_24_no_zp
)
q_24_no_zp_comp
=
q_24_no_zp_comp
.
t
().
contiguous
()
# Restore zp
q_24_comp
=
q_24_no_zp_comp
+
zp
# Resize meta to its actual shape (without moving any data)
meta
=
meta
.
resize_
(
meta
.
shape
[
1
]
//
2
,
meta
.
shape
[
0
]
*
2
)
return
q_24_comp
,
meta
def
marlin_24_quantize
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
):
size_k
,
size_n
=
w
.
shape
# Normalize group_size
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
# Inject 2:4 sparsity
w_24
,
mask_24
=
inject_24
(
w
,
size_k
,
size_n
)
# Quantize
w_24_ref
,
q_w_24
,
s
,
g_idx
,
rand_perm
=
quantize_weights
(
w_24
,
num_bits
,
group_size
,
act_order
=
False
)
# Compress quantized weight
q_w_24_comp
,
meta
=
compress_quantized_24_weight
(
q_w_24
,
size_k
,
size_n
,
num_bits
)
size_k_comp
=
size_k
//
2
# Reformat to marlin
marlin_24_q_w_comp
=
marlin_weights
(
q_w_24_comp
,
size_k_comp
,
size_n
,
num_bits
,
marlin_24_perm
[
num_bits
])
marlin_24_s
=
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
,
marlin_24_scale_perm
[
num_bits
],
marlin_24_scale_perm_single
[
num_bits
])
# Create result
res_list
=
[
w_24_ref
,
marlin_24_q_w_comp
,
meta
,
marlin_24_s
]
for
i
in
range
(
len
(
res_list
)):
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
return
res_list
def
compute_max_diff
(
output
,
output_ref
):
return
torch
.
mean
(
torch
.
abs
(
output
-
output_ref
))
/
torch
.
mean
(
torch
.
abs
(
output_ref
))
class
MarlinWorkspace
:
def
__init__
(
self
,
out_features
,
min_thread_n
,
max_parallel
):
assert
(
out_features
%
min_thread_n
==
0
),
(
"out_features = {} is undivisible by min_thread_n = {}"
.
format
(
out_features
,
min_thread_n
))
max_workspace_size
=
((
out_features
//
min_thread_n
)
*
max_parallel
)
self
.
scratch
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
Prev
1
…
13
14
15
16
17
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment