Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e661d594
Commit
e661d594
authored
Aug 12, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.4' into v0.5.4-dtk24.04.1
parents
6b16ea2e
4db5176d
Changes
374
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1043 additions
and
396 deletions
+1043
-396
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
..._executor/layers/quantization/compressed_tensors/utils.py
+5
-9
vllm/model_executor/layers/quantization/fbgemm_fp8.py
vllm/model_executor/layers/quantization/fbgemm_fp8.py
+14
-44
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+35
-15
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+29
-17
vllm/model_executor/layers/quantization/gptq_marlin_24.py
vllm/model_executor/layers/quantization/gptq_marlin_24.py
+19
-10
vllm/model_executor/layers/quantization/kv_cache.py
vllm/model_executor/layers/quantization/kv_cache.py
+2
-4
vllm/model_executor/layers/quantization/qqq.py
vllm/model_executor/layers/quantization/qqq.py
+285
-0
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+70
-67
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
...el_executor/layers/quantization/utils/marlin_utils_fp8.py
+3
-11
vllm/model_executor/layers/quantization/utils/marlin_utils_test.py
...l_executor/layers/quantization/utils/marlin_utils_test.py
+19
-10
vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py
...xecutor/layers/quantization/utils/marlin_utils_test_24.py
+14
-16
vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py
...ecutor/layers/quantization/utils/marlin_utils_test_qqq.py
+125
-0
vllm/model_executor/layers/quantization/utils/quant_utils.py
vllm/model_executor/layers/quantization/utils/quant_utils.py
+152
-52
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+3
-2
vllm/model_executor/layers/rejection_sampler.py
vllm/model_executor/layers/rejection_sampler.py
+39
-56
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+3
-0
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+87
-65
vllm/model_executor/layers/spec_decode_base_sampler.py
vllm/model_executor/layers/spec_decode_base_sampler.py
+2
-2
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+130
-16
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+7
-0
No files found.
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
View file @
e661d594
...
...
@@ -5,6 +5,9 @@ from typing import Any, Dict, Iterable, Optional
from
pydantic
import
BaseModel
,
Field
from
torch.nn
import
Module
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
FUSED_LAYER_NAME_MAPPING
)
class
CompressionFormat
(
Enum
):
dense
=
"dense"
...
...
@@ -86,13 +89,6 @@ def is_activation_quantization_format(format: str) -> bool:
return
format
in
_ACTIVATION_QUANTIZATION_FORMATS
# fused_name: List[shard_name]
_FUSED_LAYER_NAME_MAPPING
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
]
}
def
should_ignore_layer
(
layer_name
:
Optional
[
str
],
ignore
:
Iterable
[
str
])
->
bool
:
if
layer_name
is
None
:
...
...
@@ -106,8 +102,8 @@ def should_ignore_layer(layer_name: Optional[str],
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if
proj_name
in
_
FUSED_LAYER_NAME_MAPPING
:
shard_proj_names
=
_
FUSED_LAYER_NAME_MAPPING
[
proj_name
]
if
proj_name
in
FUSED_LAYER_NAME_MAPPING
:
shard_proj_names
=
FUSED_LAYER_NAME_MAPPING
[
proj_name
]
# Convert fused_name --> [shard_names]
shard_names
=
[
...
...
vllm/model_executor/layers/quantization/fbgemm_fp8.py
View file @
e661d594
...
...
@@ -9,8 +9,11 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.fp8
import
cutlass_fp8_supported
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
is_layer_skipped
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_fp8_linear
,
create_per_channel_scale_param
)
from
vllm.model_executor.utils
import
set_weight_attrs
...
...
@@ -18,14 +21,6 @@ from vllm.platforms import current_platform
logger
=
init_logger
(
__name__
)
# Note: this is a hack. We should update each model to register the
# stacked params and get it from there instead in a future PR.
# fused_name: List[shard_name]
_FUSED_LAYER_NAME_MAPPING
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
]
}
class
FBGEMMFp8Config
(
QuantizationConfig
):
"""Config class for FBGEMM Fp8."""
...
...
@@ -62,37 +57,10 @@ class FBGEMMFp8Config(QuantizationConfig):
input_scale_ub
=
cls
.
get_from_keys
(
config
,
[
"activation_scale_ub"
])
return
cls
(
ignore_list
=
ignore_list
,
input_scale_ub
=
input_scale_ub
)
def
_is_layer_skipped
(
self
,
prefix
:
str
)
->
bool
:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
proj_name
=
prefix
.
split
(
"."
)[
-
1
]
if
proj_name
in
_FUSED_LAYER_NAME_MAPPING
:
shard_prefixes
=
[
prefix
.
replace
(
proj_name
,
shard_proj_name
)
for
shard_proj_name
in
_FUSED_LAYER_NAME_MAPPING
[
proj_name
]
]
is_skipped
=
None
for
shard_prefix
in
shard_prefixes
:
is_shard_skipped
=
shard_prefix
in
self
.
ignore_list
if
is_skipped
is
None
:
is_skipped
=
is_shard_skipped
elif
is_shard_skipped
!=
is_skipped
:
raise
ValueError
(
f
"Detected some but not all shards of
{
prefix
}
"
"are quantized. All shards of fused layers "
"to have the same precision."
)
else
:
is_skipped
=
prefix
in
self
.
ignore_list
assert
is_skipped
is
not
None
return
is_skipped
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
if
isinstance
(
layer
,
LinearBase
):
if
self
.
_
is_layer_skipped
(
prefix
):
if
is_layer_skipped
(
prefix
,
self
.
ignore_list
):
return
UnquantizedLinearMethod
()
return
FBGEMMFp8LinearMethod
(
self
)
return
None
...
...
@@ -105,6 +73,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
def
__init__
(
self
,
quant_config
:
FBGEMMFp8Config
):
self
.
quant_config
=
quant_config
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
def
create_weights
(
self
,
...
...
@@ -172,11 +141,12 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
size_k
=
layer
.
input_size_per_partition
,
bias
=
bias
)
return
apply_fp8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
None
,
input_scale_ub
=
layer
.
input_scale_ub
,
bias
=
bias
,
cutlass_fp8_supported
=
True
,
use_per_token_if_dynamic
=
True
)
return
apply_fp8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
None
,
input_scale_ub
=
layer
.
input_scale_ub
,
bias
=
bias
,
cutlass_fp8_supported
=
self
.
cutlass_fp8_supported
,
use_per_token_if_dynamic
=
True
)
vllm/model_executor/layers/quantization/fp8.py
View file @
e661d594
...
...
@@ -6,17 +6,20 @@ from torch.nn.parameter import Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
fused_moe
)
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethod
Base
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
,
FusedMoEMethodBase
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
Unquantized
LinearMethod
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
is_layer_skipped
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
all_close_1d
,
apply_fp8_linear
,
create_per_tensor_scale_param
,
cutlass_fp8_supported
,
per_tensor_dequantize
,
requantize_with_max_scale
)
all_close_1d
,
apply_fp8_linear
,
convert_to_channelwise
,
create_per_tensor_scale_param
,
cutlass_fp8_supported
,
per_tensor_dequantize
,
requantize_with_max_scale
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.utils
import
print_warning_once
...
...
@@ -33,6 +36,7 @@ class Fp8Config(QuantizationConfig):
self
,
is_checkpoint_fp8_serialized
:
bool
=
False
,
activation_scheme
:
str
=
"dynamic"
,
ignored_layers
:
Optional
[
List
[
str
]]
=
None
,
)
->
None
:
self
.
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
if
is_checkpoint_fp8_serialized
:
...
...
@@ -42,6 +46,7 @@ class Fp8Config(QuantizationConfig):
raise
ValueError
(
f
"Unsupported activation scheme
{
activation_scheme
}
"
)
self
.
activation_scheme
=
activation_scheme
self
.
ignored_layers
=
ignored_layers
or
[]
@
classmethod
def
get_name
(
cls
)
->
str
:
...
...
@@ -64,14 +69,18 @@ class Fp8Config(QuantizationConfig):
quant_method
=
cls
.
get_from_keys
(
config
,
[
"quant_method"
])
is_checkpoint_fp8_serialized
=
(
"fp8"
in
quant_method
)
activation_scheme
=
cls
.
get_from_keys
(
config
,
[
"activation_scheme"
])
ignored_layers
=
cls
.
get_from_keys_or
(
config
,
[
"ignored_layers"
],
None
)
return
cls
(
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
,
activation_scheme
=
activation_scheme
)
activation_scheme
=
activation_scheme
,
ignored_layers
=
ignored_layers
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
from
vllm.attention.layer
import
Attention
# Avoid circular import
if
isinstance
(
layer
,
LinearBase
):
if
is_layer_skipped
(
prefix
,
self
.
ignored_layers
):
return
UnquantizedLinearMethod
()
return
Fp8LinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
return
Fp8MoEMethod
(
self
)
...
...
@@ -170,19 +179,29 @@ class Fp8LinearMethod(LinearMethodBase):
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
input_scale
=
None
# If checkpoint is fp8,
requantize the separately quantized logical
#
weight
s in
to
a
single fp8 weight with a single weight sca
le
.
# If checkpoint is fp8,
handle that there are N scales for N
#
shard
s in a
fused modu
le
else
:
# Dequant -> Quant with max scale.
max_w_scale
,
weight
=
requantize_with_max_scale
(
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
logical_widths
=
layer
.
logical_widths
,
)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if
self
.
use_marlin
:
weight
=
layer
.
weight
weight_scale
=
convert_to_channelwise
(
layer
.
weight_scale
,
layer
.
logical_widths
)
# If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight.
else
:
# Dequant -> Quant with max scale so we can run per tensor.
weight_scale
,
weight
=
requantize_with_max_scale
(
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
logical_widths
=
layer
.
logical_widths
,
)
# Update layer with new values.
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
max_w
_scale
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight
_scale
,
requires_grad
=
False
)
if
self
.
quant_config
.
activation_scheme
==
"static"
:
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
...
...
@@ -384,6 +403,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_moe
return
fused_moe
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
e661d594
...
...
@@ -10,11 +10,12 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
apply_gptq_marlin_linear
,
check_
gptq_
marlin_supported
,
marlin_is_k_full
,
apply_gptq_marlin_linear
,
check_marlin_supported
,
marlin_is_k_full
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
marlin_repeat_scales_on_all_ranks
,
marlin_sort_g_idx
,
replace_tensor
,
verify_
gptq_
marlin_supported
,
verify_marlin_supports_shape
)
verify_marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.scalar_type
import
scalar_types
logger
=
init_logger
(
__name__
)
...
...
@@ -22,6 +23,12 @@ logger = init_logger(__name__)
class
GPTQMarlinConfig
(
QuantizationConfig
):
"""Config class for GPTQ Marlin"""
# (num_bits, is_sym) -> quant_type
TYPE_MAP
=
{
(
4
,
True
):
scalar_types
.
uint4b8
,
(
8
,
True
):
scalar_types
.
uint8b128
,
}
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
desc_act
:
bool
,
is_sym
:
bool
,
lm_head_quantized
:
bool
)
->
None
:
if
desc_act
and
group_size
==
-
1
:
...
...
@@ -29,20 +36,23 @@ class GPTQMarlinConfig(QuantizationConfig):
# (since we have only one group per output channel)
desc_act
=
False
self
.
weight_bits
=
weight_bits
self
.
pack_factor
=
32
//
self
.
weight_bits
# packed into int32
self
.
pack_factor
=
32
//
weight_bits
# packed into int32
self
.
group_size
=
group_size
self
.
desc_act
=
desc_act
self
.
is_sym
=
is_sym
self
.
lm_head_quantized
=
lm_head_quantized
if
(
weight_bits
,
is_sym
)
not
in
self
.
TYPE_MAP
:
raise
ValueError
(
"Unsupported quantization config: "
f
"bits=
{
weight_bits
}
, sym=
{
is_sym
}
"
)
self
.
quant_type
=
self
.
TYPE_MAP
[(
weight_bits
,
is_sym
)]
# Verify supported on platform.
verify_gptq_marlin_supported
(
num_bits
=
self
.
weight_bits
,
group_size
=
self
.
group_size
,
is_sym
=
self
.
is_sym
)
verify_marlin_supported
(
quant_type
=
self
.
quant_type
,
group_size
=
self
.
group_size
)
def
__repr__
(
self
)
->
str
:
return
(
f
"GPTQMarlinConfig(
weight_bits=
{
self
.
weight_bits
}
, "
return
(
f
"GPTQMarlinConfig(
quant_type=
{
self
.
quant_type
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"desc_act=
{
self
.
desc_act
}
, "
f
"lm_head_quantized=
{
self
.
lm_head_quantized
}
)"
)
...
...
@@ -79,7 +89,8 @@ class GPTQMarlinConfig(QuantizationConfig):
user_quant
)
->
Optional
[
str
]:
can_convert
=
cls
.
is_gptq_marlin_compatible
(
hf_quant_cfg
)
is_valid_user_quant
=
(
user_quant
is
None
or
user_quant
==
"marlin"
)
is_valid_user_quant
=
(
user_quant
is
None
or
user_quant
==
"marlin"
or
user_quant
==
"gptq_marlin"
)
if
can_convert
and
is_valid_user_quant
:
msg
=
(
"The model is convertible to {} during runtime."
...
...
@@ -121,11 +132,12 @@ class GPTQMarlinConfig(QuantizationConfig):
or
desc_act
is
None
):
return
False
return
check_gptq_marlin_supported
(
num_bits
=
num_bits
,
group_size
=
group_size
,
is_sym
=
sym
,
min_capability
=
cls
.
get_min_capability
())
if
(
num_bits
,
sym
)
not
in
cls
.
TYPE_MAP
:
return
False
return
check_marlin_supported
(
quant_type
=
cls
.
TYPE_MAP
[(
num_bits
,
sym
)],
group_size
=
group_size
,
min_capability
=
cls
.
get_min_capability
())
class
GPTQMarlinLinearMethod
(
LinearMethodBase
):
...
...
@@ -292,7 +304,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
perm
=
layer
.
g_idx_sort_indices
,
size_k
=
layer
.
input_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
num_bits
=
self
.
quant_config
.
weight
_bits
)
num_bits
=
self
.
quant_config
.
quant_type
.
size
_bits
)
replace_tensor
(
layer
,
"qweight"
,
marlin_qweight
)
# Permute scales from autogptq format to marlin format.
...
...
@@ -318,7 +330,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
g_idx
=
layer
.
g_idx
,
g_idx_sort_indices
=
layer
.
g_idx_sort_indices
,
workspace
=
layer
.
workspace
,
num_bits
=
self
.
quant_config
.
weight_bits
,
wtype
=
self
.
quant_config
.
quant_type
,
output_size_per_partition
=
layer
.
output_size_per_partition
,
input_size_per_partition
=
layer
.
input_size_per_partition
,
is_k_full
=
layer
.
is_k_full
,
...
...
vllm/model_executor/layers/quantization/gptq_marlin_24.py
View file @
e661d594
...
...
@@ -9,6 +9,7 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.scalar_type
import
scalar_types
logger
=
init_logger
(
__name__
)
...
...
@@ -17,9 +18,10 @@ GPTQ_MARLIN_24_MIN_THREAD_N = 128
GPTQ_MARLIN_24_MIN_THREAD_K
=
128
GPTQ_MARLIN_24_MAX_PARALLEL
=
64
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
=
[
4
,
8
]
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
=
[
scalar_types
.
uint4b8
,
scalar_types
.
uint8b128
]
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
=
[
-
1
,
128
]
GPTQ_MARLIN_24_SUPPORTED_SYM
=
[
True
]
class
GPTQMarlin24Config
(
QuantizationConfig
):
...
...
@@ -31,14 +33,19 @@ class GPTQMarlin24Config(QuantizationConfig):
weight_bits
:
int
,
group_size
:
int
,
)
->
None
:
self
.
weight_bits
=
weight_bits
quant_type
=
{
4
:
scalar_types
.
uint4b8
,
8
:
scalar_types
.
uint8b128
,
}.
get
(
weight_bits
)
self
.
group_size
=
group_size
# Verify
if
self
.
weight_bits
not
in
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
:
if
quant_type
is
None
or
\
quant_type
not
in
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
:
raise
ValueError
(
f
"Marlin_24 does not support
weight_bits =
{
self
.
weight_bits
}
. "
f
"Only weight_bits =
{
GPTQ_MARLIN_24_SUPPORTED_
NUM_BIT
S
}
"
f
"Marlin_24 does not support
quant_type =
{
quant_type
}
. "
f
"Only weight_bits =
{
GPTQ_MARLIN_24_SUPPORTED_
QUANT_TYPE
S
}
"
"are supported."
)
if
self
.
group_size
not
in
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
:
raise
ValueError
(
...
...
@@ -46,8 +53,10 @@ class GPTQMarlin24Config(QuantizationConfig):
f
"Only group_sizes =
{
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
}
"
"are supported."
)
self
.
quant_type
=
quant_type
# 4 Bits packed into 32 bit datatype.
self
.
pack_factor
=
32
//
self
.
weight
_bits
self
.
pack_factor
=
32
//
self
.
quant_type
.
size
_bits
# Tile size used by marlin kernels.
self
.
tile_size
=
16
...
...
@@ -66,8 +75,8 @@ class GPTQMarlin24Config(QuantizationConfig):
self
.
perm_len
=
1024
def
__repr__
(
self
)
->
str
:
return
"Marlin24Config(
weight_bits
={}, group_size={})"
.
format
(
self
.
weight_bits
,
self
.
group_size
)
return
"Marlin24Config(
quant_type
={}, group_size={})"
.
format
(
self
.
quant_type
,
self
.
group_size
)
@
classmethod
def
get_name
(
cls
)
->
str
:
...
...
@@ -279,7 +288,7 @@ class GPTQMarlin24LinearMethod(LinearMethodBase):
output_2d
=
ops
.
gptq_marlin_24_gemm
(
x_2d
,
qweight
,
meta
,
scales
,
workspace
,
self
.
quant_config
.
weight_bits
,
self
.
quant_config
.
quant_type
,
size_m
,
size_n
,
size_k
)
output
=
output_2d
.
view
(
x
.
shape
[:
-
1
]
+
(
output_2d
.
shape
[
1
],
))
...
...
vllm/model_executor/layers/quantization/kv_cache.py
View file @
e661d594
...
...
@@ -46,10 +46,8 @@ class BaseKVCacheMethod(QuantizeMethodBase):
elif
layer
.
k_scale
<
0.0
and
layer
.
v_scale
<
0.0
:
# If no scales were loaded (both scales are invalid negative
# values), use the default value of 1.0
k_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
1.0
),
requires_grad
=
False
)
v_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
1.0
),
requires_grad
=
False
)
k_scale
=
1.0
v_scale
=
1.0
else
:
# If we find a single kv_scale in the checkpoint, we remap
# kv_scale to k_scale during weight loading, and duplicate
...
...
vllm/model_executor/layers/quantization/qqq.py
0 → 100644
View file @
e661d594
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
logger
=
init_logger
(
__name__
)
MARLIN_QQQ_TILE
=
16
MARLIN_QQQ_MIN_THREAD_N
=
64
MARLIN_QQQ_MIN_THREAD_K
=
128
MARLIN_QQQ_MAX_PARALLEL
=
16
MARLIN_QQQ_SUPPORTED_NUM_BITS
=
[
4
]
MARLIN_QQQ_SUPPORTED_GROUP_SIZES
=
[
-
1
,
128
]
MARLIN_QQQ_SUPPORTED_SYM
=
[
True
]
class
QQQConfig
(
QuantizationConfig
):
"""Config class for QQQ
Reference: https://arxiv.org/pdf/2406.09904
"""
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
is_sym
:
bool
=
True
,
)
->
None
:
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
is_sym
=
is_sym
# Verify
if
self
.
weight_bits
not
in
MARLIN_QQQ_SUPPORTED_NUM_BITS
:
raise
ValueError
(
f
"QQQ does not support weight_bits =
{
self
.
weight_bits
}
. "
f
"Only weight_bits =
{
MARLIN_QQQ_SUPPORTED_NUM_BITS
}
"
"are supported."
)
if
self
.
group_size
not
in
MARLIN_QQQ_SUPPORTED_GROUP_SIZES
:
raise
ValueError
(
f
"QQQ does not support group_size =
{
self
.
group_size
}
. "
f
"Only group_sizes =
{
MARLIN_QQQ_SUPPORTED_GROUP_SIZES
}
"
"are supported."
)
if
self
.
is_sym
not
in
MARLIN_QQQ_SUPPORTED_SYM
:
raise
ValueError
(
f
"QQQ does not support is_sym =
{
self
.
is_sym
}
. "
f
"Only sym =
{
MARLIN_QQQ_SUPPORTED_SYM
}
are supported."
)
# 4 Bits packed into 32 bit datatype.
self
.
pack_factor
=
32
//
self
.
weight_bits
# Tile size used by QQQ kernels.
self
.
tile_size
=
MARLIN_QQQ_TILE
# Min out_features dim
self
.
min_n_threads
=
MARLIN_QQQ_MIN_THREAD_N
# Min in_features dim
self
.
min_k_threads
=
MARLIN_QQQ_MIN_THREAD_K
# Max parallel problems to solve at once (improves large
# batch performance)
self
.
max_parallel
=
MARLIN_QQQ_MAX_PARALLEL
# Permutation length used by the QQQ kernels.
self
.
perm_len
=
1024
def
__repr__
(
self
)
->
str
:
return
"QQQConfig(weight_bits={}, group_size={})"
.
format
(
self
.
weight_bits
,
self
.
group_size
)
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"qqq"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
"""List of filenames to search for in the model directory."""
return
[
"quant_config.json"
,
"quantize_config.json"
,
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"QQQConfig"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"wbits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
return
cls
(
weight_bits
,
group_size
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QQQLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
QQQLinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
QQQLinearMethod
(
LinearMethodBase
):
"""Linear method for QQQ.
Args:
quant_config: The QQQ quantization config.
"""
def
__init__
(
self
,
quant_config
:
QQQConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
if
params_dtype
!=
torch
.
float16
:
raise
ValueError
(
f
"The params dtype must be float16, but got
{
params_dtype
}
"
)
# Validate output_size_per_partition
output_size_per_partition
=
sum
(
output_partition_sizes
)
if
output_size_per_partition
%
self
.
quant_config
.
min_n_threads
!=
0
:
raise
ValueError
(
f
"Weight output_size_per_partition = "
f
"
{
output_size_per_partition
}
is not divisible by "
f
"min_n_threads =
{
self
.
quant_config
.
min_n_threads
}
."
)
if
output_size_per_partition
%
self
.
quant_config
.
pack_factor
!=
0
:
raise
ValueError
(
f
"Weight output_size_per_partition = "
f
"
{
output_size_per_partition
}
is not divisible by "
f
"pack_factor =
{
self
.
quant_config
.
pack_factor
}
."
)
# Validate input_size_per_partition
if
input_size_per_partition
%
self
.
quant_config
.
min_k_threads
!=
0
:
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible by "
f
"min_k_threads =
{
self
.
quant_config
.
min_k_threads
}
."
)
if
(
self
.
quant_config
.
group_size
!=
-
1
and
input_size_per_partition
%
self
.
quant_config
.
group_size
!=
0
):
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible by "
f
"group_size =
{
self
.
quant_config
.
group_size
}
."
)
# Check that we have at least 4 tiles horizontally in the shard
num_tiles_per_perm
=
self
.
quant_config
.
perm_len
//
(
self
.
quant_config
.
tile_size
**
2
)
if
output_size_per_partition
%
num_tiles_per_perm
!=
0
:
raise
ValueError
(
"Each permutation group must reside on the same gpu"
)
# Quantized 4Bit weights packed into Int32.
qweight
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
tile_size
,
output_size_per_partition
*
self
.
quant_config
.
tile_size
//
self
.
quant_config
.
pack_factor
,
device
=
"cuda"
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
"marlin_tile_size"
:
self
.
quant_config
.
tile_size
,
},
)
s_channel
=
Parameter
(
torch
.
empty
(
1
,
output_size_per_partition
,
device
=
"cuda"
,
dtype
=
torch
.
float
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
s_channel
,
{
"input_dim"
:
None
,
"output_dim"
:
1
,
},
)
if
self
.
quant_config
.
group_size
==
-
1
:
s_group
=
Parameter
(
torch
.
tensor
(
[],
device
=
"cuda"
,
dtype
=
torch
.
half
,
),
requires_grad
=
False
,
)
else
:
s_group
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
,
device
=
"cuda"
,
dtype
=
torch
.
half
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
s_group
,
{
"input_dim"
:
None
if
self
.
quant_config
.
group_size
==
-
1
else
0
,
"output_dim"
:
None
if
self
.
quant_config
.
group_size
==
-
1
else
1
,
},
)
# Allocate workspace (Used for internal locking mechanism)
max_workspace_size
=
(
output_size_per_partition
//
self
.
quant_config
.
min_n_threads
)
*
self
.
quant_config
.
max_parallel
workspace
=
Parameter
(
torch
.
zeros
(
max_workspace_size
,
device
=
"cuda"
,
dtype
=
torch
.
int
),
requires_grad
=
False
)
layer
.
register_parameter
(
"B"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"s_channel"
,
s_channel
)
set_weight_attrs
(
s_channel
,
extra_weight_attrs
)
layer
.
register_parameter
(
"s_group"
,
s_group
)
set_weight_attrs
(
s_group
,
extra_weight_attrs
)
layer
.
register_parameter
(
"workspace"
,
workspace
)
set_weight_attrs
(
workspace
,
extra_weight_attrs
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
qweight
=
layer
.
B
s_ch
=
layer
.
s_channel
s_group
=
layer
.
s_group
workspace
=
layer
.
workspace
x_2d
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
size_m
=
x_2d
.
shape
[
0
]
size_k
=
x_2d
.
shape
[
1
]
size_n
=
s_ch
.
shape
[
1
]
x_int8
,
s_tok
=
ops
.
scaled_int8_quant
(
x_2d
)
output_2d
=
ops
.
marlin_qqq_gemm
(
x_int8
,
qweight
,
s_tok
,
s_ch
,
s_group
,
workspace
,
size_m
,
size_n
,
size_k
)
output
=
output_2d
.
view
(
x
.
shape
[:
-
1
]
+
(
output_2d
.
shape
[
1
],
))
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
e661d594
...
...
@@ -5,6 +5,7 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
.quant_utils
import
pack_cols
,
unpack_cols
...
...
@@ -13,80 +14,78 @@ GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K
=
128
GPTQ_MARLIN_MAX_PARALLEL
=
16
MARLIN_SUPPORTED_NUM_BITS
=
[
4
,
8
]
MARLIN_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
# In case there is a performance issue with Marlin, the variable below can be
# changed to False, which allows Marlin to perform global reductions in fp16
# precision (instead of fp32), and therefore, save on some memory movements.
USE_FP32_REDUCE_DEFAULT
=
True
def
_check_marlin_supported
(
num_bits
:
int
,
group_size
:
int
,
is_sym
:
bool
,
min_capability
:
Optional
[
int
],
has_zp
:
bool
)
->
Tuple
[
bool
,
Optional
[
str
]]:
if
min_capability
is
not
None
:
# For binary size and compile time, we don't support the same types for with and
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
# TODO: we may want to move this into the C++ so its closer to the actual impl
def
query_marlin_supported_quant_types
(
has_zp
:
bool
,
min_capability
:
Optional
[
int
]
=
None
):
if
min_capability
is
None
:
major
,
minor
=
current_platform
.
get_device_capability
()
device_capability
=
major
*
10
+
minor
if
device_capability
<
min_capability
:
return
(
False
,
"Marlin does not support device_capability = {}"
", the min_capability required is {}"
.
format
(
device_capability
,
min_capability
))
if
num_bits
not
in
MARLIN_SUPPORTED_NUM_BITS
:
return
(
False
,
"Marlin does not support weight_bits = {}. "
"Only weight_bits = {} are supported."
.
format
(
num_bits
,
MARLIN_SUPPORTED_NUM_BITS
))
if
group_size
not
in
MARLIN_SUPPORTED_GROUP_SIZES
:
return
(
False
,
"Marlin does not support group_size = {}. Only "
"group_sizes = {} are supported."
.
format
(
group_size
,
MARLIN_SUPPORTED_GROUP_SIZES
))
if
not
has_zp
and
not
is_sym
:
return
(
False
,
"Marlin without zero_points must have symmetric quantization"
)
min_capability
=
major
*
10
+
minor
return
True
,
None
if
min_capability
<
80
:
return
[]
if
has_zp
:
# AWQ style, unsigned + runtime zero-point
return
[
scalar_types
.
uint4
,
scalar_types
.
uint8
]
else
:
# GPTQ style, unsigned + symmetric bias
# TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
# to add `scalar_types.float8_e4m3fn` here
return
[
scalar_types
.
uint4b8
,
scalar_types
.
uint8b128
]
def
check_gptq_marlin_supported
(
num_bits
:
int
,
group_size
:
int
,
is_sym
:
bool
,
min_capability
:
int
)
->
bool
:
cond
,
_
=
_check_marlin_supported
(
num_bits
,
group_size
,
is_sym
,
min_capability
,
has_zp
=
False
)
return
cond
def
_check_marlin_supported
(
quant_type
:
ScalarType
,
group_size
:
Optional
[
int
],
has_zp
:
bool
,
min_capability
:
Optional
[
int
]
=
None
)
->
Tuple
[
bool
,
Optional
[
str
]]:
def
check_awq_marlin_supported
(
num_bits
:
int
,
group_size
:
int
,
has_zp
:
bool
,
min_capability
:
int
)
->
bool
:
cond
,
_
=
_check_marlin_supported
(
num_bits
,
group_size
,
False
,
min_capability
,
has_zp
=
has_zp
)
return
cond
if
min_capability
is
None
:
major
,
minor
=
current_platform
.
get_device_capability
()
min_capability
=
major
*
10
+
minor
supported_types
=
query_marlin_supported_quant_types
(
has_zp
,
min_capability
)
def
verify_gptq_marlin_supported
(
num_bits
:
int
,
group_size
:
int
,
is_sym
:
bool
)
->
None
:
cond
,
err_msg
=
_check_marlin_supported
(
num_bits
,
group_size
,
is_sym
,
min_capability
=
None
,
has_zp
=
False
)
if
not
cond
:
assert
err_msg
is
not
None
raise
ValueError
(
"GPTQ"
+
err_msg
)
if
quant_type
not
in
supported_types
:
return
(
False
,
f
"Marlin does not support weight_bits =
{
quant_type
}
. "
f
"Only types =
{
supported_types
}
"
f
"are supported (for group_size =
{
group_size
}
, "
f
"min_capability =
{
min_capability
}
, zp =
{
has_zp
}
)."
)
if
(
group_size
is
None
or
group_size
not
in
MARLIN_SUPPORTED_GROUP_SIZES
):
return
(
False
,
f
"Marlin does not support group_size =
{
group_size
}
. "
f
"Only group_sizes =
{
MARLIN_SUPPORTED_GROUP_SIZES
}
"
"are supported."
)
return
True
,
None
def
check_marlin_supported
(
quant_type
:
ScalarType
,
group_size
:
int
,
has_zp
:
bool
=
False
,
min_capability
:
Optional
[
int
]
=
None
)
->
bool
:
cond
,
_
=
_check_marlin_supported
(
quant_type
,
group_size
,
has_zp
,
min_capability
)
return
cond
def
verify_awq_marlin_supported
(
num_bits
:
int
,
group_size
:
int
,
has_zp
:
bool
)
->
None
:
cond
,
err_msg
=
_check_marlin_supported
(
num_bits
,
group_size
,
False
,
min_capability
=
None
,
has_zp
=
has_zp
)
def
verify_marlin_supported
(
quant_type
:
ScalarType
,
group_size
:
int
,
has_zp
:
bool
=
False
)
->
None
:
cond
,
err_msg
=
_check_marlin_supported
(
quant_type
,
group_size
,
has_zp
)
if
not
cond
:
assert
err_msg
is
not
None
raise
ValueError
(
"AWQ"
+
err_msg
)
raise
ValueError
(
err_msg
)
def
verify_marlin_supports_shape
(
output_size_per_partition
:
int
,
...
...
@@ -240,11 +239,12 @@ def apply_gptq_marlin_linear(
g_idx
:
torch
.
Tensor
,
g_idx_sort_indices
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_bits
:
int
,
wtype
:
ScalarType
,
output_size_per_partition
:
int
,
input_size_per_partition
:
int
,
is_k_full
:
bool
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
use_fp32_reduce
:
bool
=
USE_FP32_REDUCE_DEFAULT
)
->
torch
.
Tensor
:
reshaped_x
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
out_shape
=
input
.
shape
[:
-
1
]
+
(
output_size_per_partition
,
)
...
...
@@ -255,12 +255,13 @@ def apply_gptq_marlin_linear(
g_idx
,
g_idx_sort_indices
,
workspace
,
num_bits
,
wtype
,
size_m
=
reshaped_x
.
shape
[
0
],
size_n
=
output_size_per_partition
,
size_k
=
input_size_per_partition
,
is_k_full
=
is_k_full
,
has_zp
=
False
)
has_zp
=
False
,
use_fp32_reduce
=
use_fp32_reduce
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
...
...
@@ -276,10 +277,11 @@ def apply_awq_marlin_linear(
g_idx
:
torch
.
Tensor
,
g_idx_sort_indices
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_bits
:
int
,
quant_type
:
ScalarType
,
output_size_per_partition
:
int
,
input_size_per_partition
:
int
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
use_fp32_reduce
:
bool
=
USE_FP32_REDUCE_DEFAULT
)
->
torch
.
Tensor
:
reshaped_x
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
out_shape
=
input
.
shape
[:
-
1
]
+
(
output_size_per_partition
,
)
...
...
@@ -290,12 +292,13 @@ def apply_awq_marlin_linear(
g_idx
,
g_idx_sort_indices
,
workspace
,
num_bits
,
quant_type
,
size_m
=
reshaped_x
.
shape
[
0
],
size_n
=
output_size_per_partition
,
size_k
=
input_size_per_partition
,
is_k_full
=
True
,
has_zp
=
True
)
has_zp
=
True
,
use_fp32_reduce
=
use_fp32_reduce
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
View file @
e661d594
...
...
@@ -46,7 +46,8 @@ def apply_fp8_marlin_linear(
return
output
.
reshape
(
out_shape
)
def
prepare_fp8_layer_for_marlin
(
layer
:
torch
.
nn
.
Module
)
->
None
:
def
prepare_fp8_layer_for_marlin
(
layer
:
torch
.
nn
.
Module
,
strategy
:
str
=
"tensor"
)
->
None
:
print_warning_once
(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
...
...
@@ -74,16 +75,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
layer
.
weight
=
torch
.
nn
.
Parameter
(
marlin_qweight
,
requires_grad
=
False
)
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
is_channelwise
=
(
len
(
layer
.
weight_scale
.
shape
)
>
0
and
layer
.
weight_scale
.
shape
[
0
]
==
part_size_n
)
if
is_channelwise
:
scales
=
layer
.
weight_scale
else
:
scales
=
layer
.
weight_scale
.
repeat
(
1
,
part_size_n
)
scales
=
scales
.
to
(
layer
.
orig_dtype
).
to
(
device
)
scales
=
layer
.
weight_scale
.
to
(
layer
.
orig_dtype
)
# Permute scales
marlin_scales
=
marlin_permute_scales
(
s
=
scales
,
size_k
=
part_size_k
,
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils_test.py
View file @
e661d594
...
...
@@ -5,10 +5,12 @@ from typing import List
import
numpy
as
np
import
torch
from
vllm.scalar_type
import
ScalarType
from
.marlin_utils
import
(
GPTQ_MARLIN_TILE
,
marlin_permute_scales
,
marlin_zero_points
)
from
.quant_utils
import
(
get_pack_factor
,
quantize_weights
,
quantize_weights
_with_zp
,
sort_weights
)
from
.quant_utils
import
(
get_pack_factor
,
gptq_
quantize_weights
,
quantize_weights
,
sort_weights
)
class
MarlinWorkspace
:
...
...
@@ -90,9 +92,10 @@ def get_weight_perm(num_bits: int):
return
perm
def
marlin_quantize
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
def
marlin_quantize
(
w
:
torch
.
Tensor
,
quant_type
:
ScalarType
,
group_size
:
int
,
act_order
:
bool
):
size_k
,
size_n
=
w
.
shape
num_bits
=
quant_type
.
size_bits
# Normalize group_size
if
group_size
==
-
1
:
...
...
@@ -100,8 +103,8 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int,
assert
group_size
<=
size_k
# Quantize (and apply act_order if provided)
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
=
quantize_weights
(
w
,
num_bits
,
group_size
,
act_order
)
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
=
gptq_
quantize_weights
(
w
,
quant_type
,
group_size
,
act_order
)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing
...
...
@@ -122,7 +125,8 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int,
return
res_list
def
awq_marlin_quantize
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
):
def
awq_marlin_quantize
(
w
:
torch
.
Tensor
,
quant_type
:
ScalarType
,
group_size
:
int
):
size_k
,
size_n
=
w
.
shape
# Normalize group_size
...
...
@@ -135,13 +139,18 @@ def awq_marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int):
num_groups
=
size_k
//
group_size
# Quantize with zp
w_ref
,
q_w
,
s
,
zp
=
quantize_weights_with_zp
(
w
,
num_bits
,
group_size
)
w_ref
,
q_w
,
s
,
zp
=
quantize_weights
(
w
,
quant_type
,
group_size
,
zero_points
=
True
)
# Reformat to marlin
weight_perm
=
get_weight_perm
(
num_bits
)
marlin_q_w
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
weight_perm
)
weight_perm
=
get_weight_perm
(
quant_type
.
size_bits
)
marlin_q_w
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
quant_type
.
size_bits
,
weight_perm
)
marlin_s
=
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
)
marlin_zp
=
marlin_zero_points
(
zp
,
num_groups
,
size_n
,
num_bits
)
marlin_zp
=
marlin_zero_points
(
zp
,
num_groups
,
size_n
,
quant_type
.
size_bits
)
# Create result
res_list
=
[
w_ref
,
marlin_q_w
,
marlin_s
,
marlin_zp
]
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py
View file @
e661d594
...
...
@@ -6,8 +6,10 @@ from typing import List
import
numpy
import
torch
from
vllm.scalar_type
import
ScalarType
from
.marlin_utils_test
import
marlin_weights
from
.quant_utils
import
quantize_weights
from
.quant_utils
import
gptq_
quantize_weights
# This is PyTorch implementation of main part of reorder_meta()
...
...
@@ -348,13 +350,11 @@ def check_24(w, num_rows_to_sample=50, _verbose=False):
print
(
f
"
{
non_24_segments
}
/
{
total_segments
}
do not have 2:4 structure."
)
def
compress_quantized_24_weight
(
q_24
,
size_k
,
size_n
,
num_bits
):
def
compress_quantized_24_weight
(
q_24
,
size_k
,
size_n
,
wtype
:
ScalarType
):
assert
q_24
.
shape
==
(
size_k
,
size_n
)
# Remove zp to normalize over 0
max_q_val
=
(
1
<<
num_bits
)
-
1
zp
=
(
max_q_val
+
1
)
//
2
q_24_no_zp
=
q_24
-
zp
# Remove bias to normalize over 0
q_24_no_zp
=
q_24
-
wtype
.
bias
# Compress
q_24_no_zp
=
q_24_no_zp
.
t
().
contiguous
()
...
...
@@ -362,8 +362,8 @@ def compress_quantized_24_weight(q_24, size_k, size_n, num_bits):
q_24_no_zp
)
q_24_no_zp_comp
=
q_24_no_zp_comp
.
t
().
contiguous
()
# Restore
zp
q_24_comp
=
q_24_no_zp_comp
+
zp
# Restore
bias
q_24_comp
=
q_24_no_zp_comp
+
wtype
.
bias
# Resize meta to its actual shape (without moving any data)
meta
=
meta
.
resize_
(
meta
.
shape
[
1
]
//
2
,
meta
.
shape
[
0
]
*
2
)
...
...
@@ -427,7 +427,7 @@ def marlin_permute_scales_24(s: torch.Tensor, size_k: int, size_n: int,
def
marlin_24_quantize
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
quant_type
:
ScalarType
,
group_size
:
int
,
):
size_k
,
size_n
=
w
.
shape
...
...
@@ -441,20 +441,18 @@ def marlin_24_quantize(
w_24
,
mask_24
=
inject_24
(
w
,
size_k
,
size_n
)
# Quantize
w_24_ref
,
q_w_24
,
s
,
g_idx
,
rand_perm
=
quantize_weights
(
w_24
,
num_bits
,
group_size
,
act_order
=
False
)
w_24_ref
,
q_w_24
,
s
,
g_idx
,
rand_perm
=
gptq_quantize_weights
(
w_24
,
quant_type
,
group_size
,
act_order
=
False
)
# Compress quantized weight
q_w_24_comp
,
meta
=
compress_quantized_24_weight
(
q_w_24
,
size_k
,
size_n
,
num_bits
)
quant_type
)
size_k_comp
=
size_k
//
2
# Reformat to marlin
weight_perm
=
get_weight_perm_24
(
num
_bits
)
weight_perm
=
get_weight_perm_24
(
quant_type
.
size
_bits
)
marlin_24_q_w_comp
=
marlin_weights
(
q_w_24_comp
,
size_k_comp
,
size_n
,
num
_bits
,
weight_perm
)
quant_type
.
size
_bits
,
weight_perm
)
marlin_24_s
=
marlin_permute_scales_24
(
s
,
size_k
,
size_n
,
group_size
)
# Create result
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py
0 → 100644
View file @
e661d594
from
typing
import
List
import
numpy
import
torch
from
.marlin_utils_test
import
marlin_permute_weights
from
.quant_utils
import
get_pack_factor
,
qqq_quantize_weights
def
marlin_qqq_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
perm
,
group_size
):
# Permute
q_w
=
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
perm
)
# Pack
pack_factor
=
get_pack_factor
(
num_bits
)
orig_device
=
q_w
.
device
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_packed
=
numpy
.
zeros
((
q_w
.
shape
[
0
],
q_w
.
shape
[
1
]
//
pack_factor
),
dtype
=
numpy
.
uint32
)
if
group_size
==
size_k
:
for
i
in
range
(
pack_factor
):
q_packed
|=
(
q_w
[:,
i
::
pack_factor
]
&
0xF
)
<<
num_bits
*
i
else
:
for
i
in
range
(
pack_factor
):
q_packed
|=
q_w
[:,
i
::
pack_factor
]
<<
num_bits
*
i
q_packed
=
torch
.
from_numpy
(
q_packed
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
return
q_packed
def
get_qqq_scale_perms
():
scale_perm
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
:
List
[
int
]
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
scale_perm
,
scale_perm_single
# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501
def
get_qqq_weight_perm
(
num_bits
:
int
,
quant_type
:
str
):
perm_list
:
List
[
int
]
=
[]
for
i
in
range
(
32
):
perm1
:
List
[
int
]
=
[]
col
=
i
//
4
for
block
in
[
0
,
1
]:
for
row
in
[
4
*
(
i
%
4
),
4
*
(
i
%
4
)
+
1
,
4
*
(
i
%
4
)
+
2
,
4
*
(
i
%
4
)
+
3
,
]:
perm1
.
append
(
16
*
row
+
col
+
8
*
block
)
for
j
in
range
(
4
):
perm_list
.
extend
([
p
+
256
*
j
for
p
in
perm1
])
perm
=
numpy
.
array
(
perm_list
)
assert
quant_type
in
[
"per-channel"
,
"per-group"
],
"not supported quantization type"
if
num_bits
==
4
:
if
quant_type
==
"per-channel"
:
interleave
=
numpy
.
array
([
4
,
0
,
5
,
1
,
6
,
2
,
7
,
3
])
else
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
else
:
raise
Exception
(
"num_bits must be 4, got {}"
.
format
(
num_bits
))
perm
=
perm
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
perm
=
torch
.
from_numpy
(
perm
)
return
perm
def
marlin_qqq_permute_scales
(
s_group
,
s_channel
,
size_k
,
size_n
,
group_size
):
scale_perm
,
scale_perm_single
=
get_qqq_scale_perms
()
if
group_size
<
size_k
and
group_size
!=
-
1
:
s_group
=
s_group
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
s_channel
=
s_channel
.
reshape
(
(
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s_group
=
s_group
.
reshape
((
-
1
,
size_n
)).
contiguous
()
else
:
s_channel
=
s_channel
.
reshape
(
(
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s_channel
=
s_channel
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
s_group
,
s_channel
def
marlin_qqq_quantize
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
):
size_k
,
size_n
=
w
.
shape
# Normalize group_size
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
quant_type
=
"per-channel"
if
group_size
==
size_k
else
"per-group"
# Quantize
w_ref
,
q_w
,
s_group
,
s_channel
=
qqq_quantize_weights
(
w
,
num_bits
,
group_size
)
# Reformat to marlin_qqq
weight_perm
=
get_qqq_weight_perm
(
num_bits
,
quant_type
)
marlin_qqq_q_w
=
marlin_qqq_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
weight_perm
,
group_size
)
marlin_qqq_s_group
,
marlin_qqq_s_channel
=
marlin_qqq_permute_scales
(
s_group
,
s_channel
,
size_k
,
size_n
,
group_size
)
# Create result
res_list
=
[
w_ref
,
marlin_qqq_q_w
,
marlin_qqq_s_group
,
marlin_qqq_s_channel
]
for
i
in
range
(
len
(
res_list
)):
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
return
res_list
vllm/model_executor/layers/quantization/utils/quant_utils.py
View file @
e661d594
"""This file is used for /tests and /benchmarks"""
from
typing
import
List
import
numpy
import
torch
SUPPORTED_NUM_BITS
=
[
4
,
8
]
from
vllm.model_executor.layers.quantization.qqq
import
(
MARLIN_QQQ_SUPPORTED_NUM_BITS
)
from
vllm.scalar_type
import
ScalarType
,
scalar_types
SUPPORTED_GPTQ_QUANT_TYPES
=
[
scalar_types
.
uint4b8
,
scalar_types
.
uint8b128
]
SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
# Note: this is a hack. We should update each model to register the
# stacked params and get it from there instead in a future PR.
# fused_name: List[shard_name]
FUSED_LAYER_NAME_MAPPING
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
]
}
def
is_layer_skipped
(
prefix
:
str
,
ignored_layers
:
List
[
str
])
->
bool
:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
proj_name
=
prefix
.
split
(
"."
)[
-
1
]
if
proj_name
in
FUSED_LAYER_NAME_MAPPING
:
shard_prefixes
=
[
prefix
.
replace
(
proj_name
,
shard_proj_name
)
for
shard_proj_name
in
FUSED_LAYER_NAME_MAPPING
[
proj_name
]
]
is_skipped
=
None
for
shard_prefix
in
shard_prefixes
:
is_shard_skipped
=
shard_prefix
in
ignored_layers
if
is_skipped
is
None
:
is_skipped
=
is_shard_skipped
elif
is_shard_skipped
!=
is_skipped
:
raise
ValueError
(
f
"Detected some but not all shards of
{
prefix
}
"
"are quantized. All shards of fused layers "
"to have the same precision."
)
else
:
is_skipped
=
prefix
in
ignored_layers
assert
is_skipped
is
not
None
return
is_skipped
def
get_pack_factor
(
num_bits
):
assert
num_bits
in
SUPPORTED_NUM_BITS
,
f
"Unsupported num_bits =
{
num_bits
}
"
assert
32
%
num_bits
==
0
,
f
"Unsupported num_bits =
{
num_bits
}
"
return
32
//
num_bits
...
...
@@ -36,24 +78,23 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
)
def
quantize_weights
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
act_order
:
bool
):
def
quantize_weights
(
w
:
torch
.
Tensor
,
quant_type
:
ScalarType
,
group_size
:
int
,
zero_points
:
bool
=
False
):
assert
quant_type
.
is_integer
(),
\
"Floating point quantization may work but has not been tested"
orig_device
=
w
.
device
orig_type
=
w
.
dtype
size_k
,
size_n
=
w
.
shape
assert
w
.
is_floating_point
(),
"w must be float"
assert
num_bits
in
SUPPORTED_NUM_BITS
,
f
"Unsupported num_bits =
{
num_bits
}
"
assert
group_size
in
SUPPORTED_GROUP_SIZES
+
[
size_k
],
f
"Unsupported groupsize =
{
group_size
}
"
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
max_q_val
=
2
**
num_bits
-
1
half_q_val
=
(
max_q_val
+
1
)
//
2
# Reshape to [groupsize, -1]
if
group_size
<
size_k
:
w
=
w
.
reshape
((
-
1
,
group_size
,
size_n
))
...
...
@@ -61,16 +102,34 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
w
=
w
.
reshape
((
group_size
,
-
1
))
# Compute scale for each group
s
=
torch
.
max
(
torch
.
abs
(
w
),
0
,
keepdim
=
True
)[
0
]
s
*=
2
/
max_q_val
# 2 => symmetric
max_val
=
torch
.
max
(
w
,
0
,
keepdim
=
True
).
values
min_val
=
torch
.
min
(
w
,
0
,
keepdim
=
True
).
values
max_q_val
=
quant_type
.
max
()
min_q_val
=
quant_type
.
min
()
if
zero_points
:
assert
not
quant_type
.
is_signed
()
and
quant_type
.
max
()
>
0
w_s
=
(
max_val
-
min_val
).
clamp
(
min
=
1e-5
)
/
quant_type
.
max
()
maybe_w_zp
=
torch
.
round
(
torch
.
abs
(
min_val
/
w_s
))
\
.
clamp
(
min_q_val
,
max_q_val
).
int
()
else
:
# If the bias is such that there are no possible negative/positive
# values, set the max value to inf to avoid divide by 0
w_s
=
torch
.
max
(
abs
(
max_val
/
(
max_q_val
if
max_q_val
!=
0
else
torch
.
inf
)),
abs
(
min_val
/
(
min_q_val
if
min_q_val
!=
0
else
torch
.
inf
)))
maybe_w_zp
=
None
# Quantize
q_w
=
torch
.
round
(
w
/
s
).
int
()
q_w
+=
half_q_val
q_w
=
torch
.
clamp
(
q_w
,
0
,
max_q_val
)
w_q
=
torch
.
round
(
w
/
w_s
).
int
()
+
(
maybe_w_zp
if
zero_points
else
0
)
w_q
=
torch
.
clamp
(
w_q
,
min_q_val
,
max_q_val
)
# Compute ref (dequantized)
w_ref
=
(
q_w
-
half_q_val
).
half
()
*
s
w_ref
=
(
w_q
-
(
maybe_w_zp
if
zero_points
else
0
)).
to
(
orig_type
)
*
w_s
if
quant_type
.
has_bias
():
w_q
+=
quant_type
.
bias
# Restore original shapes
if
group_size
<
size_k
:
...
...
@@ -81,10 +140,35 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
w
=
w
.
reshape
((
size_k
,
size_n
)).
contiguous
()
return
w
q_w
=
reshape_w
(
q_w
)
w_q
=
reshape_w
(
w_q
)
w_ref
=
reshape_w
(
w_ref
)
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
w_s
=
w_s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
if
zero_points
:
maybe_w_zp
=
maybe_w_zp
.
reshape
((
-
1
,
size_n
)).
contiguous
()
maybe_w_zp
=
maybe_w_zp
.
to
(
device
=
orig_device
)
return
(
w_ref
.
to
(
device
=
orig_device
),
w_q
.
to
(
device
=
orig_device
),
w_s
.
to
(
device
=
orig_device
),
maybe_w_zp
,
)
def
gptq_quantize_weights
(
w
:
torch
.
Tensor
,
quant_type
:
ScalarType
,
group_size
:
int
,
act_order
:
bool
):
size_k
,
_
=
w
.
shape
assert
w
.
is_floating_point
(),
"w must be float"
assert
quant_type
in
SUPPORTED_GPTQ_QUANT_TYPES
,
\
f
"Unsupported gptq type =
{
quant_type
}
"
assert
group_size
in
SUPPORTED_GROUP_SIZES
+
[
size_k
],
f
"Unsupported groupsize =
{
group_size
}
"
w_ref
,
w_q
,
w_s
,
_
=
quantize_weights
(
w
,
quant_type
,
group_size
)
# Apply act_order
g_idx
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w
.
device
)
...
...
@@ -95,23 +179,20 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
),
"For act_order, groupsize = {} must be less than size_k = {}"
.
format
(
group_size
,
size_k
)
w_ref
,
q_w
,
g_idx
,
rand_perm
=
permute_rows
(
q_w
,
w_ref
,
group_size
)
w_ref
,
w_q
,
g_idx
,
rand_perm
=
permute_rows
(
w_q
,
w_ref
,
group_size
)
return
(
w_ref
.
to
(
device
=
orig_device
),
q_w
.
to
(
device
=
orig_device
),
s
.
to
(
device
=
orig_device
),
g_idx
.
to
(
device
=
orig_device
),
rand_perm
.
to
(
device
=
orig_device
),
)
return
w_ref
,
w_q
,
w_s
,
g_idx
,
rand_perm
def
quantize_weights_with_zp
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
):
# QQQ employs different quant schemes for per-group and
# per-channel quantization.
def
qqq_quantize_weights
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
):
orig_device
=
w
.
device
size_k
,
size_n
=
w
.
shape
assert
w
.
is_floating_point
(),
"w must be float"
assert
num_bits
in
SUPPORTED_NUM_BITS
,
f
"Unsupported num_bits =
{
num_bits
}
"
assert
num_bits
in
MARLIN_QQQ_SUPPORTED_NUM_BITS
,
\
f
"Unsupported num_bits =
{
num_bits
}
"
assert
group_size
in
SUPPORTED_GROUP_SIZES
+
[
size_k
],
f
"Unsupported groupsize =
{
group_size
}
"
...
...
@@ -120,33 +201,27 @@ def quantize_weights_with_zp(w: torch.Tensor, num_bits: int, group_size: int):
group_size
=
size_k
assert
group_size
<=
size_k
max_q_val
=
2
**
num_bits
-
1
min_q_val
=
0
# Reshape to [groupsize, -1]
if
group_size
<
size_k
:
# Reshape to [groupsize, -1]
w
=
w
.
reshape
((
-
1
,
group_size
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
reshape
((
group_size
,
-
1
))
# Compute scale for each group
max
=
torch
.
max
(
w
,
0
,
keepdim
=
True
)[
0
]
min
=
torch
.
min
(
w
,
0
,
keepdim
=
True
)[
0
]
s
=
(
max
-
min
).
clamp
(
min
=
1e-5
)
/
max_q_val
# Compute zero-point for each group
zp
=
(
-
torch
.
round
(
min
/
s
)).
clamp
(
min_q_val
,
max_q_val
).
int
()
# Quantize
q_w
=
torch
.
round
(
w
/
s
).
int
()
+
zp
q_w
=
torch
.
clamp
(
q_w
,
min_q_val
,
max_q_val
)
max_q_val
=
2
**
num_bits
-
1
half_q_val
=
(
max_q_val
+
1
)
//
2
# Compute ref (dequantized)
w_ref
=
(
q_w
-
zp
).
half
()
*
s
# Compute scale for each group
s_group
=
torch
.
max
(
torch
.
abs
(
w
),
0
,
keepdim
=
True
)[
0
]
s_group
*=
2
/
max_q_val
# 2 => symmetric
# Restore original shapes
if
group_size
<
size_k
:
# Quantize
q_w
=
torch
.
round
(
w
/
s_group
).
int
()
q_w
+=
half_q_val
q_w
=
torch
.
clamp
(
q_w
,
0
,
max_q_val
)
# Compute ref (dequantized)
w_ref
=
(
q_w
-
half_q_val
).
half
()
*
s_group
# Restore original shapes
def
reshape_w
(
w
):
w
=
w
.
reshape
((
group_size
,
-
1
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
...
...
@@ -156,14 +231,39 @@ def quantize_weights_with_zp(w: torch.Tensor, num_bits: int, group_size: int):
q_w
=
reshape_w
(
q_w
)
w_ref
=
reshape_w
(
w_ref
)
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
zp
=
zp
.
reshape
((
-
1
,
size_n
)).
contiguous
()
# Compute int8 quantization scale for each channel
s_channel
=
torch
.
max
(
torch
.
abs
(
w_ref
),
0
,
keepdim
=
True
)[
0
]
s_channel
/=
127.0
t_int8
=
(
w_ref
/
s_channel
).
round
().
clamp
(
-
128
,
127
).
to
(
torch
.
int8
)
w_ref
=
t_int8
.
half
()
*
s_channel
s_channel
=
s_channel
.
reshape
(
1
,
-
1
).
to
(
dtype
=
torch
.
float
)
# Fuse scales
s_group
=
(
s_group
.
reshape
(
-
1
,
size_n
).
contiguous
()
/
s_channel
).
to
(
dtype
=
torch
.
half
)
else
:
max_q_val
=
2
**
(
num_bits
-
1
)
-
1
# Compute scale for each channel
s_channel
=
torch
.
max
(
torch
.
abs
(
w
),
0
,
keepdim
=
True
)[
0
]
s_channel
/=
max_q_val
# Quantize
q_w
=
torch
.
round
(
w
/
s_channel
).
int
()
q_w
=
torch
.
clamp
(
q_w
,
-
max_q_val
,
max_q_val
)
# Compute ref (dequantized)
w_ref
=
q_w
.
half
()
*
s_channel
s_group
=
torch
.
tensor
([],
dtype
=
torch
.
half
)
# div 2 ** (8 - self.bits)) to offset right shift in unpacking
s_channel
/=
(
2
**
(
8
-
num_bits
))
s_channel
=
s_channel
.
reshape
(
-
1
,
size_n
).
contiguous
().
to
(
torch
.
float
)
return
(
w_ref
.
to
(
device
=
orig_device
),
q_w
.
to
(
device
=
orig_device
),
s
.
to
(
device
=
orig_device
),
zp
.
to
(
device
=
orig_device
),
s
_group
.
to
(
device
=
orig_device
),
s_channel
.
to
(
device
=
orig_device
),
)
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
e661d594
...
...
@@ -139,7 +139,7 @@ def apply_fp8_linear(
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
input
,
input_scale
,
batch_dim
_padding
=
17
,
num_token
_padding
=
17
,
use_per_token_if_dynamic
=
use_per_token_if_dynamic
)
per_tensor_weights
=
(
weight_scale
.
numel
()
==
1
)
...
...
@@ -177,8 +177,9 @@ def apply_fp8_linear(
output
,
_
=
torch
.
_scaled_mm
(
qinput
,
weight
,
out_dtype
=
torch
.
float32
)
# Unpad (undo
batch_dim
_padding)
# Unpad (undo
num_token
_padding)
output
=
torch
.
narrow
(
output
,
0
,
0
,
input
.
shape
[
0
])
x_scale
=
torch
.
narrow
(
x_scale
,
0
,
0
,
input
.
shape
[
0
])
# DQ
# C = sw * sx * (X * W) + bias
...
...
vllm/model_executor/layers/rejection_sampler.py
View file @
e661d594
from
functools
import
cached_property
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch.jit
...
...
@@ -36,7 +36,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
generators
:
List
[
Optional
[
torch
.
Generator
]],
seeded_seqs
:
Optional
[
Dict
[
int
,
torch
.
Generator
]]
=
None
,
)
->
torch
.
Tensor
:
"""Sample token ids using rejection sampling. This accepts or rejects
tokens proposed by the draft model using the probability of each token
...
...
@@ -66,6 +66,9 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
probabilities.
shape = [batch_size, num_speculative_tokens]
seeded_seqs: Dict of batch row index to torch generator, for
sequences using seeded generation.
Returns:
output_token_ids: The token ids sampled via rejection sampling,
or -1 if unable to sample a token because the previous token
...
...
@@ -83,7 +86,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
target_probs
,
draft_probs
,
draft_token_ids
,
generator
s
,
seeded_seq
s
,
))
output_token_ids
=
self
.
_create_output
(
...
...
@@ -100,7 +103,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
target_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
generators
:
List
[
Optional
[
torch
.
Generator
]],
seeded_seqs
:
Optional
[
Dict
[
int
,
torch
.
Generator
]],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Perform modified rejection sampling on each sequence.
...
...
@@ -117,23 +120,17 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
# shape [batch_size, k]
accepted
=
self
.
_get_accepted
(
target_probs
,
draft_probs
,
draft_token_ids
,
generator
s
)
draft_token_ids
,
seeded_seq
s
)
recovered_probs
=
self
.
_get_recovered_probs
(
target_probs
,
draft_probs
).
reshape
(
batch_size
*
k
,
vocab_size
)
seed_indices
,
non_seed_indices
=
self
.
_split_batch_by_seeded
(
generators
,
k
=
k
)
# NOTE: the recovered_probs are overwritten by this method.
recovered_token_ids
=
_multinomial
(
recovered_probs
,
num_samples
=
1
,
k
=
k
,
generators
=
generators
,
seed_indices
=
seed_indices
,
# this arg is unused when None but torch.jit requires a list
non_seed_indices
=
non_seed_indices
or
[],
seeded_seqs
=
seeded_seqs
or
{},
).
reshape
(
batch_size
,
k
)
return
accepted
,
recovered_token_ids
...
...
@@ -143,7 +140,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
target_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
generators
:
List
[
Optional
[
torch
.
Generator
]],
seeded_seqs
:
Optional
[
Dict
[
int
,
torch
.
Generator
]],
)
->
torch
.
Tensor
:
r
"""Create bool matrix over the proposed draft tokens. If
True, then a token can be accepted, else it should be
...
...
@@ -178,24 +175,26 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
selected_target_probs
=
target_probs
[
batch_indices
,
probs_indicies
,
draft_token_ids
]
seed_indices
,
non_seed_indices
=
self
.
_split_batch_by_seeded
(
generators
)
if
len
(
seed_indices
)
==
0
:
if
not
seeded_seqs
:
uniform_rand
=
torch
.
rand_like
(
selected_target_probs
)
else
:
uniform_rand
=
torch
.
empty_like
(
selected_target_probs
)
for
idx
in
seed_indices
:
uniform_rand
[
idx
,
:]
=
torch
.
rand
(
1
,
k
,
dtype
=
self
.
probs_dtype
,
device
=
target_probs
.
device
,
generator
=
generators
[
idx
])
if
non_seed_indices
:
uniform_rand
[
non_seed_indices
,
:]
=
torch
.
rand
(
len
(
non_seed_indices
),
non_seeded_indices
=
[]
for
idx
in
range
(
batch_size
):
generator
=
seeded_seqs
.
get
(
idx
)
if
generator
is
None
:
non_seeded_indices
.
append
(
idx
)
else
:
uniform_rand
[
idx
,
:]
=
torch
.
rand
(
1
,
k
,
dtype
=
self
.
probs_dtype
,
device
=
target_probs
.
device
,
generator
=
generator
)
if
non_seeded_indices
:
uniform_rand
[
non_seeded_indices
,
:]
=
torch
.
rand
(
len
(
non_seeded_indices
),
k
,
dtype
=
self
.
probs_dtype
,
device
=
target_probs
.
device
)
...
...
@@ -272,27 +271,6 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
"""
return
torch
.
finfo
(
self
.
probs_dtype
).
tiny
# partition batch into indices for which a generator is provided
# and indicies for which no generator is provided
@
staticmethod
def
_split_batch_by_seeded
(
generators
:
List
[
Optional
[
torch
.
Generator
]],
k
:
int
=
1
,
)
->
Tuple
[
List
[
int
],
Optional
[
List
[
int
]]]:
if
all
(
generator
is
None
for
generator
in
generators
):
seed_indices
:
List
[
int
]
=
[]
non_seed_indices
:
Optional
[
List
[
int
]]
=
None
else
:
seed_indices
,
non_seed_indices
=
[],
[]
for
i
,
generator
in
enumerate
(
generators
):
if
generator
is
None
:
non_seed_indices
.
extend
(
range
(
k
*
i
,
k
*
(
i
+
1
)))
else
:
seed_indices
.
extend
(
range
(
k
*
i
,
k
*
(
i
+
1
)))
return
seed_indices
,
non_seed_indices
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead that skips the sync.
...
...
@@ -304,9 +282,7 @@ def _multinomial(
probs
:
torch
.
Tensor
,
num_samples
:
int
,
k
:
int
,
generators
:
List
[
Optional
[
torch
.
Generator
]],
seed_indices
:
List
[
int
],
non_seed_indices
:
List
[
int
],
seeded_seqs
:
Dict
[
int
,
torch
.
Generator
],
)
->
torch
.
Tensor
:
if
num_samples
>
1
:
...
...
@@ -315,13 +291,20 @@ def _multinomial(
probs
=
probs
[:,
None
,
:].
expand
(
probs
.
shape
[
0
],
num_samples
,
probs
.
shape
[
1
]).
contiguous
().
view
(
-
1
,
probs
.
shape
[
1
])
q
=
torch
.
empty_like
(
probs
)
if
len
(
seed_indices
)
==
0
:
if
not
seeded_seqs
:
q
.
exponential_
(
1.0
)
else
:
q
[
non_seed_indices
].
exponential_
(
1.0
)
for
idx
in
seed_indices
:
q
[
idx
].
exponential_
(
1.0
,
generator
=
generators
[
idx
//
k
])
non_seeded_indices
:
List
[
int
]
=
[]
start
=
0
for
idx
in
range
(
len
(
q
)
//
k
):
end
=
start
+
k
generator
=
seeded_seqs
.
get
(
idx
)
if
generator
is
None
:
non_seeded_indices
.
extend
(
list
(
range
(
start
,
end
)))
else
:
q
[
start
:
end
].
exponential_
(
1.0
,
generator
=
generator
)
start
=
end
q
[
non_seeded_indices
].
exponential_
(
1.0
)
return
probs
.
div_
(
q
).
argmax
(
dim
=
1
).
view
(
-
1
,
num_samples
)
vllm/model_executor/layers/rotary_embedding.py
View file @
e661d594
...
...
@@ -774,6 +774,7 @@ def get_rope(
is_neox_style
:
bool
=
True
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
rotary_percent
:
float
=
1.0
,
)
->
RotaryEmbedding
:
if
dtype
is
None
:
dtype
=
torch
.
get_default_dtype
()
...
...
@@ -786,6 +787,8 @@ def get_rope(
rope_scaling_args
=
tuple
(
rope_scaling_tuple
.
items
())
else
:
rope_scaling_args
=
None
if
rotary_percent
<
1.0
:
rotary_dim
=
int
(
rotary_dim
*
rotary_percent
)
key
=
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
rope_scaling_args
,
dtype
)
if
key
in
_ROPE_DICT
:
...
...
vllm/model_executor/layers/sampler.py
View file @
e661d594
"""A layer that samples the next tokens from the model's outputs."""
import
itertools
from
math
import
inf
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch.nn
as
nn
from
vllm.model_executor.layers.ops.sample
import
sample
as
sample_triton
from
vllm.triton_utils
import
HAS_TRITON
if
HAS_TRITON
:
from
vllm.model_executor.layers.ops.sample
import
sample
as
sample_triton
from
vllm.model_executor.sampling_metadata
import
(
SamplingMetadata
,
SamplingTensors
,
SequenceGroupToSample
)
...
...
@@ -220,7 +225,7 @@ def _apply_min_tokens_penalty(
seqs_to_penalize
:
List
[
int
]
=
[]
for
j
,
seq_id
in
enumerate
(
seq_ids
):
seq_data
=
seq_group
.
seq_data
[
seq_id
]
if
len
(
seq_data
.
output_token_ids
)
<
min_tokens
:
if
len
(
seq_data
.
output_token_ids
_array
)
<
min_tokens
:
seqs_to_penalize
.
append
(
j
)
if
seqs_to_penalize
:
...
...
@@ -774,8 +779,11 @@ def _get_logprobs(
# The next token ids to get the logprob value from.
next_token_ids
:
List
[
int
]
=
[]
# The largest requested number of logprobs. We find logprobs as many as the
# largest num logprobs in this API.
largest_num_logprobs
=
1
# largest num logprobs in this API. If every logprobs is None, it will be
# set to -1.
largest_num_logprobs
=
-
1
# If beam search is enabled.
use_beam_search
=
False
# Select indices to compute logprob from, ranks of token ids, and the top
# k token ids from logprobs.
...
...
@@ -808,6 +816,8 @@ def _get_logprobs(
largest_num_logprobs
=
max
(
largest_num_logprobs
,
sampling_params
.
logprobs
)
use_beam_search
=
use_beam_search
or
sampling_params
.
use_beam_search
assert
len
(
next_token_ids
)
==
len
(
query_indices
)
if
len
(
query_indices
)
==
0
:
...
...
@@ -815,35 +825,40 @@ def _get_logprobs(
empty_prompt_logprob
:
Optional
[
PromptLogprobs
]
=
None
return
[
empty_prompt_logprob
],
[
empty_sampled_logprob
]
query_indices_gpu
=
torch
.
tensor
(
query_indices
,
device
=
logprobs
.
device
)
next_token_ids_gpu
=
torch
.
tensor
(
next_token_ids
,
device
=
logprobs
.
device
)
# (num_selected_query_tokens, num_logprobs). Note that query_indices can
# contain duplicates if beam search is enabled.
selected_logprobs
=
logprobs
[[
query_indices_gpu
,
next_token_ids_gpu
,
]]
ranks
=
_get_ranks
(
logprobs
[
query_indices_gpu
],
next_token_ids_gpu
,
)
assert
selected_logprobs
.
shape
[
0
]
==
ranks
.
shape
[
0
]
# Logprobs of topk tokens for a batch of sequence groups.
# (num_query_tokens_across_batch).
if
largest_num_logprobs
>
0
:
top_logprobs
,
top_token_ids
=
torch
.
topk
(
logprobs
,
largest_num_logprobs
,
dim
=-
1
)
else
:
top_logprobs
,
top_token_ids
=
None
,
None
selected_logprobs
,
ranks
=
None
,
None
top_logprobs
,
top_token_ids
=
None
,
None
# If largest_num_logprobs == -1, i.e. no logprobs are requested, we can
# skip the whole logprob calculation.
if
largest_num_logprobs
>=
0
or
use_beam_search
:
query_indices_gpu
=
torch
.
tensor
(
query_indices
,
device
=
logprobs
.
device
)
next_token_ids_gpu
=
torch
.
tensor
(
next_token_ids
,
device
=
logprobs
.
device
)
# (num_selected_query_tokens, num_logprobs). Note that query_indices can
# contain duplicates if beam search is enabled.
selected_logprobs
=
logprobs
[[
query_indices_gpu
,
next_token_ids_gpu
,
]]
ranks
=
_get_ranks
(
logprobs
[
query_indices_gpu
],
next_token_ids_gpu
,
)
assert
selected_logprobs
.
shape
[
0
]
==
ranks
.
shape
[
0
]
selected_logprobs
=
selected_logprobs
.
to
(
'cpu'
)
ranks
=
ranks
.
to
(
'cpu'
)
if
top_logprobs
is
not
None
and
top_token_ids
is
not
None
:
top_logprobs
=
top_logprobs
.
to
(
'cpu'
)
top_token_ids
=
top_token_ids
.
to
(
'cpu'
)
# We need to compute top k only if there exists logprobs > 0.
if
largest_num_logprobs
>
0
:
# Logprobs of topk tokens for a batch of sequence groups.
# (num_query_tokens_across_batch).
top_logprobs
,
top_token_ids
=
torch
.
topk
(
logprobs
,
largest_num_logprobs
,
dim
=-
1
)
top_logprobs
=
top_logprobs
.
to
(
'cpu'
)
top_token_ids
=
top_token_ids
.
to
(
'cpu'
)
selected_logprobs
=
selected_logprobs
.
to
(
'cpu'
)
ranks
=
ranks
.
to
(
'cpu'
)
# Find prompt/sample logprobs.
prompt_logprobs_per_seq_group
:
List
[
Optional
[
PromptLogprobs
]]
=
[]
...
...
@@ -940,46 +955,53 @@ def _get_sampled_logprob_if_needed(
):
"""Compute the sample logprob if needed."""
seq_ids
=
seq_group
.
seq_ids
num_logprobs
=
seq_group
.
sampling_params
.
logprobs
or
0
num_logprobs
=
seq_group
.
sampling_params
.
logprobs
use_beam_search
=
seq_group
.
sampling_params
.
use_beam_search
sampled_logprobs
:
SampleLogprobs
=
[]
next_token_ids
,
parent_seq_ids
=
sample_result
if
seq_group
.
do_sample
:
assert
len
(
next_token_ids
)
>
0
# Pre-select items from tensor. tolist() is faster than repetitive
# `.item()` calls.
selected_logprob_items
=
selected_logprobs
[
selected_logprobs_idx
:
selected_logprobs_idx
+
len
(
next_token_ids
)].
tolist
()
rank_items
=
ranks
[
selected_logprobs_idx
:
selected_logprobs_idx
+
len
(
next_token_ids
)].
tolist
()
for
idx
,
(
next_token_id
,
parent_id
)
in
enumerate
(
zip
(
next_token_ids
,
parent_seq_ids
)):
# Get the logprob of a sampled token.
sampled_logprobs_dict
=
{
next_token_id
:
(
selected_logprob_items
[
idx
],
rank_items
[
idx
])
}
# Get top K logprobs.
if
num_logprobs
>
0
:
top_ids
=
top_token_ids
[
top_logprob_idx
+
parent_id
,
:
num_logprobs
].
tolist
()
top_probs
=
top_logprobs
[
top_logprob_idx
+
parent_id
,
:
num_logprobs
].
tolist
()
# Top K is already sorted by rank, so we can use 1 ~
# num_logprobs + 1 for rank.
top_ranks
=
range
(
1
,
num_logprobs
+
1
)
sampled_logprobs_dict
.
update
({
top_id
:
(
top_prob
,
rank
)
for
top_id
,
top_prob
,
rank
in
zip
(
top_ids
,
top_probs
,
top_ranks
)
if
num_logprobs
is
None
and
not
use_beam_search
:
for
next_token_id
in
next_token_ids
:
# Use a dummy logprob
sampled_logprobs
.
append
({
next_token_id
:
Logprob
(
inf
)})
else
:
# Pre-select items from tensor. tolist() is faster than repetitive
# `.item()` calls.
selected_logprob_items
=
selected_logprobs
[
selected_logprobs_idx
:
selected_logprobs_idx
+
len
(
next_token_ids
)].
tolist
()
rank_items
=
ranks
[
selected_logprobs_idx
:
selected_logprobs_idx
+
len
(
next_token_ids
)].
tolist
()
for
idx
,
(
next_token_id
,
parent_id
)
in
enumerate
(
zip
(
next_token_ids
,
parent_seq_ids
)):
# Get the logprob of a sampled token.
sampled_logprobs_dict
=
{
next_token_id
:
(
selected_logprob_items
[
idx
],
rank_items
[
idx
])
}
if
num_logprobs
is
not
None
and
num_logprobs
>
0
:
# Get top K logprobs.
top_ids
=
top_token_ids
[
top_logprob_idx
+
parent_id
,
:
num_logprobs
].
tolist
()
top_probs
=
top_logprobs
[
top_logprob_idx
+
parent_id
,
:
num_logprobs
].
tolist
()
# Top K is already sorted by rank, so we can use 1 ~
# num_logprobs + 1 for rank.
top_ranks
=
range
(
1
,
num_logprobs
+
1
)
sampled_logprobs_dict
.
update
({
top_id
:
(
top_prob
,
rank
)
for
top_id
,
top_prob
,
rank
in
zip
(
top_ids
,
top_probs
,
top_ranks
)
})
sampled_logprobs
.
append
({
token_id
:
Logprob
(
*
logprob_and_rank
)
for
token_id
,
logprob_and_rank
in
sampled_logprobs_dict
.
items
()
})
sampled_logprobs
.
append
({
token_id
:
Logprob
(
*
logprob_and_rank
)
for
token_id
,
logprob_and_rank
in
sampled_logprobs_dict
.
items
()
})
# NOTE: This part of code is not intuitive. `selected_logprobs` include
# logprobs for the current step, which has len(next_token_ids) tokens
# per sequence group. `logprobs` includes logprobs from the previous
...
...
vllm/model_executor/layers/spec_decode_base_sampler.py
View file @
e661d594
from
abc
import
abstractmethod
from
typing
import
Lis
t
,
Optional
from
typing
import
Dic
t
,
Optional
import
torch
import
torch.jit
...
...
@@ -237,6 +237,6 @@ class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
generators
:
List
[
Optional
[
torch
.
Generator
]],
seeded_seqs
:
Optional
[
Dict
[
int
,
torch
.
Generator
]]
=
None
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
vllm/model_executor/model_loader/loader.py
View file @
e661d594
...
...
@@ -7,6 +7,7 @@ import json
import
math
import
os
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
contextmanager
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
,
Type
import
huggingface_hub
...
...
@@ -37,7 +38,49 @@ from vllm.model_executor.models.interfaces import (has_inner_state,
supports_vision
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_tpu
from
vllm.utils
import
is_pin_memory_available
,
is_tpu
@
contextmanager
def
device_loading_context
(
module
:
torch
.
nn
.
Module
,
target_device
:
torch
.
device
):
if
target_device
.
type
==
"cpu"
:
# If target is CPU, no need to move anything
yield
module
return
original_device_states
:
Dict
[
str
,
torch
.
device
]
=
{}
# Store original device states and move parameters to GPU if they're on CPU
for
name
,
p
in
module
.
named_parameters
():
if
p
.
device
.
type
==
"cpu"
:
original_device_states
[
name
]
=
p
.
device
p
.
data
=
p
.
data
.
to
(
target_device
)
# Parameters already on target device are not touched
try
:
yield
module
finally
:
# Restore parameters to their original devices, ignoring new parameters
pin_memory
=
is_pin_memory_available
()
for
name
,
p
in
module
.
named_parameters
():
if
name
in
original_device_states
:
original_device
:
torch
.
device
=
original_device_states
[
name
]
if
original_device
.
type
==
"cpu"
:
# `torch.empty_like` does not support `pin_memory` argument
cpu_data
=
torch
.
empty_strided
(
size
=
p
.
data
.
size
(),
stride
=
p
.
data
.
stride
(),
dtype
=
p
.
data
.
dtype
,
layout
=
p
.
data
.
layout
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
cpu_data
.
copy_
(
p
.
data
)
p
.
data
=
cpu_data
else
:
p
.
data
=
p
.
data
.
to
(
original_device
)
# New parameters or parameters already on target device are untouched
logger
=
init_logger
(
__name__
)
...
...
@@ -164,7 +207,7 @@ class DefaultModelLoader(BaseModelLoader):
cache_dir
=
self
.
load_config
.
download_dir
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
revision
=
revision
,
ignore_pattern
s
=
self
.
load_config
.
ignore_patterns
,
ignore_
file_
pattern
=
self
.
load_config
.
ignore_patterns
,
)
else
:
model_path
=
model
...
...
@@ -278,8 +321,9 @@ class DefaultModelLoader(BaseModelLoader):
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
)
->
nn
.
Module
:
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
t
orch
.
device
(
device_config
.
device
)
:
with
t
arget_
device
:
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
lora_config
,
multimodal_config
,
cache_config
,
scheduler_config
)
...
...
@@ -294,7 +338,13 @@ class DefaultModelLoader(BaseModelLoader):
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
and
quant_method
!=
"awq"
and
quant_method
!=
"gptq"
:
quant_method
.
process_weights_after_loading
(
module
)
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with
device_loading_context
(
module
,
target_device
):
quant_method
.
process_weights_after_loading
(
module
)
return
model
.
eval
()
...
...
@@ -705,8 +755,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
return
hf_weights_files
,
matched_pattern
==
"*.safetensors"
def
_hf_weight_iter
(
self
,
hf_weights_files
,
use_safetensors
:
bool
):
if
use_safetensors
:
return
safetensors_weights_iterator
(
hf_weights_files
)
else
:
return
pt_weights_iterator
(
hf_weights_files
)
def
_get_quantized_weights_iterator
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
]
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
]
,
pre_quant
:
bool
)
->
Tuple
[
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
],
Dict
[
str
,
Any
]]:
"""Get an iterator to the model weights with bitsandbytes quantization,
...
...
@@ -715,6 +771,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# only load the bitsandbytes module when needed
try
:
import
bitsandbytes
from
bitsandbytes.functional
import
QuantState
if
bitsandbytes
.
__version__
<
"0.42.0"
:
raise
ImportError
(
"bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.42.0."
)
...
...
@@ -728,17 +785,63 @@ class BitsAndBytesModelLoader(BaseModelLoader):
model_name_or_path
,
revision
)
quant_state_dict
=
{}
if
use_safetensors
:
weight_iterator
=
safetensors_weights_iterator
(
hf_weights_files
)
else
:
weight_iterator
=
pt_weights_iterator
(
hf_weights_files
)
def
generator
():
def
quantized_checkpoint
()
->
Generator
:
# First iterate over all quant state weights
weight_iterator
=
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
)
temp_state_dict
=
{}
for
weight_name
,
weight_tensor
in
weight_iterator
:
if
weight_name
.
endswith
(
".weight"
):
continue
# TODO: only nf4 quantization is supported for now
if
weight_name
.
endswith
(
".quant_state.bitsandbytes__fp4"
):
raise
NotImplementedError
(
"Only bitsandbytes_nf4 quantization"
f
"is supported for now.
{
weight_name
}
is fp4 quantized"
)
temp_state_dict
[
weight_name
]
=
weight_tensor
# Closure to parse quant_state for each prequant weight
def
_parse_quant_state
(
param_name
:
str
,
temp_state_dict
:
Dict
)
->
QuantState
:
quant_state
=
{}
for
k
in
temp_state_dict
:
if
param_name
+
"."
in
k
:
quant_state
[
k
]
=
temp_state_dict
[
k
]
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__nf4 in CPU
quant_state
[
param_name
+
".quant_state.bitsandbytes__nf4"
]
=
quant_state
[
param_name
+
".quant_state.bitsandbytes__nf4"
].
cpu
().
data
return
QuantState
.
from_dict
(
quant_state
,
device
=
"cuda"
)
# Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
# Filter out all weights whose suffix is not ".weight"
if
not
weight_name
.
endswith
(
".weight"
):
continue
if
weight_name
+
".quant_state.bitsandbytes__nf4"
\
in
temp_state_dict
:
quant_state
=
_parse_quant_state
(
weight_name
,
temp_state_dict
)
weight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
quant_state_dict
[
weight_name
]
=
quant_state
yield
weight_name
.
replace
(
".weight"
,
".qweight"
),
weight_tensor
else
:
yield
weight_name
,
weight_tensor
def
generator
()
->
Generator
:
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
if
any
(
target_module
in
weight_name
for
target_module
in
self
.
target_modules
):
weight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
#
bitsandbytes requires data in GPU
# bitsandbytes requires data in GPU
loaded_weight
=
weight_tensor
.
cuda
().
data
with
set_default_torch_dtype
(
torch
.
float32
):
processed_weight
,
quant_state
=
quantize_4bit
(
...
...
@@ -752,6 +855,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
yield
weight_name
,
processed_weight
if
pre_quant
:
return
quantized_checkpoint
(),
quant_state_dict
return
generator
(),
quant_state_dict
def
_load_weights
(
self
,
model_config
:
ModelConfig
,
...
...
@@ -769,12 +874,21 @@ class BitsAndBytesModelLoader(BaseModelLoader):
logger
.
info
(
"Loading weights with BitsAndBytes quantization. "
" May take a while ..."
)
qweight_iterator
,
quant_state_dict
=
(
self
.
_get_quantized_weights_iterator
(
model_config
.
model
,
model_config
.
revision
))
is_quantized_checkpoint
=
False
quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
None
)
if
quant_config
is
not
None
and
quant_config
.
get
(
'quant_method'
)
==
"bitsandbytes"
:
is_quantized_checkpoint
=
True
qweight_iterator
,
quant_state_dict
=
\
self
.
_get_quantized_weights_iterator
(
model_config
.
model
,
model_config
.
revision
,
is_quantized_checkpoint
)
model
.
load_weights
(
qweight_iterator
)
torch
.
cuda
.
empty_cache
()
param_dict
=
dict
(
model
.
named_parameters
())
stacked_quant_state_dict
:
Dict
[
str
,
Dict
[
int
,
Any
]]
=
{}
for
quant_param_name
in
quant_state_dict
:
...
...
@@ -812,9 +926,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
f
"pack_factor not set for parameter
{
param_name
}
."
)
num_elements
=
[
0
]
*
len
(
quant_states
)
for
seq
,
quant_state
in
enumerate
(
quant_states
.
items
()
)
:
for
seq
,
quant_state
in
quant_states
.
items
():
num_elements
[
seq
]
=
math
.
prod
(
quant_state
[
1
]
.
shape
)
//
pack_ratio
quant_state
.
shape
)
//
pack_ratio
offsets
=
np
.
concatenate
(([
0
],
np
.
cumsum
(
num_elements
)))
set_weight_attrs
(
param
,
{
"bnb_shard_offsets"
:
offsets
})
...
...
vllm/model_executor/model_loader/weight_utils.py
View file @
e661d594
...
...
@@ -22,6 +22,7 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.quantization
import
(
QuantizationConfig
,
get_quantization_config
)
from
vllm.model_executor.layers.quantization.schema
import
QuantParamSchema
from
vllm.platforms
import
current_platform
from
vllm.utils
import
print_warning_once
logger
=
init_logger
(
__name__
)
...
...
@@ -118,6 +119,7 @@ def convert_bin_to_safetensor_file(
# TODO(woosuk): Move this to other place.
def
get_quant_config
(
model_config
:
ModelConfig
,
load_config
:
LoadConfig
)
->
QuantizationConfig
:
quant_cls
=
get_quantization_config
(
model_config
.
quantization
)
# Read the quantization config from the HF model config, if available.
hf_quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
...
...
@@ -489,6 +491,11 @@ def initialize_dummy_weights(
"""
for
param
in
model
.
state_dict
().
values
():
if
torch
.
is_floating_point
(
param
):
if
current_platform
.
is_tpu
():
# XLA device does not support torch.Generator()
param
.
uniform_
(
low
,
high
)
continue
generator
=
torch
.
Generator
(
device
=
param
.
data
.
device
)
generator
.
manual_seed
(
seed
)
if
torch
.
finfo
(
param
.
data
.
dtype
).
bits
<
16
:
...
...
Prev
1
…
11
12
13
14
15
16
17
18
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment