Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4851c202
Commit
4851c202
authored
Sep 13, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.1' into v0.6.1-dev
parents
9b902f9e
3fd2b0d2
Changes
203
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1194 additions
and
322 deletions
+1194
-322
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+2
-1
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+2
-2
vllm/model_executor/layers/quantization/awq_triton.py
vllm/model_executor/layers/quantization/awq_triton.py
+23
-11
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+15
-10
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+28
-20
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
...on/compressed_tensors/schemes/compressed_tensors_wNa16.py
+34
-13
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
..._executor/layers/quantization/compressed_tensors/utils.py
+28
-2
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+301
-19
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+163
-0
vllm/model_executor/layers/quantization/squeezellm.py
vllm/model_executor/layers/quantization/squeezellm.py
+0
-138
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+17
-0
vllm/model_executor/layers/quantization/utils/marlin_utils_test.py
...l_executor/layers/quantization/utils/marlin_utils_test.py
+7
-4
vllm/model_executor/layers/quantization/utils/quant_utils.py
vllm/model_executor/layers/quantization/utils/quant_utils.py
+13
-6
vllm/model_executor/layers/resampler.py
vllm/model_executor/layers/resampler.py
+273
-0
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+203
-77
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+41
-6
vllm/model_executor/model_loader/tensorizer.py
vllm/model_executor/model_loader/tensorizer.py
+7
-0
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+8
-0
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+18
-10
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+11
-3
No files found.
vllm/model_executor/layers/linear.py
View file @
4851c202
...
...
@@ -29,7 +29,8 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod"
,
"AWQMarlinLinearMethod"
,
"AWQLinearMethod"
,
"GPTQMarlinLinearMethod"
,
"Fp8LinearMethod"
,
"MarlinLinearMethod"
,
"QQQLinearMethod"
,
"GPTQMarlin24LinearMethod"
,
"TPUInt8LinearMethod"
,
"GPTQLinearMethod"
,
"FBGEMMFp8LinearMethod"
"TPUInt8LinearMethod"
,
"GPTQLinearMethod"
,
"FBGEMMFp8LinearMethod"
,
"ModelOptFp8LinearMethod"
]
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
4851c202
...
...
@@ -22,10 +22,10 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
GPTQMarlin24Config
)
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
vllm.model_executor.layers.quantization.modelopt
import
ModelOptFp8Config
from
vllm.model_executor.layers.quantization.neuron_quant
import
(
NeuronQuantConfig
)
from
vllm.model_executor.layers.quantization.qqq
import
QQQConfig
from
vllm.model_executor.layers.quantization.squeezellm
import
SqueezeLLMConfig
from
vllm.model_executor.layers.quantization.tpu_int8
import
Int8TpuConfig
QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
...
...
@@ -35,6 +35,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"tpu_int8"
:
Int8TpuConfig
,
"fp8"
:
Fp8Config
,
"fbgemm_fp8"
:
FBGEMMFp8Config
,
"modelopt"
:
ModelOptFp8Config
,
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin"
:
MarlinConfig
,
...
...
@@ -43,7 +44,6 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"gptq_marlin"
:
GPTQMarlinConfig
,
"awq_marlin"
:
AWQMarlinConfig
,
"gptq"
:
GPTQConfig
,
"squeezellm"
:
SqueezeLLMConfig
,
"compressed-tensors"
:
CompressedTensorsConfig
,
"bitsandbytes"
:
BitsAndBytesConfig
,
"qqq"
:
QQQConfig
,
...
...
vllm/model_executor/layers/quantization/awq_triton.py
View file @
4851c202
...
...
@@ -22,7 +22,7 @@ def awq_dequantize_kernel(
# Compute offsets and masks for qweight_ptr.
offsets_y
=
pid_y
*
BLOCK_SIZE_Y
+
tl
.
arange
(
0
,
BLOCK_SIZE_Y
)
offsets_x
=
pid_x
*
BLOCK_SIZE_X
+
tl
.
arange
(
0
,
BLOCK_SIZE_X
*
8
)
//
8
offsets_x
=
pid_x
*
BLOCK_SIZE_X
+
tl
.
arange
(
0
,
BLOCK_SIZE_X
)
offsets
=
num_cols
*
offsets_y
[:,
None
]
+
offsets_x
[
None
,
:]
masks_y
=
offsets_y
<
num_rows
...
...
@@ -43,6 +43,9 @@ def awq_dequantize_kernel(
# Load the weights.
iweights
=
tl
.
load
(
qweight_ptr
+
offsets
,
masks
)
iweights
=
tl
.
interleave
(
iweights
,
iweights
)
iweights
=
tl
.
interleave
(
iweights
,
iweights
)
iweights
=
tl
.
interleave
(
iweights
,
iweights
)
# Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
# that will map given indices to the correct order.
...
...
@@ -59,9 +62,8 @@ def awq_dequantize_kernel(
iweights
=
(
iweights
>>
shifts
)
&
0xF
# Compute zero offsets and masks.
zero_offsets_y
=
(
pid_y
*
BLOCK_SIZE_Y
//
group_size
+
tl
.
arange
(
0
,
BLOCK_SIZE_Y
)
//
group_size
)
zero_offsets_x
=
pid_x
*
BLOCK_SIZE_X
+
tl
.
arange
(
0
,
BLOCK_SIZE_X
*
8
)
//
8
zero_offsets_y
=
pid_y
*
BLOCK_SIZE_Y
//
group_size
+
tl
.
arange
(
0
,
1
)
zero_offsets_x
=
pid_x
*
BLOCK_SIZE_X
+
tl
.
arange
(
0
,
BLOCK_SIZE_X
)
zero_offsets
=
num_cols
*
zero_offsets_y
[:,
None
]
+
zero_offsets_x
[
None
,
:]
zero_masks_y
=
zero_offsets_y
<
num_rows
//
group_size
...
...
@@ -70,13 +72,16 @@ def awq_dequantize_kernel(
# Load the zeros.
zeros
=
tl
.
load
(
zeros_ptr
+
zero_offsets
,
zero_masks
)
zeros
=
tl
.
interleave
(
zeros
,
zeros
)
zeros
=
tl
.
interleave
(
zeros
,
zeros
)
zeros
=
tl
.
interleave
(
zeros
,
zeros
)
zeros
=
tl
.
broadcast_to
(
zeros
,
(
BLOCK_SIZE_Y
,
BLOCK_SIZE_X
*
8
))
# Unpack and reorder: shift out the correct 4-bit value and mask.
zeros
=
(
zeros
>>
shifts
)
&
0xF
# Compute scale offsets and masks.
scale_offsets_y
=
(
pid_y
*
BLOCK_SIZE_Y
//
group_size
+
tl
.
arange
(
0
,
BLOCK_SIZE_Y
)
//
group_size
)
scale_offsets_y
=
pid_y
*
BLOCK_SIZE_Y
//
group_size
+
tl
.
arange
(
0
,
1
)
scale_offsets_x
=
(
pid_x
*
BLOCK_SIZE_X
*
8
+
tl
.
arange
(
0
,
BLOCK_SIZE_X
*
8
))
scale_offsets
=
(
num_cols
*
8
*
scale_offsets_y
[:,
None
]
+
...
...
@@ -87,6 +92,7 @@ def awq_dequantize_kernel(
# Load the scales.
scales
=
tl
.
load
(
scales_ptr
+
scale_offsets
,
scale_masks
)
scales
=
tl
.
broadcast_to
(
scales
,
(
BLOCK_SIZE_Y
,
BLOCK_SIZE_X
*
8
))
# Dequantize.
iweights
=
(
iweights
-
zeros
)
*
scales
...
...
@@ -137,12 +143,10 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
offsets_am
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
masks_am
=
offsets_am
<
M
offsets_bn
=
(
pid_n
*
(
BLOCK_SIZE_N
//
8
)
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
//
8
)
offsets_bn
=
pid_n
*
(
BLOCK_SIZE_N
//
8
)
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
//
8
)
masks_bn
=
offsets_bn
<
N
//
8
offsets_zn
=
(
pid_n
*
(
BLOCK_SIZE_N
//
8
)
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
//
8
)
offsets_zn
=
pid_n
*
(
BLOCK_SIZE_N
//
8
)
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
//
8
)
masks_zn
=
offsets_zn
<
N
//
8
offsets_sn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
...
...
@@ -165,22 +169,30 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
masks_b
=
masks_k
[:,
None
]
&
masks_bn
[
None
,
:]
b
=
tl
.
load
(
b_ptrs
,
mask
=
masks_b
)
b
=
tl
.
interleave
(
b
,
b
)
b
=
tl
.
interleave
(
b
,
b
)
b
=
tl
.
interleave
(
b
,
b
)
# Dequantize b.
offsets_szk
=
(
(
BLOCK_SIZE_K
*
SPLIT_K
*
k
+
pid_z
*
BLOCK_SIZE_K
)
//
group_size
+
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
//
group_size
)
tl
.
arange
(
0
,
1
)
)
offsets_z
=
(
N
//
8
)
*
offsets_szk
[:,
None
]
+
offsets_zn
[
None
,
:]
masks_zk
=
offsets_szk
<
K
//
group_size
masks_z
=
masks_zk
[:,
None
]
&
masks_zn
[
None
,
:]
zeros_ptrs
=
zeros_ptr
+
offsets_z
zeros
=
tl
.
load
(
zeros_ptrs
,
mask
=
masks_z
)
zeros
=
tl
.
interleave
(
zeros
,
zeros
)
zeros
=
tl
.
interleave
(
zeros
,
zeros
)
zeros
=
tl
.
interleave
(
zeros
,
zeros
)
zeros
=
tl
.
broadcast_to
(
zeros
,
(
BLOCK_SIZE_K
,
BLOCK_SIZE_N
))
offsets_s
=
N
*
offsets_szk
[:,
None
]
+
offsets_sn
[
None
,
:]
masks_sk
=
offsets_szk
<
K
//
group_size
masks_s
=
masks_sk
[:,
None
]
&
masks_sn
[
None
,
:]
scales_ptrs
=
scales_ptr
+
offsets_s
scales
=
tl
.
load
(
scales_ptrs
,
mask
=
masks_s
)
scales
=
tl
.
broadcast_to
(
scales
,
(
BLOCK_SIZE_K
,
BLOCK_SIZE_N
))
b
=
(
b
>>
shifts
)
&
0xF
zeros
=
(
zeros
>>
shifts
)
&
0xF
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
4851c202
...
...
@@ -116,15 +116,19 @@ class CompressedTensorsConfig(QuantizationConfig):
def
_check_scheme_supported
(
self
,
min_capability
:
int
,
error
:
bool
=
True
)
->
bool
:
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
supported
=
capability
>=
min_capability
if
error
and
not
supported
:
raise
RuntimeError
(
"Quantization scheme is not supported for "
,
f
"the current GPU. Min capability:
{
min_capability
}
. "
,
f
"Current capability:
{
capability
}
."
)
return
supported
capability
=
current_platform
.
get_device_capability
()
# type: ignore
if
capability
is
not
None
:
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
supported
=
capability
>=
min_capability
if
error
and
not
supported
:
raise
RuntimeError
(
"Quantization scheme is not supported for "
,
f
"the current GPU. Min capability:
{
min_capability
}
. "
,
f
"Current capability:
{
capability
}
."
)
return
supported
else
:
return
False
def
_is_static_tensor_w8a8
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
...
...
@@ -232,7 +236,8 @@ class CompressedTensorsConfig(QuantizationConfig):
return
CompressedTensorsWNA16
(
num_bits
=
weight_quant
.
num_bits
,
strategy
=
weight_quant
.
strategy
,
group_size
=
weight_quant
.
group_size
)
group_size
=
weight_quant
.
group_size
,
actorder
=
weight_quant
.
actorder
)
# Detect If Activation Quantization.
# TODO @dsikka: clean-up conditions
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
4851c202
...
...
@@ -5,9 +5,7 @@ from typing import Callable, List, Optional
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe
import
FusedMoEMethodBase
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
WNA16_SUPPORTED_BITS
)
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
,
FusedMoEMethodBase
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
CompressionFormat
)
from
vllm.model_executor.utils
import
set_weight_attrs
...
...
@@ -40,11 +38,10 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
if
not
(
self
.
quant_config
.
quant_format
==
CompressionFormat
.
pack_quantized
.
value
and
self
.
num_bits
in
WNA16_SUPPORTED_BITS
):
and
self
.
num_bits
==
4
):
raise
ValueError
(
"For Fused MoE layers, only "
,
f
"
{
CompressionFormat
.
pack_quantized
.
value
}
"
,
"is supported for the following bits: "
,
f
"
{
WNA16_SUPPORTED_BITS
}
"
)
"is supported for 4 bits"
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
...
...
@@ -269,19 +266,30 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.fused_
marlin_
moe
import
(
fused_marlin_moe
)
return
fused_marlin_moe
(
x
,
layer
.
w13_weight_packed
,
layer
.
w2_weight_packed
,
router_logits
,
layer
.
w13_g_idx
,
layer
.
w2_g_idx
,
layer
.
w13_g_idx_sort_indices
,
layer
.
w2_g_idx_sort_indices
,
top_k
,
custom_routing_function
,
renormalize
=
renormalize
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
)
return
fused_marlin_moe
(
x
,
layer
.
w13_weight_packed
,
layer
.
w2_weight_packed
,
router_logits
,
layer
.
w13_g_idx
,
layer
.
w2_g_idx
,
layer
.
w13_g_idx_sort_indices
,
layer
.
w2_g_idx_sort_indices
,
topk_weights
,
topk_ids
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
View file @
4851c202
...
...
@@ -5,20 +5,24 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
ActivationOrdering
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
apply_gptq_marlin_linear
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
replace_tensor
,
verify_marlin_supported
,
marlin_permute_scales
,
marlin_repeat_scales_on_all_ranks
,
marlin_sort_g_idx
,
replace_tensor
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
PackedvLLMParameter
)
PackedvLLMParameter
,
RowvLLMParameter
)
from
vllm.scalar_type
import
scalar_types
__all__
=
[
"CompressedTensorsWNA16"
]
WNA16_SUPPORTED_TYPES_MAP
=
{
4
:
scalar_types
.
uint4b8
,
8
:
scalar_types
.
uint8b128
,
8
:
scalar_types
.
uint8b128
}
WNA16_SUPPORTED_BITS
=
list
(
WNA16_SUPPORTED_TYPES_MAP
.
keys
())
...
...
@@ -28,11 +32,13 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
def
__init__
(
self
,
strategy
:
str
,
num_bits
:
int
,
group_size
:
Optional
[
int
]
=
None
):
group_size
:
Optional
[
int
]
=
None
,
actorder
:
Optional
[
ActivationOrdering
]
=
None
):
self
.
pack_factor
=
32
//
num_bits
self
.
strategy
=
strategy
self
.
group_size
=
-
1
if
group_size
is
None
else
group_size
self
.
has_g_idx
=
actorder
==
ActivationOrdering
.
GROUP
if
self
.
group_size
==
-
1
and
self
.
strategy
!=
"channel"
:
raise
ValueError
(
"Marlin kernels require group quantization or "
...
...
@@ -64,12 +70,10 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
output_size_per_partition
=
sum
(
output_partition_sizes
)
# If group_size is -1, we are in channelwise case.
channelwise
=
(
self
.
group_size
==
-
1
)
group_size
=
self
.
group_size
if
self
.
group_size
!=
-
1
else
input_size
row_parallel
=
(
input_size
!=
input_size_per_partition
)
# In the case of channelwise quantization, we need to replicate the
# scales across all gpus.
partition_scales
=
(
row_parallel
and
not
channelwise
)
partition_scales
=
not
marlin_repeat_scales_on_all_ranks
(
self
.
has_g_idx
,
self
.
group_size
,
row_parallel
)
verify_marlin_supports_shape
(
output_size_per_partition
=
output_size_per_partition
,
...
...
@@ -123,6 +127,16 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_shape"
,
weight_shape
)
# group index (for activation reordering)
if
self
.
has_g_idx
:
weight_g_idx
=
RowvLLMParameter
(
data
=
torch
.
empty
(
input_size_per_partition
,
dtype
=
torch
.
int32
,
),
input_dim
=
0
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight_g_idx"
,
weight_g_idx
)
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
input_size
=
input_size
...
...
@@ -137,9 +151,14 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
layer
.
workspace
=
marlin_make_workspace
(
layer
.
output_size_per_partition
,
device
)
# Act-order not supported in compressed-tensors yet, so set to empty.
layer
.
g_idx
=
marlin_make_empty_g_idx
(
device
)
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
device
)
# Handle sorting for activation reordering if needed.
if
self
.
has_g_idx
:
g_idx
,
g_idx_sort_indices
=
marlin_sort_g_idx
(
layer
.
weight_g_idx
)
layer
.
g_idx_sort_indices
=
g_idx_sort_indices
replace_tensor
(
layer
,
"weight_g_idx"
,
g_idx
)
else
:
layer
.
weight_g_idx
=
marlin_make_empty_g_idx
(
device
)
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
device
)
# No zero-point
layer
.
weight_zp
=
marlin_make_empty_g_idx
(
device
)
...
...
@@ -159,9 +178,11 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
replace_tensor
(
layer
,
"weight_packed"
,
marlin_qweight
)
# Permute scales from compressed-tensors format to marlin format.
# scale is required on all partitions if activation reordering
marlin_scales
=
marlin_permute_scales
(
layer
.
weight_scale
,
size_k
=
layer
.
input_size_per_partition
,
size_k
=
(
layer
.
input_size
if
self
.
has_g_idx
else
layer
.
input_size_per_partition
),
size_n
=
layer
.
output_size_per_partition
,
group_size
=
layer
.
group_size
)
replace_tensor
(
layer
,
"weight_scale"
,
marlin_scales
)
...
...
@@ -174,7 +195,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
weight
=
layer
.
weight_packed
,
weight_scale
=
layer
.
weight_scale
,
weight_zp
=
layer
.
weight_zp
,
g_idx
=
layer
.
g_idx
,
g_idx
=
layer
.
weight_
g_idx
,
g_idx_sort_indices
=
layer
.
g_idx_sort_indices
,
workspace
=
layer
.
workspace
,
wtype
=
self
.
quant_type
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
View file @
4851c202
import
re
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Union
from
pydantic
import
BaseModel
,
Field
from
pydantic
import
BaseModel
,
Field
,
field_validator
from
torch.nn
import
Module
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
...
...
@@ -40,6 +40,19 @@ class QuantizationStrategy(str, Enum):
TOKEN
=
"token"
class
ActivationOrdering
(
str
,
Enum
):
"""
Enum storing strategies for activation ordering
Group: reorder groups and weight
\n
Weight: only reorder weight, not groups. Slightly lower latency and
accuracy compared to group actorder
\n
"""
GROUP
=
"group"
WEIGHT
=
"weight"
class
QuantizationArgs
(
BaseModel
):
"""
User facing arguments used to define a quantization config
...
...
@@ -58,6 +71,8 @@ class QuantizationArgs(BaseModel):
observed with every sample. Defaults to False for static
quantization. Note that enabling dynamic quantization
will change the default observer to a memoryless one
:param actorder: whether to apply group quantization in decreasing order of
activation. Defaults to None for arbitrary ordering
"""
num_bits
:
int
=
8
...
...
@@ -67,6 +82,7 @@ class QuantizationArgs(BaseModel):
strategy
:
Optional
[
QuantizationStrategy
]
=
None
block_structure
:
Optional
[
str
]
=
None
dynamic
:
bool
=
False
actorder
:
Union
[
ActivationOrdering
,
bool
,
None
]
=
None
observer
:
str
=
Field
(
default
=
"minmax"
,
description
=
(
"The class to use to compute the quantization param - "
...
...
@@ -79,6 +95,16 @@ class QuantizationArgs(BaseModel):
"Observers constructor excluding quantization range or symmetry"
),
)
@
field_validator
(
"actorder"
,
mode
=
"before"
)
def
validate_actorder
(
cls
,
value
)
->
Optional
[
ActivationOrdering
]:
if
isinstance
(
value
,
bool
):
return
ActivationOrdering
.
GROUP
if
value
else
None
if
isinstance
(
value
,
str
):
return
ActivationOrdering
(
value
.
lower
())
return
value
def
is_activation_quantization_format
(
format
:
str
)
->
bool
:
_ACTIVATION_QUANTIZATION_FORMATS
=
[
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
4851c202
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
import
torch
from
torch.nn
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.fused_moe.layer
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
apply_gptq_marlin_linear
,
check_marlin_supported
,
marlin_is_k_full
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
marlin_repeat_scales_on_all_ranks
,
marlin_sort_g_idx
,
replace_tensor
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_moe_permute_scales
,
marlin_permute_scales
,
marlin_repeat_scales_on_all_ranks
,
marlin_sort_g_idx
,
replace_tensor
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
...
...
@@ -33,8 +37,14 @@ class GPTQMarlinConfig(QuantizationConfig):
(
8
,
True
):
scalar_types
.
uint8b128
,
}
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
desc_act
:
bool
,
is_sym
:
bool
,
lm_head_quantized
:
bool
)
->
None
:
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
desc_act
:
bool
,
is_sym
:
bool
,
lm_head_quantized
:
bool
,
)
->
None
:
if
desc_act
and
group_size
==
-
1
:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
...
...
@@ -51,10 +61,6 @@ class GPTQMarlinConfig(QuantizationConfig):
self
.
quant_type
=
self
.
TYPE_MAP
[(
weight_bits
,
is_sym
)]
# Verify supported on platform.
verify_marlin_supported
(
quant_type
=
self
.
quant_type
,
group_size
=
self
.
group_size
)
def
__repr__
(
self
)
->
str
:
return
(
f
"GPTQMarlinConfig(quant_type=
{
self
.
quant_type
}
, "
f
"group_size=
{
self
.
group_size
}
, "
...
...
@@ -109,11 +115,14 @@ class GPTQMarlinConfig(QuantizationConfig):
" faster inference"
)
return
None
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"GPTQMarlinLinearMethod"
]:
if
(
isinstance
(
layer
,
LinearBase
)
or
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
)):
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
Union
[
"GPTQMarlinLinearMethod"
,
"GPTQMarlinMoEMethod"
]]:
if
isinstance
(
layer
,
LinearBase
)
or
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
):
return
GPTQMarlinLinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
return
GPTQMarlinMoEMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
...
...
@@ -153,6 +162,10 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
def
__init__
(
self
,
quant_config
:
GPTQMarlinConfig
)
->
None
:
self
.
quant_config
=
quant_config
# Verify supported on platform.
verify_marlin_supported
(
quant_type
=
self
.
quant_config
.
quant_type
,
group_size
=
self
.
quant_config
.
group_size
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -179,7 +192,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
output_size_per_partition
=
output_size_per_partition
,
input_size_per_partition
=
input_size_per_partition
,
input_size
=
input_size
,
group_size
=
group_size
)
group_size
=
group_size
,
)
# Determine sharding
if
marlin_repeat_scales_on_all_ranks
(
self
.
quant_config
.
desc_act
,
...
...
@@ -299,7 +313,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
perm
=
layer
.
g_idx_sort_indices
,
size_k
=
layer
.
input_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
num_bits
=
self
.
quant_config
.
quant_type
.
size_bits
)
num_bits
=
self
.
quant_config
.
quant_type
.
size_bits
,
)
replace_tensor
(
layer
,
"qweight"
,
marlin_qweight
)
# Permute scales from autogptq format to marlin format.
...
...
@@ -308,7 +323,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
size_k
=
(
layer
.
input_size
if
self
.
quant_config
.
desc_act
else
layer
.
input_size_per_partition
),
size_n
=
layer
.
output_size_per_partition
,
group_size
=
self
.
quant_config
.
group_size
)
group_size
=
self
.
quant_config
.
group_size
,
)
replace_tensor
(
layer
,
"scales"
,
marlin_scales
)
def
apply
(
...
...
@@ -329,4 +345,270 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
output_size_per_partition
=
layer
.
output_size_per_partition
,
input_size_per_partition
=
layer
.
input_size_per_partition
,
is_k_full
=
layer
.
is_k_full
,
bias
=
bias
)
bias
=
bias
,
)
class
GPTQMarlinMoEMethod
(
FusedMoEMethodBase
):
"""MoE Marlin method with quantization."""
def
__init__
(
self
,
quant_config
:
GPTQMarlinConfig
)
->
None
:
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
# Currently assuming is_k_full is always True
# (input size per partition is the same as full input size)
# Supports only sym for now (no zp)
if
self
.
quant_config
.
group_size
!=
-
1
:
scales_size13
=
hidden_size
//
self
.
quant_config
.
group_size
scales_size2
=
intermediate_size
//
self
.
quant_config
.
group_size
strategy
=
FusedMoeWeightScaleSupported
.
GROUP
.
value
else
:
scales_size13
=
1
scales_size2
=
1
strategy
=
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
extra_weight_attrs
.
update
({
"quant_method"
:
strategy
,
"is_transposed"
:
True
})
# Fused gate_up_proj (column parallel)
w13_qweight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
//
self
.
quant_config
.
pack_factor
,
2
*
intermediate_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_qweight"
,
w13_qweight
)
set_weight_attrs
(
w13_qweight
,
extra_weight_attrs
)
# down_proj (row parallel)
w2_qweight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
intermediate_size
//
self
.
quant_config
.
pack_factor
,
hidden_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_qweight"
,
w2_qweight
)
set_weight_attrs
(
w2_qweight
,
extra_weight_attrs
)
# up_proj scales
w13_scales
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
scales_size13
,
2
*
intermediate_size
,
dtype
=
torch
.
half
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_scales"
,
w13_scales
)
set_weight_attrs
(
w13_scales
,
extra_weight_attrs
)
# down_proj scales
w2_scales
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
scales_size2
,
hidden_size
,
dtype
=
torch
.
half
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_scales"
,
w2_scales
)
set_weight_attrs
(
w2_scales
,
extra_weight_attrs
)
# up_proj scales
w13_qzeros
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
scales_size13
,
2
*
intermediate_size
//
self
.
quant_config
.
pack_factor
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_qzeros"
,
w13_qzeros
)
set_weight_attrs
(
w13_qzeros
,
extra_weight_attrs
)
# down_proj scales
w2_qzeros
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
scales_size2
,
hidden_size
//
self
.
quant_config
.
pack_factor
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_qzeros"
,
w2_qzeros
)
set_weight_attrs
(
w2_qzeros
,
extra_weight_attrs
)
w13_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_g_idx"
,
w13_g_idx
)
set_weight_attrs
(
w13_g_idx
,
extra_weight_attrs
)
w2_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
intermediate_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_g_idx"
,
w2_g_idx
)
set_weight_attrs
(
w2_g_idx
,
extra_weight_attrs
)
w13_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_g_idx_sort_indices"
,
w13_g_idx_sort_indices
)
set_weight_attrs
(
w13_g_idx_sort_indices
,
extra_weight_attrs
)
w2_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
intermediate_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_g_idx_sort_indices"
,
w2_g_idx_sort_indices
)
set_weight_attrs
(
w2_g_idx_sort_indices
,
extra_weight_attrs
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# Process act_order
if
self
.
quant_config
.
desc_act
:
# Get sorting based on g_idx
num_experts
=
layer
.
w13_g_idx
.
shape
[
0
]
w13_g_idx_sort_indices
=
torch
.
empty_like
(
layer
.
w13_g_idx
)
w2_g_idx_sort_indices
=
torch
.
empty_like
(
layer
.
w2_g_idx
)
w13_sorted_g_idx
=
torch
.
empty_like
(
layer
.
w13_g_idx
)
w2_sorted_g_idx
=
torch
.
empty_like
(
layer
.
w2_g_idx
)
for
e
in
range
(
num_experts
):
w13_g_idx_sort_indices
[
e
]
=
torch
.
argsort
(
layer
.
w13_g_idx
[
e
]).
to
(
torch
.
int32
)
w2_g_idx_sort_indices
[
e
]
=
torch
.
argsort
(
layer
.
w2_g_idx
[
e
]).
to
(
torch
.
int32
)
w13_sorted_g_idx
[
e
]
=
layer
.
w13_g_idx
[
e
][
w13_g_idx_sort_indices
[
e
]]
w2_sorted_g_idx
[
e
]
=
layer
.
w2_g_idx
[
e
][
w2_g_idx_sort_indices
[
e
]]
replace_tensor
(
layer
,
"w13_g_idx"
,
w13_sorted_g_idx
)
replace_tensor
(
layer
,
"w2_g_idx"
,
w2_sorted_g_idx
)
replace_tensor
(
layer
,
"w13_g_idx_sort_indices"
,
w13_g_idx_sort_indices
)
replace_tensor
(
layer
,
"w2_g_idx_sort_indices"
,
w2_g_idx_sort_indices
)
else
:
# Reset g_idx related tensors
num_experts
=
layer
.
w13_g_idx
.
shape
[
0
]
device
=
layer
.
w13_g_idx
.
device
layer
.
w13_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
layer
.
w2_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
layer
.
w13_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
layer
.
w2_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
# Repack weights
marlin_w13_qweight
=
ops
.
gptq_marlin_moe_repack
(
layer
.
w13_qweight
,
layer
.
w13_g_idx_sort_indices
,
layer
.
w13_qweight
.
shape
[
1
]
*
self
.
quant_config
.
pack_factor
,
layer
.
w13_qweight
.
shape
[
2
],
self
.
quant_config
.
quant_type
.
size_bits
,
)
replace_tensor
(
layer
,
"w13_qweight"
,
marlin_w13_qweight
)
marlin_w2_qweight
=
ops
.
gptq_marlin_moe_repack
(
layer
.
w2_qweight
,
layer
.
w2_g_idx_sort_indices
,
layer
.
w2_qweight
.
shape
[
1
]
*
self
.
quant_config
.
pack_factor
,
layer
.
w2_qweight
.
shape
[
2
],
self
.
quant_config
.
quant_type
.
size_bits
,
)
replace_tensor
(
layer
,
"w2_qweight"
,
marlin_w2_qweight
)
# Repack scales
marlin_w13_scales
=
marlin_moe_permute_scales
(
s
=
layer
.
w13_scales
,
size_k
=
layer
.
intermediate_size_per_partition
,
size_n
=
layer
.
w13_scales
.
shape
[
2
],
group_size
=
self
.
quant_config
.
group_size
,
)
replace_tensor
(
layer
,
"w13_scales"
,
marlin_w13_scales
)
marlin_w2_scales
=
marlin_moe_permute_scales
(
s
=
layer
.
w2_scales
,
size_k
=
layer
.
w2_scales
.
shape
[
1
]
*
self
.
quant_config
.
pack_factor
,
size_n
=
layer
.
w2_scales
.
shape
[
2
],
group_size
=
self
.
quant_config
.
group_size
,
)
replace_tensor
(
layer
,
"w2_scales"
,
marlin_w2_scales
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
fused_marlin_moe
)
# The input must currently be float16
orig_dtype
=
x
.
dtype
x
=
x
.
half
()
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
None
)
return
fused_marlin_moe
(
x
,
layer
.
w13_qweight
,
layer
.
w2_qweight
,
router_logits
,
layer
.
w13_g_idx
,
layer
.
w2_g_idx
,
layer
.
w13_g_idx_sort_indices
,
layer
.
w2_g_idx_sort_indices
,
topk_weights
,
topk_ids
,
w1_scale
=
layer
.
w13_scales
,
w2_scale
=
layer
.
w2_scales
,
).
to
(
orig_dtype
)
vllm/model_executor/layers/quantization/modelopt.py
0 → 100644
View file @
4851c202
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_fp8_linear
,
cutlass_fp8_supported
,
requantize_with_max_scale
)
from
vllm.model_executor.parameter
import
(
ModelWeightParameter
,
PerTensorScaleParameter
)
logger
=
init_logger
(
__name__
)
ACTIVATION_SCHEMES
=
[
"static"
]
class
ModelOptFp8Config
(
QuantizationConfig
):
"""Config class for ModelOpt FP8."""
def
__init__
(
self
,
is_checkpoint_fp8_serialized
:
bool
=
False
,
)
->
None
:
self
.
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
if
is_checkpoint_fp8_serialized
:
logger
.
warning
(
"Detected ModelOpt fp8 checkpoint. Please note that"
" the format is experimental and could change."
)
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"modelopt"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
bfloat16
,
torch
.
half
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
89
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"hf_quant_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"ModelOptFp8Config"
:
quant_config
=
cls
.
get_from_keys
(
config
,
[
"quantization"
])
quant_method
=
quant_config
[
"quant_algo"
]
is_checkpoint_fp8_serialized
=
(
"FP8"
in
quant_method
)
if
not
is_checkpoint_fp8_serialized
:
raise
ValueError
(
"ModelOpt currently only supports static FP8"
"quantization in vLLM. Please check the "
"`hf_quant_config.json` file for your model's "
"quant configuration."
)
return
cls
(
is_checkpoint_fp8_serialized
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
from
vllm.attention.layer
import
Attention
# Avoid circular import
if
isinstance
(
layer
,
LinearBase
):
return
ModelOptFp8LinearMethod
(
self
)
elif
isinstance
(
layer
,
Attention
):
return
ModelOptFp8KVCacheMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
ModelOptFp8KVCacheMethod
(
BaseKVCacheMethod
):
"""
Supports loading kv-cache scaling factors from FP8 checkpoints.
"""
def
__init__
(
self
,
quant_config
:
ModelOptFp8Config
):
super
().
__init__
(
quant_config
)
class
ModelOptFp8LinearMethod
(
LinearMethodBase
):
"""Linear method for Model Optimizer static quantization.
Supports loading FP8 checkpoints with static weight scale and
activation scale. Future support might be added for dynamic
scales.
Limitations:
1. Only support per-tensor quantization due to torch._scaled_mm support.
2. Only support float8_e4m3fn datatype
Args: quant_config: The ModelOpt quantization config.
"""
def
__init__
(
self
,
quant_config
:
ModelOptFp8Config
):
self
.
quant_config
=
quant_config
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
del
input_size
,
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
weight_dtype
=
(
torch
.
float8_e4m3fn
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
else
params_dtype
)
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
weight_dtype
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight"
,
weight
)
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# WEIGHT SCALE
weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
weight_scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
# INPUT SCALE
scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
layer
.
register_parameter
(
"input_scale"
,
scale
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
max_w_scale
,
weight
=
requantize_with_max_scale
(
layer
.
weight
,
layer
.
weight_scale
,
layer
.
logical_widths
)
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
max_w_scale
,
requires_grad
=
False
)
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
apply_fp8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
,
cutlass_fp8_supported
=
self
.
cutlass_fp8_supported
)
vllm/model_executor/layers/quantization/squeezellm.py
deleted
100644 → 0
View file @
9b902f9e
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils
import
is_hip
class
SqueezeLLMConfig
(
QuantizationConfig
):
"""Config class for SqueezeLLM.
Reference: https://arxiv.org/pdf/2306.07629
"""
def
__init__
(
self
,
weight_bits
:
int
,
)
->
None
:
self
.
weight_bits
=
weight_bits
if
self
.
weight_bits
!=
4
:
raise
ValueError
(
"Currently, only 4-bit weight quantization is supported for "
f
"SqueezeLLM, but got
{
self
.
weight_bits
}
bits."
)
self
.
pack_factor
=
32
//
self
.
weight_bits
def
__repr__
(
self
)
->
str
:
return
f
"SqueezeLLMConfig(weight_bits=
{
self
.
weight_bits
}
)"
def
get_name
(
self
)
->
str
:
return
"squeezellm"
def
get_supported_act_dtypes
(
self
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
70
@
staticmethod
def
get_config_filenames
()
->
List
[
str
]:
return
[
"quant_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"SqueezeLLMConfig"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"wbits"
])
return
cls
(
weight_bits
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
QuantizeMethodBase
]:
if
isinstance
(
layer
,
LinearBase
):
return
SqueezeLLMLinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
SqueezeLLMLinearMethod
(
QuantizeMethodBase
):
"""Linear method for SqueezeLLM.
Args:
quant_config: The SqueezeLLM quantization config.
"""
def
__init__
(
self
,
quant_config
:
SqueezeLLMConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
if
input_size_per_partition
%
self
.
quant_config
.
pack_factor
!=
0
:
raise
ValueError
(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size."
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
qweight
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
pack_factor
,
output_size_per_partition
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
0
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
lookup_table
=
Parameter
(
torch
.
empty
(
output_size
,
self
.
quant_config
.
weight_bits
**
2
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
lookup_table
,
{
"output_dim"
:
0
,
})
layer
.
register_parameter
(
"qweight"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"lookup_table"
,
lookup_table
)
set_weight_attrs
(
lookup_table
,
extra_weight_attrs
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
qweight
=
layer
.
qweight
lookup_table
=
layer
.
lookup_table
out_shape
=
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
-
1
],
)
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
if
is_hip
():
out_f
=
torch
.
zeros
(
out_shape
,
dtype
=
torch
.
float
)
ops
.
squeezellm_gemm
(
reshaped_x
,
qweight
,
out_f
,
lookup_table
)
out
=
out_f
.
to
(
dtype
=
torch
.
float16
)
else
:
# NOTE: The output tensor should be zero-initialized.
out
=
torch
.
zeros
(
out_shape
,
dtype
=
torch
.
float16
)
ops
.
squeezellm_gemm
(
reshaped_x
,
qweight
,
out
,
lookup_table
)
if
bias
is
not
None
:
out
.
add_
(
bias
)
return
out
.
reshape
(
out_shape
)
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
4851c202
...
...
@@ -176,6 +176,23 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
return
s
def
marlin_moe_permute_scales
(
s
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
group_size
:
int
,
):
num_experts
=
s
.
shape
[
0
]
output
=
torch
.
empty
(
(
num_experts
,
s
.
shape
[
1
],
s
.
shape
[
2
]),
device
=
s
.
device
,
dtype
=
s
.
dtype
,
)
for
e
in
range
(
num_experts
):
output
[
e
]
=
marlin_permute_scales
(
s
[
e
],
size_k
,
size_n
,
group_size
)
return
output
def
marlin_zero_points
(
zp
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
num_bits
:
int
)
->
torch
.
Tensor
:
# Permute zero-points in a similar way to scales, but do not use the
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils_test.py
View file @
4851c202
"""Utility functions used for tests and benchmarks"""
from
typing
import
List
from
typing
import
List
,
Optional
import
numpy
as
np
import
torch
...
...
@@ -92,8 +92,11 @@ def get_weight_perm(num_bits: int):
return
perm
def
marlin_quantize
(
w
:
torch
.
Tensor
,
quant_type
:
ScalarType
,
group_size
:
int
,
act_order
:
bool
):
def
marlin_quantize
(
w
:
torch
.
Tensor
,
quant_type
:
ScalarType
,
group_size
:
int
,
act_order
:
bool
,
test_perm
:
Optional
[
torch
.
Tensor
]
=
None
):
size_k
,
size_n
=
w
.
shape
num_bits
=
quant_type
.
size_bits
...
...
@@ -104,7 +107,7 @@ def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int,
# Quantize (and apply act_order if provided)
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
=
gptq_quantize_weights
(
w
,
quant_type
,
group_size
,
act_order
)
w
,
quant_type
,
group_size
,
act_order
,
test_perm
)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing
...
...
vllm/model_executor/layers/quantization/utils/quant_utils.py
View file @
4851c202
"""This file is used for /tests and /benchmarks"""
from
typing
import
List
from
typing
import
List
,
Optional
import
numpy
import
torch
...
...
@@ -53,7 +53,10 @@ def get_pack_factor(num_bits):
return
32
//
num_bits
def
permute_rows
(
q_w
:
torch
.
Tensor
,
w_ref
:
torch
.
Tensor
,
group_size
:
int
):
def
permute_rows
(
q_w
:
torch
.
Tensor
,
w_ref
:
torch
.
Tensor
,
group_size
:
int
,
test_perm
:
Optional
[
torch
.
Tensor
]
=
None
):
assert
q_w
.
shape
==
w_ref
.
shape
orig_device
=
q_w
.
device
...
...
@@ -64,7 +67,7 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
g_idx
[
i
]
=
i
//
group_size
# Simulate act_order by doing a random permutation on K
rand_perm
=
torch
.
randperm
(
k_size
)
rand_perm
=
test_perm
if
test_perm
is
not
None
else
torch
.
randperm
(
k_size
)
g_idx
=
g_idx
[
rand_perm
].
contiguous
()
q_w
=
q_w
[
rand_perm
,
:].
contiguous
()
...
...
@@ -164,8 +167,11 @@ def quantize_weights(w: torch.Tensor,
)
def
gptq_quantize_weights
(
w
:
torch
.
Tensor
,
quant_type
:
ScalarType
,
group_size
:
int
,
act_order
:
bool
):
def
gptq_quantize_weights
(
w
:
torch
.
Tensor
,
quant_type
:
ScalarType
,
group_size
:
int
,
act_order
:
bool
,
test_perm
:
Optional
[
torch
.
Tensor
]
=
None
):
size_k
,
_
=
w
.
shape
assert
w
.
is_floating_point
(),
"w must be float"
...
...
@@ -186,7 +192,8 @@ def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType,
),
"For act_order, groupsize = {} must be less than size_k = {}"
.
format
(
group_size
,
size_k
)
w_ref
,
w_q
,
g_idx
,
rand_perm
=
permute_rows
(
w_q
,
w_ref
,
group_size
)
w_ref
,
w_q
,
g_idx
,
rand_perm
=
permute_rows
(
w_q
,
w_ref
,
group_size
,
test_perm
)
return
w_ref
,
w_q
,
w_s
,
g_idx
,
rand_perm
...
...
vllm/model_executor/layers/resampler.py
0 → 100644
View file @
4851c202
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
#
# Copyright 2023 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Shared resampler perceiver network used in multimodal models and
related helpers for sincos positional embeddings.
Example models: Qwen (Qwen-VL), Minicpmv2.0
"""
import
math
from
functools
import
partial
from
typing
import
Callable
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch.nn.init
import
trunc_normal_
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
DEFAULT_LN
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
def
get_abs_pos
(
abs_pos
:
torch
.
Tensor
,
tgt_size
:
Union
[
torch
.
Tensor
,
int
])
->
torch
.
Tensor
:
# abs_pos: L, C
# tgt_size: (H, W)
# return: M, C
src_size
=
int
(
math
.
sqrt
(
abs_pos
.
size
(
0
)))
dtype
=
abs_pos
.
dtype
if
isinstance
(
tgt_size
,
int
):
tgt_size
=
(
tgt_size
,
tgt_size
)
if
(
src_size
==
tgt_size
[
0
]
and
src_size
==
tgt_size
[
1
]):
return
abs_pos
return
(
F
.
interpolate
(
abs_pos
.
float
().
reshape
(
1
,
src_size
,
src_size
,
-
1
).
permute
(
0
,
3
,
1
,
2
),
size
=
(
tgt_size
[
0
],
tgt_size
[
1
]),
mode
=
"bicubic"
,
align_corners
=
False
,
).
permute
(
0
,
2
,
3
,
1
).
flatten
(
0
,
2
).
to
(
dtype
=
dtype
))
# sin/cos positional embedding helpers are adapted from:
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def
get_1d_sincos_pos_embed_from_grid
(
embed_dim
:
int
,
pos
:
np
.
ndarray
,
version
:
Tuple
[
int
,
int
]
=
(
2
,
0
))
->
torch
.
Tensor
:
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,) / (H, W)
out: (M, D) / (H, W, D)
"""
assert
embed_dim
%
2
==
0
omega
=
np
.
arange
(
embed_dim
//
2
,
dtype
=
np
.
float32
)
omega
/=
embed_dim
/
2.0
omega
=
1.0
/
10000
**
omega
# (D/2,)
if
version
==
(
2
,
0
):
pos
=
pos
.
reshape
(
-
1
)
# (M,)
out
=
np
.
einsum
(
"m,d->md"
,
pos
,
omega
)
# (M, D/2), outer product
emb_sin
=
np
.
sin
(
out
)
# (M, D/2)
emb_cos
=
np
.
cos
(
out
)
# (M, D/2)
emb
=
np
.
concatenate
([
emb_sin
,
emb_cos
],
axis
=
1
)
# (M, D)
else
:
out
=
np
.
einsum
(
"hw,d->hwd"
,
pos
,
omega
)
# (H, W, D/2), outer product
emb_sin
=
np
.
sin
(
out
)
# (H, W, D/2)
emb_cos
=
np
.
cos
(
out
)
# (H, W, D/2)
emb
=
np
.
concatenate
([
emb_sin
,
emb_cos
],
axis
=-
1
)
# (H, W, D)
return
emb
def
get_2d_sincos_pos_embed_from_grid
(
embed_dim
:
int
,
grid
:
np
.
ndarray
,
version
:
Tuple
[
int
,
int
]
=
(
2
,
0
))
->
torch
.
Tensor
:
assert
embed_dim
%
2
==
0
# use half of dimensions to encode grid_h
emb_h
=
get_1d_sincos_pos_embed_from_grid
(
embed_dim
//
2
,
grid
[
0
],
version
)
# (H*W, D/2) or (H, W, D/2)
emb_w
=
get_1d_sincos_pos_embed_from_grid
(
embed_dim
//
2
,
grid
[
1
],
version
)
# (H*W, D/2) or (H, W, D/2)
if
version
==
(
2
,
0
):
emb
=
np
.
concatenate
([
emb_h
,
emb_w
],
axis
=
1
)
# (H*W, D)
else
:
emb
=
np
.
concatenate
([
emb_h
,
emb_w
],
axis
=-
1
)
# (H, W, D)
return
emb
def
get_2d_sincos_pos_embed
(
embed_dim
:
int
,
grid_size
:
Union
[
int
,
Tuple
[
int
,
int
]],
cls_token
:
bool
=
False
,
version
:
Tuple
[
int
,
int
]
=
(
2
,
0
),
)
->
torch
.
Tensor
:
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if
isinstance
(
grid_size
,
int
):
grid_h_size
,
grid_w_size
=
grid_size
,
grid_size
else
:
grid_h_size
,
grid_w_size
=
grid_size
[
0
],
grid_size
[
1
]
grid_h
=
np
.
arange
(
grid_h_size
,
dtype
=
np
.
float32
)
grid_w
=
np
.
arange
(
grid_w_size
,
dtype
=
np
.
float32
)
grid
=
np
.
meshgrid
(
grid_w
,
grid_h
)
# here w goes first
grid
=
np
.
stack
(
grid
,
axis
=
0
)
assert
isinstance
(
grid
,
np
.
ndarray
)
and
\
grid
.
shape
==
(
2
,
grid_h_size
,
grid_w_size
)
if
version
==
(
2
,
0
):
grid
=
grid
.
reshape
([
2
,
1
,
grid_h_size
,
grid_w_size
])
pos_embed
=
get_2d_sincos_pos_embed_from_grid
(
embed_dim
,
grid
,
version
)
if
cls_token
:
pos_embed
=
np
.
concatenate
([
np
.
zeros
([
1
,
embed_dim
]),
pos_embed
],
axis
=
0
)
else
:
pos_embed
=
get_2d_sincos_pos_embed_from_grid
(
embed_dim
,
grid
,
version
)
return
pos_embed
class
BaseResampler
(
nn
.
Module
):
"""
A 2D perceiver-resampler network with one cross attention layers by
(grid_size**2) learnable queries and 2d sincos pos_emb.
Outputs:
A tensor with the shape of (grid_size**2, embed_dim)
"""
def
__init__
(
self
,
num_queries
:
int
,
embed_dim
:
int
,
num_heads
:
int
,
kv_dim
:
Optional
[
int
]
=
None
,
norm_layer
:
Callable
[[
int
],
nn
.
LayerNorm
]
=
DEFAULT_LN
,
do_post_projection
:
bool
=
True
,
)
->
None
:
super
().
__init__
()
self
.
num_queries
=
num_queries
self
.
embed_dim
=
embed_dim
self
.
num_heads
=
num_heads
self
.
query
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_queries
,
embed_dim
))
trunc_normal_
(
self
.
query
,
std
=
0.02
)
if
kv_dim
is
not
None
and
kv_dim
!=
embed_dim
:
self
.
kv_proj
=
ReplicatedLinear
(
kv_dim
,
embed_dim
,
bias
=
False
)
else
:
# Maintain the same return value with ReplicatedLinear.forward
self
.
kv_proj
=
lambda
*
args
,
**
kwargs
:
(
# type: ignore # noqa
nn
.
Identity
()(
*
args
,
**
kwargs
),
None
,
)
self
.
attn
=
nn
.
MultiheadAttention
(
embed_dim
,
num_heads
)
self
.
ln_q
=
norm_layer
(
embed_dim
)
self
.
ln_kv
=
norm_layer
(
embed_dim
)
self
.
do_post_projection
=
do_post_projection
self
.
ln_post
=
norm_layer
(
embed_dim
)
if
do_post_projection
else
None
self
.
proj
=
nn
.
Parameter
(
(
embed_dim
**-
0.5
)
*
torch
.
randn
(
embed_dim
,
embed_dim
))
if
do_post_projection
else
None
def
_init_weights
(
self
,
m
:
nn
.
Module
)
->
None
:
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
0.02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
_repeat
(
self
,
query
,
N
:
int
):
return
query
.
unsqueeze
(
1
).
repeat
(
1
,
N
,
1
)
class
Resampler2
(
BaseResampler
):
"""Resampler-perceiver network to be used for a variety of model types,
e.g., Qwen-vl / Minicpmv 2.0. The main difference is the addition of the
do_post_projection arg, which indicates whether or not there should be
a post layer normalization and projector after the attention. This is
present in minicpmv2.0, but not qwen-vl.
"""
def
__init__
(
self
,
grid_size
:
int
,
embed_dim
:
int
,
num_heads
:
int
,
kv_dim
:
Optional
[
int
]
=
None
,
norm_layer
:
Callable
[[
int
],
nn
.
LayerNorm
]
=
DEFAULT_LN
,
adaptive
:
bool
=
False
,
do_post_projection
:
bool
=
True
,
)
->
None
:
super
().
__init__
(
grid_size
**
2
,
embed_dim
,
num_heads
,
kv_dim
,
norm_layer
,
do_post_projection
=
do_post_projection
)
self
.
adaptive
=
adaptive
pos_embed_arr
=
get_2d_sincos_pos_embed
(
embed_dim
,
grid_size
,
version
=
(
2
,
0
))
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
from_numpy
(
pos_embed_arr
).
requires_grad_
(
False
))
self
.
apply
(
self
.
_init_weights
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
tgt_sizes
:
Optional
[
torch
.
Tensor
]
=
None
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
tgt_sizes
is
None
:
tgt_sizes
=
int
(
math
.
sqrt
(
x
.
size
(
1
)))
if
self
.
adaptive
:
pos_embed_arr
=
get_2d_sincos_pos_embed
(
self
.
embed_dim
,
tgt_sizes
,
version
=
(
2
,
0
))
pos_embed
=
torch
.
from_numpy
(
pos_embed_arr
).
to
(
device
=
x
.
device
,
dtype
=
x
.
dtype
)
else
:
pos_embed
=
get_abs_pos
(
self
.
pos_embed
,
tgt_sizes
).
to
(
device
=
x
.
device
,
dtype
=
x
.
dtype
)
x
,
_
=
self
.
kv_proj
(
x
)
x
=
self
.
ln_kv
(
x
).
permute
(
1
,
0
,
2
)
N
=
x
.
shape
[
1
]
q
=
self
.
ln_q
(
self
.
query
)
out
=
self
.
attn
(
self
.
_repeat
(
q
,
N
)
+
self
.
pos_embed
.
unsqueeze
(
1
),
x
+
pos_embed
.
unsqueeze
(
1
),
x
,
attn_mask
=
attn_mask
,
)[
0
]
x
=
out
.
permute
(
1
,
0
,
2
)
if
self
.
do_post_projection
:
x
=
self
.
ln_post
(
x
)
x
=
x
@
self
.
proj
return
x
vllm/model_executor/layers/rotary_embedding.py
View file @
4851c202
...
...
@@ -28,7 +28,6 @@ import torch
import
torch.nn
as
nn
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.platforms
import
current_platform
def
_rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -48,21 +47,29 @@ def _apply_rotary_emb(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
is_neox_style
:
bool
,
)
->
torch
.
Tensor
:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
orig_dtype
=
x
.
dtype
x
=
x
.
float
()
x1
,
x2
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
cos
=
cos
.
unsqueeze
(
-
2
)
sin
=
sin
.
unsqueeze
(
-
2
)
cos
=
cos
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
sin
=
sin
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
if
is_neox_style
:
x1
,
x2
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
else
:
x1
=
x
[...,
::
2
]
x2
=
x
[...,
1
::
2
]
o1
=
x1
*
cos
-
x2
*
sin
o2
=
x2
*
cos
+
x1
*
sin
return
torch
.
cat
((
o1
,
o2
),
dim
=-
1
).
to
(
orig_dtype
)
if
is_neox_style
:
return
torch
.
cat
((
o1
,
o2
),
dim
=-
1
)
else
:
return
torch
.
stack
((
o1
,
o2
),
dim
=-
1
).
flatten
(
-
2
)
class
RotaryEmbedding
(
CustomOp
):
...
...
@@ -87,10 +94,9 @@ class RotaryEmbedding(CustomOp):
cache
=
self
.
_compute_cos_sin_cache
()
cache
=
cache
.
to
(
dtype
)
self
.
cos_sin_cache
:
torch
.
Tensor
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
self
.
use_native2
=
current_platform
.
is_tpu
()
and
is_neox_style
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
"""Compute the inverse frequency."""
# NOTE(woosuk): To exactly match the HF implementation, we need to
...
...
@@ -119,59 +125,7 @@ class RotaryEmbedding(CustomOp):
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""A PyTorch-native implementation equivalent to forward().
This method mimics the implementation of the custom CUDA kernel
used in `forward_cuda()`.
"""
query
=
query
.
view
(
*
query
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
key
=
key
.
view
(
*
key
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
if
self
.
rotary_dim
<
self
.
head_size
:
query_pass
=
query
[...,
self
.
rotary_dim
:]
key_pass
=
key
[...,
self
.
rotary_dim
:]
self
.
cos_sin_cache
:
torch
.
Tensor
=
self
.
cos_sin_cache
.
to
(
positions
.
device
,
dtype
=
query
.
dtype
)
cos_sin
=
self
.
cos_sin_cache
[
torch
.
add
(
positions
,
offsets
)
if
offsets
is
not
None
else
positions
]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
if
self
.
is_neox_style
:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos
=
cos
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
else
:
cos
=
cos
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
rotate_fn
=
_rotate_neox
if
self
.
is_neox_style
else
_rotate_gptj
query_rot
=
query_rot
*
cos
+
rotate_fn
(
query_rot
)
*
sin
key_rot
=
key_rot
*
cos
+
rotate_fn
(
key_rot
)
*
sin
if
self
.
rotary_dim
<
self
.
head_size
:
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
)
else
:
query
=
query_rot
key
=
key_rot
query
=
query
.
flatten
(
-
2
)
key
=
key
.
flatten
(
-
2
)
return
query
,
key
def
forward_native2
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Another PyTorch-native implementation of forward().
This method might perform better than `forward_native()` when compiled.
"""
"""A PyTorch-native implementation of forward()."""
if
offsets
is
not
None
:
positions
=
positions
+
offsets
positions
=
positions
.
flatten
()
...
...
@@ -183,14 +137,14 @@ class RotaryEmbedding(CustomOp):
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
_apply_rotary_emb
(
query_rot
,
cos
,
sin
)
query_rot
=
_apply_rotary_emb
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
key_shape
=
key
.
shape
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
_apply_rotary_emb
(
key_rot
,
cos
,
sin
)
key_rot
=
_apply_rotary_emb
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
...
...
@@ -203,7 +157,7 @@ class RotaryEmbedding(CustomOp):
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
vllm
import
_custom_ops
as
ops
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
positions
.
device
,
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
query
.
device
,
dtype
=
query
.
dtype
)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
...
...
@@ -240,17 +194,6 @@ class RotaryEmbedding(CustomOp):
self
.
cos_sin_cache
,
self
.
is_neox_style
)
return
query
,
key
def
forward_tpu
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
forward_fn
=
(
self
.
forward_native2
if
self
.
use_native2
else
self
.
forward_native
)
return
forward_fn
(
positions
,
query
,
key
,
offsets
)
def
extra_repr
(
self
)
->
str
:
s
=
f
"head_size=
{
self
.
head_size
}
, rotary_dim=
{
self
.
rotary_dim
}
"
s
+=
f
", max_position_embeddings=
{
self
.
max_position_embeddings
}
"
...
...
@@ -769,6 +712,179 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
return
new_freqs
class
MRotaryEmbedding
(
RotaryEmbedding
):
"""Rotary Embedding with Multimodal Sections."""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
mrope_section
:
Optional
[
List
[
int
]]
=
None
,
)
->
None
:
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
self
.
mrope_section
=
mrope_section
if
self
.
mrope_section
:
assert
sum
(
self
.
mrope_section
)
==
rotary_dim
//
2
def
forward
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""PyTorch-native implementation equivalent to forward().
Args:
positions:
[num_tokens,] (text only) or
[3, num_tokens] (T/H/W positions with multimodal inputs)
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
assert
positions
.
ndim
==
1
or
positions
.
ndim
==
2
num_tokens
=
positions
.
shape
[
-
1
]
cos_sin
=
self
.
cos_sin_cache
[
positions
]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
if
positions
.
ndim
==
2
:
assert
self
.
mrope_section
cos
=
torch
.
cat
([
m
[
i
]
for
i
,
m
in
enumerate
(
cos
.
split
(
self
.
mrope_section
,
dim
=-
1
))
],
dim
=-
1
)
sin
=
torch
.
cat
([
m
[
i
]
for
i
,
m
in
enumerate
(
sin
.
split
(
self
.
mrope_section
,
dim
=-
1
))
],
dim
=-
1
)
query_shape
=
query
.
shape
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
_apply_rotary_emb
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
key_shape
=
key
.
shape
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
_apply_rotary_emb
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
@
staticmethod
def
get_input_positions
(
input_tokens
:
List
[
int
],
image_grid_thw
:
Union
[
List
[
List
[
int
]],
torch
.
Tensor
],
video_grid_thw
:
Union
[
List
[
List
[
int
]],
torch
.
Tensor
],
image_token_id
:
int
,
video_token_id
:
int
,
vision_start_token_id
:
int
,
vision_end_token_id
:
int
,
spatial_merge_size
:
int
,
context_len
:
int
=
0
,
)
->
Tuple
[
List
[
List
[
int
]],
int
]:
"""Get mrope input positions and delta value."""
if
isinstance
(
image_grid_thw
,
torch
.
Tensor
):
image_grid_thw
=
image_grid_thw
.
tolist
()
if
isinstance
(
video_grid_thw
,
torch
.
Tensor
):
video_grid_thw
=
video_grid_thw
.
tolist
()
input_tokens_tensor
=
torch
.
tensor
(
input_tokens
)
vision_start_indices
=
torch
.
argwhere
(
input_tokens_tensor
==
vision_start_token_id
).
squeeze
(
1
)
vision_tokens
=
input_tokens_tensor
[
vision_start_indices
+
1
]
image_nums
=
(
vision_tokens
==
image_token_id
).
sum
()
video_nums
=
(
vision_tokens
==
video_token_id
).
sum
()
llm_pos_ids_list
:
list
=
[]
st
=
0
remain_images
,
remain_videos
=
image_nums
,
video_nums
image_index
,
video_index
=
0
,
0
for
_
in
range
(
image_nums
+
video_nums
):
if
image_token_id
in
input_tokens
and
remain_images
>
0
:
ed_image
=
input_tokens
.
index
(
image_token_id
,
st
)
else
:
ed_image
=
len
(
input_tokens
)
+
1
if
video_token_id
in
input_tokens
and
remain_videos
>
0
:
ed_video
=
input_tokens
.
index
(
video_token_id
,
st
)
else
:
ed_video
=
len
(
input_tokens
)
+
1
if
ed_image
<
ed_video
:
t
,
h
,
w
=
(
image_grid_thw
[
image_index
][
0
],
image_grid_thw
[
image_index
][
1
],
image_grid_thw
[
image_index
][
2
],
)
image_index
+=
1
remain_images
-=
1
ed
=
ed_image
else
:
t
,
h
,
w
=
(
video_grid_thw
[
video_index
][
0
],
video_grid_thw
[
video_index
][
1
],
video_grid_thw
[
video_index
][
2
],
)
video_index
+=
1
remain_videos
-=
1
ed
=
ed_video
llm_grid_t
,
llm_grid_h
,
llm_grid_w
=
\
t
,
h
//
spatial_merge_size
,
w
//
spatial_merge_size
text_len
=
ed
-
st
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
t_index
=
torch
.
arange
(
llm_grid_t
).
view
(
-
1
,
1
).
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
).
flatten
()
h_index
=
torch
.
arange
(
llm_grid_h
).
view
(
1
,
-
1
,
1
).
expand
(
llm_grid_t
,
-
1
,
llm_grid_w
).
flatten
()
w_index
=
torch
.
arange
(
llm_grid_w
).
view
(
1
,
1
,
-
1
).
expand
(
llm_grid_t
,
llm_grid_h
,
-
1
).
flatten
()
llm_pos_ids_list
.
append
(
torch
.
stack
([
t_index
,
h_index
,
w_index
])
+
text_len
+
st_idx
)
st
=
ed
+
llm_grid_t
*
llm_grid_h
*
llm_grid_w
if
st
<
len
(
input_tokens
):
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
text_len
=
len
(
input_tokens
)
-
st
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
llm_positions
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
).
reshape
(
3
,
-
1
)
llm_positions
=
llm_positions
[:,
context_len
:]
mrope_position_delta
=
(
llm_positions
.
max
()
+
1
-
len
(
input_tokens
)).
item
()
return
llm_positions
.
tolist
(),
mrope_position_delta
@
staticmethod
def
get_next_input_positions
(
mrope_position_delta
:
int
,
context_len
:
int
,
seq_len
:
int
,
)
->
List
[
List
[
int
]]:
return
[
list
(
range
(
context_len
+
mrope_position_delta
,
seq_len
+
mrope_position_delta
))
for
_
in
range
(
3
)
]
_ROPE_DICT
:
Dict
[
Tuple
,
RotaryEmbedding
]
=
{}
...
...
@@ -809,7 +925,7 @@ def get_rope(
# The correct one should be "longrope" but keep "su" here
# for backward compatible
if
scaling_type
not
in
{
"su"
,
"longrope"
}:
scaling_factor
=
rope_scaling
[
"factor"
]
scaling_factor
=
rope_scaling
.
get
(
"factor"
,
1.0
)
if
scaling_type
==
"llama3"
:
low_freq_factor
=
rope_scaling
[
"low_freq_factor"
]
high_freq_factor
=
rope_scaling
[
"high_freq_factor"
]
...
...
@@ -873,6 +989,16 @@ def get_rope(
head_size
,
rotary_dim
,
max_position
,
original_max_position
,
base
,
is_neox_style
,
dtype
,
short_factor
,
long_factor
,
**
extra_kwargs
)
elif
scaling_type
==
"mrope"
:
return
MRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
,
mrope_section
=
rope_scaling
[
"mrope_section"
],
)
else
:
raise
ValueError
(
f
"Unknown RoPE scaling type
{
scaling_type
}
"
)
_ROPE_DICT
[
key
]
=
rotary_emb
...
...
vllm/model_executor/model_loader/loader.py
View file @
4851c202
...
...
@@ -17,6 +17,7 @@ import torch
from
huggingface_hub
import
HfApi
,
hf_hub_download
from
torch
import
nn
from
transformers
import
AutoModelForCausalLM
,
PretrainedConfig
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoadFormat
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
...
...
@@ -94,8 +95,9 @@ def _get_quantization_config(
"""Get the quantization config."""
if
model_config
.
quantization
is
not
None
:
quant_config
=
get_quant_config
(
model_config
,
load_config
)
if
not
current_platform
.
is_tpu
():
capability
=
current_platform
.
get_device_capability
()
capability
=
current_platform
.
get_device_capability
()
# type: ignore
if
capability
is
not
None
:
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
if
capability
<
quant_config
.
get_min_capability
():
raise
ValueError
(
...
...
@@ -187,6 +189,11 @@ class BaseModelLoader(ABC):
def
__init__
(
self
,
load_config
:
LoadConfig
):
self
.
load_config
=
load_config
@
abstractmethod
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
"""Download a model so that it can be immediately loaded."""
raise
NotImplementedError
@
abstractmethod
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
...
...
@@ -195,7 +202,7 @@ class BaseModelLoader(ABC):
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
)
->
nn
.
Module
:
"""Load a model with the given configurations."""
...
raise
NotImplementedError
class
DefaultModelLoader
(
BaseModelLoader
):
...
...
@@ -244,12 +251,17 @@ class DefaultModelLoader(BaseModelLoader):
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
load_format
=
self
.
load_config
.
load_format
use_safetensors
=
False
index_file
=
SAFE_WEIGHTS_INDEX_NAME
# Some quantized models use .pt files for storing the weights.
if
load_format
==
LoadFormat
.
AUTO
:
allow_patterns
=
[
"*.safetensors"
,
"*.bin"
]
elif
load_format
==
LoadFormat
.
SAFETENSORS
:
use_safetensors
=
True
allow_patterns
=
[
"*.safetensors"
]
elif
load_format
==
LoadFormat
.
MISTRAL
:
use_safetensors
=
True
allow_patterns
=
[
"consolidated*.safetensors"
]
index_file
=
"consolidated.safetensors.index.json"
elif
load_format
==
LoadFormat
.
PT
:
allow_patterns
=
[
"*.pt"
]
elif
load_format
==
LoadFormat
.
NPCACHE
:
...
...
@@ -287,10 +299,10 @@ class DefaultModelLoader(BaseModelLoader):
# any files not found in the index.
if
not
is_local
:
download_safetensors_index_file_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
revision
)
model_name_or_path
,
index_file
,
self
.
load_config
.
download_dir
,
revision
)
hf_weights_files
=
filter_duplicate_safetensors_files
(
hf_weights_files
,
hf_folder
)
hf_weights_files
,
hf_folder
,
index_file
)
else
:
hf_weights_files
=
filter_files_not_needed_for_inference
(
hf_weights_files
)
...
...
@@ -332,6 +344,11 @@ class DefaultModelLoader(BaseModelLoader):
weights_iterator
=
_xla_weights_iterator
(
weights_iterator
)
return
weights_iterator
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
,
fall_back_to_pt
=
True
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
...
...
@@ -374,6 +391,9 @@ class DummyModelLoader(BaseModelLoader):
raise
ValueError
(
f
"Model loader extra config is not supported for "
f
"load format
{
load_config
.
load_format
}
"
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
pass
# Nothing to download
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
...
...
@@ -464,6 +484,12 @@ class TensorizerLoader(BaseModelLoader):
model
=
load_with_tensorizer
(
tensorizer_config
,
**
extra_kwargs
)
return
model
.
eval
()
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
tensorizer_config
.
verify_with_model_config
(
model_config
)
with
self
.
tensorizer_config
.
open_stream
():
pass
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
...
...
@@ -565,6 +591,9 @@ class ShardedStateLoader(BaseModelLoader):
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
...
...
@@ -992,6 +1021,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
set_weight_attrs
(
param
,
{
"matmul_state"
:
[
None
]
*
len
(
quant_states
)})
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
...
...
@@ -1067,6 +1099,9 @@ class GGUFModelLoader(BaseModelLoader):
return
gguf_quant_weights_iterator
(
model_name_or_path
,
gguf_to_hf_name_map
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
...
...
vllm/model_executor/model_loader/tensorizer.py
View file @
4851c202
...
...
@@ -99,6 +99,13 @@ class TensorizerConfig:
"Loading a model using Tensorizer with quantization on vLLM"
" is unstable and may lead to errors."
)
def
open_stream
(
self
,
tensorizer_args
:
Optional
[
"TensorizerArgs"
]
=
None
):
if
tensorizer_args
is
None
:
tensorizer_args
=
self
.
_construct_tensorizer_args
()
return
open_stream
(
self
.
tensorizer_uri
,
**
tensorizer_args
.
stream_params
)
def
load_with_tensorizer
(
tensorizer_config
:
TensorizerConfig
,
**
extra_kwargs
)
->
nn
.
Module
:
...
...
vllm/model_executor/model_loader/utils.py
View file @
4851c202
...
...
@@ -43,10 +43,18 @@ def get_model_architecture(
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
mixtral_supported
=
[
"fp8"
,
"compressed-tensors"
]
# for gptq_marlin, only run fused MoE for int4
if
model_config
.
quantization
==
"gptq_marlin"
:
hf_quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
None
)
if
hf_quant_config
and
hf_quant_config
.
get
(
"bits"
)
==
4
:
mixtral_supported
.
append
(
"gptq_marlin"
)
if
(
model_config
.
quantization
is
not
None
and
model_config
.
quantization
not
in
mixtral_supported
and
"MixtralForCausalLM"
in
architectures
):
architectures
=
[
"QuantMixtralForCausalLM"
]
return
ModelRegistry
.
resolve_model_cls
(
architectures
)
...
...
vllm/model_executor/model_loader/weight_utils.py
View file @
4851c202
...
...
@@ -16,7 +16,6 @@ import torch
from
huggingface_hub
import
HfFileSystem
,
hf_hub_download
,
snapshot_download
from
safetensors.torch
import
load_file
,
safe_open
,
save_file
from
tqdm.auto
import
tqdm
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
vllm.config
import
LoadConfig
,
ModelConfig
from
vllm.distributed
import
get_tensor_model_parallel_rank
...
...
@@ -193,6 +192,13 @@ def get_quant_config(model_config: ModelConfig,
if
model_config
.
quantization
==
"bitsandbytes"
:
config
[
"adapter_name_or_path"
]
=
model_name_or_path
elif
model_config
.
quantization
==
"modelopt"
:
if
config
[
"producer"
][
"name"
]
==
"modelopt"
:
return
quant_cls
.
from_config
(
config
)
else
:
raise
ValueError
(
f
"Unsupported quantization config"
f
" found for
{
model_config
.
quantization
}
in
{
f
}
."
)
return
quant_cls
.
from_config
(
config
)
...
...
@@ -251,6 +257,7 @@ def download_weights_from_hf(
def
download_safetensors_index_file_from_hf
(
model_name_or_path
:
str
,
index_file
:
str
,
cache_dir
:
Optional
[
str
],
revision
:
Optional
[
str
]
=
None
,
)
->
None
:
...
...
@@ -269,36 +276,37 @@ def download_safetensors_index_file_from_hf(
# Download the safetensors index file.
hf_hub_download
(
repo_id
=
model_name_or_path
,
filename
=
SAFE_WEIGHTS_INDEX_NAME
,
filename
=
index_file
,
cache_dir
=
cache_dir
,
revision
=
revision
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
)
# If file not found on remote or locally, we should not fail since
# only some models will have
SAFE_WEIGHTS_INDEX_NAME
.
# only some models will have
index_file
.
except
huggingface_hub
.
utils
.
EntryNotFoundError
:
logger
.
info
(
"No %s found in remote."
,
SAFE_WEIGHTS_INDEX_NAME
)
logger
.
info
(
"No %s found in remote."
,
index_file
)
except
huggingface_hub
.
utils
.
LocalEntryNotFoundError
:
logger
.
info
(
"No %s found in local cache."
,
SAFE_WEIGHTS_INDEX_NAME
)
logger
.
info
(
"No %s found in local cache."
,
index_file
)
# For models like Mistral-7B-v0.3, there are both sharded
# safetensors files and a consolidated safetensors file.
# Passing both of these to the weight loader functionality breaks.
# So, we use the
SAFE_WEIGHTS_INDEX_NAME
to
# So, we use the
index_file
to
# look up which safetensors files should be used.
def
filter_duplicate_safetensors_files
(
hf_weights_files
:
List
[
str
],
hf_folder
:
str
)
->
List
[
str
]:
hf_folder
:
str
,
index_file
:
str
)
->
List
[
str
]:
# model.safetensors.index.json is a mapping from keys in the
# torch state_dict to safetensors file holding that weight.
index_file_name
=
os
.
path
.
join
(
hf_folder
,
SAFE_WEIGHTS_INDEX_NAME
)
index_file_name
=
os
.
path
.
join
(
hf_folder
,
index_file
)
if
not
os
.
path
.
isfile
(
index_file_name
):
return
hf_weights_files
# Iterate through the weight_map (weight_name: safetensors files)
# to identify weights that we should use.
with
open
(
index_file_name
)
as
index_file
:
weight_map
=
json
.
load
(
index_file
)[
"weight_map"
]
with
open
(
index_file_name
,
"r"
)
as
f
:
weight_map
=
json
.
load
(
f
)[
"weight_map"
]
weight_files_in_index
=
set
()
for
weight_name
in
weight_map
:
weight_files_in_index
.
add
(
...
...
vllm/model_executor/models/__init__.py
View file @
4851c202
...
...
@@ -51,9 +51,10 @@ _GENERATION_MODELS = {
"PhiForCausalLM"
:
(
"phi"
,
"PhiForCausalLM"
),
"Phi3ForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"PhiMoEForCausalLM"
:
(
"phimoe"
,
"PhiMoEForCausalLM"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
"Qwen2ForCausalLM"
:
(
"qwen2"
,
"Qwen2ForCausalLM"
),
"Qwen2MoeForCausalLM"
:
(
"qwen2_moe"
,
"Qwen2MoeForCausalLM"
),
"Qwen2VLForConditionalGeneration"
:
(
"qwen2_vl"
,
"Qwen2VLForConditionalGeneration"
),
"RWForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"StableLMEpochForCausalLM"
:
(
"stablelm"
,
"StablelmForCausalLM"
),
"StableLmForCausalLM"
:
(
"stablelm"
,
"StablelmForCausalLM"
),
...
...
@@ -81,13 +82,20 @@ _MULTIMODAL_MODELS = {
"InternVLChatModel"
:
(
"internvl"
,
"InternVLChatModel"
),
"LlavaForConditionalGeneration"
:
(
"llava"
,
"LlavaForConditionalGeneration"
),
"LlavaNextForConditionalGeneration"
:
(
"llava_next"
,
"LlavaNextForConditionalGeneration"
),
"LlavaNextForConditionalGeneration"
:
(
"llava_next"
,
"LlavaNextForConditionalGeneration"
),
"LlavaNextVideoForConditionalGeneration"
:
(
"llava_next_video"
,
"LlavaNextVideoForConditionalGeneration"
),
"MiniCPMV"
:
(
"minicpmv"
,
"MiniCPMV"
),
"PaliGemmaForConditionalGeneration"
:
(
"paligemma"
,
"PaliGemmaForConditionalGeneration"
),
"Phi3VForCausalLM"
:
(
"phi3v"
,
"Phi3VForCausalLM"
),
"UltravoxModel"
:
(
"ultravox"
,
"UltravoxModel"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
"PixtralForConditionalGeneration"
:
(
"pixtral"
,
"PixtralForConditionalGeneration"
),
"Qwen2VLForConditionalGeneration"
:
(
"qwen2_vl"
,
"Qwen2VLForConditionalGeneration"
),
}
_CONDITIONAL_GENERATION_MODELS
=
{
"BartModel"
:
(
"bart"
,
"BartForConditionalGeneration"
),
...
...
Prev
1
…
4
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment