Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
705f6a35
Commit
705f6a35
authored
Jul 16, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.2' into v0.5.2-dtk24.04.1
parents
af837396
4cf256ae
Changes
439
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
1806 additions
and
920 deletions
+1806
-920
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
...on/compressed_tensors/schemes/compressed_tensors_wNa16.py
+165
-0
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
..._executor/layers/quantization/compressed_tensors/utils.py
+9
-0
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+240
-151
vllm/model_executor/layers/quantization/gptq.py
vllm/model_executor/layers/quantization/gptq.py
+10
-3
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+96
-236
vllm/model_executor/layers/quantization/marlin.py
vllm/model_executor/layers/quantization/marlin.py
+10
-3
vllm/model_executor/layers/quantization/squeezellm.py
vllm/model_executor/layers/quantization/squeezellm.py
+2
-1
vllm/model_executor/layers/quantization/utils/marlin_24_perms.py
...del_executor/layers/quantization/utils/marlin_24_perms.py
+0
-58
vllm/model_executor/layers/quantization/utils/marlin_perms.py
.../model_executor/layers/quantization/utils/marlin_perms.py
+0
-58
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+167
-210
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
...el_executor/layers/quantization/utils/marlin_utils_fp8.py
+109
-0
vllm/model_executor/layers/quantization/utils/marlin_utils_test.py
...l_executor/layers/quantization/utils/marlin_utils_test.py
+120
-0
vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py
...xecutor/layers/quantization/utils/marlin_utils_test_24.py
+160
-3
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+162
-0
vllm/model_executor/layers/rejection_sampler.py
vllm/model_executor/layers/rejection_sampler.py
+17
-171
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+120
-13
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+16
-13
vllm/model_executor/layers/spec_decode_base_sampler.py
vllm/model_executor/layers/spec_decode_base_sampler.py
+219
-0
vllm/model_executor/layers/typical_acceptance_sampler.py
vllm/model_executor/layers/typical_acceptance_sampler.py
+184
-0
No files found.
Too many changes to show.
To preserve performance only
439 of 439+
files are displayed.
Plain diff
Email patch
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
0 → 100644
View file @
705f6a35
from
typing
import
Callable
,
List
,
Optional
import
torch
from
torch.nn
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
apply_marlin_linear
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
replace_tensor
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.utils
import
set_weight_attrs
__all__
=
[
"CompressedTensorsWNA16"
]
WNA16_SUPPORTED_BITS
=
[
4
,
8
]
class
CompressedTensorsWNA16
(
CompressedTensorsScheme
):
def
__init__
(
self
,
strategy
:
str
,
num_bits
:
int
,
group_size
:
Optional
[
int
]
=
None
):
self
.
num_bits
=
num_bits
self
.
pack_factor
=
32
//
self
.
num_bits
self
.
strategy
=
strategy
self
.
group_size
:
int
if
group_size
is
None
:
if
self
.
strategy
!=
"channel"
:
raise
ValueError
(
"Marlin kernels require group quantization or "
"channelwise quantization, but found no group "
"size and strategy is not channelwise."
)
self
.
group_size
=
-
1
else
:
self
.
group_size
=
group_size
# Verify supported on platform.
verify_marlin_supported
(
num_bits
=
self
.
num_bits
,
group_size
=
self
.
group_size
,
is_sym
=
True
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size
:
int
,
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
output_size_per_partition
=
sum
(
output_partition_sizes
)
# If group_size is -1, we are in channelwise case.
group_size
=
input_size
if
self
.
group_size
==
-
1
else
self
.
group_size
verify_marlin_supports_shape
(
output_size_per_partition
=
output_size_per_partition
,
input_size_per_partition
=
input_size_per_partition
,
input_size
=
input_size
,
group_size
=
group_size
)
weight_scale_dim
=
None
scales_and_zp_size
=
input_size
//
group_size
if
(
input_size
!=
input_size_per_partition
and
self
.
group_size
is
not
None
):
weight_scale_dim
=
1
scales_and_zp_size
=
input_size_per_partition
//
group_size
weight
=
Parameter
(
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
//
self
.
pack_factor
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
pack_factor
,
"weight_loader"
:
weight_loader
})
layer
.
register_parameter
(
"weight_packed"
,
weight
)
weight_scale
=
Parameter
(
torch
.
empty
(
output_size_per_partition
,
scales_and_zp_size
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
weight_scale
,
{
"weight_loader"
:
weight_loader
,
"input_dim"
:
weight_scale_dim
,
"output_dim"
:
0
})
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
# A 2D array defining the original shape of the weights
# before packing
weight_shape
=
Parameter
(
torch
.
empty
(
2
,
dtype
=
torch
.
int64
),
requires_grad
=
False
)
layer
.
register_parameter
(
"weight_shape"
,
weight_shape
)
set_weight_attrs
(
weight_shape
,
{
"weight_loader"
:
weight_loader
,
"ignore_warning"
:
True
,
})
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
input_size
=
input_size
layer
.
group_size
=
group_size
# Checkpoints are serialized in compressed-tensors format, which is
# different from marlin format. Handle repacking here.
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
device
=
layer
.
weight_packed
.
device
# Allocate marlin workspace.
layer
.
workspace
=
marlin_make_workspace
(
layer
.
output_size_per_partition
,
device
)
# Act-order not supported in compressed-tensors yet, so set to empty.
layer
.
g_idx
=
marlin_make_empty_g_idx
(
device
)
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
device
)
# Repack weights from compressed-tensors format to marlin format.
marlin_qweight
=
ops
.
gptq_marlin_repack
(
layer
.
weight_packed
.
t
().
contiguous
(),
perm
=
layer
.
g_idx_sort_indices
,
size_k
=
layer
.
input_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
num_bits
=
self
.
num_bits
)
replace_tensor
(
layer
,
"weight_packed"
,
marlin_qweight
)
# Permute scales from compressed-tensors format to marlin format.
marlin_scales
=
marlin_permute_scales
(
layer
.
weight_scale
.
squeeze
().
t
().
contiguous
(),
size_k
=
layer
.
input_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
group_size
=
layer
.
group_size
)
replace_tensor
(
layer
,
"weight_scale"
,
marlin_scales
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
apply_marlin_linear
(
input
=
x
,
weight
=
layer
.
weight_packed
,
weight_scale
=
layer
.
weight_scale
,
g_idx
=
layer
.
g_idx
,
g_idx_sort_indices
=
layer
.
g_idx_sort_indices
,
workspace
=
layer
.
workspace
,
num_bits
=
self
.
num_bits
,
output_size_per_partition
=
layer
.
output_size_per_partition
,
input_size_per_partition
=
layer
.
input_size_per_partition
,
is_k_full
=
True
,
bias
=
bias
)
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
View file @
705f6a35
...
@@ -6,6 +6,15 @@ from pydantic import BaseModel, Field
...
@@ -6,6 +6,15 @@ from pydantic import BaseModel, Field
from
torch.nn
import
Module
from
torch.nn
import
Module
class
CompressionFormat
(
Enum
):
dense
=
"dense"
sparse_bitmask
=
"sparse-bitmask"
float_quantized
=
"float-quantized"
int_quantized
=
"int-quantized"
pack_quantized
=
"pack-quantized"
marlin_24
=
"marlin-24"
class
QuantizationType
(
str
,
Enum
):
class
QuantizationType
(
str
,
Enum
):
"""
"""
Enum storing quantization type options
Enum storing quantization type options
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
705f6a35
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
import
torch
from
torch.nn
import
Module
from
torch.nn
import
Module
...
@@ -6,10 +6,18 @@ from torch.nn.parameter import Parameter
...
@@ -6,10 +6,18 @@ from torch.nn.parameter import Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
fused_moe
)
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
all_close_1d
,
apply_fp8_linear
,
create_per_tensor_scale_param
,
cutlass_fp8_supported
,
per_tensor_dequantize
,
requantize_with_max_scale
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.utils
import
print_warning_once
from
vllm.utils
import
print_warning_once
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
...
@@ -17,24 +25,6 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
...
@@ -17,24 +25,6 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
def
cutlass_fp8_supported
()
->
bool
:
capability
=
torch
.
cuda
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
major
,
minor
=
torch
.
version
.
cuda
.
split
(
"."
)
version
=
int
(
major
)
*
10
+
int
(
minor
)
# CUTLASS FP8 kernels need at least
# CUDA 12.0 on SM90 systems (Hopper)
# CUDA 12.4 on SM89 systems (Lovelace)
gpu_is_supported
=
False
if
capability
>=
90
:
gpu_is_supported
=
version
>
120
elif
capability
>=
89
:
gpu_is_supported
=
version
>
124
return
gpu_is_supported
class
Fp8Config
(
QuantizationConfig
):
class
Fp8Config
(
QuantizationConfig
):
"""Config class for FP8."""
"""Config class for FP8."""
...
@@ -62,7 +52,7 @@ class Fp8Config(QuantizationConfig):
...
@@ -62,7 +52,7 @@ class Fp8Config(QuantizationConfig):
@
classmethod
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
def
get_min_capability
(
cls
)
->
int
:
return
8
9
return
8
0
@
classmethod
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
def
get_config_filenames
(
cls
)
->
List
[
str
]:
...
@@ -82,7 +72,9 @@ class Fp8Config(QuantizationConfig):
...
@@ -82,7 +72,9 @@ class Fp8Config(QuantizationConfig):
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
return
Fp8LinearMethod
(
self
)
return
Fp8LinearMethod
(
self
)
if
isinstance
(
layer
,
Attention
):
elif
isinstance
(
layer
,
FusedMoE
):
return
Fp8MoEMethod
(
self
)
elif
isinstance
(
layer
,
Attention
):
return
Fp8KVCacheMethod
(
self
)
return
Fp8KVCacheMethod
(
self
)
return
None
return
None
...
@@ -112,23 +104,11 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -112,23 +104,11 @@ class Fp8LinearMethod(LinearMethodBase):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
def
_create_scale_param
(
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
self
,
# kernel for fast weight-only FP8 quantization
scale_name
:
str
,
capability
=
current_platform
.
get_device_capability
()
layer
:
torch
.
nn
.
Module
,
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
output_partition_sizes
:
List
[
int
],
self
.
use_marlin
=
capability
<
89
**
extra_weight_attrs
,
)
->
None
:
scale
=
Parameter
(
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
scale_name
,
scale
)
set_weight_attrs
(
scale
,
{
**
extra_weight_attrs
,
"fp8_scales_shard_indexer"
:
self
.
scales_shard_indexer
,
})
def
create_weights
(
def
create_weights
(
self
,
self
,
...
@@ -143,9 +123,12 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -143,9 +123,12 @@ class Fp8LinearMethod(LinearMethodBase):
del
input_size
,
output_size
del
input_size
,
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
layer
.
process_after_load
=
True
layer
.
logical_widths
=
output_partition_sizes
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
orig_dtype
=
params_dtype
# WEIGHT
# WEIGHT
weight_dtype
=
(
torch
.
float8_e4m3fn
weight_dtype
=
(
torch
.
float8_e4m3fn
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
else
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
else
...
@@ -165,129 +148,255 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -165,129 +148,255 @@ class Fp8LinearMethod(LinearMethodBase):
# Otherwise, wait until process_weights_after_loading.
# Otherwise, wait until process_weights_after_loading.
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# WEIGHT SCALE
# WEIGHT SCALE
self
.
_create_scale_param
(
scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
scale_name
=
"weight_scale"
,
**
extra_weight_attrs
)
layer
=
layer
,
layer
.
register_parameter
(
"weight_scale"
,
scale
)
output_partition_sizes
=
output_partition_sizes
,
**
extra_weight_attrs
)
# INPUT ACTIVATION SCALE
# INPUT ACTIVATION SCALE
if
self
.
quant_config
.
activation_scheme
==
"static"
:
if
self
.
quant_config
.
activation_scheme
==
"static"
:
self
.
_create_scale_param
(
scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
scale_name
=
"input_scale"
,
**
extra_weight_attrs
)
layer
=
layer
,
layer
.
register_parameter
(
"input_scale"
,
scale
)
output_partition_sizes
=
output_partition_sizes
,
**
extra_weight_attrs
)
def
scales_shard_indexer
(
self
,
param
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
shard_id
:
Union
[
str
,
int
])
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
qkv_idxs
=
{
"q"
:
0
,
"k"
:
1
,
"v"
:
2
}
if
isinstance
(
shard_id
,
int
):
pass
elif
isinstance
(
shard_id
,
str
):
if
shard_id
not
in
qkv_idxs
:
raise
ValueError
(
f
"Unknown shard_id:
{
shard_id
}
"
)
shard_id
=
qkv_idxs
[
shard_id
]
else
:
ValueError
(
f
"Shard id must be int or str but got
{
type
(
shard_id
)
}
"
)
return
param
[
shard_id
],
loaded_weight
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
(
not
hasattr
(
layer
,
"process_after_load"
)
# If checkpoint not serialized fp8, quantize the weights.
or
not
layer
.
process_after_load
):
return
# If checkpoint is fp/bf16 (not serialized fp8), quantize the weights.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
qweight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
layer
.
weight
,
qweight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
layer
.
weight
,
scale
=
None
)
scale
=
None
)
# Update the layer with the new values.
layer
.
weight
=
Parameter
(
qweight
.
t
(),
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
qweight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
logical_widths
=
None
layer
.
input_scale
=
None
layer
.
input_scale
=
None
return
# If checkpoint is fp8, requantize the separately quantized logical
# If checkpoint is fp8, requantize the separately quantized logical
# weights into a single fp8 weight with a single weight scale.
# weights into a single fp8 weight with a single weight scale.
else
:
else
:
# WEIGHT_SCALE / WEIGHT
# Dequant -> Quant with max scale.
# Loop over logical weights, requantizing with single scale.
max_w_scale
,
weight
=
requantize_with_max_scale
(
max_w_scale
=
layer
.
weight_scale
.
max
()
weight
=
layer
.
weight
,
start
=
0
weight_scale
=
layer
.
weight_scale
,
for
idx
,
logical_width
in
enumerate
(
layer
.
logical_widths
):
logical_widths
=
layer
.
logical_widths
,
end
=
start
+
logical_width
)
weight_dq
=
per_tensor_dequantize
(
layer
.
weight
[
start
:
end
,
:],
layer
.
weight_scale
[
idx
])
layer
.
weight
[
start
:
end
,
:]
=
per_tensor_quantize
(
weight_dq
,
layer
.
weight_scale
.
max
())
start
=
end
layer
.
weight_scale
=
Parameter
(
max_w_scale
,
requires_grad
=
False
)
# WEIGHT
# Update layer with new values.
# Transpose weight for passing to torch._scaled_mm
weight
=
layer
.
weight
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
max_w_scale
,
requires_grad
=
False
)
# INPUT ACTIVATION SCALE
if
self
.
quant_config
.
activation_scheme
==
"static"
:
# Dynamic: set to None (required input to ops.scaled_fp8_quant).
# Static: set to max of the input_scales (since they are equal).
if
self
.
quant_config
.
activation_scheme
==
"dynamic"
:
layer
.
input_scale
=
None
elif
self
.
quant_config
.
activation_scheme
==
"static"
:
if
not
all_close_1d
(
layer
.
input_scale
):
raise
ValueError
(
"All the input_scales for the logical weights of a "
f
"layer must be equal. But got
{
layer
.
input_scale
}
"
)
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
requires_grad
=
False
)
else
:
else
:
raise
ValueError
(
layer
.
input_scale
=
None
f
"Unknown scheme
{
self
.
quant_config
.
activation_scheme
}
"
)
if
self
.
use_marlin
:
prepare_fp8_layer_for_marlin
(
layer
)
# Activations not quantized for marlin.
del
layer
.
input_scale
def
apply
(
self
,
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
# ops.scaled_fp8_quant supports both dynamic and static quant.
if
self
.
use_marlin
:
# If dynamic, layer.input_scale is None and x_scale computed from x.
return
apply_fp8_marlin_linear
(
# If static, layer.input_scale is scalar and x_scale is input_scale.
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
workspace
=
layer
.
workspace
,
size_n
=
layer
.
output_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
bias
=
bias
)
return
apply_fp8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
,
cutlass_fp8_supported
=
self
.
cutlass_fp8_supported
)
class
Fp8MoEMethod
(
FusedMoEMethodBase
):
"""MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Args:
quant_config: The quantization config.
"""
if
bias
is
None
and
self
.
cutlass_fp8_supported
:
def
__init__
(
self
,
quant_config
:
Fp8Config
)
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
x
,
layer
.
input_scale
)
self
.
quant_config
=
quant_config
# Fused GEMM_DQ
def
create_weights
(
self
,
layer
:
Module
,
num_experts
:
int
,
hidden_size
:
int
,
output
=
ops
.
cutlass_scaled_mm_dq
(
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
qinput
,
**
extra_weight_attrs
):
layer
.
weight
,
out_dtype
=
x
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
)
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
params_dtype
=
torch
.
float8_e4m3fn
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_scale"
,
w13_scale
)
w2_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_scale"
,
w2_scale
)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
set_weight_attrs
(
w13_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_scale
,
extra_weight_attrs
)
# INPUT_SCALES
if
self
.
quant_config
.
activation_scheme
==
"static"
:
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
raise
ValueError
(
"Found static activation scheme for checkpoint that "
"was not serialized fp8."
)
a13_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"a13_scale"
,
a13_scale
)
set_weight_attrs
(
a13_scale
,
extra_weight_attrs
)
a2_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"a2_scale"
,
a2_scale
)
set_weight_attrs
(
a2_scale
,
extra_weight_attrs
)
else
:
else
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
x
,
layer
.
a13_scale
=
None
layer
.
input_scale
,
layer
.
a2_scale
=
None
batch_dim_padding
=
17
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# Fused GEMM_DQ -- note we padded the input above because
# torch._scaled_mm is more performant for matrices with
# If checkpoint is fp16, quantize in place.
# batch dimension > 16. Note that this could change
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# in the future.
w13_weight
=
torch
.
empty_like
(
layer
.
w13_weight
.
data
,
output
,
_
=
torch
.
_scaled_mm
(
dtype
=
torch
.
float8_e4m3fn
)
qinput
,
w2_weight
=
torch
.
empty_like
(
layer
.
w2_weight
.
data
,
layer
.
weight
,
dtype
=
torch
.
float8_e4m3fn
)
out_dtype
=
x
.
dtype
,
scale_a
=
x_scale
,
# Re-initialize w13_scale because we directly quantize
scale_b
=
layer
.
weight_scale
,
# merged w13 weights and generate a single scaling factor.
bias
=
bias
,
layer
.
w13_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
)
layer
.
num_experts
,
dtype
=
torch
.
float32
,
device
=
w13_weight
.
device
),
requires_grad
=
False
)
for
expert
in
range
(
layer
.
num_experts
):
w13_weight
[
expert
,
:,
:],
layer
.
w13_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
layer
.
w13_weight
.
data
[
expert
,
:,
:])
w2_weight
[
expert
,
:,
:],
layer
.
w2_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
layer
.
w2_weight
.
data
[
expert
,
:,
:])
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
return
# If checkpoint is fp8, we need to handle that the
# MoE kernels require single activation scale and single weight
# scale for w13 per expert.
else
:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if
self
.
quant_config
.
activation_scheme
==
"static"
:
if
layer
.
a13_scale
is
None
or
layer
.
a2_scale
is
None
:
raise
ValueError
(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if
(
not
all_close_1d
(
layer
.
a13_scale
)
or
not
all_close_1d
(
layer
.
a2_scale
)):
print_warning_once
(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer. "
)
layer
.
a13_scale
=
torch
.
nn
.
Parameter
(
layer
.
a13_scale
.
max
(),
requires_grad
=
False
)
layer
.
a2_scale
=
torch
.
nn
.
Parameter
(
layer
.
a2_scale
.
max
(),
requires_grad
=
False
)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert
layer
.
w13_scale
is
not
None
shard_size
=
layer
.
intermediate_size_per_partition
max_w13_scales
=
layer
.
w13_scale
.
max
(
dim
=
1
).
values
for
expert_id
in
range
(
layer
.
num_experts
):
start
=
0
for
shard_id
in
range
(
2
):
dq_weight
=
per_tensor_dequantize
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
layer
.
w13_scale
[
expert_id
][
shard_id
])
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
_
=
ops
.
scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
start
+=
shard_size
layer
.
w13_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
return
return
torch
.
narrow
(
output
,
0
,
0
,
x
.
shape
[
0
])
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
fused_moe
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
router_logits
,
top_k
,
renormalize
=
renormalize
,
inplace
=
True
,
use_fp8
=
True
,
w1_scale
=
layer
.
w13_scale
,
w2_scale
=
layer
.
w2_scale
,
a1_scale
=
layer
.
a13_scale
,
a2_scale
=
layer
.
a2_scale
,
use_grouped_topk
=
use_grouped_topk
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
)
class
Fp8KVCacheMethod
(
QuantizeMethodBase
):
class
Fp8KVCacheMethod
(
QuantizeMethodBase
):
...
@@ -326,23 +435,3 @@ class Fp8KVCacheMethod(QuantizeMethodBase):
...
@@ -326,23 +435,3 @@ class Fp8KVCacheMethod(QuantizeMethodBase):
"cause accuracy issues. Please make sure kv-cache scaling "
"cause accuracy issues. Please make sure kv-cache scaling "
"factor is available in the fp8 checkpoint."
)
"factor is available in the fp8 checkpoint."
)
del
layer
.
kv_scale
del
layer
.
kv_scale
def
all_close_1d
(
x
:
torch
.
Tensor
)
->
bool
:
assert
len
(
x
.
shape
)
==
1
return
all
(
torch
.
allclose
(
x
[
0
],
x
[
i
])
for
i
in
range
(
x
.
shape
[
0
]))
def
per_tensor_quantize
(
tensor
:
torch
.
Tensor
,
inv_scale
:
Union
[
float
,
torch
.
Tensor
])
->
torch
.
Tensor
:
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
qweight
=
(
tensor
/
inv_scale
).
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)
return
qweight
.
to
(
torch
.
float8_e4m3fn
)
def
per_tensor_dequantize
(
tensor
:
torch
.
Tensor
,
inv_scale
:
Union
[
float
,
torch
.
Tensor
])
->
torch
.
Tensor
:
fake_qweight
=
tensor
.
to
(
torch
.
float16
)
dq_weight
=
fake_qweight
*
inv_scale
return
dq_weight
vllm/model_executor/layers/quantization/gptq.py
View file @
705f6a35
...
@@ -10,6 +10,7 @@ from vllm import _custom_ops as ops
...
@@ -10,6 +10,7 @@ from vllm import _custom_ops as ops
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
...
@@ -24,10 +25,12 @@ class GPTQConfig(QuantizationConfig):
...
@@ -24,10 +25,12 @@ class GPTQConfig(QuantizationConfig):
weight_bits
:
int
,
weight_bits
:
int
,
group_size
:
int
,
group_size
:
int
,
desc_act
:
bool
,
desc_act
:
bool
,
lm_head_quantized
:
bool
,
)
->
None
:
)
->
None
:
self
.
weight_bits
=
weight_bits
self
.
weight_bits
=
weight_bits
self
.
group_size
=
group_size
self
.
group_size
=
group_size
self
.
desc_act
=
desc_act
self
.
desc_act
=
desc_act
self
.
lm_head_quantized
=
lm_head_quantized
self
.
pack_factor
=
Fraction
(
32
,
self
.
weight_bits
)
self
.
pack_factor
=
Fraction
(
32
,
self
.
weight_bits
)
if
self
.
weight_bits
not
in
[
2
,
3
,
4
,
8
]:
if
self
.
weight_bits
not
in
[
2
,
3
,
4
,
8
]:
raise
ValueError
(
raise
ValueError
(
...
@@ -37,7 +40,8 @@ class GPTQConfig(QuantizationConfig):
...
@@ -37,7 +40,8 @@ class GPTQConfig(QuantizationConfig):
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"GPTQConfig(weight_bits=
{
self
.
weight_bits
}
, "
return
(
f
"GPTQConfig(weight_bits=
{
self
.
weight_bits
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"desc_act=
{
self
.
desc_act
}
)"
)
f
"desc_act=
{
self
.
desc_act
}
),"
f
"lm_head_quantized=
{
self
.
lm_head_quantized
}
"
)
@
classmethod
@
classmethod
def
get_name
(
cls
)
->
str
:
def
get_name
(
cls
)
->
str
:
...
@@ -61,11 +65,14 @@ class GPTQConfig(QuantizationConfig):
...
@@ -61,11 +65,14 @@ class GPTQConfig(QuantizationConfig):
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"bits"
])
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"bits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
desc_act
=
cls
.
get_from_keys
(
config
,
[
"desc_act"
])
desc_act
=
cls
.
get_from_keys
(
config
,
[
"desc_act"
])
return
cls
(
weight_bits
,
group_size
,
desc_act
)
lm_head_quantized
=
cls
.
get_from_keys_or
(
config
,
[
"lm_head"
],
default
=
False
)
return
cls
(
weight_bits
,
group_size
,
desc_act
,
lm_head_quantized
)
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"GPTQLinearMethod"
]:
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"GPTQLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
if
(
isinstance
(
layer
,
LinearBase
)
or
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
)):
return
GPTQLinearMethod
(
self
)
return
GPTQLinearMethod
(
self
)
return
None
return
None
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
705f6a35
import
enum
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
import
torch
...
@@ -11,90 +9,43 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
...
@@ -11,90 +9,43 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs
)
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
apply_marlin_linear
,
check_marlin_supported
,
marlin_is_k_full
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
marlin_repeat_scales_on_all_ranks
,
marlin_sort_g_idx
,
replace_tensor
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
GPTQ_MARLIN_TILE
=
16
GPTQ_MARLIN_MIN_THREAD_N
=
64
GPTQ_MARLIN_MIN_THREAD_K
=
128
GPTQ_MARLIN_MAX_PARALLEL
=
16
GPTQ_MARLIN_SUPPORTED_NUM_BITS
=
[
4
,
8
]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
GPTQ_MARLIN_SUPPORTED_SYM
=
[
True
]
# Permutations for Marlin scale shuffling
def
get_scale_perms
(
num_bits
):
scale_perm
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
scale_perm
,
scale_perm_single
def
get_pack_factor
(
num_bits
):
assert
(
num_bits
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
),
f
"Unsupported num_bits =
{
num_bits
}
"
return
32
//
num_bits
def
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
,
num_bits
):
scale_perm
,
scale_perm_single
=
get_scale_perms
(
num_bits
)
if
group_size
<
size_k
and
group_size
!=
-
1
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
else
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
s
class
GPTQMarlinConfig
(
QuantizationConfig
):
class
GPTQMarlinConfig
(
QuantizationConfig
):
"""Config class for GPTQ Marlin"""
"""Config class for GPTQ Marlin"""
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
desc_act
:
bool
,
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
desc_act
:
bool
,
is_sym
:
bool
)
->
None
:
is_sym
:
bool
,
lm_head_quantized
:
bool
)
->
None
:
if
desc_act
and
group_size
==
-
1
:
if
desc_act
and
group_size
==
-
1
:
# In this case, act_order == True is the same as act_order == False
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
# (since we have only one group per output channel)
desc_act
=
False
desc_act
=
False
self
.
weight_bits
=
weight_bits
self
.
weight_bits
=
weight_bits
self
.
pack_factor
=
32
//
self
.
weight_bits
# packed into int32
self
.
group_size
=
group_size
self
.
group_size
=
group_size
self
.
desc_act
=
desc_act
self
.
desc_act
=
desc_act
self
.
is_sym
=
is_sym
self
.
is_sym
=
is_sym
self
.
lm_head_quantized
=
lm_head_quantized
# Verify
# Verify supported on platform.
if
self
.
weight_bits
not
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
:
verify_marlin_supported
(
num_bits
=
self
.
weight_bits
,
raise
ValueError
(
group_size
=
self
.
group_size
,
f
"Marlin does not support weight_bits =
{
self
.
weight_bits
}
. "
is_sym
=
self
.
is_sym
)
f
"Only weight_bits =
{
GPTQ_MARLIN_SUPPORTED_NUM_BITS
}
"
"are supported."
)
if
self
.
group_size
not
in
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
:
raise
ValueError
(
f
"Marlin does not support group_size =
{
self
.
group_size
}
. "
f
"Only group_sizes =
{
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
}
"
"are supported."
)
if
self
.
is_sym
not
in
GPTQ_MARLIN_SUPPORTED_SYM
:
raise
ValueError
(
f
"Marlin does not support is_sym =
{
self
.
is_sym
}
. "
f
"Only sym =
{
GPTQ_MARLIN_SUPPORTED_SYM
}
are supported."
)
# Init
self
.
pack_factor
=
get_pack_factor
(
weight_bits
)
self
.
tile_size
=
GPTQ_MARLIN_TILE
self
.
min_thread_n
=
GPTQ_MARLIN_MIN_THREAD_N
self
.
min_thread_k
=
GPTQ_MARLIN_MIN_THREAD_K
self
.
max_parallel
=
GPTQ_MARLIN_MAX_PARALLEL
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"GPTQMarlinConfig(weight_bits=
{
self
.
weight_bits
}
, "
return
(
f
"GPTQMarlinConfig(weight_bits=
{
self
.
weight_bits
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"desc_act=
{
self
.
desc_act
}
)"
)
f
"desc_act=
{
self
.
desc_act
}
, "
f
"lm_head_quantized=
{
self
.
lm_head_quantized
}
)"
)
@
classmethod
@
classmethod
def
get_name
(
cls
)
->
str
:
def
get_name
(
cls
)
->
str
:
...
@@ -118,7 +69,10 @@ class GPTQMarlinConfig(QuantizationConfig):
...
@@ -118,7 +69,10 @@ class GPTQMarlinConfig(QuantizationConfig):
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
desc_act
=
cls
.
get_from_keys
(
config
,
[
"desc_act"
])
desc_act
=
cls
.
get_from_keys
(
config
,
[
"desc_act"
])
is_sym
=
cls
.
get_from_keys
(
config
,
[
"sym"
])
is_sym
=
cls
.
get_from_keys
(
config
,
[
"sym"
])
return
cls
(
weight_bits
,
group_size
,
desc_act
,
is_sym
)
lm_head_quantized
=
cls
.
get_from_keys_or
(
config
,
[
"lm_head"
],
default
=
False
)
return
cls
(
weight_bits
,
group_size
,
desc_act
,
is_sym
,
lm_head_quantized
)
@
classmethod
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
...
@@ -143,7 +97,8 @@ class GPTQMarlinConfig(QuantizationConfig):
...
@@ -143,7 +97,8 @@ class GPTQMarlinConfig(QuantizationConfig):
def
get_quant_method
(
def
get_quant_method
(
self
,
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"GPTQMarlinLinearMethod"
]:
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"GPTQMarlinLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
if
(
isinstance
(
layer
,
LinearBase
)
or
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
)):
return
GPTQMarlinLinearMethod
(
self
)
return
GPTQMarlinLinearMethod
(
self
)
return
None
return
None
...
@@ -163,21 +118,10 @@ class GPTQMarlinConfig(QuantizationConfig):
...
@@ -163,21 +118,10 @@ class GPTQMarlinConfig(QuantizationConfig):
or
desc_act
is
None
):
or
desc_act
is
None
):
return
False
return
False
# If the capability of the device is too low, cannot convert.
return
check_marlin_supported
(
num_bits
=
num_bits
,
major
,
minor
=
torch
.
cuda
.
get_device_capability
()
group_size
=
group_size
,
device_capability
=
major
*
10
+
minor
is_sym
=
sym
,
if
device_capability
<
cls
.
get_min_capability
():
min_capability
=
cls
.
get_min_capability
())
return
False
# Otherwise, can convert if model satisfies marlin constraints.
return
(
num_bits
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
and
group_size
in
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
and
sym
in
GPTQ_MARLIN_SUPPORTED_SYM
)
class
GPTQMarlinState
(
Enum
):
REPACK
=
enum
.
auto
()
READY
=
enum
.
auto
()
class
GPTQMarlinLinearMethod
(
LinearMethodBase
):
class
GPTQMarlinLinearMethod
(
LinearMethodBase
):
...
@@ -201,6 +145,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -201,6 +145,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
**
extra_weight_attrs
,
**
extra_weight_attrs
,
)
->
None
:
)
->
None
:
del
output_size
del
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
is_row_parallel
=
input_size
!=
input_size_per_partition
# Normalize group_size
# Normalize group_size
if
self
.
quant_config
.
group_size
!=
-
1
:
if
self
.
quant_config
.
group_size
!=
-
1
:
...
@@ -208,58 +154,25 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -208,58 +154,25 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
else
:
else
:
group_size
=
input_size
group_size
=
input_size
# Validate dtype
verify_marlin_supports_shape
(
if
params_dtype
not
in
[
torch
.
float16
,
torch
.
bfloat16
]:
output_size_per_partition
=
output_size_per_partition
,
raise
ValueError
(
f
"The params dtype must be float16 "
input_size_per_partition
=
input_size_per_partition
,
f
"or bfloat16, but got
{
params_dtype
}
"
)
input_size
=
input_size
,
group_size
=
group_size
)
# Validate output_size_per_partition
output_size_per_partition
=
sum
(
output_partition_sizes
)
# Determine sharding
if
output_size_per_partition
%
self
.
quant_config
.
min_thread_n
!=
0
:
if
marlin_repeat_scales_on_all_ranks
(
self
.
quant_config
.
desc_act
,
raise
ValueError
(
self
.
quant_config
.
group_size
,
f
"Weight output_size_per_partition = "
is_row_parallel
):
f
"
{
output_size_per_partition
}
is not divisible by "
# By setting scale_dim == None, weight_loader will
f
" min_thread_n =
{
self
.
quant_config
.
min_thread_n
}
."
)
# repeat the scales on each GPU in TP>1 case.
scales_and_zp_input_dim
=
None
# Validate input_size_per_partition
scales_and_zp_size
=
input_size
//
group_size
if
input_size_per_partition
%
self
.
quant_config
.
min_thread_k
!=
0
:
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible "
f
"by min_thread_k =
{
self
.
quant_config
.
min_thread_k
}
."
)
if
(
group_size
<
input_size
and
input_size_per_partition
%
group_size
!=
0
):
raise
ValueError
(
f
"Weight input_size_per_partition =
{
input_size_per_partition
}
"
f
" is not divisible by group_size =
{
group_size
}
."
)
# Detect sharding of scales/zp
# By default, no sharding over "input dim"
scales_and_zp_size
=
input_size
//
group_size
scales_and_zp_input_dim
=
None
if
self
.
quant_config
.
desc_act
:
# Act-order case
assert
self
.
quant_config
.
group_size
!=
-
1
is_k_full
=
input_size_per_partition
==
input_size
else
:
else
:
# No act-order case
# By setting scale_dim == 0, weight_loader will
# shard the scales in TP>1 case.
# K is always full due to full alignment with
scales_and_zp_input_dim
=
0
# group-size and shard of scales/zp
scales_and_zp_size
=
input_size_per_partition
//
group_size
is_k_full
=
True
# If this is a row-parallel case, then shard scales/zp
if
(
input_size
!=
input_size_per_partition
and
self
.
quant_config
.
group_size
!=
-
1
):
scales_and_zp_size
=
input_size_per_partition
//
group_size
scales_and_zp_input_dim
=
0
# Init buffers
# Quantized weights
# Quantized weights
qweight
=
Parameter
(
qweight
=
Parameter
(
...
@@ -298,11 +211,6 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -298,11 +211,6 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
},
},
)
)
g_idx_sort_indices
=
torch
.
empty
(
g_idx
.
shape
,
dtype
=
torch
.
int32
,
)
# Scales
# Scales
scales
=
Parameter
(
scales
=
Parameter
(
torch
.
empty
(
torch
.
empty
(
...
@@ -342,25 +250,52 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -342,25 +250,52 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
},
},
)
)
# Allocate marlin workspace
max_workspace_size
=
(
output_size_per_partition
//
self
.
quant_config
.
min_thread_n
)
*
self
.
quant_config
.
max_parallel
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
requires_grad
=
False
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
layer
.
register_parameter
(
"g_idx"
,
g_idx
)
layer
.
register_parameter
(
"g_idx"
,
g_idx
)
layer
.
register_parameter
(
"scales"
,
scales
)
layer
.
register_parameter
(
"scales"
,
scales
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
layer
.
g_idx_sort_indices
=
g_idx_sort_indices
layer
.
workspace
=
workspace
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
input_size
=
input_size
layer
.
input_size
=
input_size
layer
.
is_k_full
=
is_k_full
layer
.
is_k_full
=
marlin_is_k_full
(
self
.
quant_config
.
desc_act
,
layer
.
marlin_state
=
GPTQMarlinState
.
REPACK
is_row_parallel
)
# Checkpoints are serialized in AutoGPTQ format, which is different from the
# marlin format. This function is called after the weights are loaded.
# Here, we handle the repacking, including the activation reordering case.
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
device
=
layer
.
qweight
.
device
# Allocate marlin workspace
layer
.
workspace
=
marlin_make_workspace
(
layer
.
output_size_per_partition
,
device
)
# Handle sorting for activation reordering if needed.
if
self
.
quant_config
.
desc_act
:
g_idx
,
g_idx_sort_indices
=
marlin_sort_g_idx
(
layer
.
g_idx
)
layer
.
g_idx_sort_indices
=
g_idx_sort_indices
replace_tensor
(
layer
,
"g_idx"
,
g_idx
)
else
:
layer
.
g_idx
=
marlin_make_empty_g_idx
(
device
)
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
device
)
# Repack weights from autogptq format to marlin format.
marlin_qweight
=
ops
.
gptq_marlin_repack
(
layer
.
qweight
,
perm
=
layer
.
g_idx_sort_indices
,
size_k
=
layer
.
input_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
num_bits
=
self
.
quant_config
.
weight_bits
)
replace_tensor
(
layer
,
"qweight"
,
marlin_qweight
)
# Permute scales from autogptq format to marlin format.
marlin_scales
=
marlin_permute_scales
(
layer
.
scales
,
size_k
=
(
layer
.
input_size
if
self
.
quant_config
.
desc_act
else
layer
.
input_size_per_partition
),
size_n
=
layer
.
output_size_per_partition
,
group_size
=
self
.
quant_config
.
group_size
)
replace_tensor
(
layer
,
"scales"
,
marlin_scales
)
def
apply
(
def
apply
(
self
,
self
,
...
@@ -368,90 +303,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -368,90 +303,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
return
apply_marlin_linear
(
input
=
x
,
size_m
=
reshaped_x
.
shape
[
0
]
weight
=
layer
.
qweight
,
part_size_n
=
layer
.
output_size_per_partition
weight_scale
=
layer
.
scales
,
part_size_k
=
layer
.
input_size_per_partition
g_idx
=
layer
.
g_idx
,
full_size_k
=
layer
.
input_size
g_idx_sort_indices
=
layer
.
g_idx_sort_indices
,
workspace
=
layer
.
workspace
,
out_shape
=
x
.
shape
[:
-
1
]
+
(
part_size_n
,
)
num_bits
=
self
.
quant_config
.
weight_bits
,
output_size_per_partition
=
layer
.
output_size_per_partition
,
if
layer
.
marlin_state
==
GPTQMarlinState
.
REPACK
:
input_size_per_partition
=
layer
.
input_size_per_partition
,
layer
.
marlin_state
=
GPTQMarlinState
.
READY
is_k_full
=
layer
.
is_k_full
,
bias
=
bias
)
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def
replace_tensor
(
name
,
new_t
):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr
(
layer
,
name
).
resize_
(
new_t
.
shape
)
getattr
(
layer
,
name
).
copy_
(
new_t
)
del
new_t
cur_device
=
layer
.
qweight
.
device
# Process act_order
if
self
.
quant_config
.
desc_act
:
# Get sorting based on g_idx
g_idx_sort_indices
=
torch
.
argsort
(
layer
.
g_idx
).
to
(
torch
.
int
)
sorted_g_idx
=
layer
.
g_idx
[
g_idx_sort_indices
]
replace_tensor
(
"g_idx"
,
sorted_g_idx
)
replace_tensor
(
"g_idx_sort_indices"
,
g_idx_sort_indices
)
else
:
# Reset g_idx related tensors
layer
.
g_idx
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
,
)
layer
.
g_idx_sort_indices
=
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
cur_device
),
requires_grad
=
False
,
)
# Repack weights
marlin_qweight
=
ops
.
gptq_marlin_repack
(
layer
.
qweight
,
layer
.
g_idx_sort_indices
,
part_size_k
,
part_size_n
,
self
.
quant_config
.
weight_bits
,
)
replace_tensor
(
"qweight"
,
marlin_qweight
)
# Permute scales
scales_size_k
=
part_size_k
scales_size_n
=
part_size_n
if
self
.
quant_config
.
desc_act
:
scales_size_k
=
full_size_k
marlin_scales
=
marlin_permute_scales
(
layer
.
scales
,
scales_size_k
,
scales_size_n
,
self
.
quant_config
.
group_size
,
self
.
quant_config
.
weight_bits
,
)
replace_tensor
(
"scales"
,
marlin_scales
)
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
layer
.
qweight
,
layer
.
scales
,
layer
.
g_idx
,
layer
.
g_idx_sort_indices
,
layer
.
workspace
,
self
.
quant_config
.
weight_bits
,
size_m
,
part_size_n
,
part_size_k
,
layer
.
is_k_full
,
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
.
reshape
(
out_shape
)
vllm/model_executor/layers/quantization/marlin.py
View file @
705f6a35
...
@@ -8,6 +8,7 @@ from vllm.logger import init_logger
...
@@ -8,6 +8,7 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -22,9 +23,11 @@ class MarlinConfig(QuantizationConfig):
...
@@ -22,9 +23,11 @@ class MarlinConfig(QuantizationConfig):
def
__init__
(
def
__init__
(
self
,
self
,
group_size
:
int
,
group_size
:
int
,
lm_head_quantized
:
bool
,
)
->
None
:
)
->
None
:
# Group size for the quantization.
# Group size for the quantization.
self
.
group_size
=
group_size
self
.
group_size
=
group_size
self
.
lm_head_quantized
=
lm_head_quantized
if
self
.
group_size
!=
128
and
self
.
group_size
!=
-
1
:
if
self
.
group_size
!=
128
and
self
.
group_size
!=
-
1
:
raise
ValueError
(
raise
ValueError
(
"Currently, only group size 128 and -1 (channelwise) "
"Currently, only group size 128 and -1 (channelwise) "
...
@@ -51,7 +54,8 @@ class MarlinConfig(QuantizationConfig):
...
@@ -51,7 +54,8 @@ class MarlinConfig(QuantizationConfig):
self
.
perm_len
=
1024
self
.
perm_len
=
1024
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
f
"MarlinConfig(group_size=
{
self
.
group_size
}
)"
return
(
f
"MarlinConfig(group_size=
{
self
.
group_size
}
, "
f
"lm_head_quantized=
{
self
.
lm_head_quantized
}
)"
)
@
classmethod
@
classmethod
def
get_name
(
cls
)
->
str
:
def
get_name
(
cls
)
->
str
:
...
@@ -73,7 +77,9 @@ class MarlinConfig(QuantizationConfig):
...
@@ -73,7 +77,9 @@ class MarlinConfig(QuantizationConfig):
@
classmethod
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"MarlinConfig"
:
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"MarlinConfig"
:
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
return
cls
(
group_size
)
lm_head_quantized
=
cls
.
get_from_keys_or
(
config
,
[
"lm_head"
],
default
=
False
)
return
cls
(
group_size
,
lm_head_quantized
)
@
classmethod
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
...
@@ -96,7 +102,8 @@ class MarlinConfig(QuantizationConfig):
...
@@ -96,7 +102,8 @@ class MarlinConfig(QuantizationConfig):
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"MarlinLinearMethod"
]:
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"MarlinLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
if
(
isinstance
(
layer
,
LinearBase
)
or
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
)):
return
MarlinLinearMethod
(
self
)
return
MarlinLinearMethod
(
self
)
return
None
return
None
...
...
vllm/model_executor/layers/quantization/squeezellm.py
View file @
705f6a35
...
@@ -39,7 +39,8 @@ class SqueezeLLMConfig(QuantizationConfig):
...
@@ -39,7 +39,8 @@ class SqueezeLLMConfig(QuantizationConfig):
def
get_supported_act_dtypes
(
self
)
->
List
[
torch
.
dtype
]:
def
get_supported_act_dtypes
(
self
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
]
return
[
torch
.
half
]
def
get_min_capability
(
self
)
->
int
:
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
70
return
70
@
staticmethod
@
staticmethod
...
...
vllm/model_executor/layers/quantization/utils/marlin_24_perms.py
deleted
100644 → 0
View file @
af837396
"""This file is used for /tests and /benchmarks"""
import
numpy
import
torch
# Precompute permutations for Marlin24 weight and scale shuffling # noqa: E501
#
# Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501
# with the tensor-core format that is described here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
#
# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
# (without the need to use ldmatrix instructions) # noqa: E501
def
get_perms_24
(
num_bits
):
perm_list
=
[]
for
i
in
range
(
32
):
perm1
=
[]
col
=
i
//
4
col_o
=
col
//
2
for
block
in
[
0
,
1
]:
for
row
in
[
2
*
(
i
%
4
),
2
*
(
i
%
4
)
+
1
,
2
*
(
i
%
4
+
4
),
2
*
(
i
%
4
+
4
)
+
1
,
]:
perm1
.
append
(
16
*
row
+
col_o
*
256
+
8
*
(
col
%
2
)
+
4
*
block
)
for
j
in
range
(
4
):
perm_list
.
extend
([
p
+
1
*
j
for
p
in
perm1
])
perm
=
numpy
.
array
(
perm_list
)
if
num_bits
==
4
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
interleave
=
numpy
.
array
([
0
,
2
,
1
,
3
])
else
:
raise
ValueError
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
perm
=
perm
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
perm
=
torch
.
from_numpy
(
perm
)
scale_perm
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
*
8
+
j
for
j
in
[
0
,
4
,
1
,
5
,
2
,
6
,
3
,
7
]])
scale_perm_single
=
[]
for
i
in
range
(
8
):
scale_perm_single
.
extend
([
8
*
i
+
j
for
j
in
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
]])
return
perm
,
scale_perm
,
scale_perm_single
marlin_24_perm
=
{}
marlin_24_scale_perm
=
{}
marlin_24_scale_perm_single
=
{}
for
num_bits
in
[
4
,
8
]:
perm_24
,
scale_perm_24
,
scale_perm_single_24
=
get_perms_24
(
num_bits
)
marlin_24_perm
[
num_bits
]
=
perm_24
marlin_24_scale_perm
[
num_bits
]
=
scale_perm_24
marlin_24_scale_perm_single
[
num_bits
]
=
scale_perm_single_24
vllm/model_executor/layers/quantization/utils/marlin_perms.py
deleted
100644 → 0
View file @
af837396
"""This file is used for /tests and /benchmarks"""
import
numpy
import
torch
# Precompute permutations for Marlin weight and scale shuffling # noqa: E501
#
# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501
# with the tensor-core format that is described here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
#
# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
# (without the need to use ldmatrix instructions) # noqa: E501
def
get_perms
(
num_bits
):
perm_list
=
[]
for
i
in
range
(
32
):
perm1
=
[]
col
=
i
//
4
for
block
in
[
0
,
1
]:
for
row
in
[
2
*
(
i
%
4
),
2
*
(
i
%
4
)
+
1
,
2
*
(
i
%
4
+
4
),
2
*
(
i
%
4
+
4
)
+
1
,
]:
perm1
.
append
(
16
*
row
+
col
+
8
*
block
)
for
j
in
range
(
4
):
perm_list
.
extend
([
p
+
256
*
j
for
p
in
perm1
])
perm
=
numpy
.
array
(
perm_list
)
if
num_bits
==
4
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
interleave
=
numpy
.
array
([
0
,
2
,
1
,
3
])
else
:
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
perm
=
perm
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
perm
=
torch
.
from_numpy
(
perm
)
scale_perm
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
perm
,
scale_perm
,
scale_perm_single
marlin_perm
=
{}
marlin_scale_perm
=
{}
marlin_scale_perm_single
=
{}
for
num_bits
in
[
4
,
8
]:
perm
,
scale_perm
,
scale_perm_single
=
get_perms
(
num_bits
)
marlin_perm
[
num_bits
]
=
perm
marlin_scale_perm
[
num_bits
]
=
scale_perm
marlin_scale_perm_single
[
num_bits
]
=
scale_perm_single
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
705f6a35
"""This file is used for /tests and /benchmarks"""
from
typing
import
List
,
Optional
,
Tuple
import
random
import
numpy
import
torch
import
torch
from
vllm.model_executor.layers.quantization.utils.format_24
import
(
from
vllm
import
_custom_ops
as
ops
mask_creator
,
sparse_semi_structured_from_dense_cutlass
)
from
vllm.platforms
import
current_platform
from
vllm.model_executor.layers.quantization.utils.marlin_24_perms
import
(
marlin_24_perm
,
marlin_24_scale_perm
,
marlin_24_scale_perm_single
)
GPTQ_MARLIN_TILE
=
16
from
vllm.model_executor.layers.quantization.utils.marlin_perms
import
(
GPTQ_MARLIN_MIN_THREAD_N
=
64
marlin_perm
,
marlin_scale_perm
,
marlin_scale_perm_single
)
GPTQ_MARLIN_MIN_THREAD_K
=
128
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GPTQ_MARLIN_MAX_PARALLEL
=
16
get_pack_factor
,
quantize_weights
,
sort_weights
)
GPTQ_MARLIN_SUPPORTED_NUM_BITS
=
[
4
,
8
]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
GPTQ_MARLIN_SUPPORTED_SYM
=
[
True
]
GTPQ_MARLIN_UNSUPPORTED_GROUP_SIZE_ACT_ORDER
=
[
-
1
]
__cuda_arch
=
torch
.
cuda
.
get_device_capability
()
MARLIN_TILE
=
16
def
check_marlin_supported
(
num_bits
:
int
,
group_size
:
int
,
is_sym
:
bool
,
min_capability
:
int
)
->
bool
:
def
is_marlin_supported
():
# If the capability of the device is too low, cannot convert.
return
__cuda_arch
[
0
]
>=
8
major
,
minor
=
current_platform
.
get_device_capability
()
device_capability
=
major
*
10
+
minor
if
device_capability
<
min_capability
:
def
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
perm
,
tile
=
MARLIN_TILE
):
return
False
assert
q_w
.
shape
==
(
size_k
,
size_n
)
assert
size_k
%
tile
==
0
,
f
"size_k =
{
size_k
}
, tile =
{
tile
}
"
return
(
device_capability
>=
min_capability
assert
size_n
%
tile
==
0
,
f
"size_k =
{
size_n
}
, tile =
{
tile
}
"
and
num_bits
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
and
group_size
in
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
# Permute weights to 16x64 marlin tiles
and
is_sym
in
GPTQ_MARLIN_SUPPORTED_SYM
)
q_w
=
q_w
.
reshape
((
size_k
//
tile
,
tile
,
size_n
//
tile
,
tile
))
q_w
=
q_w
.
permute
((
0
,
2
,
1
,
3
))
q_w
=
q_w
.
reshape
((
size_k
//
tile
,
size_n
*
tile
))
def
verify_marlin_supported
(
num_bits
:
int
,
group_size
:
Optional
[
int
],
is_sym
:
bool
)
->
None
:
q_w
=
q_w
.
reshape
((
-
1
,
perm
.
numel
()))[:,
perm
].
reshape
(
q_w
.
shape
)
if
num_bits
not
in
GPTQ_MARLIN_SUPPORTED_NUM_BITS
:
return
q_w
raise
ValueError
(
f
"Marlin does not support weight_bits =
{
num_bits
}
. "
f
"Only weight_bits =
{
GPTQ_MARLIN_SUPPORTED_NUM_BITS
}
"
def
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
perm
):
"are supported."
)
# Permute
if
(
group_size
is
None
q_w
=
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
perm
)
or
group_size
not
in
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
):
raise
ValueError
(
# Pack
f
"Marlin does not support group_size =
{
group_size
}
. "
pack_factor
=
get_pack_factor
(
num_bits
)
f
"Only group_sizes =
{
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
}
"
orig_device
=
q_w
.
device
"are supported."
)
if
is_sym
not
in
GPTQ_MARLIN_SUPPORTED_SYM
:
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
raise
ValueError
(
f
"Marlin does not support is_sym = is_sym. "
q_packed
=
numpy
.
zeros
((
q_w
.
shape
[
0
],
q_w
.
shape
[
1
]
//
pack_factor
),
f
"Only sym =
{
GPTQ_MARLIN_SUPPORTED_SYM
}
are supported."
)
dtype
=
numpy
.
uint32
)
for
i
in
range
(
pack_factor
):
q_packed
|=
q_w
[:,
i
::
pack_factor
]
<<
num_bits
*
i
def
verify_marlin_supports_shape
(
output_size_per_partition
:
int
,
input_size_per_partition
:
int
,
q_packed
=
torch
.
from_numpy
(
q_packed
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
input_size
:
int
,
group_size
:
int
)
->
None
:
return
q_packed
# Validate output_size_per_partition
if
output_size_per_partition
%
GPTQ_MARLIN_MIN_THREAD_N
!=
0
:
raise
ValueError
(
f
"Weight output_size_per_partition = "
def
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
,
scale_perm
,
f
"
{
output_size_per_partition
}
is not divisible by "
scale_perm_single
):
f
" min_thread_n =
{
GPTQ_MARLIN_MIN_THREAD_N
}
. "
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq."
)
# Validate input_size_per_partition
if
input_size_per_partition
%
GPTQ_MARLIN_MIN_THREAD_K
!=
0
:
raise
ValueError
(
f
"Weight input_size_per_partition = "
f
"
{
input_size_per_partition
}
is not divisible "
f
"by min_thread_k =
{
GPTQ_MARLIN_MIN_THREAD_K
}
. "
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq."
)
if
(
group_size
<
input_size
and
input_size_per_partition
%
group_size
!=
0
):
raise
ValueError
(
f
"Weight input_size_per_partition =
{
input_size_per_partition
}
"
f
" is not divisible by group_size =
{
group_size
}
."
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq."
)
def
marlin_make_workspace
(
output_size_per_partition
:
int
,
device
:
torch
.
device
)
->
torch
.
Tensor
:
max_workspace_size
=
(
output_size_per_partition
//
GPTQ_MARLIN_MIN_THREAD_N
)
*
GPTQ_MARLIN_MAX_PARALLEL
return
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
device
,
requires_grad
=
False
)
def
marlin_is_k_full
(
act_order
:
bool
,
is_row_parallel
:
bool
)
->
bool
:
return
(
not
act_order
)
or
(
act_order
and
not
is_row_parallel
)
def
marlin_repeat_scales_on_all_ranks
(
act_order
:
bool
,
group_size
:
int
,
is_row_parallel
:
bool
)
->
bool
:
# Need to repeat scales on every rank if act_ordering or
# channelwise and RowParallelLinear
is_channelwise
=
group_size
==
-
1
return
act_order
or
(
is_channelwise
and
is_row_parallel
)
def
marlin_make_empty_g_idx
(
device
:
torch
.
device
)
->
torch
.
Tensor
:
return
torch
.
nn
.
Parameter
(
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
),
requires_grad
=
False
)
def
marlin_sort_g_idx
(
g_idx
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
g_idx_sort_indices
=
torch
.
argsort
(
g_idx
).
to
(
torch
.
int
)
return
g_idx
[
g_idx_sort_indices
],
g_idx_sort_indices
def
get_scale_perms
():
scale_perm
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
:
List
[
int
]
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
scale_perm
,
scale_perm_single
def
marlin_permute_scales
(
s
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
group_size
:
int
)
->
torch
.
Tensor
:
scale_perm
,
scale_perm_single
=
get_scale_perms
()
if
group_size
<
size_k
and
group_size
!=
-
1
:
if
group_size
<
size_k
and
group_size
!=
-
1
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
else
:
else
:
...
@@ -68,157 +138,44 @@ def marlin_permute_scales(s, size_k, size_n, group_size, scale_perm,
...
@@ -68,157 +138,44 @@ def marlin_permute_scales(s, size_k, size_n, group_size, scale_perm,
return
s
return
s
def
marlin_quantize
(
# Newly generated tensors need to replace existing tensors that are
w
:
torch
.
Tensor
,
# already registered as parameters by vLLM (and won't be freed)
num_bits
:
int
,
def
replace_tensor
(
layer
:
torch
.
nn
.
Module
,
name
:
str
,
group_size
:
int
,
new_t
:
torch
.
Tensor
)
->
None
:
act_order
:
bool
,
# It is important to use resize_() here since it ensures
):
# the same buffer is reused
size_k
,
size_n
=
w
.
shape
getattr
(
layer
,
name
).
resize_
(
new_t
.
shape
)
getattr
(
layer
,
name
).
copy_
(
new_t
)
# Normalize group_size
del
new_t
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
def
apply_marlin_linear
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
# Quantize (and apply act_order if provided)
weight_scale
:
torch
.
Tensor
,
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
=
quantize_weights
(
w
,
num_bits
,
group_size
,
g_idx
:
torch
.
Tensor
,
act_order
)
g_idx_sort_indices
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
# For act_order, sort the "weights" and "g_idx" so that group ids are
num_bits
:
int
,
# increasing
output_size_per_partition
:
int
,
sort_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w
.
device
)
input_size_per_partition
:
int
,
if
act_order
:
is_k_full
:
bool
,
q_w
,
g_idx
,
sort_indices
=
sort_weights
(
q_w
,
g_idx
)
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
reshaped_x
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
# Reformat to marlin
out_shape
=
input
.
shape
[:
-
1
]
+
(
output_size_per_partition
,
)
marlin_q_w
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
marlin_perm
[
num_bits
])
output
=
ops
.
gptq_marlin_gemm
(
reshaped_x
,
marlin_s
=
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
,
weight
,
marlin_scale_perm
[
num_bits
],
weight_scale
,
marlin_scale_perm_single
[
num_bits
])
g_idx
,
g_idx_sort_indices
,
# Create result
workspace
,
res_list
=
[
w_ref
,
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
rand_perm
]
num_bits
,
for
i
in
range
(
len
(
res_list
)):
size_m
=
reshaped_x
.
shape
[
0
],
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
size_n
=
output_size_per_partition
,
size_k
=
input_size_per_partition
,
return
res_list
is_k_full
=
is_k_full
)
if
bias
is
not
None
:
def
inject_24
(
w
,
size_k
,
size_n
):
output
.
add_
(
bias
)
# In-place add
assert
w
.
shape
==
(
size_k
,
size_n
)
return
output
.
reshape
(
out_shape
)
mask
=
mask_creator
(
w
.
t
()).
t
().
cuda
().
bool
()
return
(
mask
*
w
).
contiguous
(),
mask
.
contiguous
()
def
check_24
(
w
,
num_rows_to_sample
=
50
,
_verbose
=
False
):
BLOCK_SIZE
=
4
MAX_NON_ZEROS
=
2
w
=
w
.
t
().
contiguous
()
print
(
"check_24: w.shape = {}"
.
format
(
w
.
shape
))
num_rows
,
num_cols
=
w
.
shape
sampled_row_idxs
=
random
.
choices
(
range
(
num_rows
),
k
=
num_rows_to_sample
)
if
_verbose
:
print
(
f
"Sampled row idxs =
{
sampled_row_idxs
}
"
)
total_segments
=
0
non_24_segments
=
0
for
i
in
sampled_row_idxs
:
for
j
in
range
(
0
,
num_cols
-
BLOCK_SIZE
,
BLOCK_SIZE
):
total_segments
+=
1
block
=
w
[
i
,
j
:
j
+
BLOCK_SIZE
]
num_nonzero
=
torch
.
count_nonzero
(
block
)
if
num_nonzero
>
MAX_NON_ZEROS
:
print
(
"i = {} j = {} block = {}"
.
format
(
i
,
j
,
block
))
non_24_segments
+=
1
print
(
f
"
{
non_24_segments
}
/
{
total_segments
}
do not have 2:4 structure."
)
def
compress_quantized_24_weight
(
q_24
,
size_k
,
size_n
,
num_bits
):
assert
q_24
.
shape
==
(
size_k
,
size_n
)
# Remove zp to normalize over 0
max_q_val
=
(
1
<<
num_bits
)
-
1
zp
=
(
max_q_val
+
1
)
//
2
q_24_no_zp
=
q_24
-
zp
# Compress
q_24_no_zp
=
q_24_no_zp
.
t
().
contiguous
()
q_24_no_zp_comp
,
meta
=
sparse_semi_structured_from_dense_cutlass
(
q_24_no_zp
)
q_24_no_zp_comp
=
q_24_no_zp_comp
.
t
().
contiguous
()
# Restore zp
q_24_comp
=
q_24_no_zp_comp
+
zp
# Resize meta to its actual shape (without moving any data)
meta
=
meta
.
resize_
(
meta
.
shape
[
1
]
//
2
,
meta
.
shape
[
0
]
*
2
)
return
q_24_comp
,
meta
def
marlin_24_quantize
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
):
size_k
,
size_n
=
w
.
shape
# Normalize group_size
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
# Inject 2:4 sparsity
w_24
,
mask_24
=
inject_24
(
w
,
size_k
,
size_n
)
# Quantize
w_24_ref
,
q_w_24
,
s
,
g_idx
,
rand_perm
=
quantize_weights
(
w_24
,
num_bits
,
group_size
,
act_order
=
False
)
# Compress quantized weight
q_w_24_comp
,
meta
=
compress_quantized_24_weight
(
q_w_24
,
size_k
,
size_n
,
num_bits
)
size_k_comp
=
size_k
//
2
# Reformat to marlin
marlin_24_q_w_comp
=
marlin_weights
(
q_w_24_comp
,
size_k_comp
,
size_n
,
num_bits
,
marlin_24_perm
[
num_bits
])
marlin_24_s
=
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
,
marlin_24_scale_perm
[
num_bits
],
marlin_24_scale_perm_single
[
num_bits
])
# Create result
res_list
=
[
w_24_ref
,
marlin_24_q_w_comp
,
meta
,
marlin_24_s
]
for
i
in
range
(
len
(
res_list
)):
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
return
res_list
def
compute_max_diff
(
output
,
output_ref
):
return
torch
.
mean
(
torch
.
abs
(
output
-
output_ref
))
/
torch
.
mean
(
torch
.
abs
(
output_ref
))
class
MarlinWorkspace
:
def
__init__
(
self
,
out_features
,
min_thread_n
,
max_parallel
):
assert
(
out_features
%
min_thread_n
==
0
),
(
"out_features = {} is undivisible by min_thread_n = {}"
.
format
(
out_features
,
min_thread_n
))
max_workspace_size
=
((
out_features
//
min_thread_n
)
*
max_parallel
)
self
.
scratch
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
0 → 100644
View file @
705f6a35
from
typing
import
Optional
import
torch
import
vllm._custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.utils
import
print_warning_once
from
.marlin_utils
import
marlin_make_workspace
,
marlin_permute_scales
def
is_fp8_marlin_supported
():
capability
=
current_platform
.
get_device_capability
()
return
capability
[
0
]
>=
8
def
apply_fp8_marlin_linear
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
size_n
:
int
,
size_k
:
int
,
bias
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# For GPUs that lack FP8 hardware support, we can leverage the
# Marlin kernel for fast weight-only FP8 quantization
reshaped_x
=
input
.
reshape
(
-
1
,
input
.
shape
[
-
1
])
out_shape
=
input
.
shape
[:
-
1
]
+
(
size_n
,
)
output
=
ops
.
fp8_marlin_gemm
(
a
=
reshaped_x
,
b_q_weight
=
weight
,
b_scales
=
weight_scale
,
workspace
=
workspace
,
num_bits
=
8
,
size_m
=
reshaped_x
.
shape
[
0
],
size_n
=
size_n
,
size_k
=
size_k
,
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
.
reshape
(
out_shape
)
def
prepare_fp8_layer_for_marlin
(
layer
:
torch
.
nn
.
Module
)
->
None
:
print_warning_once
(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
part_size_n
=
layer
.
output_size_per_partition
part_size_k
=
layer
.
input_size_per_partition
device
=
layer
.
weight
.
device
# WORKSPACE
layer
.
workspace
=
marlin_make_workspace
(
part_size_n
,
device
)
# WEIGHT
# Repack weights to marlin format
marlin_qweight
=
ops
.
gptq_marlin_repack
(
b_q_weight
=
pack_fp8_to_int32
(
layer
.
weight
),
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
device
),
size_k
=
part_size_k
,
size_n
=
part_size_n
,
num_bits
=
8
)
layer
.
weight
=
torch
.
nn
.
Parameter
(
marlin_qweight
,
requires_grad
=
False
)
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
scales
=
layer
.
weight_scale
.
repeat
(
1
,
part_size_n
).
to
(
layer
.
orig_dtype
).
to
(
device
)
# Permute scales
marlin_scales
=
marlin_permute_scales
(
s
=
scales
,
size_k
=
part_size_k
,
size_n
=
part_size_n
,
group_size
=-
1
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
marlin_scales
,
requires_grad
=
False
)
def
pack_fp8_to_int32
(
fp8_tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Repack FP8 weights to gptq format (packed int32 elements)
"""
assert
fp8_tensor
.
dtype
==
torch
.
float8_e4m3fn
assert
fp8_tensor
.
shape
[
0
]
%
4
==
0
# Reshape to prepare for packing
reshaped
=
fp8_tensor
.
reshape
(
-
1
,
4
,
*
fp8_tensor
.
shape
[
1
:])
# Convert fp8 to uint8 (byte) representation
byte_tensor
=
reshaped
.
view
(
torch
.
uint8
)
# Pack 4 uint8 values into one int32
packed
=
(
byte_tensor
[:,
0
].
to
(
torch
.
int32
)
|
(
byte_tensor
[:,
1
].
to
(
torch
.
int32
)
<<
8
)
|
(
byte_tensor
[:,
2
].
to
(
torch
.
int32
)
<<
16
)
|
(
byte_tensor
[:,
3
].
to
(
torch
.
int32
)
<<
24
))
return
packed
.
view
(
fp8_tensor
.
shape
[
0
]
//
4
,
*
fp8_tensor
.
shape
[
1
:]).
contiguous
()
vllm/model_executor/layers/quantization/utils/marlin_utils_test.py
0 → 100644
View file @
705f6a35
"""Utility functions used for tests and benchmarks"""
from
typing
import
List
import
numpy
import
torch
from
.marlin_utils
import
GPTQ_MARLIN_TILE
,
marlin_permute_scales
from
.quant_utils
import
get_pack_factor
,
quantize_weights
,
sort_weights
class
MarlinWorkspace
:
def
__init__
(
self
,
out_features
,
min_thread_n
,
max_parallel
):
assert
(
out_features
%
min_thread_n
==
0
),
(
"out_features = {} is undivisible by min_thread_n = {}"
.
format
(
out_features
,
min_thread_n
))
max_workspace_size
=
((
out_features
//
min_thread_n
)
*
max_parallel
)
self
.
scratch
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
def
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
perm
,
tile
=
GPTQ_MARLIN_TILE
):
assert
q_w
.
shape
==
(
size_k
,
size_n
)
assert
size_k
%
tile
==
0
,
f
"size_k =
{
size_k
}
, tile =
{
tile
}
"
assert
size_n
%
tile
==
0
,
f
"size_k =
{
size_n
}
, tile =
{
tile
}
"
# Permute weights to 16x64 marlin tiles
q_w
=
q_w
.
reshape
((
size_k
//
tile
,
tile
,
size_n
//
tile
,
tile
))
q_w
=
q_w
.
permute
((
0
,
2
,
1
,
3
))
q_w
=
q_w
.
reshape
((
size_k
//
tile
,
size_n
*
tile
))
q_w
=
q_w
.
reshape
((
-
1
,
perm
.
numel
()))[:,
perm
].
reshape
(
q_w
.
shape
)
return
q_w
def
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
perm
):
# Permute
q_w
=
marlin_permute_weights
(
q_w
,
size_k
,
size_n
,
perm
)
# Pack
pack_factor
=
get_pack_factor
(
num_bits
)
orig_device
=
q_w
.
device
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
numpy
.
uint32
)
q_packed
=
numpy
.
zeros
((
q_w
.
shape
[
0
],
q_w
.
shape
[
1
]
//
pack_factor
),
dtype
=
numpy
.
uint32
)
for
i
in
range
(
pack_factor
):
q_packed
|=
q_w
[:,
i
::
pack_factor
]
<<
num_bits
*
i
q_packed
=
torch
.
from_numpy
(
q_packed
.
astype
(
numpy
.
int32
)).
to
(
orig_device
)
return
q_packed
def
get_weight_perm
(
num_bits
:
int
):
perm_list
:
List
[
int
]
=
[]
for
i
in
range
(
32
):
perm1
:
List
[
int
]
=
[]
col
=
i
//
4
for
block
in
[
0
,
1
]:
for
row
in
[
2
*
(
i
%
4
),
2
*
(
i
%
4
)
+
1
,
2
*
(
i
%
4
+
4
),
2
*
(
i
%
4
+
4
)
+
1
,
]:
perm1
.
append
(
16
*
row
+
col
+
8
*
block
)
for
j
in
range
(
4
):
perm_list
.
extend
([
p
+
256
*
j
for
p
in
perm1
])
perm
=
numpy
.
array
(
perm_list
)
if
num_bits
==
4
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
interleave
=
numpy
.
array
([
0
,
2
,
1
,
3
])
else
:
raise
Exception
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
perm
=
perm
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
perm
=
torch
.
from_numpy
(
perm
)
return
perm
def
marlin_quantize
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
act_order
:
bool
):
size_k
,
size_n
=
w
.
shape
# Normalize group_size
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
# Quantize (and apply act_order if provided)
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
=
quantize_weights
(
w
,
num_bits
,
group_size
,
act_order
)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing
sort_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w
.
device
)
if
act_order
:
q_w
,
g_idx
,
sort_indices
=
sort_weights
(
q_w
,
g_idx
)
# Reformat to marlin
weight_perm
=
get_weight_perm
(
num_bits
)
marlin_q_w
=
marlin_weights
(
q_w
,
size_k
,
size_n
,
num_bits
,
weight_perm
)
marlin_s
=
marlin_permute_scales
(
s
,
size_k
,
size_n
,
group_size
)
# Create result
res_list
=
[
w_ref
,
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
rand_perm
]
for
i
in
range
(
len
(
res_list
)):
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
return
res_list
vllm/model_executor/layers/quantization/utils/
forma
t_24.py
→
vllm/model_executor/layers/quantization/utils/
marlin_utils_tes
t_24.py
View file @
705f6a35
#
"""Utility functions used for tests and benchmarks"""
# Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es).
#
import
random
from
typing
import
List
import
numpy
import
torch
import
torch
from
.marlin_utils_test
import
marlin_weights
from
.quant_utils
import
quantize_weights
# This is PyTorch implementation of main part of reorder_meta()
# This is PyTorch implementation of main part of reorder_meta()
# function, from tools/util/include/cutlass/util/host_reorder.h file
# function, from tools/util/include/cutlass/util/host_reorder.h file
...
@@ -306,3 +311,155 @@ def mask_creator(tensor):
...
@@ -306,3 +311,155 @@ def mask_creator(tensor):
mask
=
w_b
.
scatter_
(
dim
=
1
,
index
=
index
,
value
=
0
).
reshape
(
tensor
.
shape
)
mask
=
w_b
.
scatter_
(
dim
=
1
,
index
=
index
,
value
=
0
).
reshape
(
tensor
.
shape
)
return
mask
return
mask
def
inject_24
(
w
,
size_k
,
size_n
):
assert
w
.
shape
==
(
size_k
,
size_n
)
mask
=
mask_creator
(
w
.
t
()).
t
().
cuda
().
bool
()
return
(
mask
*
w
).
contiguous
(),
mask
.
contiguous
()
def
check_24
(
w
,
num_rows_to_sample
=
50
,
_verbose
=
False
):
BLOCK_SIZE
=
4
MAX_NON_ZEROS
=
2
w
=
w
.
t
().
contiguous
()
print
(
"check_24: w.shape = {}"
.
format
(
w
.
shape
))
num_rows
,
num_cols
=
w
.
shape
sampled_row_idxs
=
random
.
choices
(
range
(
num_rows
),
k
=
num_rows_to_sample
)
if
_verbose
:
print
(
f
"Sampled row idxs =
{
sampled_row_idxs
}
"
)
total_segments
=
0
non_24_segments
=
0
for
i
in
sampled_row_idxs
:
for
j
in
range
(
0
,
num_cols
-
BLOCK_SIZE
,
BLOCK_SIZE
):
total_segments
+=
1
block
=
w
[
i
,
j
:
j
+
BLOCK_SIZE
]
num_nonzero
=
torch
.
count_nonzero
(
block
)
if
num_nonzero
>
MAX_NON_ZEROS
:
print
(
"i = {} j = {} block = {}"
.
format
(
i
,
j
,
block
))
non_24_segments
+=
1
print
(
f
"
{
non_24_segments
}
/
{
total_segments
}
do not have 2:4 structure."
)
def
compress_quantized_24_weight
(
q_24
,
size_k
,
size_n
,
num_bits
):
assert
q_24
.
shape
==
(
size_k
,
size_n
)
# Remove zp to normalize over 0
max_q_val
=
(
1
<<
num_bits
)
-
1
zp
=
(
max_q_val
+
1
)
//
2
q_24_no_zp
=
q_24
-
zp
# Compress
q_24_no_zp
=
q_24_no_zp
.
t
().
contiguous
()
q_24_no_zp_comp
,
meta
=
sparse_semi_structured_from_dense_cutlass
(
q_24_no_zp
)
q_24_no_zp_comp
=
q_24_no_zp_comp
.
t
().
contiguous
()
# Restore zp
q_24_comp
=
q_24_no_zp_comp
+
zp
# Resize meta to its actual shape (without moving any data)
meta
=
meta
.
resize_
(
meta
.
shape
[
1
]
//
2
,
meta
.
shape
[
0
]
*
2
)
return
q_24_comp
,
meta
def
get_scale_perms_24
():
scale_perm
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
*
8
+
j
for
j
in
[
0
,
4
,
1
,
5
,
2
,
6
,
3
,
7
]])
scale_perm_single
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm_single
.
extend
([
8
*
i
+
j
for
j
in
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
]])
return
scale_perm
,
scale_perm_single
def
get_weight_perm_24
(
num_bits
:
int
):
perm_list
:
List
[
int
]
=
[]
for
i
in
range
(
32
):
perm1
:
List
[
int
]
=
[]
col
=
i
//
4
col_o
=
col
//
2
for
block
in
[
0
,
1
]:
for
row
in
[
2
*
(
i
%
4
),
2
*
(
i
%
4
)
+
1
,
2
*
(
i
%
4
+
4
),
2
*
(
i
%
4
+
4
)
+
1
,
]:
perm1
.
append
(
16
*
row
+
col_o
*
256
+
8
*
(
col
%
2
)
+
4
*
block
)
for
j
in
range
(
4
):
perm_list
.
extend
([
p
+
1
*
j
for
p
in
perm1
])
perm
=
numpy
.
array
(
perm_list
)
if
num_bits
==
4
:
interleave
=
numpy
.
array
([
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
])
elif
num_bits
==
8
:
interleave
=
numpy
.
array
([
0
,
2
,
1
,
3
])
else
:
raise
ValueError
(
"num_bits must be 4 or 8, got {}"
.
format
(
num_bits
))
perm
=
perm
.
reshape
((
-
1
,
len
(
interleave
)))[:,
interleave
].
ravel
()
perm
=
torch
.
from_numpy
(
perm
)
return
perm
def
marlin_permute_scales_24
(
s
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
group_size
:
int
)
->
torch
.
Tensor
:
scale_perm
,
scale_perm_single
=
get_scale_perms_24
()
if
group_size
<
size_k
and
group_size
!=
-
1
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
else
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
s
def
marlin_24_quantize
(
w
:
torch
.
Tensor
,
num_bits
:
int
,
group_size
:
int
,
):
size_k
,
size_n
=
w
.
shape
# Normalize group_size
if
group_size
==
-
1
:
group_size
=
size_k
assert
group_size
<=
size_k
# Inject 2:4 sparsity
w_24
,
mask_24
=
inject_24
(
w
,
size_k
,
size_n
)
# Quantize
w_24_ref
,
q_w_24
,
s
,
g_idx
,
rand_perm
=
quantize_weights
(
w_24
,
num_bits
,
group_size
,
act_order
=
False
)
# Compress quantized weight
q_w_24_comp
,
meta
=
compress_quantized_24_weight
(
q_w_24
,
size_k
,
size_n
,
num_bits
)
size_k_comp
=
size_k
//
2
# Reformat to marlin
weight_perm
=
get_weight_perm_24
(
num_bits
)
marlin_24_q_w_comp
=
marlin_weights
(
q_w_24_comp
,
size_k_comp
,
size_n
,
num_bits
,
weight_perm
)
marlin_24_s
=
marlin_permute_scales_24
(
s
,
size_k
,
size_n
,
group_size
)
# Create result
res_list
=
[
w_24_ref
,
marlin_24_q_w_comp
,
meta
,
marlin_24_s
]
for
i
in
range
(
len
(
res_list
)):
res_list
[
i
]
=
res_list
[
i
].
to
(
w
.
device
)
return
res_list
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
0 → 100644
View file @
705f6a35
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch.nn
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
def
cutlass_fp8_supported
()
->
bool
:
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
return
ops
.
cutlass_scaled_mm_supports_fp8
(
capability
)
def
per_tensor_dequantize
(
tensor
:
torch
.
Tensor
,
inv_scale
:
Union
[
float
,
torch
.
Tensor
])
->
torch
.
Tensor
:
fake_qweight
=
tensor
.
to
(
torch
.
float16
)
dq_weight
=
fake_qweight
*
inv_scale
return
dq_weight
def
all_close_1d
(
x
:
torch
.
Tensor
)
->
bool
:
assert
len
(
x
.
shape
)
==
1
return
all
(
torch
.
allclose
(
x
[
0
],
x
[
i
])
for
i
in
range
(
x
.
shape
[
0
]))
def
create_per_tensor_scale_param
(
output_partition_sizes
:
List
[
int
],
**
extra_weight_attrs
,
)
->
Parameter
:
scale
=
Parameter
(
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
requires_grad
=
False
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
set_weight_attrs
(
scale
,
{
"needs_scalar_to_array"
:
True
,
**
extra_weight_attrs
})
return
scale
def
create_per_channel_scale_param
(
output_partition_sizes
:
List
[
int
],
**
extra_weight_attrs
)
->
Parameter
:
scale
=
Parameter
(
torch
.
empty
((
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
requires_grad
=
False
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
set_weight_attrs
(
scale
,
{
"output_dim"
:
0
,
**
extra_weight_attrs
})
return
scale
def
convert_to_channelwise
(
weight_scale
:
torch
.
Tensor
,
logical_widths
:
List
[
int
])
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Create channelwise buffer
weight_scale_channel
=
torch
.
empty
((
sum
(
logical_widths
),
1
),
dtype
=
torch
.
float32
,
device
=
weight_scale
.
device
)
# Expand each scale to match the size of each logical matrix.
start
=
0
for
idx
,
logical_width
in
enumerate
(
logical_widths
):
end
=
start
+
logical_width
weight_scale_channel
[
start
:
end
,
:]
=
weight_scale
[
idx
]
start
=
end
return
weight_scale_channel
def
requantize_with_max_scale
(
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
logical_widths
:
List
[
int
])
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Max scale to be used for requanitzation.
max_w_scale
=
weight_scale
.
max
()
# QKV / MLP is fused in the on disk checkpoint if any of the
# weight scales are still set to the default since we initialize
# N weight scales for N shards but we only load 1 weight scale
# from disk in this case. Skip requantization in this case (since)
# we already are quantized with the single scale.
# * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
unfused_module_in_checkpoint
=
(
weight_scale
[
-
1
]
>
torch
.
finfo
(
torch
.
float8_e4m3fn
).
min
)
# If unfused checkpoint, need requanize with the single scale.
if
unfused_module_in_checkpoint
:
start
=
0
for
idx
,
logical_width
in
enumerate
(
logical_widths
):
end
=
start
+
logical_width
weight_dq
=
per_tensor_dequantize
(
weight
[
start
:
end
,
:],
weight_scale
[
idx
])
weight
[
start
:
end
,
:],
_
=
ops
.
scaled_fp8_quant
(
weight_dq
,
max_w_scale
)
start
=
end
return
max_w_scale
,
weight
def
apply_fp8_linear
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
input_scale
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
cutlass_fp8_supported
:
bool
=
True
,
)
->
torch
.
Tensor
:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
if
cutlass_fp8_supported
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
input
,
input_scale
)
# Fused GEMM_DQ
output
=
ops
.
cutlass_scaled_mm
(
qinput
,
weight
,
out_dtype
=
input
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
bias
=
bias
)
else
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
input
,
input_scale
,
batch_dim_padding
=
17
)
# Fused GEMM_DQ -- note we padded the input above because
# torch._scaled_mm is more performant for matrices with
# batch dimension > 16. Note that this could change
# in the future.
output
,
_
=
torch
.
_scaled_mm
(
qinput
,
weight
,
out_dtype
=
input
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
bias
=
bias
)
return
torch
.
narrow
(
output
,
0
,
0
,
input
.
shape
[
0
])
def
apply_int8_linear
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
input_scale
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
):
# ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale.
x_q
,
x_scale
=
ops
.
scaled_int8_quant
(
input
,
input_scale
)
return
ops
.
cutlass_scaled_mm
(
x_q
,
weight
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
out_dtype
=
input
.
dtype
,
bias
=
bias
)
vllm/model_executor/layers/rejection_sampler.py
View file @
705f6a35
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
Optional
,
Tuple
from
typing
import
Tuple
import
torch
import
torch
import
torch.jit
import
torch.jit
import
torch.nn
as
nn
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
SpecDecodeBaseSampler
)
class
RejectionSampler
(
nn
.
Module
):
class
RejectionSampler
(
SpecDecodeBaseSampler
):
"""Apply modified rejection sampling as described in "Accelerating Large
"""Apply modified rejection sampling as described in "Accelerating Large
Language Model Decoding with Speculative Sampling"
Language Model Decoding with Speculative Sampling"
https://arxiv.org/pdf/2302.01318.pdf.
https://arxiv.org/pdf/2302.01318.pdf.
...
@@ -22,39 +24,11 @@ class RejectionSampler(nn.Module):
...
@@ -22,39 +24,11 @@ class RejectionSampler(nn.Module):
Require when bonus tokens will cause corrupt KV cache for
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
during sampling. This catches correctness issues but adds
nontrivial latency.
nontrivial latency.
"""
"""
super
().
__init__
()
super
().
__init__
(
disable_bonus_tokens
=
disable_bonus_tokens
,
self
.
_disable_bonus_tokens
=
disable_bonus_tokens
strict_mode
=
strict_mode
)
self
.
_strict_mode
=
strict_mode
# NOTE: A "bonus token" is accepted iff all proposal tokens are
# accepted. There is always only one possible bonus token. We store this
# value in a variable for readability.
self
.
_num_bonus_tokens
=
1
self
.
num_accepted_tokens
:
Optional
[
torch
.
Tensor
]
=
None
self
.
num_emitted_tokens
:
Optional
[
torch
.
Tensor
]
=
None
self
.
num_draft_tokens
:
int
=
0
def
init_gpu_tensors
(
self
,
rank
:
int
)
->
None
:
assert
self
.
num_accepted_tokens
is
None
device
=
f
"cuda:
{
rank
}
"
self
.
num_accepted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
device
)
self
.
num_emitted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
device
)
@
property
def
probs_dtype
(
self
):
return
torch
.
float32
@
property
def
token_id_dtype
(
self
):
return
torch
.
int64
def
forward
(
def
forward
(
self
,
self
,
...
@@ -100,21 +74,15 @@ class RejectionSampler(nn.Module):
...
@@ -100,21 +74,15 @@ class RejectionSampler(nn.Module):
# Only perform shape/dtype/device checking in strict mode, as it adds
# Only perform shape/dtype/device checking in strict mode, as it adds
# overhead.
# overhead.
if
self
.
_strict_mode
:
if
self
.
_strict_mode
:
self
.
_raise_if_incorrect_shape
(
target_probs
,
bonus_token_ids
,
self
.
_raise_if_incorrect_input
(
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
)
self
.
_raise_if_incorrect_dtype
(
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
)
draft_probs
,
draft_token_ids
)
self
.
_raise_if_inconsistent_device
(
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
)
accepted
,
recovered_token_ids
=
(
self
.
_raise_if_out_of_bounds_vocab
(
target_probs
.
shape
[
-
1
],
self
.
_batch_modified_rejection_sampling
(
bonus_token_ids
,
target_probs
,
draft_token_ids
)
draft_probs
,
draft_token_ids
,
accepted
,
recovered_token_ids
=
self
.
_batch_modified_rejection_sampling
(
))
target_probs
,
draft_probs
,
draft_token_ids
,
)
output_token_ids
=
self
.
_create_output
(
output_token_ids
=
self
.
_create_output
(
accepted
,
accepted
,
...
@@ -272,128 +240,6 @@ class RejectionSampler(nn.Module):
...
@@ -272,128 +240,6 @@ class RejectionSampler(nn.Module):
"""
"""
return
torch
.
finfo
(
self
.
probs_dtype
).
tiny
return
torch
.
finfo
(
self
.
probs_dtype
).
tiny
def
_create_output
(
self
,
accepted
:
torch
.
Tensor
,
# [batch_size, k]
recovered_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
bonus_token_ids
:
torch
.
Tensor
,
# [batch_size]
)
->
torch
.
Tensor
:
"""Format output. Returns a matrix of token ids. When
a token is rejected via rejection sampling, all subsequent
token ids are set to -1 for the sequence.
shape = [batch_size, k + num_bonus_tokens]
"""
bonus_token_ids
=
bonus_token_ids
.
squeeze
()
batch_size
,
k
=
recovered_token_ids
.
shape
# Determine the index of the first False value for each row.
limits
=
(
accepted
==
0
).
max
(
1
).
indices
limits
[
~
(
accepted
==
0
).
any
(
1
)]
=
k
# Create masks using the indices.
indices
=
torch
.
arange
(
k
,
device
=
accepted
.
device
).
unsqueeze
(
0
)
accepted_mask
=
indices
<
limits
.
unsqueeze
(
1
)
after_false_mask
=
indices
==
limits
.
unsqueeze
(
1
)
# Create an extended output tensor
output_with_bonus_tokens
=
-
torch
.
ones
(
(
batch_size
,
k
+
self
.
_num_bonus_tokens
),
dtype
=
self
.
token_id_dtype
,
device
=
accepted
.
device
)
output
=
output_with_bonus_tokens
[:,
:
k
]
# Fill in the first k columns of the output tensor using masks and data
# tensors.
torch
.
where
(
accepted_mask
,
draft_token_ids
,
-
torch
.
ones_like
(
draft_token_ids
),
out
=
output
)
# Fill the last column.
# We check output directly as accepted may have True values inconsistent
# with causal acceptance.
output_with_bonus_tokens
[:,
-
1
]
=
torch
.
where
(
output
[:,
-
1
]
!=
-
1
,
bonus_token_ids
,
-
1
)
# We disable bonus tokens because it causes corrupt KV cache for
# proposal methods that require KV cache. We can fix it by "prefilling"
# the bonus token in the proposer. The following issue tracks the fix.
# https://github.com/vllm-project/vllm/issues/4212
if
self
.
_disable_bonus_tokens
:
output_with_bonus_tokens
[:,
-
1
]
=
-
1
# Fill the recovered token ids.
output
.
mul_
(
~
after_false_mask
).
add_
(
recovered_token_ids
.
mul
(
after_false_mask
))
self
.
num_accepted_tokens
+=
accepted
.
sum
()
self
.
num_emitted_tokens
+=
(
output_with_bonus_tokens
!=
-
1
).
sum
()
self
.
num_draft_tokens
+=
batch_size
*
k
return
output_with_bonus_tokens
def
_raise_if_incorrect_shape
(
self
,
target_probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
)
->
None
:
(
target_batch_size
,
num_target_probs
,
target_vocab_size
)
=
target_probs
.
shape
bonus_batch_size
,
num_bonus_tokens
=
bonus_token_ids
.
shape
draft_batch_size
,
num_draft_probs
,
draft_vocab_size
=
draft_probs
.
shape
draft_token_ids_batch_size
,
num_draft_token_ids
=
draft_token_ids
.
shape
assert
draft_batch_size
==
target_batch_size
assert
num_draft_probs
==
num_target_probs
assert
(
draft_vocab_size
==
target_vocab_size
),
f
"
{
draft_vocab_size
=
}
{
target_vocab_size
=
}
"
assert
draft_token_ids_batch_size
==
draft_batch_size
assert
num_draft_token_ids
==
num_draft_probs
assert
bonus_batch_size
==
target_batch_size
assert
num_bonus_tokens
==
self
.
_num_bonus_tokens
def
_raise_if_incorrect_dtype
(
self
,
target_probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
)
->
None
:
assert
all
(
probs
.
dtype
==
self
.
probs_dtype
for
probs
in
[
target_probs
,
draft_probs
])
assert
all
(
token_ids
.
dtype
==
self
.
token_id_dtype
for
token_ids
in
[
bonus_token_ids
,
draft_token_ids
])
def
_raise_if_inconsistent_device
(
self
,
target_probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
)
->
None
:
devices
=
[
t
.
device
for
t
in
[
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
]
]
assert
all
([
devices
[
0
]
==
device
for
device
in
devices
])
def
_raise_if_out_of_bounds_vocab
(
self
,
vocab_size
:
int
,
bonus_token_ids
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
)
->
None
:
assert
torch
.
all
(
bonus_token_ids
<
vocab_size
)
assert
torch
.
all
(
bonus_token_ids
>=
0
)
assert
torch
.
all
(
draft_token_ids
<
vocab_size
)
assert
torch
.
all
(
draft_token_ids
>=
0
)
# torch.multinomial forces a GPU<->CPU sync.
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead that skips the sync.
# Therefore, we use an optimized implementation instead that skips the sync.
...
...
vllm/model_executor/layers/rotary_embedding.py
View file @
705f6a35
...
@@ -28,6 +28,7 @@ import torch
...
@@ -28,6 +28,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.utils
import
is_tpu
def
_rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -43,6 +44,19 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
...
@@ -43,6 +44,19 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
return
x
.
flatten
(
-
2
)
return
x
.
flatten
(
-
2
)
def
_apply_rotary_emb
(
x
:
torch
.
Tensor
,
freqs_cis
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
x_
=
torch
.
view_as_complex
(
torch
.
stack
(
torch
.
chunk
(
x
.
transpose
(
1
,
2
).
float
(),
2
,
dim
=-
1
),
dim
=-
1
))
x_out
=
torch
.
view_as_real
(
x_
*
freqs_cis
).
type_as
(
x
)
x_out
=
torch
.
cat
(
torch
.
chunk
(
x_out
,
2
,
dim
=-
1
),
dim
=-
2
)
x_out
=
x_out
.
reshape
(
x_out
.
shape
[
0
],
x_out
.
shape
[
1
],
x_out
.
shape
[
2
],
-
1
).
transpose
(
1
,
2
)
return
x_out
class
RotaryEmbedding
(
CustomOp
):
class
RotaryEmbedding
(
CustomOp
):
"""Original rotary positional embedding."""
"""Original rotary positional embedding."""
...
@@ -64,8 +78,14 @@ class RotaryEmbedding(CustomOp):
...
@@ -64,8 +78,14 @@ class RotaryEmbedding(CustomOp):
self
.
dtype
=
dtype
self
.
dtype
=
dtype
cache
=
self
.
_compute_cos_sin_cache
()
cache
=
self
.
_compute_cos_sin_cache
()
cache
=
cache
.
to
(
dtype
)
self
.
use_native2
=
is_tpu
()
and
is_neox_style
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
if
not
self
.
use_native2
:
cache
=
cache
.
to
(
dtype
)
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
else
:
cos
,
sin
=
cache
.
chunk
(
2
,
dim
=-
1
)
freqs_cis
=
cos
+
1j
*
sin
self
.
register_buffer
(
"freqs_cis"
,
freqs_cis
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
"""Compute the inverse frequency."""
"""Compute the inverse frequency."""
...
@@ -100,7 +120,11 @@ class RotaryEmbedding(CustomOp):
...
@@ -100,7 +120,11 @@ class RotaryEmbedding(CustomOp):
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""PyTorch-native implementation equivalent to forward()."""
"""A PyTorch-native implementation equivalent to forward().
This method mimics the implementation of the custom CUDA kernel
used in `forward_cuda()`.
"""
query
=
query
.
view
(
*
query
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
query
=
query
.
view
(
*
query
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
key
=
key
.
view
(
*
key
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
key
=
key
.
view
(
*
key
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
...
@@ -138,6 +162,42 @@ class RotaryEmbedding(CustomOp):
...
@@ -138,6 +162,42 @@ class RotaryEmbedding(CustomOp):
key
=
key
.
flatten
(
-
2
)
key
=
key
.
flatten
(
-
2
)
return
query
,
key
return
query
,
key
def
forward_native2
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Another PyTorch-native implementation of forward().
This method might perform better than `forward_native()` when compiled.
"""
if
positions
.
dim
()
==
1
:
batch_size
=
1
seq_len
=
positions
.
shape
[
0
]
else
:
batch_size
,
seq_len
=
positions
.
shape
if
offsets
is
not
None
:
positions
=
positions
+
offsets
freqs_cis
=
self
.
freqs_cis
.
index_select
(
0
,
positions
.
flatten
())
freqs_cis
=
freqs_cis
.
view
(
batch_size
,
1
,
seq_len
,
-
1
)
query_shape
=
query
.
shape
query
=
query
.
view
(
batch_size
,
seq_len
,
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
_apply_rotary_emb
(
query_rot
,
freqs_cis
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
key_shape
=
key
.
shape
key
=
key
.
view
(
batch_size
,
seq_len
,
-
1
,
self
.
head_size
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
_apply_rotary_emb
(
key_rot
,
freqs_cis
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
def
forward_cuda
(
def
forward_cuda
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
@@ -161,6 +221,40 @@ class RotaryEmbedding(CustomOp):
...
@@ -161,6 +221,40 @@ class RotaryEmbedding(CustomOp):
self
.
cos_sin_cache
,
self
.
is_neox_style
)
self
.
cos_sin_cache
,
self
.
is_neox_style
)
return
query
,
key
return
query
,
key
def
forward_xpu
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
vllm._ipex_ops
import
ipex_ops
as
ops
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
positions
.
device
,
dtype
=
query
.
dtype
)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if
offsets
is
not
None
:
ops
.
batched_rotary_embedding
(
positions
,
query
,
key
,
self
.
head_size
,
self
.
cos_sin_cache
,
self
.
is_neox_style
,
self
.
rotary_dim
,
offsets
)
else
:
ops
.
rotary_embedding
(
positions
,
query
,
key
,
self
.
head_size
,
self
.
cos_sin_cache
,
self
.
is_neox_style
)
return
query
,
key
def
forward_tpu
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
forward_fn
=
(
self
.
forward_native2
if
self
.
use_native2
else
self
.
forward_native
)
return
forward_fn
(
positions
,
query
,
key
,
offsets
)
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
s
=
f
"head_size=
{
self
.
head_size
}
, rotary_dim=
{
self
.
rotary_dim
}
"
s
=
f
"head_size=
{
self
.
head_size
}
, rotary_dim=
{
self
.
rotary_dim
}
"
s
+=
f
", max_position_embeddings=
{
self
.
max_position_embeddings
}
"
s
+=
f
", max_position_embeddings=
{
self
.
max_position_embeddings
}
"
...
@@ -396,7 +490,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
...
@@ -396,7 +490,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
return
cache
return
cache
class
Phi3
Su
ScaledRotaryEmbedding
(
nn
.
Module
):
class
Phi3
LongRoPE
ScaledRotaryEmbedding
(
nn
.
Module
):
"""Phi3 family of models scaled rotary embedding.
"""Phi3 family of models scaled rotary embedding.
Based on the original RotaryEmbedding implementation.
Based on the original RotaryEmbedding implementation.
...
@@ -413,18 +507,19 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
...
@@ -413,18 +507,19 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
short_factor
:
List
[
float
],
short_factor
:
List
[
float
],
long_factor
:
List
[
float
],
long_factor
:
List
[
float
],
short_mscale
:
float
=
1.
1
,
short_mscale
:
float
=
1.
0
,
long_mscale
:
float
=
1.
225
,
long_mscale
:
float
=
1.
0
,
):
):
super
().
__init__
()
super
().
__init__
()
if
rotary_dim
!=
head_size
:
if
rotary_dim
!=
head_size
:
raise
ValueError
(
raise
ValueError
(
f
"`Phi3
Su
ScaledRotaryEmbedding` does not support
rotary_dim !=
\
f
"`Phi3
LongRoPE
ScaledRotaryEmbedding` does not support
\
head_size (
{
rotary_dim
}
!=
{
head_size
}
)."
)
rotary_dim !=
head_size (
{
rotary_dim
}
!=
{
head_size
}
)."
)
if
is_neox_style
is
False
:
if
is_neox_style
is
False
:
raise
ValueError
(
raise
ValueError
(
"`Phi3SuScaledRotaryEmbedding` only supports neox_style."
)
"`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
)
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
...
@@ -435,6 +530,16 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
...
@@ -435,6 +530,16 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
self
.
short_mscale
=
short_mscale
self
.
short_mscale
=
short_mscale
self
.
long_mscale
=
long_mscale
self
.
long_mscale
=
long_mscale
scale
=
(
self
.
max_position_embeddings
/
self
.
original_max_position_embeddings
)
if
scale
<=
1.0
:
self
.
scaling_factor
=
1.0
else
:
self
.
scaling_factor
=
math
.
sqrt
(
1
+
math
.
log
(
scale
)
/
math
.
log
(
self
.
original_max_position_embeddings
))
short_cache
=
self
.
_compute_cos_sin_cache
(
short_cache
=
self
.
_compute_cos_sin_cache
(
original_max_position_embeddings
,
short_factor
,
short_mscale
)
original_max_position_embeddings
,
short_factor
,
short_mscale
)
short_cache
=
short_cache
.
to
(
dtype
)
short_cache
=
short_cache
.
to
(
dtype
)
...
@@ -470,8 +575,8 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
...
@@ -470,8 +575,8 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
inv_freq
=
self
.
_compute_inv_freq
(
rescale_factors
)
inv_freq
=
self
.
_compute_inv_freq
(
rescale_factors
)
t
=
torch
.
arange
(
max_position_embeddings
,
dtype
=
torch
.
float
)
t
=
torch
.
arange
(
max_position_embeddings
,
dtype
=
torch
.
float
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
*
mscale
cos
=
freqs
.
cos
()
*
mscale
*
self
.
scaling_factor
sin
=
freqs
.
sin
()
*
mscale
sin
=
freqs
.
sin
()
*
mscale
*
self
.
scaling_factor
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
return
cache
...
@@ -660,7 +765,9 @@ def get_rope(
...
@@ -660,7 +765,9 @@ def get_rope(
is_neox_style
,
dtype
)
is_neox_style
,
dtype
)
else
:
else
:
scaling_type
=
rope_scaling
[
"type"
]
scaling_type
=
rope_scaling
[
"type"
]
if
scaling_type
!=
"su"
:
# The correct one should be "longrope" but keep "su" here
# for backward compatible
if
scaling_type
!=
"su"
and
scaling_type
!=
"longrope"
:
scaling_factor
=
rope_scaling
[
"factor"
]
scaling_factor
=
rope_scaling
[
"factor"
]
if
scaling_type
==
"linear"
:
if
scaling_type
==
"linear"
:
rotary_emb
=
LinearScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
rotary_emb
=
LinearScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
...
@@ -710,7 +817,7 @@ def get_rope(
...
@@ -710,7 +817,7 @@ def get_rope(
for
k
,
v
in
rope_scaling
.
items
()
for
k
,
v
in
rope_scaling
.
items
()
if
k
in
(
"short_mscale"
,
"long_mscale"
)
if
k
in
(
"short_mscale"
,
"long_mscale"
)
}
}
rotary_emb
=
Phi3
Su
ScaledRotaryEmbedding
(
rotary_emb
=
Phi3
LongRoPE
ScaledRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
original_max_position
,
head_size
,
rotary_dim
,
max_position
,
original_max_position
,
base
,
is_neox_style
,
dtype
,
short_factor
,
long_factor
,
base
,
is_neox_style
,
dtype
,
short_factor
,
long_factor
,
**
extra_kwargs
)
**
extra_kwargs
)
...
...
vllm/model_executor/layers/sampler.py
View file @
705f6a35
...
@@ -174,7 +174,7 @@ def _apply_min_tokens_penalty(
...
@@ -174,7 +174,7 @@ def _apply_min_tokens_penalty(
min_tokens
=
sampling_params
.
min_tokens
min_tokens
=
sampling_params
.
min_tokens
token_ids_to_penalize
=
sampling_params
.
all_stop_token_ids
token_ids_to_penalize
=
sampling_params
.
all_stop_token_ids
if
min_tokens
>
0
and
token_ids_to_penalize
:
if
min_tokens
>
0
and
token_ids_to_penalize
:
seqs_to_penalize
=
[]
seqs_to_penalize
:
List
[
int
]
=
[]
for
j
,
seq_id
in
enumerate
(
seq_ids
):
for
j
,
seq_id
in
enumerate
(
seq_ids
):
seq_data
=
seq_group
.
seq_data
[
seq_id
]
seq_data
=
seq_group
.
seq_data
[
seq_id
]
if
len
(
seq_data
.
output_token_ids
)
<
min_tokens
:
if
len
(
seq_data
.
output_token_ids
)
<
min_tokens
:
...
@@ -285,7 +285,7 @@ def _greedy_sample(
...
@@ -285,7 +285,7 @@ def _greedy_sample(
same as the length of selected_seq_groups. If the corresponding
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
seq_group has do_sample=False, tuple contains ([], [])
"""
"""
samples
=
samples
.
tolist
()
samples
_lst
=
samples
.
tolist
()
sample_idx
=
0
sample_idx
=
0
results
:
SampleResultType
=
[]
results
:
SampleResultType
=
[]
for
seq_group
in
selected_seq_groups
:
for
seq_group
in
selected_seq_groups
:
...
@@ -298,7 +298,7 @@ def _greedy_sample(
...
@@ -298,7 +298,7 @@ def _greedy_sample(
assert
num_parent_seqs
==
1
,
(
assert
num_parent_seqs
==
1
,
(
"Greedy sampling should have only one seq."
)
"Greedy sampling should have only one seq."
)
parent_ids
=
list
(
range
(
num_parent_seqs
))
parent_ids
=
list
(
range
(
num_parent_seqs
))
next_token_ids
=
[
samples
[
sample_idx
]]
next_token_ids
=
[
samples
_lst
[
sample_idx
]]
results
.
append
((
next_token_ids
,
parent_ids
))
results
.
append
((
next_token_ids
,
parent_ids
))
sample_idx
+=
num_parent_seqs
sample_idx
+=
num_parent_seqs
return
results
return
results
...
@@ -394,7 +394,7 @@ def _beam_search_sample(
...
@@ -394,7 +394,7 @@ def _beam_search_sample(
next_token_ids
=
next_token_ids
.
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
else
:
else
:
# Generation phase.
# Generation phase.
cumulative_logprobs
:
List
[
in
t
]
=
[
cumulative_logprobs
:
List
[
floa
t
]
=
[
seq_group
.
seq_data
[
seq_id
].
cumulative_logprob
seq_group
.
seq_data
[
seq_id
].
cumulative_logprob
for
seq_id
in
seq_ids
for
seq_id
in
seq_ids
]
]
...
@@ -466,8 +466,9 @@ def _sample_with_torch(
...
@@ -466,8 +466,9 @@ def _sample_with_torch(
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
sample_results_dict
:
Dict
[
int
,
Tuple
[
List
[
int
],
List
[
int
]]]
=
{}
sample_results_dict
:
Dict
[
int
,
Tuple
[
List
[
int
],
List
[
int
]]]
=
{}
sample_metadata
=
{}
sample_metadata
:
Dict
[
SamplingType
,
multinomial_samples
=
{}
Tuple
[
List
[
int
],
List
[
SequenceGroupToSample
]]]
=
{}
multinomial_samples
:
Dict
[
SamplingType
,
torch
.
Tensor
]
=
{}
# Create output tensor for sampled token ids.
# Create output tensor for sampled token ids.
if
include_gpu_probs_tensor
:
if
include_gpu_probs_tensor
:
...
@@ -494,7 +495,7 @@ def _sample_with_torch(
...
@@ -494,7 +495,7 @@ def _sample_with_torch(
greedy_samples
=
torch
.
argmax
(
logprobs
[
long_sample_indices
],
greedy_samples
=
torch
.
argmax
(
logprobs
[
long_sample_indices
],
dim
=-
1
)
dim
=-
1
)
if
include_gpu_probs_tensor
:
if
sampled_token_ids_tensor
is
not
None
:
# Store sampled tokens in output tensor.
# Store sampled tokens in output tensor.
sampled_token_ids_tensor
[
sampled_token_ids_tensor
[
long_sample_indices
]
=
greedy_samples
.
unsqueeze
(
-
1
)
long_sample_indices
]
=
greedy_samples
.
unsqueeze
(
-
1
)
...
@@ -522,7 +523,7 @@ def _sample_with_torch(
...
@@ -522,7 +523,7 @@ def _sample_with_torch(
probs
[
long_sample_indices
],
max_best_of_in_batch
,
probs
[
long_sample_indices
],
max_best_of_in_batch
,
**
seeded_args
)
**
seeded_args
)
if
include_gpu_probs_tensor
:
if
sampled_token_ids_tensor
is
not
None
:
# Store sampled tokens in output tensor.
# Store sampled tokens in output tensor.
sampled_token_ids_tensor
[
sampled_token_ids_tensor
[
long_sample_indices
]
=
multinomial_samples
[
sampling_type
]
long_sample_indices
]
=
multinomial_samples
[
sampling_type
]
...
@@ -571,7 +572,9 @@ def _sample_with_triton_kernel(
...
@@ -571,7 +572,9 @@ def _sample_with_triton_kernel(
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
sample_results_dict
:
Dict
[
int
,
Tuple
[
List
[
int
],
List
[
int
]]]
=
{}
sample_results_dict
:
Dict
[
int
,
Tuple
[
List
[
int
],
List
[
int
]]]
=
{}
sample_metadata
=
{}
sample_metadata
:
Dict
[
SamplingType
,
Tuple
[
List
[
int
],
List
[
SequenceGroupToSample
],
torch
.
Tensor
,
torch
.
Tensor
]]
=
{}
max_best_of_in_batch
=
1
max_best_of_in_batch
=
1
# Counterintiutively, having two loops here is actually faster.
# Counterintiutively, having two loops here is actually faster.
...
@@ -676,7 +679,7 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
...
@@ -676,7 +679,7 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
Returns:
Returns:
torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
Each element in the returned tensor represents the rank
Each element in the returned tensor represents the rank
of the chosen token in the input logprob tensor.
of the chosen token in the input logprob tensor.
"""
"""
vals
=
x
[
torch
.
arange
(
0
,
len
(
x
),
device
=
x
.
device
,
dtype
=
indices
.
dtype
),
vals
=
x
[
torch
.
arange
(
0
,
len
(
x
),
device
=
x
.
device
,
dtype
=
indices
.
dtype
),
...
@@ -962,7 +965,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
...
@@ -962,7 +965,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
distribution.
distribution.
- Greedy sampling performs `argmax` to obtain the token with the
- Greedy sampling performs `argmax` to obtain the token with the
highest likelihood.
highest likelihood.
Ignoring greedy sampling for a moment, we find that the computed probability
Ignoring greedy sampling for a moment, we find that the computed probability
distribution has the following property: we can sample from it independently
distribution has the following property: we can sample from it independently
and find that the token sampled by the Sampler has a frequency corresponding
and find that the token sampled by the Sampler has a frequency corresponding
...
@@ -1008,14 +1011,14 @@ def _build_sampler_output(
...
@@ -1008,14 +1011,14 @@ def _build_sampler_output(
speculative decoding rejection sampling.
speculative decoding rejection sampling.
"""
"""
sampler_output
=
[]
sampler_output
:
List
[
CompletionSequenceGroupOutput
]
=
[]
for
(
seq_group
,
sample_result
,
group_prompt_logprobs
,
for
(
seq_group
,
sample_result
,
group_prompt_logprobs
,
group_sample_logprobs
)
in
zip
(
sampling_metadata
.
seq_groups
,
group_sample_logprobs
)
in
zip
(
sampling_metadata
.
seq_groups
,
sample_results
,
prompt_logprobs
,
sample_results
,
prompt_logprobs
,
sample_logprobs
):
sample_logprobs
):
seq_ids
=
seq_group
.
seq_ids
seq_ids
=
seq_group
.
seq_ids
next_token_ids
,
parent_ids
=
sample_result
next_token_ids
,
parent_ids
=
sample_result
seq_outputs
=
[]
seq_outputs
:
List
[
SequenceOutput
]
=
[]
for
parent_id
,
next_token_id
,
logprobs
in
zip
(
parent_ids
,
for
parent_id
,
next_token_id
,
logprobs
in
zip
(
parent_ids
,
next_token_ids
,
next_token_ids
,
group_sample_logprobs
):
group_sample_logprobs
):
...
...
vllm/model_executor/layers/spec_decode_base_sampler.py
0 → 100644
View file @
705f6a35
from
abc
import
abstractmethod
from
typing
import
Optional
import
torch
import
torch.jit
import
torch.nn
as
nn
class
SpecDecodeBaseSampler
(
nn
.
Module
):
"""Base class for samplers used for Speculative Decoding verification
step.
"""
def
__init__
(
self
,
disable_bonus_tokens
:
bool
=
True
,
strict_mode
:
bool
=
False
):
"""Base class constructor.
Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
"""
super
().
__init__
()
self
.
_disable_bonus_tokens
=
disable_bonus_tokens
self
.
_strict_mode
=
strict_mode
# NOTE: A "bonus token" is accepted iff all proposal tokens are
# accepted. There is always only one possible bonus token. We store this
# value in a variable for readability.
self
.
_num_bonus_tokens
=
1
self
.
num_accepted_tokens
:
Optional
[
torch
.
Tensor
]
=
None
self
.
num_emitted_tokens
:
Optional
[
torch
.
Tensor
]
=
None
self
.
num_draft_tokens
:
int
=
0
def
init_gpu_tensors
(
self
,
rank
:
int
)
->
None
:
assert
self
.
num_accepted_tokens
is
None
device
=
f
"cuda:
{
rank
}
"
self
.
num_accepted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
device
)
self
.
num_emitted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
device
)
@
property
def
probs_dtype
(
self
):
return
torch
.
float32
@
property
def
token_id_dtype
(
self
):
return
torch
.
int64
@
abstractmethod
def
forward
(
self
,
target_probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
def
_create_output
(
self
,
accepted
:
torch
.
Tensor
,
# [batch_size, k]
substitute_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
bonus_token_ids
:
torch
.
Tensor
,
# [batch_size]
)
->
torch
.
Tensor
:
"""Format output. Returns a matrix of token ids. When
a token is rejected via sampling, all subsequent token ids are
set to -1 for the sequence.
Args:
accepted: A boolean tensor indicating if the corresponding
draft token in draft_token_ids should be accepted or not.
substitute_token_ids: A tensor of token_ids that can be used
as substitutes for the draft token ids if the proposed token
is rejected.
draft_token_ids: A tensor of token ids speculated by the
draft model.
bonus_token_ids: Token ids to use as the bonus token if
all the draft tokens are accepted.
Returns:
A tensor containing the accepted token ids. The shape of the
tensor is [batch_size, k + num_bonus_tokens]
"""
batch_size
,
k
=
substitute_token_ids
.
shape
bonus_token_ids
=
bonus_token_ids
.
squeeze
()
# Determine the index of the first False value for each row.
limits
=
(
accepted
==
0
).
max
(
1
).
indices
limits
[
~
(
accepted
==
0
).
any
(
1
)]
=
k
# Create masks using the indices.
indices
=
torch
.
arange
(
k
,
device
=
accepted
.
device
).
unsqueeze
(
0
)
accepted_mask
=
indices
<
limits
.
unsqueeze
(
1
)
after_false_mask
=
indices
==
limits
.
unsqueeze
(
1
)
# Create an extended output tensor
output_with_bonus_tokens
=
-
torch
.
ones
(
(
batch_size
,
k
+
self
.
_num_bonus_tokens
),
dtype
=
self
.
token_id_dtype
,
device
=
accepted
.
device
)
output
=
output_with_bonus_tokens
[:,
:
k
]
# Fill in the first k columns of the output tensor using masks and data
# tensors.
output
[:,
:
k
]
=
torch
.
where
(
accepted_mask
,
draft_token_ids
,
-
torch
.
ones_like
(
draft_token_ids
))
# Fill the last column.
# We check output directly as accepted may have True values inconsistent
# with causal acceptance.
output_with_bonus_tokens
[:,
-
1
]
=
torch
.
where
(
output
[:,
-
1
]
!=
-
1
,
bonus_token_ids
,
-
1
)
# We disable bonus tokens because it causes corrupt KV cache for
# proposal methods that require KV cache. We can fix it by "prefilling"
# the bonus token in the proposer. The following issue tracks the fix.
# https://github.com/vllm-project/vllm/issues/4212
if
self
.
_disable_bonus_tokens
:
output_with_bonus_tokens
[:,
-
1
]
=
-
1
# Fill the recovered token ids.
output
.
mul_
(
~
after_false_mask
).
add_
(
substitute_token_ids
.
mul
(
after_false_mask
))
self
.
num_accepted_tokens
+=
accepted
.
sum
()
self
.
num_emitted_tokens
+=
(
output_with_bonus_tokens
!=
-
1
).
sum
()
self
.
num_draft_tokens
+=
batch_size
*
k
return
output_with_bonus_tokens
def
_raise_if_incorrect_input
(
self
,
target_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
self
.
_raise_if_incorrect_shape
(
target_probs
,
draft_token_ids
,
bonus_token_ids
,
draft_probs
)
self
.
_raise_if_incorrect_dtype
(
target_probs
,
draft_token_ids
,
bonus_token_ids
,
draft_probs
)
self
.
_raise_if_inconsistent_device
(
target_probs
,
draft_token_ids
,
bonus_token_ids
,
draft_probs
)
self
.
_raise_if_out_of_bounds_vocab
(
target_probs
.
shape
[
-
1
],
draft_token_ids
,
bonus_token_ids
)
def
_raise_if_incorrect_shape
(
self
,
target_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
(
target_batch_size
,
num_target_probs
,
target_vocab_size
)
=
target_probs
.
shape
# validate the shape of draft token ids.
draft_token_ids_batch_size
,
num_draft_token_ids
=
draft_token_ids
.
shape
assert
draft_token_ids_batch_size
==
target_batch_size
assert
num_draft_token_ids
==
num_target_probs
# validate the shape of bonus token ids
bonus_batch_size
,
num_bonus_tokens
=
bonus_token_ids
.
shape
assert
bonus_batch_size
==
target_batch_size
assert
num_bonus_tokens
==
self
.
_num_bonus_tokens
# validate the shape of draft probs if it is set
if
draft_probs
is
not
None
:
(
draft_batch_size
,
num_draft_probs
,
draft_vocab_size
)
=
draft_probs
.
shape
assert
draft_batch_size
==
target_batch_size
assert
num_draft_probs
==
num_target_probs
assert
(
draft_vocab_size
==
target_vocab_size
),
f
"
{
draft_vocab_size
=
}
{
target_vocab_size
=
}
"
def
_raise_if_incorrect_dtype
(
self
,
target_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
assert
target_probs
.
dtype
==
self
.
probs_dtype
assert
draft_token_ids
.
dtype
==
self
.
token_id_dtype
assert
bonus_token_ids
.
dtype
==
self
.
token_id_dtype
if
draft_probs
is
not
None
:
assert
draft_probs
.
dtype
==
self
.
probs_dtype
def
_raise_if_inconsistent_device
(
self
,
target_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
devices
=
[
t
.
device
for
t
in
[
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
]
if
t
is
not
None
]
assert
all
([
devices
[
0
]
==
device
for
device
in
devices
])
def
_raise_if_out_of_bounds_vocab
(
self
,
vocab_size
:
int
,
draft_token_ids
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
)
->
None
:
assert
torch
.
all
(
bonus_token_ids
<
vocab_size
)
assert
torch
.
all
(
bonus_token_ids
>=
0
)
assert
torch
.
all
(
draft_token_ids
<
vocab_size
)
assert
torch
.
all
(
draft_token_ids
>=
0
)
vllm/model_executor/layers/typical_acceptance_sampler.py
0 → 100644
View file @
705f6a35
import
torch
import
torch.jit
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
SpecDecodeBaseSampler
)
class
TypicalAcceptanceSampler
(
SpecDecodeBaseSampler
):
"""Apply typical acceptance sampling as described in section 3.3.1 in
"MEDUSA: Simple LLM Inference Acceleration Framework with
Multiple Decoding Heads"
https://arxiv.org/pdf/2401.10774
"""
def
__init__
(
self
,
posterior_threshold
:
float
,
posterior_alpha
:
float
,
disable_bonus_tokens
:
bool
=
False
,
strict_mode
:
bool
=
False
,
):
"""Create a Typical Acceptance Sampler.
Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
posterior_threshold : A threshold value that sets a lower bound
on the posterior probability of a token in target model for it
to be accepted.
posterior_alpha : A scaling factor for the entropy-based
threshold in typical acceptance sampling.
"""
self
.
_posterior_threshold
=
posterior_threshold
self
.
_posterior_alpha
=
posterior_alpha
super
().
__init__
(
disable_bonus_tokens
=
disable_bonus_tokens
,
strict_mode
=
strict_mode
)
def
forward
(
self
,
target_probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Sample token ids using typical acceptance sampling. This accepts
or rejects tokens proposed by the draft model using the probability
of each token according to the draft and target models.
In the worst case where all draft tokens are rejected, it is guaranteed
one token will be emitted.
In the case where all draft tokens are accepted, the bonus token will be
accepted conditioned on self._disable_bonus_tokens being false.
Args:
target_probs: The probability distribution over token ids given
context according to the target model.
shape = [batch_size, num_speculative_tokens, vocab_size]
bonus_token_ids: The "bonus" token ids that are accepted iff all
speculative tokens in a sequence are accepted.
shape = [batch_size, num_bonus_tokens]
draft_probs: This parameter is unused by the acceptance sampler.
draft_token_ids: The token ids that were sampled from the draft
probabilities.
shape = [batch_size, num_speculative_tokens]
Returns:
output_token_ids: The token ids sampled via rejection sampling,
or -1 if unable to sample a token because the previous token
was rejected.
shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
"""
# Only perform shape/dtype/device checking in strict mode, as it adds
# overhead.
if
self
.
_strict_mode
:
self
.
_raise_if_incorrect_input
(
target_probs
,
draft_token_ids
,
bonus_token_ids
)
accepted
=
self
.
_evaluate_accepted_tokens
(
target_probs
,
draft_token_ids
)
recovered_token_ids
=
self
.
_replacement_token_ids
(
target_probs
)
output_token_ids
=
self
.
_create_output
(
accepted
,
recovered_token_ids
,
draft_token_ids
,
bonus_token_ids
)
return
output_token_ids
def
_evaluate_accepted_tokens
(
self
,
target_probs
,
draft_token_ids
):
r
"""
Evaluates and returns a mask of accepted tokens based on the
posterior probabilities.
Parameters:
----------
target_probs : torch.Tensor
A tensor of shape (batch_size, k, vocab_size) representing
the probabilities of each token in the vocabulary for each
position in the proposed sequence. This is the distribution
generated by the target model.
draft_token_ids : torch.Tensor
A tensor of shape (batch_size, k) representing the proposed
token ids.
A draft token_id x_{n+k} is accepted if it satisfies the
following condition
.. math::
p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) >
\min \left( \epsilon, \delta * \exp \left(
-H(p_{\text{original}}(
\cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right)
where :math:`p_{\text{original}}` corresponds to target_probs
and :math:`\epsilon` and :math:`\delta` correspond to hyperparameters
specified using self._posterior_threshold and self._posterior_alpha
This method computes the posterior probabilities for the given
draft token ids based on the provided target probabilities. It
calculates the entropy of the posterior distribution and determines
a dynamic threshold for each token position using the provided
posterior_threshold and posterior_alpha values. The method then
returns a boolean mask indicating which tokens can be accepted.
Returns:
-------
torch.Tensor
A boolean tensor of shape (batch_size, k) where each element
indicates whether the corresponding draft token has been accepted
or rejected. True indicates acceptance and false indicates
rejection.
"""
device
=
target_probs
.
device
candidates_prob
=
torch
.
gather
(
target_probs
,
dim
=-
1
,
index
=
draft_token_ids
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
# A small constant added to prevent computing the logarithm of zero,
# which can lead to undefined values.
epsilon
=
1e-5
posterior_entropy
=
-
torch
.
sum
(
target_probs
*
torch
.
log
(
target_probs
+
epsilon
),
dim
=-
1
)
threshold
=
torch
.
minimum
(
torch
.
ones_like
(
posterior_entropy
,
device
=
device
)
*
self
.
_posterior_threshold
,
torch
.
exp
(
-
posterior_entropy
)
*
self
.
_posterior_alpha
,
)
accepted_mask
=
candidates_prob
>
threshold
return
accepted_mask
def
_replacement_token_ids
(
self
,
target_probs
):
"""
Generate one replacement token ID for each sequence based on target
probabilities. The replacement token is used as the fallback option
if typical acceptance sampling does not accept any draft tokens for
that particular sequence.
This method computes the token IDs to be replaced by selecting the
token with the highest probability for each sequence in the first
position. The rest of the output is filled with -1.
Parameters
----------
target_probs : torch.Tensor
A tensor of shape (batch_size, k, vocab_size) containing
the target probability distribution
Returns
-------
torch.Tensor
A tensor of shape (batch_size, k) with the replacement
token IDs. Only the first column is set, and the rest of the
columns are filled with -1.
"""
max_indices
=
torch
.
argmax
(
target_probs
[:,
0
,
:],
dim
=
1
)
output
=
-
torch
.
ones
((
target_probs
.
shape
[
0
],
target_probs
.
shape
[
1
]),
dtype
=
self
.
token_id_dtype
,
device
=
target_probs
.
device
)
output
[:,
0
]
=
max_indices
return
output
Prev
1
…
18
19
20
21
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment