Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e661d594
Commit
e661d594
authored
Aug 12, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.4' into v0.5.4-dtk24.04.1
parents
6b16ea2e
4db5176d
Changes
374
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1043 additions
and
396 deletions
+1043
-396
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
..._executor/layers/quantization/compressed_tensors/utils.py
+5
-9
vllm/model_executor/layers/quantization/fbgemm_fp8.py
vllm/model_executor/layers/quantization/fbgemm_fp8.py
+14
-44
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+35
-15
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+29
-17
vllm/model_executor/layers/quantization/gptq_marlin_24.py
vllm/model_executor/layers/quantization/gptq_marlin_24.py
+19
-10
vllm/model_executor/layers/quantization/kv_cache.py
vllm/model_executor/layers/quantization/kv_cache.py
+2
-4
vllm/model_executor/layers/quantization/qqq.py
vllm/model_executor/layers/quantization/qqq.py
+285
-0
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+70
-67
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
...el_executor/layers/quantization/utils/marlin_utils_fp8.py
+3
-11
vllm/model_executor/layers/quantization/utils/marlin_utils_test.py
...l_executor/layers/quantization/utils/marlin_utils_test.py
+19
-10
vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py
...xecutor/layers/quantization/utils/marlin_utils_test_24.py
+14
-16
vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py
...ecutor/layers/quantization/utils/marlin_utils_test_qqq.py
+125
-0
vllm/model_executor/layers/quantization/utils/quant_utils.py
vllm/model_executor/layers/quantization/utils/quant_utils.py
+152
-52
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+3
-2
vllm/model_executor/layers/rejection_sampler.py
vllm/model_executor/layers/rejection_sampler.py
+39
-56
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+3
-0
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+87
-65
vllm/model_executor/layers/spec_decode_base_sampler.py
vllm/model_executor/layers/spec_decode_base_sampler.py
+2
-2
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+130
-16
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+7
-0
No files found.
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
View file @
e661d594
...
@@ -5,6 +5,9 @@ from typing import Any, Dict, Iterable, Optional
...
@@ -5,6 +5,9 @@ from typing import Any, Dict, Iterable, Optional
from
pydantic
import
BaseModel
,
Field
from
pydantic
import
BaseModel
,
Field
from
torch.nn
import
Module
from
torch.nn
import
Module
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
FUSED_LAYER_NAME_MAPPING
)
class
CompressionFormat
(
Enum
):
class
CompressionFormat
(
Enum
):
dense
=
"dense"
dense
=
"dense"
...
@@ -86,13 +89,6 @@ def is_activation_quantization_format(format: str) -> bool:
...
@@ -86,13 +89,6 @@ def is_activation_quantization_format(format: str) -> bool:
return
format
in
_ACTIVATION_QUANTIZATION_FORMATS
return
format
in
_ACTIVATION_QUANTIZATION_FORMATS
# fused_name: List[shard_name]
_FUSED_LAYER_NAME_MAPPING
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
]
}
def
should_ignore_layer
(
layer_name
:
Optional
[
str
],
def
should_ignore_layer
(
layer_name
:
Optional
[
str
],
ignore
:
Iterable
[
str
])
->
bool
:
ignore
:
Iterable
[
str
])
->
bool
:
if
layer_name
is
None
:
if
layer_name
is
None
:
...
@@ -106,8 +102,8 @@ def should_ignore_layer(layer_name: Optional[str],
...
@@ -106,8 +102,8 @@ def should_ignore_layer(layer_name: Optional[str],
# in the safetensors checkpoint. So, we convert the name
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
# each shard of the fused layer has the same scheme.
if
proj_name
in
_
FUSED_LAYER_NAME_MAPPING
:
if
proj_name
in
FUSED_LAYER_NAME_MAPPING
:
shard_proj_names
=
_
FUSED_LAYER_NAME_MAPPING
[
proj_name
]
shard_proj_names
=
FUSED_LAYER_NAME_MAPPING
[
proj_name
]
# Convert fused_name --> [shard_names]
# Convert fused_name --> [shard_names]
shard_names
=
[
shard_names
=
[
...
...
vllm/model_executor/layers/quantization/fbgemm_fp8.py
View file @
e661d594
...
@@ -9,8 +9,11 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
...
@@ -9,8 +9,11 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod
)
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.fp8
import
cutlass_fp8_supported
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
)
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
is_layer_skipped
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_fp8_linear
,
create_per_channel_scale_param
)
apply_fp8_linear
,
create_per_channel_scale_param
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
...
@@ -18,14 +21,6 @@ from vllm.platforms import current_platform
...
@@ -18,14 +21,6 @@ from vllm.platforms import current_platform
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
# Note: this is a hack. We should update each model to register the
# stacked params and get it from there instead in a future PR.
# fused_name: List[shard_name]
_FUSED_LAYER_NAME_MAPPING
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
]
}
class
FBGEMMFp8Config
(
QuantizationConfig
):
class
FBGEMMFp8Config
(
QuantizationConfig
):
"""Config class for FBGEMM Fp8."""
"""Config class for FBGEMM Fp8."""
...
@@ -62,37 +57,10 @@ class FBGEMMFp8Config(QuantizationConfig):
...
@@ -62,37 +57,10 @@ class FBGEMMFp8Config(QuantizationConfig):
input_scale_ub
=
cls
.
get_from_keys
(
config
,
[
"activation_scale_ub"
])
input_scale_ub
=
cls
.
get_from_keys
(
config
,
[
"activation_scale_ub"
])
return
cls
(
ignore_list
=
ignore_list
,
input_scale_ub
=
input_scale_ub
)
return
cls
(
ignore_list
=
ignore_list
,
input_scale_ub
=
input_scale_ub
)
def
_is_layer_skipped
(
self
,
prefix
:
str
)
->
bool
:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
proj_name
=
prefix
.
split
(
"."
)[
-
1
]
if
proj_name
in
_FUSED_LAYER_NAME_MAPPING
:
shard_prefixes
=
[
prefix
.
replace
(
proj_name
,
shard_proj_name
)
for
shard_proj_name
in
_FUSED_LAYER_NAME_MAPPING
[
proj_name
]
]
is_skipped
=
None
for
shard_prefix
in
shard_prefixes
:
is_shard_skipped
=
shard_prefix
in
self
.
ignore_list
if
is_skipped
is
None
:
is_skipped
=
is_shard_skipped
elif
is_shard_skipped
!=
is_skipped
:
raise
ValueError
(
f
"Detected some but not all shards of
{
prefix
}
"
"are quantized. All shards of fused layers "
"to have the same precision."
)
else
:
is_skipped
=
prefix
in
self
.
ignore_list
assert
is_skipped
is
not
None
return
is_skipped
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
if
self
.
_
is_layer_skipped
(
prefix
):
if
is_layer_skipped
(
prefix
,
self
.
ignore_list
):
return
UnquantizedLinearMethod
()
return
UnquantizedLinearMethod
()
return
FBGEMMFp8LinearMethod
(
self
)
return
FBGEMMFp8LinearMethod
(
self
)
return
None
return
None
...
@@ -105,6 +73,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
...
@@ -105,6 +73,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
def
__init__
(
self
,
quant_config
:
FBGEMMFp8Config
):
def
__init__
(
self
,
quant_config
:
FBGEMMFp8Config
):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
def
create_weights
(
def
create_weights
(
self
,
self
,
...
@@ -172,11 +141,12 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
...
@@ -172,11 +141,12 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
size_k
=
layer
.
input_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
bias
=
bias
)
bias
=
bias
)
return
apply_fp8_linear
(
input
=
x
,
return
apply_fp8_linear
(
weight
=
layer
.
weight
,
input
=
x
,
weight_scale
=
layer
.
weight_scale
,
weight
=
layer
.
weight
,
input_scale
=
None
,
weight_scale
=
layer
.
weight_scale
,
input_scale_ub
=
layer
.
input_scale_ub
,
input_scale
=
None
,
bias
=
bias
,
input_scale_ub
=
layer
.
input_scale_ub
,
cutlass_fp8_supported
=
True
,
bias
=
bias
,
use_per_token_if_dynamic
=
True
)
cutlass_fp8_supported
=
self
.
cutlass_fp8_supported
,
use_per_token_if_dynamic
=
True
)
vllm/model_executor/layers/quantization/fp8.py
View file @
e661d594
...
@@ -6,17 +6,20 @@ from torch.nn.parameter import Parameter
...
@@ -6,17 +6,20 @@ from torch.nn.parameter import Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
,
FusedMoEMethodBase
fused_moe
)
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethod
Base
Unquantized
LinearMethod
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
)
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
is_layer_skipped
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
all_close_1d
,
apply_fp8_linear
,
create_per_tensor_scale_param
,
all_close_1d
,
apply_fp8_linear
,
convert_to_channelwise
,
cutlass_fp8_supported
,
per_tensor_dequantize
,
requantize_with_max_scale
)
create_per_tensor_scale_param
,
cutlass_fp8_supported
,
per_tensor_dequantize
,
requantize_with_max_scale
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
print_warning_once
from
vllm.utils
import
print_warning_once
...
@@ -33,6 +36,7 @@ class Fp8Config(QuantizationConfig):
...
@@ -33,6 +36,7 @@ class Fp8Config(QuantizationConfig):
self
,
self
,
is_checkpoint_fp8_serialized
:
bool
=
False
,
is_checkpoint_fp8_serialized
:
bool
=
False
,
activation_scheme
:
str
=
"dynamic"
,
activation_scheme
:
str
=
"dynamic"
,
ignored_layers
:
Optional
[
List
[
str
]]
=
None
,
)
->
None
:
)
->
None
:
self
.
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
self
.
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
if
is_checkpoint_fp8_serialized
:
if
is_checkpoint_fp8_serialized
:
...
@@ -42,6 +46,7 @@ class Fp8Config(QuantizationConfig):
...
@@ -42,6 +46,7 @@ class Fp8Config(QuantizationConfig):
raise
ValueError
(
raise
ValueError
(
f
"Unsupported activation scheme
{
activation_scheme
}
"
)
f
"Unsupported activation scheme
{
activation_scheme
}
"
)
self
.
activation_scheme
=
activation_scheme
self
.
activation_scheme
=
activation_scheme
self
.
ignored_layers
=
ignored_layers
or
[]
@
classmethod
@
classmethod
def
get_name
(
cls
)
->
str
:
def
get_name
(
cls
)
->
str
:
...
@@ -64,14 +69,18 @@ class Fp8Config(QuantizationConfig):
...
@@ -64,14 +69,18 @@ class Fp8Config(QuantizationConfig):
quant_method
=
cls
.
get_from_keys
(
config
,
[
"quant_method"
])
quant_method
=
cls
.
get_from_keys
(
config
,
[
"quant_method"
])
is_checkpoint_fp8_serialized
=
(
"fp8"
in
quant_method
)
is_checkpoint_fp8_serialized
=
(
"fp8"
in
quant_method
)
activation_scheme
=
cls
.
get_from_keys
(
config
,
[
"activation_scheme"
])
activation_scheme
=
cls
.
get_from_keys
(
config
,
[
"activation_scheme"
])
ignored_layers
=
cls
.
get_from_keys_or
(
config
,
[
"ignored_layers"
],
None
)
return
cls
(
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
,
return
cls
(
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
,
activation_scheme
=
activation_scheme
)
activation_scheme
=
activation_scheme
,
ignored_layers
=
ignored_layers
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
from
vllm.attention.layer
import
Attention
# Avoid circular import
from
vllm.attention.layer
import
Attention
# Avoid circular import
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
if
is_layer_skipped
(
prefix
,
self
.
ignored_layers
):
return
UnquantizedLinearMethod
()
return
Fp8LinearMethod
(
self
)
return
Fp8LinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
elif
isinstance
(
layer
,
FusedMoE
):
return
Fp8MoEMethod
(
self
)
return
Fp8MoEMethod
(
self
)
...
@@ -170,19 +179,29 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -170,19 +179,29 @@ class Fp8LinearMethod(LinearMethodBase):
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
input_scale
=
None
layer
.
input_scale
=
None
# If checkpoint is fp8,
requantize the separately quantized logical
# If checkpoint is fp8,
handle that there are N scales for N
#
weight
s in
to
a
single fp8 weight with a single weight sca
le
.
#
shard
s in a
fused modu
le
else
:
else
:
# Dequant -> Quant with max scale.
# If using marlin (w8a16), kernel uses channelwise weights,
max_w_scale
,
weight
=
requantize_with_max_scale
(
# so extend the weight scales to be channelwise.
weight
=
layer
.
weight
,
if
self
.
use_marlin
:
weight_scale
=
layer
.
weight_scale
,
weight
=
layer
.
weight
logical_widths
=
layer
.
logical_widths
,
weight_scale
=
convert_to_channelwise
(
layer
.
weight_scale
,
)
layer
.
logical_widths
)
# If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight.
else
:
# Dequant -> Quant with max scale so we can run per tensor.
weight_scale
,
weight
=
requantize_with_max_scale
(
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
logical_widths
=
layer
.
logical_widths
,
)
# Update layer with new values.
# Update layer with new values.
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
max_w
_scale
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight
_scale
,
requires_grad
=
False
)
if
self
.
quant_config
.
activation_scheme
==
"static"
:
if
self
.
quant_config
.
activation_scheme
==
"static"
:
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
requires_grad
=
False
)
...
@@ -384,6 +403,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -384,6 +403,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
num_expert_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
topk_group
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_moe
return
fused_moe
(
x
,
return
fused_moe
(
x
,
layer
.
w13_weight
,
layer
.
w13_weight
,
layer
.
w2_weight
,
layer
.
w2_weight
,
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
e661d594
...
@@ -10,11 +10,12 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
...
@@ -10,11 +10,12 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
apply_gptq_marlin_linear
,
check_
gptq_
marlin_supported
,
marlin_is_k_full
,
apply_gptq_marlin_linear
,
check_marlin_supported
,
marlin_is_k_full
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
marlin_repeat_scales_on_all_ranks
,
marlin_sort_g_idx
,
replace_tensor
,
marlin_repeat_scales_on_all_ranks
,
marlin_sort_g_idx
,
replace_tensor
,
verify_
gptq_
marlin_supported
,
verify_marlin_supports_shape
)
verify_marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.scalar_type
import
scalar_types
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -22,6 +23,12 @@ logger = init_logger(__name__)
...
@@ -22,6 +23,12 @@ logger = init_logger(__name__)
class
GPTQMarlinConfig
(
QuantizationConfig
):
class
GPTQMarlinConfig
(
QuantizationConfig
):
"""Config class for GPTQ Marlin"""
"""Config class for GPTQ Marlin"""
# (num_bits, is_sym) -> quant_type
TYPE_MAP
=
{
(
4
,
True
):
scalar_types
.
uint4b8
,
(
8
,
True
):
scalar_types
.
uint8b128
,
}
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
desc_act
:
bool
,
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
desc_act
:
bool
,
is_sym
:
bool
,
lm_head_quantized
:
bool
)
->
None
:
is_sym
:
bool
,
lm_head_quantized
:
bool
)
->
None
:
if
desc_act
and
group_size
==
-
1
:
if
desc_act
and
group_size
==
-
1
:
...
@@ -29,20 +36,23 @@ class GPTQMarlinConfig(QuantizationConfig):
...
@@ -29,20 +36,23 @@ class GPTQMarlinConfig(QuantizationConfig):
# (since we have only one group per output channel)
# (since we have only one group per output channel)
desc_act
=
False
desc_act
=
False
self
.
weight_bits
=
weight_bits
self
.
pack_factor
=
32
//
weight_bits
# packed into int32
self
.
pack_factor
=
32
//
self
.
weight_bits
# packed into int32
self
.
group_size
=
group_size
self
.
group_size
=
group_size
self
.
desc_act
=
desc_act
self
.
desc_act
=
desc_act
self
.
is_sym
=
is_sym
self
.
lm_head_quantized
=
lm_head_quantized
self
.
lm_head_quantized
=
lm_head_quantized
if
(
weight_bits
,
is_sym
)
not
in
self
.
TYPE_MAP
:
raise
ValueError
(
"Unsupported quantization config: "
f
"bits=
{
weight_bits
}
, sym=
{
is_sym
}
"
)
self
.
quant_type
=
self
.
TYPE_MAP
[(
weight_bits
,
is_sym
)]
# Verify supported on platform.
# Verify supported on platform.
verify_gptq_marlin_supported
(
num_bits
=
self
.
weight_bits
,
verify_marlin_supported
(
quant_type
=
self
.
quant_type
,
group_size
=
self
.
group_size
,
group_size
=
self
.
group_size
)
is_sym
=
self
.
is_sym
)
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"GPTQMarlinConfig(
weight_bits=
{
self
.
weight_bits
}
, "
return
(
f
"GPTQMarlinConfig(
quant_type=
{
self
.
quant_type
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"desc_act=
{
self
.
desc_act
}
, "
f
"desc_act=
{
self
.
desc_act
}
, "
f
"lm_head_quantized=
{
self
.
lm_head_quantized
}
)"
)
f
"lm_head_quantized=
{
self
.
lm_head_quantized
}
)"
)
...
@@ -79,7 +89,8 @@ class GPTQMarlinConfig(QuantizationConfig):
...
@@ -79,7 +89,8 @@ class GPTQMarlinConfig(QuantizationConfig):
user_quant
)
->
Optional
[
str
]:
user_quant
)
->
Optional
[
str
]:
can_convert
=
cls
.
is_gptq_marlin_compatible
(
hf_quant_cfg
)
can_convert
=
cls
.
is_gptq_marlin_compatible
(
hf_quant_cfg
)
is_valid_user_quant
=
(
user_quant
is
None
or
user_quant
==
"marlin"
)
is_valid_user_quant
=
(
user_quant
is
None
or
user_quant
==
"marlin"
or
user_quant
==
"gptq_marlin"
)
if
can_convert
and
is_valid_user_quant
:
if
can_convert
and
is_valid_user_quant
:
msg
=
(
"The model is convertible to {} during runtime."
msg
=
(
"The model is convertible to {} during runtime."
...
@@ -121,11 +132,12 @@ class GPTQMarlinConfig(QuantizationConfig):
...
@@ -121,11 +132,12 @@ class GPTQMarlinConfig(QuantizationConfig):
or
desc_act
is
None
):
or
desc_act
is
None
):
return
False
return
False
return
check_gptq_marlin_supported
(
if
(
num_bits
,
sym
)
not
in
cls
.
TYPE_MAP
:
num_bits
=
num_bits
,
return
False
group_size
=
group_size
,
is_sym
=
sym
,
return
check_marlin_supported
(
quant_type
=
cls
.
TYPE_MAP
[(
num_bits
,
sym
)],
min_capability
=
cls
.
get_min_capability
())
group_size
=
group_size
,
min_capability
=
cls
.
get_min_capability
())
class
GPTQMarlinLinearMethod
(
LinearMethodBase
):
class
GPTQMarlinLinearMethod
(
LinearMethodBase
):
...
@@ -292,7 +304,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -292,7 +304,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
perm
=
layer
.
g_idx_sort_indices
,
perm
=
layer
.
g_idx_sort_indices
,
size_k
=
layer
.
input_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
num_bits
=
self
.
quant_config
.
weight
_bits
)
num_bits
=
self
.
quant_config
.
quant_type
.
size
_bits
)
replace_tensor
(
layer
,
"qweight"
,
marlin_qweight
)
replace_tensor
(
layer
,
"qweight"
,
marlin_qweight
)
# Permute scales from autogptq format to marlin format.
# Permute scales from autogptq format to marlin format.
...
@@ -318,7 +330,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -318,7 +330,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
g_idx
=
layer
.
g_idx
,
g_idx
=
layer
.
g_idx
,
g_idx_sort_indices
=
layer
.
g_idx_sort_indices
,
g_idx_sort_indices
=
layer
.
g_idx_sort_indices
,
workspace
=
layer
.
workspace
,
workspace
=
layer
.
workspace
,
num_bits
=
self
.
quant_config
.
weight_bits
,
wtype
=
self
.
quant_config
.
quant_type
,
output_size_per_partition
=
layer
.
output_size_per_partition
,
output_size_per_partition
=
layer
.
output_size_per_partition
,
input_size_per_partition
=
layer
.
input_size_per_partition
,
input_size_per_partition
=
layer
.
input_size_per_partition
,
is_k_full
=
layer
.
is_k_full
,
is_k_full
=
layer
.
is_k_full
,
...
...
vllm/model_executor/layers/quantization/gptq_marlin_24.py
View file @
e661d594
...
@@ -9,6 +9,7 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
...
@@ -9,6 +9,7 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.scalar_type
import
scalar_types
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -17,9 +18,10 @@ GPTQ_MARLIN_24_MIN_THREAD_N = 128
...
@@ -17,9 +18,10 @@ GPTQ_MARLIN_24_MIN_THREAD_N = 128
GPTQ_MARLIN_24_MIN_THREAD_K
=
128
GPTQ_MARLIN_24_MIN_THREAD_K
=
128
GPTQ_MARLIN_24_MAX_PARALLEL
=
64
GPTQ_MARLIN_24_MAX_PARALLEL
=
64
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
=
[
4
,
8
]
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
=
[
scalar_types
.
uint4b8
,
scalar_types
.
uint8b128
]
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
=
[
-
1
,
128
]
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
=
[
-
1
,
128
]
GPTQ_MARLIN_24_SUPPORTED_SYM
=
[
True
]
class
GPTQMarlin24Config
(
QuantizationConfig
):
class
GPTQMarlin24Config
(
QuantizationConfig
):
...
@@ -31,14 +33,19 @@ class GPTQMarlin24Config(QuantizationConfig):
...
@@ -31,14 +33,19 @@ class GPTQMarlin24Config(QuantizationConfig):
weight_bits
:
int
,
weight_bits
:
int
,
group_size
:
int
,
group_size
:
int
,
)
->
None
:
)
->
None
:
self
.
weight_bits
=
weight_bits
quant_type
=
{
4
:
scalar_types
.
uint4b8
,
8
:
scalar_types
.
uint8b128
,
}.
get
(
weight_bits
)
self
.
group_size
=
group_size
self
.
group_size
=
group_size
# Verify
# Verify
if
self
.
weight_bits
not
in
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
:
if
quant_type
is
None
or
\
quant_type
not
in
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
:
raise
ValueError
(
raise
ValueError
(
f
"Marlin_24 does not support
weight_bits =
{
self
.
weight_bits
}
. "
f
"Marlin_24 does not support
quant_type =
{
quant_type
}
. "
f
"Only weight_bits =
{
GPTQ_MARLIN_24_SUPPORTED_
NUM_BIT
S
}
"
f
"Only weight_bits =
{
GPTQ_MARLIN_24_SUPPORTED_
QUANT_TYPE
S
}
"
"are supported."
)
"are supported."
)
if
self
.
group_size
not
in
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
:
if
self
.
group_size
not
in
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
:
raise
ValueError
(
raise
ValueError
(
...
@@ -46,8 +53,10 @@ class GPTQMarlin24Config(QuantizationConfig):
...
@@ -46,8 +53,10 @@ class GPTQMarlin24Config(QuantizationConfig):
f
"Only group_sizes =
{
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
}
"
f
"Only group_sizes =
{
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
}
"
"are supported."
)
"are supported."
)
self
.
quant_type
=
quant_type
# 4 Bits packed into 32 bit datatype.
# 4 Bits packed into 32 bit datatype.
self
.
pack_factor
=
32
//
self
.
weight
_bits
self
.
pack_factor
=
32
//
self
.
quant_type
.
size
_bits
# Tile size used by marlin kernels.
# Tile size used by marlin kernels.
self
.
tile_size
=
16
self
.
tile_size
=
16
...
@@ -66,8 +75,8 @@ class GPTQMarlin24Config(QuantizationConfig):
...
@@ -66,8 +75,8 @@ class GPTQMarlin24Config(QuantizationConfig):
self
.
perm_len
=
1024
self
.
perm_len
=
1024
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
"Marlin24Config(
weight_bits
={}, group_size={})"
.
format
(
return
"Marlin24Config(
quant_type
={}, group_size={})"
.
format
(
self
.
weight_bits
,
self
.
group_size
)
self
.
quant_type
,
self
.
group_size
)
@
classmethod
@
classmethod
def
get_name
(
cls
)
->
str
:
def
get_name
(
cls
)
->
str
:
...
@@ -279,7 +288,7 @@ class GPTQMarlin24LinearMethod(LinearMethodBase):
...
@@ -279,7 +288,7 @@ class GPTQMarlin24LinearMethod(LinearMethodBase):
output_2d
=
ops
.
gptq_marlin_24_gemm
(
x_2d
,
qweight
,
meta
,
scales
,
output_2d
=
ops
.
gptq_marlin_24_gemm
(
x_2d
,
qweight
,
meta
,
scales
,
workspace
,
workspace
,
self
.
quant_config
.
weight_bits
,
self
.
quant_config
.
quant_type
,
size_m
,
size_n
,
size_k
)
size_m
,
size_n
,
size_k
)
output
=
output_2d
.
view
(
x
.
shape
[:
-
1
]
+
(
output_2d
.
shape
[
1
],
))
output
=
output_2d
.
view
(
x
.
shape
[:
-
1
]
+
(
output_2d
.
shape
[
1
],
))
...
...
vllm/model_executor/layers/quantization/kv_cache.py
View file @
e661d594
...
@@ -46,10 +46,8 @@ class BaseKVCacheMethod(QuantizeMethodBase):
...
@@ -46,10 +46,8 @@ class BaseKVCacheMethod(QuantizeMethodBase):
elif
layer
.
k_scale
<
0.0
and
layer
.
v_scale
<
0.0
:
elif
layer
.
k_scale
<
0.0
and
layer
.
v_scale
<
0.0
:
# If no scales were loaded (both scales are invalid negative
# If no scales were loaded (both scales are invalid negative
# values), use the default value of 1.0
# values), use the default value of 1.0
k_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
1.0
),
k_scale
=
1.0
requires_grad
=
False
)
v_scale
=
1.0
v_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
1.0
),
requires_grad
=
False
)
else
:
else
:
# If we find a single kv_scale in the checkpoint, we remap
# If we find a single kv_scale in the checkpoint, we remap
# kv_scale to k_scale during weight loading, and duplicate
# kv_scale to k_scale during weight loading, and duplicate
...
...
vllm/model_executor/layers/quantization/qqq.py
0 → 100644
View file @
e661d594
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
logger
=
init_logger
(
__name__
)
MARLIN_QQQ_TILE
=
16
MARLIN_QQQ_MIN_THREAD_N
=
64
MARLIN_QQQ_MIN_THREAD_K
=
128
MARLIN_QQQ_MAX_PARALLEL
=
16
MARLIN_QQQ_SUPPORTED_NUM_BITS
=
[
4
]
MARLIN_QQQ_SUPPORTED_GROUP_SIZES
=
[
-
1
,
128
]
MARLIN_QQQ_SUPPORTED_SYM
=
[
True
]
class
QQQConfig
(
QuantizationConfig
):
"""Config class for QQQ
Reference: https://arxiv.org/pdf/2406.09904
"""
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
is_sym
:
bool
=
True
,
)
->
None
:
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
is_sym
=
is_sym
# Verify
if
self
.
weight_bits
not
in
MARLIN_QQQ_SUPPORTED_NUM_BITS
:
raise
ValueError
(
f
"QQQ does not support weight_bits =
{
self
.
weight_bits
}
. "
f
"Only weight_bits =
{
MARLIN_QQQ_SUPPORTED_NUM_BITS
}
"
"are supported."
)
if
self
.
group_size
not
in
MARLIN_QQQ_SUPPORTED_GROUP_SIZES
:
raise
ValueError
(
f
"QQQ does not support group_size =
{
self
.
group_size
}
. "
f
"Only group_sizes =
{
MARLIN_QQQ_SUPPORTED_GROUP_SIZES
}
"
"are supported."
)
if
self
.
is_sym
not
in
MARLIN_QQQ_SUPPORTED_SYM
:
raise
ValueError
(
f
"QQQ does not support is_sym =
{
self
.
is_sym
}
. "
f
"Only sym =
{
MARLIN_QQQ_SUPPORTED_SYM
}
are supported."
)
# 4 Bits packed into 32 bit datatype.
self
.
pack_factor
=
32
//
self
.
weight_bits
# Tile size used by QQQ kernels.
self
.
tile_size
=
MARLIN_QQQ_TILE
# Min out_features dim
self
.
min_n_threads
=
MARLIN_QQQ_MIN_THREAD_N
# Min in_features dim
self
.
min_k_threads
=
MARLIN_QQQ_MIN_THREAD_K
# Max parallel problems to solve at once (improves large
# batch performance)
self
.
max_parallel
=
MARLIN_QQQ_MAX_PARALLEL
# Permutation length used by the QQQ kernels.
self
.
perm_len
=
1024
def
__repr__
(
self
)
->
str
:
return
"QQQConfig(weight_bits={}, group_size={})"
.
format
(
self
.
weight_bits
,
self
.
group_size
)
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"qqq"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
"""List of filenames to search for in the model directory."""
return
[
"quant_config.json"
,
"quantize_config.json"
,
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"QQQConfig"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"wbits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
return
cls
(
weight_bits
,
group_size
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QQQLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
QQQLinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
QQQLinearMethod
(
LinearMethodBase
):
"""Linear method for QQQ.
Args:
quant_config: The QQQ quantization config.
"""
def
__init__
(
self
,
quant_config
:
QQQConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
if
params_dtype
!=
torch
.
float16
:
raise
ValueError
(
f
"The params dtype must be float16, but got
{
params_dtype
}
"
)
# Validate output_size_per_partition
output_size_per_partition
=
sum
(
output_partition_sizes
)
if
output_size_per_partition
%
self
.
quant_config
.
min_n_threads
!=
0
:
raise
ValueError
(
f
"Weight output_size_per_partition = "
f
"
{
output_size_per_partition
}
is not divisible by "
f
"min_n_threads =
{
self
.
quant_config
.
min_n_threads
}
."
)
if
output_size_per_partition
%
self
.
quant_config
.
pack_factor
!=
0
:
raise
ValueError
(
f
"Weight output_size_per_partition = "
f
"
{
output_size_per_partition
}
is not divisible by "
f
"pack_factor =
{
self
.
quant_config
.
pack_factor
}
."
)
# Validate input_size_per_partition
if
input_size_per_partition
%
self
.
quant_config
.
min_k_threads
!=
0
:
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible by "
f
"min_k_threads =
{
self
.
quant_config
.
min_k_threads
}
."
)
if
(
self
.
quant_config
.
group_size
!=
-
1
and
input_size_per_partition
%
self
.
quant_config
.
group_size
!=
0
):
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible by "
f
"group_size =
{
self
.
quant_config
.
group_size
}
."
)
# Check that we have at least 4 tiles horizontally in the shard
num_tiles_per_perm
=
self
.
quant_config
.
perm_len
//
(
self
.
quant_config
.
tile_size
**
2
)
if
output_size_per_partition
%
num_tiles_per_perm
!=
0
:
raise
ValueError
(
"Each permutation group must reside on the same gpu"
)
# Quantized 4Bit weights packed into Int32.
qweight
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
tile_size
,
output_size_per_partition
*
self
.
quant_config
.
tile_size
//
self
.
quant_config
.
pack_factor
,
device
=
"cuda"
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
"marlin_tile_size"
:
self
.
quant_config
.
tile_size
,
},
)
s_channel
=
Parameter
(
torch
.
empty
(
1
,
output_size_per_partition
,
device
=
"cuda"
,
dtype
=
torch
.
float
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
s_channel
,
{
"input_dim"
:
None
,
"output_dim"
:
1
,
},
)
if
self
.
quant_config
.
group_size
==
-
1
:
s_group
=
Parameter
(
torch
.
tensor
(
[],
device
=
"cuda"
,
dtype
=
torch
.
half
,
),
requires_grad
=
False
,
)
else
:
s_group
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
,
device
=
"cuda"
,
dtype
=
torch
.
half
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
s_group
,
{
"input_dim"
:
None
if
self
.
quant_config
.
group_size
==
-
1
else
0
,
"output_dim"
:
None
if
self
.
quant_config
.
group_size
==
-
1
else
1
,
},
)
# Allocate workspace (Used for internal locking mechanism)
max_workspace_size
=
(
output_size_per_partition
//
self
.
quant_config
.
min_n_threads
)
*
self
.
quant_config
.
max_parallel
workspace
=
Parameter
(
torch
.
zeros
(
max_workspace_size
,
device
=
"cuda"
,
dtype
=
torch
.
int
),
requires_grad
=
False
)
layer
.
register_parameter
(
"B"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"s_channel"
,
s_channel
)
set_weight_attrs
(
s_channel
,
extra_weight_attrs
)
layer
.
register_parameter
(
"s_group"
,
s_group
)
set_weight_attrs
(
s_group
,
extra_weight_attrs
)
layer
.
register_parameter
(
"workspace"
,
workspace
)
set_weight_attrs
(
workspace
,
extra_weight_attrs
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
qweight
=
layer
.
B
s_ch
=
layer
.
s_channel
s_group
=
layer
.
s_group
workspace
=
layer
.
workspace
x_2d
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
size_m
=
x_2d
.
shape
[
0
]
size_k
=
x_2d
.
shape
[
1
]
size_n
=
s_ch
.
shape
[
1
]
x_int8
,
s_tok
=
ops
.
scaled_int8_quant
(
x_2d
)
output_2d
=
ops
.
marlin_qqq_gemm
(
x_int8
,
qweight
,
s_tok
,
s_ch
,
s_group
,
workspace
,
size_m
,
size_n
,
size_k
)
output
=
output_2d
.
view
(
x
.
shape
[:
-
1
]
+
(
output_2d
.
shape
[
1
],
))
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
e661d594
...
@@ -5,6 +5,7 @@ import torch
...
@@ -5,6 +5,7 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
.quant_utils
import
pack_cols
,
unpack_cols
from
.quant_utils
import
pack_cols
,
unpack_cols
...
@@ -13,80 +14,78 @@ GPTQ_MARLIN_MIN_THREAD_N = 64
...
@@ -13,80 +14,78 @@ GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K
=
128
GPTQ_MARLIN_MIN_THREAD_K
=
128
GPTQ_MARLIN_MAX_PARALLEL
=
16
GPTQ_MARLIN_MAX_PARALLEL
=
16
MARLIN_SUPPORTED_NUM_BITS
=
[
4
,
8
]
MARLIN_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
MARLIN_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
# In case there is a performance issue with Marlin, the variable below can be
# changed to False, which allows Marlin to perform global reductions in fp16
# precision (instead of fp32), and therefore, save on some memory movements.
USE_FP32_REDUCE_DEFAULT
=
True
def
_check_marlin_supported
(
num_bits
:
int
,
group_size
:
int
,
is_sym
:
bool
,
min_capability
:
Optional
[
int
],
# For binary size and compile time, we don't support the same types for with and
has_zp
:
bool
)
->
Tuple
[
bool
,
Optional
[
str
]]:
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
if
min_capability
is
not
None
:
# TODO: we may want to move this into the C++ so its closer to the actual impl
def
query_marlin_supported_quant_types
(
has_zp
:
bool
,
min_capability
:
Optional
[
int
]
=
None
):
if
min_capability
is
None
:
major
,
minor
=
current_platform
.
get_device_capability
()
major
,
minor
=
current_platform
.
get_device_capability
()
device_capability
=
major
*
10
+
minor
min_capability
=
major
*
10
+
minor
if
device_capability
<
min_capability
:
return
(
False
,
"Marlin does not support device_capability = {}"
", the min_capability required is {}"
.
format
(
device_capability
,
min_capability
))
if
num_bits
not
in
MARLIN_SUPPORTED_NUM_BITS
:
return
(
False
,
"Marlin does not support weight_bits = {}. "
"Only weight_bits = {} are supported."
.
format
(
num_bits
,
MARLIN_SUPPORTED_NUM_BITS
))
if
group_size
not
in
MARLIN_SUPPORTED_GROUP_SIZES
:
return
(
False
,
"Marlin does not support group_size = {}. Only "
"group_sizes = {} are supported."
.
format
(
group_size
,
MARLIN_SUPPORTED_GROUP_SIZES
))
if
not
has_zp
and
not
is_sym
:
return
(
False
,
"Marlin without zero_points must have symmetric quantization"
)
return
True
,
None
if
min_capability
<
80
:
return
[]
if
has_zp
:
# AWQ style, unsigned + runtime zero-point
return
[
scalar_types
.
uint4
,
scalar_types
.
uint8
]
else
:
# GPTQ style, unsigned + symmetric bias
# TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
# to add `scalar_types.float8_e4m3fn` here
return
[
scalar_types
.
uint4b8
,
scalar_types
.
uint8b128
]
def
check_gptq_marlin_supported
(
num_bits
:
int
,
group_size
:
int
,
is_sym
:
bool
,
min_capability
:
int
)
->
bool
:
cond
,
_
=
_check_marlin_supported
(
num_bits
,
group_size
,
is_sym
,
min_capability
,
has_zp
=
False
)
return
cond
def
_check_marlin_supported
(
quant_type
:
ScalarType
,
group_size
:
Optional
[
int
],
has_zp
:
bool
,
min_capability
:
Optional
[
int
]
=
None
)
->
Tuple
[
bool
,
Optional
[
str
]]:
def
check_awq_marlin_supported
(
num_bits
:
int
,
group_size
:
int
,
has_zp
:
bool
,
if
min_capability
is
None
:
min_capability
:
int
)
->
bool
:
major
,
minor
=
current_platform
.
get_device_capability
()
cond
,
_
=
_check_marlin_supported
(
num_bits
,
min_capability
=
major
*
10
+
minor
group_size
,
False
,
min_capability
,
has_zp
=
has_zp
)
return
cond
supported_types
=
query_marlin_supported_quant_types
(
has_zp
,
min_capability
)
def
verify_gptq_marlin_supported
(
num_bits
:
int
,
group_size
:
int
,
if
quant_type
not
in
supported_types
:
is_sym
:
bool
)
->
None
:
return
(
False
,
f
"Marlin does not support weight_bits =
{
quant_type
}
. "
cond
,
err_msg
=
_check_marlin_supported
(
num_bits
,
f
"Only types =
{
supported_types
}
"
group_size
,
f
"are supported (for group_size =
{
group_size
}
, "
is_sym
,
f
"min_capability =
{
min_capability
}
, zp =
{
has_zp
}
)."
)
min_capability
=
None
,
if
(
group_size
is
None
or
group_size
not
in
MARLIN_SUPPORTED_GROUP_SIZES
):
has_zp
=
False
)
return
(
False
,
f
"Marlin does not support group_size =
{
group_size
}
. "
if
not
cond
:
f
"Only group_sizes =
{
MARLIN_SUPPORTED_GROUP_SIZES
}
"
assert
err_msg
is
not
None
"are supported."
)
raise
ValueError
(
"GPTQ"
+
err_msg
)
return
True
,
None
def
check_marlin_supported
(
quant_type
:
ScalarType
,
group_size
:
int
,
has_zp
:
bool
=
False
,
min_capability
:
Optional
[
int
]
=
None
)
->
bool
:
cond
,
_
=
_check_marlin_supported
(
quant_type
,
group_size
,
has_zp
,
min_capability
)
return
cond
def
verify_awq_marlin_supported
(
num_bits
:
int
,
group_size
:
int
,
def
verify_marlin_supported
(
quant_type
:
ScalarType
,
has_zp
:
bool
)
->
None
:
group_size
:
int
,
cond
,
err_msg
=
_check_marlin_supported
(
num_bits
,
has_zp
:
bool
=
False
)
->
None
:
group_size
,
cond
,
err_msg
=
_check_marlin_supported
(
quant_type
,
group_size
,
has_zp
)
False
,
min_capability
=
None
,
has_zp
=
has_zp
)
if
not
cond
:
if
not
cond
:
assert
err_msg
is
not
None
assert
err_msg
is
not
None
raise
ValueError
(
"AWQ"
+
err_msg
)
raise
ValueError
(
err_msg
)
def
verify_marlin_supports_shape
(
output_size_per_partition
:
int
,
def
verify_marlin_supports_shape
(
output_size_per_partition
:
int
,
...
@@ -240,11 +239,12 @@ def apply_gptq_marlin_linear(
...
@@ -240,11 +239,12 @@ def apply_gptq_marlin_linear(
g_idx
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
g_idx_sort_indices
:
torch
.
Tensor
,
g_idx_sort_indices
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_bits
:
int
,
wtype
:
ScalarType
,
output_size_per_partition
:
int
,
output_size_per_partition
:
int
,
input_size_per_partition
:
int
,
input_size_per_partition
:
int
,
is_k_full
:
bool
,
is_k_full
:
bool
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
use_fp32_reduce
:
bool
=
USE_FP32_REDUCE_DEFAULT
)
->
torch
.
Tensor
:
reshaped_x
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
reshaped_x
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
out_shape
=
input
.
shape
[:
-
1
]
+
(
output_size_per_partition
,
)
out_shape
=
input
.
shape
[:
-
1
]
+
(
output_size_per_partition
,
)
...
@@ -255,12 +255,13 @@ def apply_gptq_marlin_linear(
...
@@ -255,12 +255,13 @@ def apply_gptq_marlin_linear(
g_idx
,
g_idx
,
g_idx_sort_indices
,
g_idx_sort_indices
,
workspace
,
workspace
,
num_bits
,
wtype
,
size_m
=
reshaped_x
.
shape
[
0
],
size_m
=
reshaped_x
.
shape
[
0
],
size_n
=
output_size_per_partition
,
size_n
=
output_size_per_partition
,
size_k
=
input_size_per_partition
,
size_k
=
input_size_per_partition
,
is_k_full
=
is_k_full
,
is_k_full
=
is_k_full
,
has_zp
=
False
)
has_zp
=
False
,
use_fp32_reduce
=
use_fp32_reduce
)
if
bias
is
not
None
:
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
output
.
add_
(
bias
)
# In-place add
...
@@ -276,10 +277,11 @@ def apply_awq_marlin_linear(
...
@@ -276,10 +277,11 @@ def apply_awq_marlin_linear(
g_idx
:
torch
.
Tensor
,
g_idx
:
torch
.
Tensor
,
g_idx_sort_indices
:
torch
.
Tensor
,
g_idx_sort_indices
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
num_bits
:
int
,
quant_type
:
ScalarType
,
output_size_per_partition
:
int
,
output_size_per_partition
:
int
,
input_size_per_partition
:
int
,
input_size_per_partition
:
int
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
use_fp32_reduce
:
bool
=
USE_FP32_REDUCE_DEFAULT
)
->
torch
.
Tensor
:
reshaped_x
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
reshaped_x
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
out_shape
=
input
.
shape
[:
-
1
]
+
(
output_size_per_partition
,
)
out_shape
=
input
.
shape
[:
-
1
]
+
(
output_size_per_partition
,
)
...
@@ -290,12 +292,13 @@ def apply_awq_marlin_linear(
...
@@ -290,12 +292,13 @@ def apply_awq_marlin_linear(
g_idx
,
g_idx
,
g_idx_sort_indices
,
g_idx_sort_indices
,
workspace
,
workspace
,
num_bits
,
quant_type
,
size_m
=
reshaped_x
.
shape
[
0
],
size_m
=
reshaped_x
.
shape
[
0
],
size_n
=
output_size_per_partition
,
size_n
=
output_size_per_partition
,
size_k
=
input_size_per_partition
,
size_k
=
input_size_per_partition
,
is_k_full
=
True
,
is_k_full
=
True
,
has_zp
=
True
)
has_zp
=
True
,
use_fp32_reduce
=
use_fp32_reduce
)
if
bias
is
not
None
:
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
output
.
add_
(
bias
)
# In-place add
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
View file @
e661d594
...
@@ -46,7 +46,8 @@ def apply_fp8_marlin_linear(
...
@@ -46,7 +46,8 @@ def apply_fp8_marlin_linear(
return
output
.
reshape
(
out_shape
)
return
output
.
reshape
(
out_shape
)
def
prepare_fp8_layer_for_marlin
(
layer
:
torch
.
nn
.
Module
)
->
None
:
def
prepare_fp8_layer_for_marlin
(
layer
:
torch
.
nn
.
Module
,
strategy
:
str
=
"tensor"
)
->
None
:
print_warning_once
(
print_warning_once
(
"Your GPU does not have native support for FP8 computation but "
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"FP8 quantization is being used. Weight-only FP8 compression will "
...
@@ -74,16 +75,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
...
@@ -74,16 +75,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
layer
.
weight
=
torch
.
nn
.
Parameter
(
marlin_qweight
,
requires_grad
=
False
)
layer
.
weight
=
torch
.
nn
.
Parameter
(
marlin_qweight
,
requires_grad
=
False
)
# WEIGHT SCALES
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
scales
=
layer
.
weight_scale
.
to
(
layer
.
orig_dtype
)
# expand it to channelwise
is_channelwise
=
(
len
(
layer
.
weight_scale
.
shape
)
>
0
and
layer
.
weight_scale
.
shape
[
0
]
==
part_size_n
)
if
is_channelwise
:
scales
=
layer
.
weight_scale
else
:
scales
=
layer
.
weight_scale
.
repeat
(
1
,
part_size_n
)
scales
=
scales
.
to
(
layer
.
orig_dtype
).
to
(
device
)
# Permute scales
# Permute scales
marlin_scales
=
marlin_permute_scales
(
s
=
scales
,
marlin_scales
=
marlin_permute_scales
(
s
=
scales
,
size_k
=
part_size_k
,
size_k
=
part_size_k
,
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils_test.py
View file @
e661d594
...
@@ -5,10 +5,12 @@ from typing import List
...
@@ -5,10 +5,12 @@ from typing import List
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
vllm.scalar_type
import
ScalarType
from
.marlin_utils
import
(
GPTQ_MARLIN_TILE
,
marlin_permute_scales
,
from
.marlin_utils
import
(
GPTQ_MARLIN_TILE
,
marlin_permute_scales
,
marlin_zero_points
)
marlin_zero_points
)
from
.quant_utils
import
(
get_pack_factor
,
quantize_weights
,
from
.quant_utils
import
(
get_pack_factor
,
gptq_
quantize_weights
,
quantize_weights
_with_zp
,
sort_weights
)
quantize_weights
,
sort_weights
)
class
MarlinWorkspace
:
class
MarlinWorkspace
:
...
@@ -90,9 +92,10 @@ def get_weight_perm(num_bits: int):
...
@@ -90,9 +92,10 @@ def get_weight_perm(num_bits: int):
return
perm
return
perm
def
marlin_quantize
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
def
marlin_quantize
(
w
:
torch
.
Tensor
,
quant_type
:
ScalarType
,
group_size
:
int
,
act_order
:
bool
):
act_order
:
bool
):
size_k
,
size_n
=
w
.
shape
size_k
,
size_n
=
w
.
shape
num_bits
=
quant_type
.
size_bits
# Normalize group_size
# Normalize group_size
if
group_size
==
-
1
:
if
group_size
==
-
1
:
...
@@ -100,8 +103,8 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int,
...
@@ -100,8 +103,8 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int,
assert
group_size
<=
size_k
assert
group_size
<=
size_k
# Quantize (and apply act_order if provided)
# Quantize (and apply act_order if provided)
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
=
quantize_weights
(
w
,
num_bits
,
group_size
,
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
=
gptq_
quantize_weights
(
act_order
)
w
,
quant_type
,
group_size
,
act_order
)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing
# increasing
...
@@ -122,7 +125,8 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int,
...
@@ -122,7 +125,8 @@ def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int,
return
res_list
return
res_list
def
awq_marlin_quantize
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
):
def
awq_marlin_quantize
(
w
:
torch
.
Tensor
,
quant_type
:
ScalarType
,
group_size
:
int
):
size_k
,
size_n
=
w
.
shape
size_k
,
size_n
=
w
.
shape
# Normalize group_size
# Normalize group_size
...
@@ -135,13 +139,18 @@ def awq_marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int):
...
@@ -135,13 +139,18 @@ def awq_marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int):
num_groups
=
size_k
//
group_size
num_groups
=
size_k
//
group_size
# Quantize with zp
# Quantize with zp
w_ref
,
q_w
,
s
,
zp
=
quantize_weights_with_zp
(
w
,
num_bits
,
group_size
)
w_ref
,
q_w
,
s
,
zp
=
quantize_weights
(
w
,
quant_type
,
group_size
,
zero_points
=
True
)
# Reformat to marlin
# Reformat to marlin
weight_perm
=
get_weight_perm
(
num_bits
)
weight_perm
=
get_weight_perm
(
quant_type
.
size_bits
)
marlin_q_w
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
weight_perm
)
marlin_q_w
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
quant_type
.
size_bits
,
weight_perm
)
marlin_s
=
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
)
marlin_s
=
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
)
marlin_zp
=
marlin_zero_points
(
zp
,
num_groups
,
size_n
,
num_bits
)
marlin_zp
=
marlin_zero_points
(
zp
,
num_groups
,
size_n
,
quant_type
.
size_bits
)
# Create result
# Create result
res_list
=
[
w_ref
,
marlin_q_w
,
marlin_s
,
marlin_zp
]
res_list
=
[
w_ref
,
marlin_q_w
,
marlin_s
,
marlin_zp
]
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py
View file @
e661d594
...
@@ -6,8 +6,10 @@ from typing import List
...
@@ -6,8 +6,10 @@ from typing import List
import
numpy
import
numpy
import
torch
import
torch
from
vllm.scalar_type
import
ScalarType
from
.marlin_utils_test
import
marlin_weights
from
.marlin_utils_test
import
marlin_weights
from
.quant_utils
import
quantize_weights
from
.quant_utils
import
gptq_
quantize_weights
# This is PyTorch implementation of main part of reorder_meta()
# This is PyTorch implementation of main part of reorder_meta()
...
@@ -348,13 +350,11 @@ def check_24(w, num_rows_to_sample=50, _verbose=False):
...
@@ -348,13 +350,11 @@ def check_24(w, num_rows_to_sample=50, _verbose=False):
print
(
f
"
{
non_24_segments
}
/
{
total_segments
}
do not have 2:4 structure."
)
print
(
f
"
{
non_24_segments
}
/
{
total_segments
}
do not have 2:4 structure."
)
def
compress_quantized_24_weight
(
q_24
,
size_k
,
size_n
,
num_bits
):
def
compress_quantized_24_weight
(
q_24
,
size_k
,
size_n
,
wtype
:
ScalarType
):
assert
q_24
.
shape
==
(
size_k
,
size_n
)
assert
q_24
.
shape
==
(
size_k
,
size_n
)
# Remove zp to normalize over 0
# Remove bias to normalize over 0
max_q_val
=
(
1
<<
num_bits
)
-
1
q_24_no_zp
=
q_24
-
wtype
.
bias
zp
=
(
max_q_val
+
1
)
//
2
q_24_no_zp
=
q_24
-
zp
# Compress
# Compress
q_24_no_zp
=
q_24_no_zp
.
t
().
contiguous
()
q_24_no_zp
=
q_24_no_zp
.
t
().
contiguous
()
...
@@ -362,8 +362,8 @@ def compress_quantized_24_weight(q_24, size_k, size_n, num_bits):
...
@@ -362,8 +362,8 @@ def compress_quantized_24_weight(q_24, size_k, size_n, num_bits):
q_24_no_zp
)
q_24_no_zp
)
q_24_no_zp_comp
=
q_24_no_zp_comp
.
t
().
contiguous
()
q_24_no_zp_comp
=
q_24_no_zp_comp
.
t
().
contiguous
()
# Restore
zp
# Restore
bias
q_24_comp
=
q_24_no_zp_comp
+
zp
q_24_comp
=
q_24_no_zp_comp
+
wtype
.
bias
# Resize meta to its actual shape (without moving any data)
# Resize meta to its actual shape (without moving any data)
meta
=
meta
.
resize_
(
meta
.
shape
[
1
]
//
2
,
meta
.
shape
[
0
]
*
2
)
meta
=
meta
.
resize_
(
meta
.
shape
[
1
]
//
2
,
meta
.
shape
[
0
]
*
2
)
...
@@ -427,7 +427,7 @@ def marlin_permute_scales_24(s: torch.Tensor, size_k: int, size_n: int,
...
@@ -427,7 +427,7 @@ def marlin_permute_scales_24(s: torch.Tensor, size_k: int, size_n: int,
def
marlin_24_quantize
(
def
marlin_24_quantize
(
w
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
num_bits
:
int
,
quant_type
:
ScalarType
,
group_size
:
int
,
group_size
:
int
,
):
):
size_k
,
size_n
=
w
.
shape
size_k
,
size_n
=
w
.
shape
...
@@ -441,20 +441,18 @@ def marlin_24_quantize(
...
@@ -441,20 +441,18 @@ def marlin_24_quantize(
w_24
,
mask_24
=
inject_24
(
w
,
size_k
,
size_n
)
w_24
,
mask_24
=
inject_24
(
w
,
size_k
,
size_n
)
# Quantize
# Quantize
w_24_ref
,
q_w_24
,
s
,
g_idx
,
rand_perm
=
quantize_weights
(
w_24
,
w_24_ref
,
q_w_24
,
s
,
g_idx
,
rand_perm
=
gptq_quantize_weights
(
num_bits
,
w_24
,
quant_type
,
group_size
,
act_order
=
False
)
group_size
,
act_order
=
False
)
# Compress quantized weight
# Compress quantized weight
q_w_24_comp
,
meta
=
compress_quantized_24_weight
(
q_w_24
,
size_k
,
size_n
,
q_w_24_comp
,
meta
=
compress_quantized_24_weight
(
q_w_24
,
size_k
,
size_n
,
num_bits
)
quant_type
)
size_k_comp
=
size_k
//
2
size_k_comp
=
size_k
//
2
# Reformat to marlin
# Reformat to marlin
weight_perm
=
get_weight_perm_24
(
num
_bits
)
weight_perm
=
get_weight_perm_24
(
quant_type
.
size
_bits
)
marlin_24_q_w_comp
=
marlin_weights
(
q_w_24_comp
,
size_k_comp
,
size_n
,
marlin_24_q_w_comp
=
marlin_weights
(
q_w_24_comp
,
size_k_comp
,
size_n
,
num
_bits
,
weight_perm
)
quant_type
.
size
_bits
,
weight_perm
)
marlin_24_s
=
marlin_permute_scales_24
(
s
,
size_k
,
size_n
,
group_size
)
marlin_24_s
=
marlin_permute_scales_24
(
s
,
size_k
,
size_n
,
group_size
)
# Create result
# Create result
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py
0 → 100644
View file @
e661d594
from
typing
import
List
import
numpy
import
torch
from
.marlin_utils_test
import
marlin_permute_weights
from
.quant_utils
import
get_pack_factor
,
qqq_quantize_weights
def
marlin_qqq_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
perm
,
group_size
):
# Permute
q_w
=
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
perm
)
# Pack
pack_factor
=
get_pack_factor
(
num_bits
)
orig_device
=
q_w
.
device
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_packed
=
numpy
.
zeros
((
q_w
.
shape
[
0
],
q_w
.
shape
[
1
]
//
pack_factor
),
dtype
=
numpy
.
uint32
)
if
group_size
==
size_k
:
for
i
in
range
(
pack_factor
):
q_packed
|=
(
q_w
[:,
i
::
pack_factor
]
&
0xF
)
<<
num_bits
*
i
else
:
for
i
in
range
(
pack_factor
):
q_packed
|=
q_w
[:,
i
::
pack_factor
]
<<
num_bits
*
i
q_packed
=
torch
.
from_numpy
(
q_packed
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
return
q_packed
def
get_qqq_scale_perms
():
scale_perm
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
:
List
[
int
]
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
scale_perm
,
scale_perm_single
# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501
def
get_qqq_weight_perm
(
num_bits
:
int
,
quant_type
:
str
):
perm_list
:
List
[
int
]
=
[]
for
i
in
range
(
32
):
perm1
:
List
[
int
]
=
[]
col
=
i
//
4
for
block
in
[
0
,
1
]:
for
row
in
[
4
*
(
i
%
4
),
4
*
(
i
%
4
)
+
1
,
4
*
(
i
%
4
)
+
2
,
4
*
(
i
%
4
)
+
3
,
]:
perm1
.
append
(
16
*
row
+
col
+
8
*
block
)
for
j
in
range
(
4
):
perm_list
.
extend
([
p
+
256
*
j
for
p
in
perm1
])
perm
=
numpy
.
array
(
perm_list
)
assert
quant_type
in
[
"per-channel"
,
"per-group"
],
"not supported quantization type"
if
num_bits
==
4
:
if
quant_type
==
"per-channel"
:
interleave
=
numpy
.
array
([
4
,
0
,
5
,
1
,
6
,
2
,
7
,
3
])
else
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
else
:
raise
Exception
(
"num_bits must be 4, got {}"
.
format
(
num_bits
))
perm
=
perm
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
perm
=
torch
.
from_numpy
(
perm
)
return
perm
def
marlin_qqq_permute_scales
(
s_group
,
s_channel
,
size_k
,
size_n
,
group_size
):
scale_perm
,
scale_perm_single
=
get_qqq_scale_perms
()
if
group_size
<
size_k
and
group_size
!=
-
1
:
s_group
=
s_group
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
s_channel
=
s_channel
.
reshape
(
(
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s_group
=
s_group
.
reshape
((
-
1
,
size_n
)).
contiguous
()
else
:
s_channel
=
s_channel
.
reshape
(
(
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s_channel
=
s_channel
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
s_group
,
s_channel
def
marlin_qqq_quantize
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
):
size_k
,
size_n
=
w
.
shape
# Normalize group_size
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
quant_type
=
"per-channel"
if
group_size
==
size_k
else
"per-group"
# Quantize
w_ref
,
q_w
,
s_group
,
s_channel
=
qqq_quantize_weights
(
w
,
num_bits
,
group_size
)
# Reformat to marlin_qqq
weight_perm
=
get_qqq_weight_perm
(
num_bits
,
quant_type
)
marlin_qqq_q_w
=
marlin_qqq_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
weight_perm
,
group_size
)
marlin_qqq_s_group
,
marlin_qqq_s_channel
=
marlin_qqq_permute_scales
(
s_group
,
s_channel
,
size_k
,
size_n
,
group_size
)
# Create result
res_list
=
[
w_ref
,
marlin_qqq_q_w
,
marlin_qqq_s_group
,
marlin_qqq_s_channel
]
for
i
in
range
(
len
(
res_list
)):
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
return
res_list
vllm/model_executor/layers/quantization/utils/quant_utils.py
View file @
e661d594
"""This file is used for /tests and /benchmarks"""
"""This file is used for /tests and /benchmarks"""
from
typing
import
List
import
numpy
import
numpy
import
torch
import
torch
SUPPORTED_NUM_BITS
=
[
4
,
8
]
from
vllm.model_executor.layers.quantization.qqq
import
(
MARLIN_QQQ_SUPPORTED_NUM_BITS
)
from
vllm.scalar_type
import
ScalarType
,
scalar_types
SUPPORTED_GPTQ_QUANT_TYPES
=
[
scalar_types
.
uint4b8
,
scalar_types
.
uint8b128
]
SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
# Note: this is a hack. We should update each model to register the
# stacked params and get it from there instead in a future PR.
# fused_name: List[shard_name]
FUSED_LAYER_NAME_MAPPING
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
]
}
def
is_layer_skipped
(
prefix
:
str
,
ignored_layers
:
List
[
str
])
->
bool
:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
proj_name
=
prefix
.
split
(
"."
)[
-
1
]
if
proj_name
in
FUSED_LAYER_NAME_MAPPING
:
shard_prefixes
=
[
prefix
.
replace
(
proj_name
,
shard_proj_name
)
for
shard_proj_name
in
FUSED_LAYER_NAME_MAPPING
[
proj_name
]
]
is_skipped
=
None
for
shard_prefix
in
shard_prefixes
:
is_shard_skipped
=
shard_prefix
in
ignored_layers
if
is_skipped
is
None
:
is_skipped
=
is_shard_skipped
elif
is_shard_skipped
!=
is_skipped
:
raise
ValueError
(
f
"Detected some but not all shards of
{
prefix
}
"
"are quantized. All shards of fused layers "
"to have the same precision."
)
else
:
is_skipped
=
prefix
in
ignored_layers
assert
is_skipped
is
not
None
return
is_skipped
def
get_pack_factor
(
num_bits
):
def
get_pack_factor
(
num_bits
):
assert
num_bits
in
SUPPORTED_NUM_BITS
,
f
"Unsupported num_bits =
{
num_bits
}
"
assert
32
%
num_bits
==
0
,
f
"Unsupported num_bits =
{
num_bits
}
"
return
32
//
num_bits
return
32
//
num_bits
...
@@ -36,24 +78,23 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
...
@@ -36,24 +78,23 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
)
)
def
quantize_weights
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
def
quantize_weights
(
w
:
torch
.
Tensor
,
act_order
:
bool
):
quant_type
:
ScalarType
,
group_size
:
int
,
zero_points
:
bool
=
False
):
assert
quant_type
.
is_integer
(),
\
"Floating point quantization may work but has not been tested"
orig_device
=
w
.
device
orig_device
=
w
.
device
orig_type
=
w
.
dtype
size_k
,
size_n
=
w
.
shape
size_k
,
size_n
=
w
.
shape
assert
w
.
is_floating_point
(),
"w must be float"
assert
w
.
is_floating_point
(),
"w must be float"
assert
num_bits
in
SUPPORTED_NUM_BITS
,
f
"Unsupported num_bits =
{
num_bits
}
"
assert
group_size
in
SUPPORTED_GROUP_SIZES
+
[
size_k
],
f
"Unsupported groupsize =
{
group_size
}
"
if
group_size
==
-
1
:
if
group_size
==
-
1
:
group_size
=
size_k
group_size
=
size_k
assert
group_size
<=
size_k
assert
group_size
<=
size_k
max_q_val
=
2
**
num_bits
-
1
half_q_val
=
(
max_q_val
+
1
)
//
2
# Reshape to [groupsize, -1]
# Reshape to [groupsize, -1]
if
group_size
<
size_k
:
if
group_size
<
size_k
:
w
=
w
.
reshape
((
-
1
,
group_size
,
size_n
))
w
=
w
.
reshape
((
-
1
,
group_size
,
size_n
))
...
@@ -61,16 +102,34 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
...
@@ -61,16 +102,34 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
w
=
w
.
reshape
((
group_size
,
-
1
))
w
=
w
.
reshape
((
group_size
,
-
1
))
# Compute scale for each group
# Compute scale for each group
s
=
torch
.
max
(
torch
.
abs
(
w
),
0
,
keepdim
=
True
)[
0
]
max_val
=
torch
.
max
(
w
,
0
,
keepdim
=
True
).
values
s
*=
2
/
max_q_val
# 2 => symmetric
min_val
=
torch
.
min
(
w
,
0
,
keepdim
=
True
).
values
max_q_val
=
quant_type
.
max
()
min_q_val
=
quant_type
.
min
()
if
zero_points
:
assert
not
quant_type
.
is_signed
()
and
quant_type
.
max
()
>
0
w_s
=
(
max_val
-
min_val
).
clamp
(
min
=
1e-5
)
/
quant_type
.
max
()
maybe_w_zp
=
torch
.
round
(
torch
.
abs
(
min_val
/
w_s
))
\
.
clamp
(
min_q_val
,
max_q_val
).
int
()
else
:
# If the bias is such that there are no possible negative/positive
# values, set the max value to inf to avoid divide by 0
w_s
=
torch
.
max
(
abs
(
max_val
/
(
max_q_val
if
max_q_val
!=
0
else
torch
.
inf
)),
abs
(
min_val
/
(
min_q_val
if
min_q_val
!=
0
else
torch
.
inf
)))
maybe_w_zp
=
None
# Quantize
# Quantize
q_w
=
torch
.
round
(
w
/
s
).
int
()
w_q
=
torch
.
round
(
w
/
w_s
).
int
()
+
(
maybe_w_zp
if
zero_points
else
0
)
q_w
+=
half_q_val
w_q
=
torch
.
clamp
(
w_q
,
min_q_val
,
max_q_val
)
q_w
=
torch
.
clamp
(
q_w
,
0
,
max_q_val
)
# Compute ref (dequantized)
# Compute ref (dequantized)
w_ref
=
(
q_w
-
half_q_val
).
half
()
*
s
w_ref
=
(
w_q
-
(
maybe_w_zp
if
zero_points
else
0
)).
to
(
orig_type
)
*
w_s
if
quant_type
.
has_bias
():
w_q
+=
quant_type
.
bias
# Restore original shapes
# Restore original shapes
if
group_size
<
size_k
:
if
group_size
<
size_k
:
...
@@ -81,10 +140,35 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
...
@@ -81,10 +140,35 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
w
=
w
.
reshape
((
size_k
,
size_n
)).
contiguous
()
w
=
w
.
reshape
((
size_k
,
size_n
)).
contiguous
()
return
w
return
w
q_w
=
reshape_w
(
q_w
)
w_q
=
reshape_w
(
w_q
)
w_ref
=
reshape_w
(
w_ref
)
w_ref
=
reshape_w
(
w_ref
)
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
w_s
=
w_s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
if
zero_points
:
maybe_w_zp
=
maybe_w_zp
.
reshape
((
-
1
,
size_n
)).
contiguous
()
maybe_w_zp
=
maybe_w_zp
.
to
(
device
=
orig_device
)
return
(
w_ref
.
to
(
device
=
orig_device
),
w_q
.
to
(
device
=
orig_device
),
w_s
.
to
(
device
=
orig_device
),
maybe_w_zp
,
)
def
gptq_quantize_weights
(
w
:
torch
.
Tensor
,
quant_type
:
ScalarType
,
group_size
:
int
,
act_order
:
bool
):
size_k
,
_
=
w
.
shape
assert
w
.
is_floating_point
(),
"w must be float"
assert
quant_type
in
SUPPORTED_GPTQ_QUANT_TYPES
,
\
f
"Unsupported gptq type =
{
quant_type
}
"
assert
group_size
in
SUPPORTED_GROUP_SIZES
+
[
size_k
],
f
"Unsupported groupsize =
{
group_size
}
"
w_ref
,
w_q
,
w_s
,
_
=
quantize_weights
(
w
,
quant_type
,
group_size
)
# Apply act_order
# Apply act_order
g_idx
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w
.
device
)
g_idx
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w
.
device
)
...
@@ -95,23 +179,20 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
...
@@ -95,23 +179,20 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
),
"For act_order, groupsize = {} must be less than size_k = {}"
.
format
(
),
"For act_order, groupsize = {} must be less than size_k = {}"
.
format
(
group_size
,
size_k
)
group_size
,
size_k
)
w_ref
,
q_w
,
g_idx
,
rand_perm
=
permute_rows
(
q_w
,
w_ref
,
group_size
)
w_ref
,
w_q
,
g_idx
,
rand_perm
=
permute_rows
(
w_q
,
w_ref
,
group_size
)
return
(
return
w_ref
,
w_q
,
w_s
,
g_idx
,
rand_perm
w_ref
.
to
(
device
=
orig_device
),
q_w
.
to
(
device
=
orig_device
),
s
.
to
(
device
=
orig_device
),
g_idx
.
to
(
device
=
orig_device
),
rand_perm
.
to
(
device
=
orig_device
),
)
def
quantize_weights_with_zp
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
):
# QQQ employs different quant schemes for per-group and
# per-channel quantization.
def
qqq_quantize_weights
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
):
orig_device
=
w
.
device
orig_device
=
w
.
device
size_k
,
size_n
=
w
.
shape
size_k
,
size_n
=
w
.
shape
assert
w
.
is_floating_point
(),
"w must be float"
assert
w
.
is_floating_point
(),
"w must be float"
assert
num_bits
in
SUPPORTED_NUM_BITS
,
f
"Unsupported num_bits =
{
num_bits
}
"
assert
num_bits
in
MARLIN_QQQ_SUPPORTED_NUM_BITS
,
\
f
"Unsupported num_bits =
{
num_bits
}
"
assert
group_size
in
SUPPORTED_GROUP_SIZES
+
[
assert
group_size
in
SUPPORTED_GROUP_SIZES
+
[
size_k
size_k
],
f
"Unsupported groupsize =
{
group_size
}
"
],
f
"Unsupported groupsize =
{
group_size
}
"
...
@@ -120,33 +201,27 @@ def quantize_weights_with_zp(w: torch.Tensor, num_bits: int, group_size: int):
...
@@ -120,33 +201,27 @@ def quantize_weights_with_zp(w: torch.Tensor, num_bits: int, group_size: int):
group_size
=
size_k
group_size
=
size_k
assert
group_size
<=
size_k
assert
group_size
<=
size_k
max_q_val
=
2
**
num_bits
-
1
min_q_val
=
0
# Reshape to [groupsize, -1]
if
group_size
<
size_k
:
if
group_size
<
size_k
:
# Reshape to [groupsize, -1]
w
=
w
.
reshape
((
-
1
,
group_size
,
size_n
))
w
=
w
.
reshape
((
-
1
,
group_size
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
reshape
((
group_size
,
-
1
))
w
=
w
.
reshape
((
group_size
,
-
1
))
# Compute scale for each group
max_q_val
=
2
**
num_bits
-
1
max
=
torch
.
max
(
w
,
0
,
keepdim
=
True
)[
0
]
half_q_val
=
(
max_q_val
+
1
)
//
2
min
=
torch
.
min
(
w
,
0
,
keepdim
=
True
)[
0
]
s
=
(
max
-
min
).
clamp
(
min
=
1e-5
)
/
max_q_val
# Compute zero-point for each group
zp
=
(
-
torch
.
round
(
min
/
s
)).
clamp
(
min_q_val
,
max_q_val
).
int
()
# Quantize
q_w
=
torch
.
round
(
w
/
s
).
int
()
+
zp
q_w
=
torch
.
clamp
(
q_w
,
min_q_val
,
max_q_val
)
# Compute ref (dequantized)
# Compute scale for each group
w_ref
=
(
q_w
-
zp
).
half
()
*
s
s_group
=
torch
.
max
(
torch
.
abs
(
w
),
0
,
keepdim
=
True
)[
0
]
s_group
*=
2
/
max_q_val
# 2 => symmetric
# Restore original shapes
# Quantize
if
group_size
<
size_k
:
q_w
=
torch
.
round
(
w
/
s_group
).
int
()
q_w
+=
half_q_val
q_w
=
torch
.
clamp
(
q_w
,
0
,
max_q_val
)
# Compute ref (dequantized)
w_ref
=
(
q_w
-
half_q_val
).
half
()
*
s_group
# Restore original shapes
def
reshape_w
(
w
):
def
reshape_w
(
w
):
w
=
w
.
reshape
((
group_size
,
-
1
,
size_n
))
w
=
w
.
reshape
((
group_size
,
-
1
,
size_n
))
w
=
w
.
permute
(
1
,
0
,
2
)
w
=
w
.
permute
(
1
,
0
,
2
)
...
@@ -156,14 +231,39 @@ def quantize_weights_with_zp(w: torch.Tensor, num_bits: int, group_size: int):
...
@@ -156,14 +231,39 @@ def quantize_weights_with_zp(w: torch.Tensor, num_bits: int, group_size: int):
q_w
=
reshape_w
(
q_w
)
q_w
=
reshape_w
(
q_w
)
w_ref
=
reshape_w
(
w_ref
)
w_ref
=
reshape_w
(
w_ref
)
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
# Compute int8 quantization scale for each channel
zp
=
zp
.
reshape
((
-
1
,
size_n
)).
contiguous
()
s_channel
=
torch
.
max
(
torch
.
abs
(
w_ref
),
0
,
keepdim
=
True
)[
0
]
s_channel
/=
127.0
t_int8
=
(
w_ref
/
s_channel
).
round
().
clamp
(
-
128
,
127
).
to
(
torch
.
int8
)
w_ref
=
t_int8
.
half
()
*
s_channel
s_channel
=
s_channel
.
reshape
(
1
,
-
1
).
to
(
dtype
=
torch
.
float
)
# Fuse scales
s_group
=
(
s_group
.
reshape
(
-
1
,
size_n
).
contiguous
()
/
s_channel
).
to
(
dtype
=
torch
.
half
)
else
:
max_q_val
=
2
**
(
num_bits
-
1
)
-
1
# Compute scale for each channel
s_channel
=
torch
.
max
(
torch
.
abs
(
w
),
0
,
keepdim
=
True
)[
0
]
s_channel
/=
max_q_val
# Quantize
q_w
=
torch
.
round
(
w
/
s_channel
).
int
()
q_w
=
torch
.
clamp
(
q_w
,
-
max_q_val
,
max_q_val
)
# Compute ref (dequantized)
w_ref
=
q_w
.
half
()
*
s_channel
s_group
=
torch
.
tensor
([],
dtype
=
torch
.
half
)
# div 2 ** (8 - self.bits)) to offset right shift in unpacking
s_channel
/=
(
2
**
(
8
-
num_bits
))
s_channel
=
s_channel
.
reshape
(
-
1
,
size_n
).
contiguous
().
to
(
torch
.
float
)
return
(
return
(
w_ref
.
to
(
device
=
orig_device
),
w_ref
.
to
(
device
=
orig_device
),
q_w
.
to
(
device
=
orig_device
),
q_w
.
to
(
device
=
orig_device
),
s
.
to
(
device
=
orig_device
),
s
_group
.
to
(
device
=
orig_device
),
zp
.
to
(
device
=
orig_device
),
s_channel
.
to
(
device
=
orig_device
),
)
)
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
e661d594
...
@@ -139,7 +139,7 @@ def apply_fp8_linear(
...
@@ -139,7 +139,7 @@ def apply_fp8_linear(
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
input
,
input
,
input_scale
,
input_scale
,
batch_dim
_padding
=
17
,
num_token
_padding
=
17
,
use_per_token_if_dynamic
=
use_per_token_if_dynamic
)
use_per_token_if_dynamic
=
use_per_token_if_dynamic
)
per_tensor_weights
=
(
weight_scale
.
numel
()
==
1
)
per_tensor_weights
=
(
weight_scale
.
numel
()
==
1
)
...
@@ -177,8 +177,9 @@ def apply_fp8_linear(
...
@@ -177,8 +177,9 @@ def apply_fp8_linear(
output
,
_
=
torch
.
_scaled_mm
(
qinput
,
output
,
_
=
torch
.
_scaled_mm
(
qinput
,
weight
,
weight
,
out_dtype
=
torch
.
float32
)
out_dtype
=
torch
.
float32
)
# Unpad (undo
batch_dim
_padding)
# Unpad (undo
num_token
_padding)
output
=
torch
.
narrow
(
output
,
0
,
0
,
input
.
shape
[
0
])
output
=
torch
.
narrow
(
output
,
0
,
0
,
input
.
shape
[
0
])
x_scale
=
torch
.
narrow
(
x_scale
,
0
,
0
,
input
.
shape
[
0
])
# DQ
# DQ
# C = sw * sx * (X * W) + bias
# C = sw * sx * (X * W) + bias
...
...
vllm/model_executor/layers/rejection_sampler.py
View file @
e661d594
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
import
torch.jit
import
torch.jit
...
@@ -36,7 +36,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
...
@@ -36,7 +36,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
bonus_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
generators
:
List
[
Optional
[
torch
.
Generator
]],
seeded_seqs
:
Optional
[
Dict
[
int
,
torch
.
Generator
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Sample token ids using rejection sampling. This accepts or rejects
"""Sample token ids using rejection sampling. This accepts or rejects
tokens proposed by the draft model using the probability of each token
tokens proposed by the draft model using the probability of each token
...
@@ -66,6 +66,9 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
...
@@ -66,6 +66,9 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
probabilities.
probabilities.
shape = [batch_size, num_speculative_tokens]
shape = [batch_size, num_speculative_tokens]
seeded_seqs: Dict of batch row index to torch generator, for
sequences using seeded generation.
Returns:
Returns:
output_token_ids: The token ids sampled via rejection sampling,
output_token_ids: The token ids sampled via rejection sampling,
or -1 if unable to sample a token because the previous token
or -1 if unable to sample a token because the previous token
...
@@ -83,7 +86,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
...
@@ -83,7 +86,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
target_probs
,
target_probs
,
draft_probs
,
draft_probs
,
draft_token_ids
,
draft_token_ids
,
generator
s
,
seeded_seq
s
,
))
))
output_token_ids
=
self
.
_create_output
(
output_token_ids
=
self
.
_create_output
(
...
@@ -100,7 +103,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
...
@@ -100,7 +103,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
target_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
target_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
generators
:
List
[
Optional
[
torch
.
Generator
]],
seeded_seqs
:
Optional
[
Dict
[
int
,
torch
.
Generator
]],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Perform modified rejection sampling on each sequence.
"""Perform modified rejection sampling on each sequence.
...
@@ -117,23 +120,17 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
...
@@ -117,23 +120,17 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
# shape [batch_size, k]
# shape [batch_size, k]
accepted
=
self
.
_get_accepted
(
target_probs
,
draft_probs
,
accepted
=
self
.
_get_accepted
(
target_probs
,
draft_probs
,
draft_token_ids
,
generator
s
)
draft_token_ids
,
seeded_seq
s
)
recovered_probs
=
self
.
_get_recovered_probs
(
recovered_probs
=
self
.
_get_recovered_probs
(
target_probs
,
draft_probs
).
reshape
(
batch_size
*
k
,
vocab_size
)
target_probs
,
draft_probs
).
reshape
(
batch_size
*
k
,
vocab_size
)
seed_indices
,
non_seed_indices
=
self
.
_split_batch_by_seeded
(
generators
,
k
=
k
)
# NOTE: the recovered_probs are overwritten by this method.
# NOTE: the recovered_probs are overwritten by this method.
recovered_token_ids
=
_multinomial
(
recovered_token_ids
=
_multinomial
(
recovered_probs
,
recovered_probs
,
num_samples
=
1
,
num_samples
=
1
,
k
=
k
,
k
=
k
,
generators
=
generators
,
seeded_seqs
=
seeded_seqs
or
{},
seed_indices
=
seed_indices
,
# this arg is unused when None but torch.jit requires a list
non_seed_indices
=
non_seed_indices
or
[],
).
reshape
(
batch_size
,
k
)
).
reshape
(
batch_size
,
k
)
return
accepted
,
recovered_token_ids
return
accepted
,
recovered_token_ids
...
@@ -143,7 +140,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
...
@@ -143,7 +140,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
target_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
target_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
generators
:
List
[
Optional
[
torch
.
Generator
]],
seeded_seqs
:
Optional
[
Dict
[
int
,
torch
.
Generator
]],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
r
"""Create bool matrix over the proposed draft tokens. If
r
"""Create bool matrix over the proposed draft tokens. If
True, then a token can be accepted, else it should be
True, then a token can be accepted, else it should be
...
@@ -178,24 +175,26 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
...
@@ -178,24 +175,26 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
selected_target_probs
=
target_probs
[
batch_indices
,
probs_indicies
,
selected_target_probs
=
target_probs
[
batch_indices
,
probs_indicies
,
draft_token_ids
]
draft_token_ids
]
seed_indices
,
non_seed_indices
=
self
.
_split_batch_by_seeded
(
if
not
seeded_seqs
:
generators
)
if
len
(
seed_indices
)
==
0
:
uniform_rand
=
torch
.
rand_like
(
selected_target_probs
)
uniform_rand
=
torch
.
rand_like
(
selected_target_probs
)
else
:
else
:
uniform_rand
=
torch
.
empty_like
(
selected_target_probs
)
uniform_rand
=
torch
.
empty_like
(
selected_target_probs
)
for
idx
in
seed_indices
:
non_seeded_indices
=
[]
uniform_rand
[
idx
,
:]
=
torch
.
rand
(
1
,
for
idx
in
range
(
batch_size
):
k
,
generator
=
seeded_seqs
.
get
(
idx
)
dtype
=
self
.
probs_dtype
,
if
generator
is
None
:
device
=
target_probs
.
device
,
non_seeded_indices
.
append
(
idx
)
generator
=
generators
[
idx
])
else
:
uniform_rand
[
idx
,
:]
=
torch
.
rand
(
if
non_seed_indices
:
1
,
uniform_rand
[
non_seed_indices
,
:]
=
torch
.
rand
(
k
,
len
(
non_seed_indices
),
dtype
=
self
.
probs_dtype
,
device
=
target_probs
.
device
,
generator
=
generator
)
if
non_seeded_indices
:
uniform_rand
[
non_seeded_indices
,
:]
=
torch
.
rand
(
len
(
non_seeded_indices
),
k
,
k
,
dtype
=
self
.
probs_dtype
,
dtype
=
self
.
probs_dtype
,
device
=
target_probs
.
device
)
device
=
target_probs
.
device
)
...
@@ -272,27 +271,6 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
...
@@ -272,27 +271,6 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
"""
"""
return
torch
.
finfo
(
self
.
probs_dtype
).
tiny
return
torch
.
finfo
(
self
.
probs_dtype
).
tiny
# partition batch into indices for which a generator is provided
# and indicies for which no generator is provided
@
staticmethod
def
_split_batch_by_seeded
(
generators
:
List
[
Optional
[
torch
.
Generator
]],
k
:
int
=
1
,
)
->
Tuple
[
List
[
int
],
Optional
[
List
[
int
]]]:
if
all
(
generator
is
None
for
generator
in
generators
):
seed_indices
:
List
[
int
]
=
[]
non_seed_indices
:
Optional
[
List
[
int
]]
=
None
else
:
seed_indices
,
non_seed_indices
=
[],
[]
for
i
,
generator
in
enumerate
(
generators
):
if
generator
is
None
:
non_seed_indices
.
extend
(
range
(
k
*
i
,
k
*
(
i
+
1
)))
else
:
seed_indices
.
extend
(
range
(
k
*
i
,
k
*
(
i
+
1
)))
return
seed_indices
,
non_seed_indices
# torch.multinomial forces a GPU<->CPU sync.
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead that skips the sync.
# Therefore, we use an optimized implementation instead that skips the sync.
...
@@ -304,9 +282,7 @@ def _multinomial(
...
@@ -304,9 +282,7 @@ def _multinomial(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
num_samples
:
int
,
num_samples
:
int
,
k
:
int
,
k
:
int
,
generators
:
List
[
Optional
[
torch
.
Generator
]],
seeded_seqs
:
Dict
[
int
,
torch
.
Generator
],
seed_indices
:
List
[
int
],
non_seed_indices
:
List
[
int
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
num_samples
>
1
:
if
num_samples
>
1
:
...
@@ -315,13 +291,20 @@ def _multinomial(
...
@@ -315,13 +291,20 @@ def _multinomial(
probs
=
probs
[:,
None
,
:].
expand
(
probs
.
shape
[
0
],
num_samples
,
probs
=
probs
[:,
None
,
:].
expand
(
probs
.
shape
[
0
],
num_samples
,
probs
.
shape
[
1
]).
contiguous
().
view
(
probs
.
shape
[
1
]).
contiguous
().
view
(
-
1
,
probs
.
shape
[
1
])
-
1
,
probs
.
shape
[
1
])
q
=
torch
.
empty_like
(
probs
)
q
=
torch
.
empty_like
(
probs
)
if
len
(
seed_indices
)
==
0
:
if
not
seeded_seqs
:
q
.
exponential_
(
1.0
)
q
.
exponential_
(
1.0
)
else
:
else
:
q
[
non_seed_indices
].
exponential_
(
1.0
)
non_seeded_indices
:
List
[
int
]
=
[]
for
idx
in
seed_indices
:
start
=
0
q
[
idx
].
exponential_
(
1.0
,
generator
=
generators
[
idx
//
k
])
for
idx
in
range
(
len
(
q
)
//
k
):
end
=
start
+
k
generator
=
seeded_seqs
.
get
(
idx
)
if
generator
is
None
:
non_seeded_indices
.
extend
(
list
(
range
(
start
,
end
)))
else
:
q
[
start
:
end
].
exponential_
(
1.0
,
generator
=
generator
)
start
=
end
q
[
non_seeded_indices
].
exponential_
(
1.0
)
return
probs
.
div_
(
q
).
argmax
(
dim
=
1
).
view
(
-
1
,
num_samples
)
return
probs
.
div_
(
q
).
argmax
(
dim
=
1
).
view
(
-
1
,
num_samples
)
vllm/model_executor/layers/rotary_embedding.py
View file @
e661d594
...
@@ -774,6 +774,7 @@ def get_rope(
...
@@ -774,6 +774,7 @@ def get_rope(
is_neox_style
:
bool
=
True
,
is_neox_style
:
bool
=
True
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
rotary_percent
:
float
=
1.0
,
)
->
RotaryEmbedding
:
)
->
RotaryEmbedding
:
if
dtype
is
None
:
if
dtype
is
None
:
dtype
=
torch
.
get_default_dtype
()
dtype
=
torch
.
get_default_dtype
()
...
@@ -786,6 +787,8 @@ def get_rope(
...
@@ -786,6 +787,8 @@ def get_rope(
rope_scaling_args
=
tuple
(
rope_scaling_tuple
.
items
())
rope_scaling_args
=
tuple
(
rope_scaling_tuple
.
items
())
else
:
else
:
rope_scaling_args
=
None
rope_scaling_args
=
None
if
rotary_percent
<
1.0
:
rotary_dim
=
int
(
rotary_dim
*
rotary_percent
)
key
=
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
key
=
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
rope_scaling_args
,
dtype
)
rope_scaling_args
,
dtype
)
if
key
in
_ROPE_DICT
:
if
key
in
_ROPE_DICT
:
...
...
vllm/model_executor/layers/sampler.py
View file @
e661d594
"""A layer that samples the next tokens from the model's outputs."""
"""A layer that samples the next tokens from the model's outputs."""
import
itertools
import
itertools
from
math
import
inf
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.model_executor.layers.ops.sample
import
sample
as
sample_triton
from
vllm.triton_utils
import
HAS_TRITON
if
HAS_TRITON
:
from
vllm.model_executor.layers.ops.sample
import
sample
as
sample_triton
from
vllm.model_executor.sampling_metadata
import
(
SamplingMetadata
,
from
vllm.model_executor.sampling_metadata
import
(
SamplingMetadata
,
SamplingTensors
,
SamplingTensors
,
SequenceGroupToSample
)
SequenceGroupToSample
)
...
@@ -220,7 +225,7 @@ def _apply_min_tokens_penalty(
...
@@ -220,7 +225,7 @@ def _apply_min_tokens_penalty(
seqs_to_penalize
:
List
[
int
]
=
[]
seqs_to_penalize
:
List
[
int
]
=
[]
for
j
,
seq_id
in
enumerate
(
seq_ids
):
for
j
,
seq_id
in
enumerate
(
seq_ids
):
seq_data
=
seq_group
.
seq_data
[
seq_id
]
seq_data
=
seq_group
.
seq_data
[
seq_id
]
if
len
(
seq_data
.
output_token_ids
)
<
min_tokens
:
if
len
(
seq_data
.
output_token_ids
_array
)
<
min_tokens
:
seqs_to_penalize
.
append
(
j
)
seqs_to_penalize
.
append
(
j
)
if
seqs_to_penalize
:
if
seqs_to_penalize
:
...
@@ -774,8 +779,11 @@ def _get_logprobs(
...
@@ -774,8 +779,11 @@ def _get_logprobs(
# The next token ids to get the logprob value from.
# The next token ids to get the logprob value from.
next_token_ids
:
List
[
int
]
=
[]
next_token_ids
:
List
[
int
]
=
[]
# The largest requested number of logprobs. We find logprobs as many as the
# The largest requested number of logprobs. We find logprobs as many as the
# largest num logprobs in this API.
# largest num logprobs in this API. If every logprobs is None, it will be
largest_num_logprobs
=
1
# set to -1.
largest_num_logprobs
=
-
1
# If beam search is enabled.
use_beam_search
=
False
# Select indices to compute logprob from, ranks of token ids, and the top
# Select indices to compute logprob from, ranks of token ids, and the top
# k token ids from logprobs.
# k token ids from logprobs.
...
@@ -808,6 +816,8 @@ def _get_logprobs(
...
@@ -808,6 +816,8 @@ def _get_logprobs(
largest_num_logprobs
=
max
(
largest_num_logprobs
,
largest_num_logprobs
=
max
(
largest_num_logprobs
,
sampling_params
.
logprobs
)
sampling_params
.
logprobs
)
use_beam_search
=
use_beam_search
or
sampling_params
.
use_beam_search
assert
len
(
next_token_ids
)
==
len
(
query_indices
)
assert
len
(
next_token_ids
)
==
len
(
query_indices
)
if
len
(
query_indices
)
==
0
:
if
len
(
query_indices
)
==
0
:
...
@@ -815,35 +825,40 @@ def _get_logprobs(
...
@@ -815,35 +825,40 @@ def _get_logprobs(
empty_prompt_logprob
:
Optional
[
PromptLogprobs
]
=
None
empty_prompt_logprob
:
Optional
[
PromptLogprobs
]
=
None
return
[
empty_prompt_logprob
],
[
empty_sampled_logprob
]
return
[
empty_prompt_logprob
],
[
empty_sampled_logprob
]
query_indices_gpu
=
torch
.
tensor
(
query_indices
,
device
=
logprobs
.
device
)
selected_logprobs
,
ranks
=
None
,
None
next_token_ids_gpu
=
torch
.
tensor
(
next_token_ids
,
device
=
logprobs
.
device
)
top_logprobs
,
top_token_ids
=
None
,
None
# (num_selected_query_tokens, num_logprobs). Note that query_indices can
# If largest_num_logprobs == -1, i.e. no logprobs are requested, we can
# contain duplicates if beam search is enabled.
# skip the whole logprob calculation.
selected_logprobs
=
logprobs
[[
if
largest_num_logprobs
>=
0
or
use_beam_search
:
query_indices_gpu
,
query_indices_gpu
=
torch
.
tensor
(
query_indices
,
device
=
logprobs
.
device
)
next_token_ids_gpu
,
next_token_ids_gpu
=
torch
.
tensor
(
next_token_ids
,
]]
device
=
logprobs
.
device
)
ranks
=
_get_ranks
(
logprobs
[
query_indices_gpu
],
# (num_selected_query_tokens, num_logprobs). Note that query_indices can
next_token_ids_gpu
,
# contain duplicates if beam search is enabled.
)
selected_logprobs
=
logprobs
[[
assert
selected_logprobs
.
shape
[
0
]
==
ranks
.
shape
[
0
]
query_indices_gpu
,
next_token_ids_gpu
,
# Logprobs of topk tokens for a batch of sequence groups.
]]
# (num_query_tokens_across_batch).
ranks
=
_get_ranks
(
if
largest_num_logprobs
>
0
:
logprobs
[
query_indices_gpu
],
top_logprobs
,
top_token_ids
=
torch
.
topk
(
logprobs
,
next_token_ids_gpu
,
largest_num_logprobs
,
)
dim
=-
1
)
assert
selected_logprobs
.
shape
[
0
]
==
ranks
.
shape
[
0
]
else
:
top_logprobs
,
top_token_ids
=
None
,
None
selected_logprobs
=
selected_logprobs
.
to
(
'cpu'
)
# We need to compute top k only if there exists logprobs > 0.
ranks
=
ranks
.
to
(
'cpu'
)
if
largest_num_logprobs
>
0
:
if
top_logprobs
is
not
None
and
top_token_ids
is
not
None
:
# Logprobs of topk tokens for a batch of sequence groups.
top_logprobs
=
top_logprobs
.
to
(
'cpu'
)
# (num_query_tokens_across_batch).
top_token_ids
=
top_token_ids
.
to
(
'cpu'
)
top_logprobs
,
top_token_ids
=
torch
.
topk
(
logprobs
,
largest_num_logprobs
,
dim
=-
1
)
top_logprobs
=
top_logprobs
.
to
(
'cpu'
)
top_token_ids
=
top_token_ids
.
to
(
'cpu'
)
selected_logprobs
=
selected_logprobs
.
to
(
'cpu'
)
ranks
=
ranks
.
to
(
'cpu'
)
# Find prompt/sample logprobs.
# Find prompt/sample logprobs.
prompt_logprobs_per_seq_group
:
List
[
Optional
[
PromptLogprobs
]]
=
[]
prompt_logprobs_per_seq_group
:
List
[
Optional
[
PromptLogprobs
]]
=
[]
...
@@ -940,46 +955,53 @@ def _get_sampled_logprob_if_needed(
...
@@ -940,46 +955,53 @@ def _get_sampled_logprob_if_needed(
):
):
"""Compute the sample logprob if needed."""
"""Compute the sample logprob if needed."""
seq_ids
=
seq_group
.
seq_ids
seq_ids
=
seq_group
.
seq_ids
num_logprobs
=
seq_group
.
sampling_params
.
logprobs
or
0
num_logprobs
=
seq_group
.
sampling_params
.
logprobs
use_beam_search
=
seq_group
.
sampling_params
.
use_beam_search
sampled_logprobs
:
SampleLogprobs
=
[]
sampled_logprobs
:
SampleLogprobs
=
[]
next_token_ids
,
parent_seq_ids
=
sample_result
next_token_ids
,
parent_seq_ids
=
sample_result
if
seq_group
.
do_sample
:
if
seq_group
.
do_sample
:
assert
len
(
next_token_ids
)
>
0
assert
len
(
next_token_ids
)
>
0
# Pre-select items from tensor. tolist() is faster than repetitive
if
num_logprobs
is
None
and
not
use_beam_search
:
# `.item()` calls.
for
next_token_id
in
next_token_ids
:
selected_logprob_items
=
selected_logprobs
[
# Use a dummy logprob
selected_logprobs_idx
:
selected_logprobs_idx
+
sampled_logprobs
.
append
({
next_token_id
:
Logprob
(
inf
)})
len
(
next_token_ids
)].
tolist
()
else
:
rank_items
=
ranks
[
selected_logprobs_idx
:
selected_logprobs_idx
+
# Pre-select items from tensor. tolist() is faster than repetitive
len
(
next_token_ids
)].
tolist
()
# `.item()` calls.
for
idx
,
(
next_token_id
,
selected_logprob_items
=
selected_logprobs
[
parent_id
)
in
enumerate
(
zip
(
next_token_ids
,
parent_seq_ids
)):
selected_logprobs_idx
:
selected_logprobs_idx
+
# Get the logprob of a sampled token.
len
(
next_token_ids
)].
tolist
()
sampled_logprobs_dict
=
{
rank_items
=
ranks
[
selected_logprobs_idx
:
selected_logprobs_idx
+
next_token_id
:
(
selected_logprob_items
[
idx
],
rank_items
[
idx
])
len
(
next_token_ids
)].
tolist
()
}
for
idx
,
(
next_token_id
,
parent_id
)
in
enumerate
(
# Get top K logprobs.
zip
(
next_token_ids
,
parent_seq_ids
)):
if
num_logprobs
>
0
:
# Get the logprob of a sampled token.
top_ids
=
top_token_ids
[
top_logprob_idx
+
sampled_logprobs_dict
=
{
parent_id
,
:
num_logprobs
].
tolist
()
next_token_id
:
top_probs
=
top_logprobs
[
top_logprob_idx
+
(
selected_logprob_items
[
idx
],
rank_items
[
idx
])
parent_id
,
:
num_logprobs
].
tolist
()
}
# Top K is already sorted by rank, so we can use 1 ~
if
num_logprobs
is
not
None
and
num_logprobs
>
0
:
# num_logprobs + 1 for rank.
# Get top K logprobs.
top_ranks
=
range
(
1
,
num_logprobs
+
1
)
top_ids
=
top_token_ids
[
top_logprob_idx
+
sampled_logprobs_dict
.
update
({
parent_id
,
:
num_logprobs
].
tolist
()
top_id
:
(
top_prob
,
rank
)
top_probs
=
top_logprobs
[
for
top_id
,
top_prob
,
rank
in
zip
(
top_ids
,
top_probs
,
top_logprob_idx
+
parent_id
,
:
num_logprobs
].
tolist
()
top_ranks
)
# Top K is already sorted by rank, so we can use 1 ~
# num_logprobs + 1 for rank.
top_ranks
=
range
(
1
,
num_logprobs
+
1
)
sampled_logprobs_dict
.
update
({
top_id
:
(
top_prob
,
rank
)
for
top_id
,
top_prob
,
rank
in
zip
(
top_ids
,
top_probs
,
top_ranks
)
})
sampled_logprobs
.
append
({
token_id
:
Logprob
(
*
logprob_and_rank
)
for
token_id
,
logprob_and_rank
in
sampled_logprobs_dict
.
items
()
})
})
sampled_logprobs
.
append
({
token_id
:
Logprob
(
*
logprob_and_rank
)
for
token_id
,
logprob_and_rank
in
sampled_logprobs_dict
.
items
()
})
# NOTE: This part of code is not intuitive. `selected_logprobs` include
# NOTE: This part of code is not intuitive. `selected_logprobs` include
# logprobs for the current step, which has len(next_token_ids) tokens
# logprobs for the current step, which has len(next_token_ids) tokens
# per sequence group. `logprobs` includes logprobs from the previous
# per sequence group. `logprobs` includes logprobs from the previous
...
...
vllm/model_executor/layers/spec_decode_base_sampler.py
View file @
e661d594
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
typing
import
Lis
t
,
Optional
from
typing
import
Dic
t
,
Optional
import
torch
import
torch
import
torch.jit
import
torch.jit
...
@@ -237,6 +237,6 @@ class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
...
@@ -237,6 +237,6 @@ class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
bonus_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
generators
:
List
[
Optional
[
torch
.
Generator
]],
seeded_seqs
:
Optional
[
Dict
[
int
,
torch
.
Generator
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
raise
NotImplementedError
raise
NotImplementedError
vllm/model_executor/model_loader/loader.py
View file @
e661d594
...
@@ -7,6 +7,7 @@ import json
...
@@ -7,6 +7,7 @@ import json
import
math
import
math
import
os
import
os
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
contextmanager
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
,
Type
import
huggingface_hub
import
huggingface_hub
...
@@ -37,7 +38,49 @@ from vllm.model_executor.models.interfaces import (has_inner_state,
...
@@ -37,7 +38,49 @@ from vllm.model_executor.models.interfaces import (has_inner_state,
supports_vision
)
supports_vision
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_tpu
from
vllm.utils
import
is_pin_memory_available
,
is_tpu
@
contextmanager
def
device_loading_context
(
module
:
torch
.
nn
.
Module
,
target_device
:
torch
.
device
):
if
target_device
.
type
==
"cpu"
:
# If target is CPU, no need to move anything
yield
module
return
original_device_states
:
Dict
[
str
,
torch
.
device
]
=
{}
# Store original device states and move parameters to GPU if they're on CPU
for
name
,
p
in
module
.
named_parameters
():
if
p
.
device
.
type
==
"cpu"
:
original_device_states
[
name
]
=
p
.
device
p
.
data
=
p
.
data
.
to
(
target_device
)
# Parameters already on target device are not touched
try
:
yield
module
finally
:
# Restore parameters to their original devices, ignoring new parameters
pin_memory
=
is_pin_memory_available
()
for
name
,
p
in
module
.
named_parameters
():
if
name
in
original_device_states
:
original_device
:
torch
.
device
=
original_device_states
[
name
]
if
original_device
.
type
==
"cpu"
:
# `torch.empty_like` does not support `pin_memory` argument
cpu_data
=
torch
.
empty_strided
(
size
=
p
.
data
.
size
(),
stride
=
p
.
data
.
stride
(),
dtype
=
p
.
data
.
dtype
,
layout
=
p
.
data
.
layout
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
cpu_data
.
copy_
(
p
.
data
)
p
.
data
=
cpu_data
else
:
p
.
data
=
p
.
data
.
to
(
original_device
)
# New parameters or parameters already on target device are untouched
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -164,7 +207,7 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -164,7 +207,7 @@ class DefaultModelLoader(BaseModelLoader):
cache_dir
=
self
.
load_config
.
download_dir
,
cache_dir
=
self
.
load_config
.
download_dir
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
revision
=
revision
,
revision
=
revision
,
ignore_pattern
s
=
self
.
load_config
.
ignore_patterns
,
ignore_
file_
pattern
=
self
.
load_config
.
ignore_patterns
,
)
)
else
:
else
:
model_path
=
model
model_path
=
model
...
@@ -278,8 +321,9 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -278,8 +321,9 @@ class DefaultModelLoader(BaseModelLoader):
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
)
->
nn
.
Module
:
cache_config
:
CacheConfig
)
->
nn
.
Module
:
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
t
orch
.
device
(
device_config
.
device
)
:
with
t
arget_
device
:
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
lora_config
,
multimodal_config
,
lora_config
,
multimodal_config
,
cache_config
,
scheduler_config
)
cache_config
,
scheduler_config
)
...
@@ -294,7 +338,13 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -294,7 +338,13 @@ class DefaultModelLoader(BaseModelLoader):
for
_
,
module
in
model
.
named_modules
():
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
and
quant_method
!=
"awq"
and
quant_method
!=
"gptq"
:
if
quant_method
is
not
None
and
quant_method
!=
"awq"
and
quant_method
!=
"gptq"
:
quant_method
.
process_weights_after_loading
(
module
)
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with
device_loading_context
(
module
,
target_device
):
quant_method
.
process_weights_after_loading
(
module
)
return
model
.
eval
()
return
model
.
eval
()
...
@@ -705,8 +755,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -705,8 +755,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
return
hf_weights_files
,
matched_pattern
==
"*.safetensors"
return
hf_weights_files
,
matched_pattern
==
"*.safetensors"
def
_hf_weight_iter
(
self
,
hf_weights_files
,
use_safetensors
:
bool
):
if
use_safetensors
:
return
safetensors_weights_iterator
(
hf_weights_files
)
else
:
return
pt_weights_iterator
(
hf_weights_files
)
def
_get_quantized_weights_iterator
(
def
_get_quantized_weights_iterator
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
]
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
]
,
pre_quant
:
bool
)
->
Tuple
[
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
],
Dict
[
str
,
)
->
Tuple
[
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
],
Dict
[
str
,
Any
]]:
Any
]]:
"""Get an iterator to the model weights with bitsandbytes quantization,
"""Get an iterator to the model weights with bitsandbytes quantization,
...
@@ -715,6 +771,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -715,6 +771,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# only load the bitsandbytes module when needed
# only load the bitsandbytes module when needed
try
:
try
:
import
bitsandbytes
import
bitsandbytes
from
bitsandbytes.functional
import
QuantState
if
bitsandbytes
.
__version__
<
"0.42.0"
:
if
bitsandbytes
.
__version__
<
"0.42.0"
:
raise
ImportError
(
"bitsandbytes version is wrong. Please "
raise
ImportError
(
"bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.42.0."
)
"install bitsandbytes>=0.42.0."
)
...
@@ -728,17 +785,63 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -728,17 +785,63 @@ class BitsAndBytesModelLoader(BaseModelLoader):
model_name_or_path
,
revision
)
model_name_or_path
,
revision
)
quant_state_dict
=
{}
quant_state_dict
=
{}
if
use_safetensors
:
weight_iterator
=
safetensors_weights_iterator
(
hf_weights_files
)
else
:
weight_iterator
=
pt_weights_iterator
(
hf_weights_files
)
def
generator
():
def
quantized_checkpoint
()
->
Generator
:
# First iterate over all quant state weights
weight_iterator
=
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
)
temp_state_dict
=
{}
for
weight_name
,
weight_tensor
in
weight_iterator
:
for
weight_name
,
weight_tensor
in
weight_iterator
:
if
weight_name
.
endswith
(
".weight"
):
continue
# TODO: only nf4 quantization is supported for now
if
weight_name
.
endswith
(
".quant_state.bitsandbytes__fp4"
):
raise
NotImplementedError
(
"Only bitsandbytes_nf4 quantization"
f
"is supported for now.
{
weight_name
}
is fp4 quantized"
)
temp_state_dict
[
weight_name
]
=
weight_tensor
# Closure to parse quant_state for each prequant weight
def
_parse_quant_state
(
param_name
:
str
,
temp_state_dict
:
Dict
)
->
QuantState
:
quant_state
=
{}
for
k
in
temp_state_dict
:
if
param_name
+
"."
in
k
:
quant_state
[
k
]
=
temp_state_dict
[
k
]
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__nf4 in CPU
quant_state
[
param_name
+
".quant_state.bitsandbytes__nf4"
]
=
quant_state
[
param_name
+
".quant_state.bitsandbytes__nf4"
].
cpu
().
data
return
QuantState
.
from_dict
(
quant_state
,
device
=
"cuda"
)
# Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
# Filter out all weights whose suffix is not ".weight"
if
not
weight_name
.
endswith
(
".weight"
):
continue
if
weight_name
+
".quant_state.bitsandbytes__nf4"
\
in
temp_state_dict
:
quant_state
=
_parse_quant_state
(
weight_name
,
temp_state_dict
)
weight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
quant_state_dict
[
weight_name
]
=
quant_state
yield
weight_name
.
replace
(
".weight"
,
".qweight"
),
weight_tensor
else
:
yield
weight_name
,
weight_tensor
def
generator
()
->
Generator
:
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
if
any
(
target_module
in
weight_name
if
any
(
target_module
in
weight_name
for
target_module
in
self
.
target_modules
):
for
target_module
in
self
.
target_modules
):
weight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
weight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
#
bitsandbytes requires data in GPU
# bitsandbytes requires data in GPU
loaded_weight
=
weight_tensor
.
cuda
().
data
loaded_weight
=
weight_tensor
.
cuda
().
data
with
set_default_torch_dtype
(
torch
.
float32
):
with
set_default_torch_dtype
(
torch
.
float32
):
processed_weight
,
quant_state
=
quantize_4bit
(
processed_weight
,
quant_state
=
quantize_4bit
(
...
@@ -752,6 +855,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -752,6 +855,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
yield
weight_name
,
processed_weight
yield
weight_name
,
processed_weight
if
pre_quant
:
return
quantized_checkpoint
(),
quant_state_dict
return
generator
(),
quant_state_dict
return
generator
(),
quant_state_dict
def
_load_weights
(
self
,
model_config
:
ModelConfig
,
def
_load_weights
(
self
,
model_config
:
ModelConfig
,
...
@@ -769,12 +874,21 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -769,12 +874,21 @@ class BitsAndBytesModelLoader(BaseModelLoader):
logger
.
info
(
"Loading weights with BitsAndBytes quantization. "
logger
.
info
(
"Loading weights with BitsAndBytes quantization. "
" May take a while ..."
)
" May take a while ..."
)
qweight_iterator
,
quant_state_dict
=
(
is_quantized_checkpoint
=
False
self
.
_get_quantized_weights_iterator
(
model_config
.
model
,
quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
model_config
.
revision
))
None
)
if
quant_config
is
not
None
and
quant_config
.
get
(
'quant_method'
)
==
"bitsandbytes"
:
is_quantized_checkpoint
=
True
qweight_iterator
,
quant_state_dict
=
\
self
.
_get_quantized_weights_iterator
(
model_config
.
model
,
model_config
.
revision
,
is_quantized_checkpoint
)
model
.
load_weights
(
qweight_iterator
)
model
.
load_weights
(
qweight_iterator
)
torch
.
cuda
.
empty_cache
()
param_dict
=
dict
(
model
.
named_parameters
())
param_dict
=
dict
(
model
.
named_parameters
())
stacked_quant_state_dict
:
Dict
[
str
,
Dict
[
int
,
Any
]]
=
{}
stacked_quant_state_dict
:
Dict
[
str
,
Dict
[
int
,
Any
]]
=
{}
for
quant_param_name
in
quant_state_dict
:
for
quant_param_name
in
quant_state_dict
:
...
@@ -812,9 +926,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -812,9 +926,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
f
"pack_factor not set for parameter
{
param_name
}
."
)
f
"pack_factor not set for parameter
{
param_name
}
."
)
num_elements
=
[
0
]
*
len
(
quant_states
)
num_elements
=
[
0
]
*
len
(
quant_states
)
for
seq
,
quant_state
in
enumerate
(
quant_states
.
items
()
)
:
for
seq
,
quant_state
in
quant_states
.
items
():
num_elements
[
seq
]
=
math
.
prod
(
num_elements
[
seq
]
=
math
.
prod
(
quant_state
[
1
]
.
shape
)
//
pack_ratio
quant_state
.
shape
)
//
pack_ratio
offsets
=
np
.
concatenate
(([
0
],
np
.
cumsum
(
num_elements
)))
offsets
=
np
.
concatenate
(([
0
],
np
.
cumsum
(
num_elements
)))
set_weight_attrs
(
param
,
{
"bnb_shard_offsets"
:
offsets
})
set_weight_attrs
(
param
,
{
"bnb_shard_offsets"
:
offsets
})
...
...
vllm/model_executor/model_loader/weight_utils.py
View file @
e661d594
...
@@ -22,6 +22,7 @@ from vllm.logger import init_logger
...
@@ -22,6 +22,7 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.quantization
import
(
QuantizationConfig
,
from
vllm.model_executor.layers.quantization
import
(
QuantizationConfig
,
get_quantization_config
)
get_quantization_config
)
from
vllm.model_executor.layers.quantization.schema
import
QuantParamSchema
from
vllm.model_executor.layers.quantization.schema
import
QuantParamSchema
from
vllm.platforms
import
current_platform
from
vllm.utils
import
print_warning_once
from
vllm.utils
import
print_warning_once
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -118,6 +119,7 @@ def convert_bin_to_safetensor_file(
...
@@ -118,6 +119,7 @@ def convert_bin_to_safetensor_file(
# TODO(woosuk): Move this to other place.
# TODO(woosuk): Move this to other place.
def
get_quant_config
(
model_config
:
ModelConfig
,
def
get_quant_config
(
model_config
:
ModelConfig
,
load_config
:
LoadConfig
)
->
QuantizationConfig
:
load_config
:
LoadConfig
)
->
QuantizationConfig
:
quant_cls
=
get_quantization_config
(
model_config
.
quantization
)
quant_cls
=
get_quantization_config
(
model_config
.
quantization
)
# Read the quantization config from the HF model config, if available.
# Read the quantization config from the HF model config, if available.
hf_quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
hf_quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
...
@@ -489,6 +491,11 @@ def initialize_dummy_weights(
...
@@ -489,6 +491,11 @@ def initialize_dummy_weights(
"""
"""
for
param
in
model
.
state_dict
().
values
():
for
param
in
model
.
state_dict
().
values
():
if
torch
.
is_floating_point
(
param
):
if
torch
.
is_floating_point
(
param
):
if
current_platform
.
is_tpu
():
# XLA device does not support torch.Generator()
param
.
uniform_
(
low
,
high
)
continue
generator
=
torch
.
Generator
(
device
=
param
.
data
.
device
)
generator
=
torch
.
Generator
(
device
=
param
.
data
.
device
)
generator
.
manual_seed
(
seed
)
generator
.
manual_seed
(
seed
)
if
torch
.
finfo
(
param
.
data
.
dtype
).
bits
<
16
:
if
torch
.
finfo
(
param
.
data
.
dtype
).
bits
<
16
:
...
...
Prev
1
…
11
12
13
14
15
16
17
18
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment