Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4851c202
"vllm/vscode:/vscode.git/clone" did not exist on "67fc16cd8cf778a30ad0f7619fe77bd85f1d1633"
Commit
4851c202
authored
Sep 13, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.1' into v0.6.1-dev
parents
9b902f9e
3fd2b0d2
Changes
203
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1194 additions
and
322 deletions
+1194
-322
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+2
-1
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+2
-2
vllm/model_executor/layers/quantization/awq_triton.py
vllm/model_executor/layers/quantization/awq_triton.py
+23
-11
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+15
-10
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+28
-20
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
...on/compressed_tensors/schemes/compressed_tensors_wNa16.py
+34
-13
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
..._executor/layers/quantization/compressed_tensors/utils.py
+28
-2
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+301
-19
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+163
-0
vllm/model_executor/layers/quantization/squeezellm.py
vllm/model_executor/layers/quantization/squeezellm.py
+0
-138
vllm/model_executor/layers/quantization/utils/marlin_utils.py
.../model_executor/layers/quantization/utils/marlin_utils.py
+17
-0
vllm/model_executor/layers/quantization/utils/marlin_utils_test.py
...l_executor/layers/quantization/utils/marlin_utils_test.py
+7
-4
vllm/model_executor/layers/quantization/utils/quant_utils.py
vllm/model_executor/layers/quantization/utils/quant_utils.py
+13
-6
vllm/model_executor/layers/resampler.py
vllm/model_executor/layers/resampler.py
+273
-0
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+203
-77
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+41
-6
vllm/model_executor/model_loader/tensorizer.py
vllm/model_executor/model_loader/tensorizer.py
+7
-0
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+8
-0
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+18
-10
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+11
-3
No files found.
vllm/model_executor/layers/linear.py
View file @
4851c202
...
@@ -29,7 +29,8 @@ WEIGHT_LOADER_V2_SUPPORTED = [
...
@@ -29,7 +29,8 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod"
,
"AWQMarlinLinearMethod"
,
"CompressedTensorsLinearMethod"
,
"AWQMarlinLinearMethod"
,
"AWQLinearMethod"
,
"GPTQMarlinLinearMethod"
,
"Fp8LinearMethod"
,
"AWQLinearMethod"
,
"GPTQMarlinLinearMethod"
,
"Fp8LinearMethod"
,
"MarlinLinearMethod"
,
"QQQLinearMethod"
,
"GPTQMarlin24LinearMethod"
,
"MarlinLinearMethod"
,
"QQQLinearMethod"
,
"GPTQMarlin24LinearMethod"
,
"TPUInt8LinearMethod"
,
"GPTQLinearMethod"
,
"FBGEMMFp8LinearMethod"
"TPUInt8LinearMethod"
,
"GPTQLinearMethod"
,
"FBGEMMFp8LinearMethod"
,
"ModelOptFp8LinearMethod"
]
]
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
4851c202
...
@@ -22,10 +22,10 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
...
@@ -22,10 +22,10 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
GPTQMarlin24Config
)
GPTQMarlin24Config
)
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
vllm.model_executor.layers.quantization.modelopt
import
ModelOptFp8Config
from
vllm.model_executor.layers.quantization.neuron_quant
import
(
from
vllm.model_executor.layers.quantization.neuron_quant
import
(
NeuronQuantConfig
)
NeuronQuantConfig
)
from
vllm.model_executor.layers.quantization.qqq
import
QQQConfig
from
vllm.model_executor.layers.quantization.qqq
import
QQQConfig
from
vllm.model_executor.layers.quantization.squeezellm
import
SqueezeLLMConfig
from
vllm.model_executor.layers.quantization.tpu_int8
import
Int8TpuConfig
from
vllm.model_executor.layers.quantization.tpu_int8
import
Int8TpuConfig
QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
...
@@ -35,6 +35,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
...
@@ -35,6 +35,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"tpu_int8"
:
Int8TpuConfig
,
"tpu_int8"
:
Int8TpuConfig
,
"fp8"
:
Fp8Config
,
"fp8"
:
Fp8Config
,
"fbgemm_fp8"
:
FBGEMMFp8Config
,
"fbgemm_fp8"
:
FBGEMMFp8Config
,
"modelopt"
:
ModelOptFp8Config
,
# The order of gptq methods is important for config.py iteration over
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
# override_quantization_method(..)
"marlin"
:
MarlinConfig
,
"marlin"
:
MarlinConfig
,
...
@@ -43,7 +44,6 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
...
@@ -43,7 +44,6 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"gptq_marlin"
:
GPTQMarlinConfig
,
"gptq_marlin"
:
GPTQMarlinConfig
,
"awq_marlin"
:
AWQMarlinConfig
,
"awq_marlin"
:
AWQMarlinConfig
,
"gptq"
:
GPTQConfig
,
"gptq"
:
GPTQConfig
,
"squeezellm"
:
SqueezeLLMConfig
,
"compressed-tensors"
:
CompressedTensorsConfig
,
"compressed-tensors"
:
CompressedTensorsConfig
,
"bitsandbytes"
:
BitsAndBytesConfig
,
"bitsandbytes"
:
BitsAndBytesConfig
,
"qqq"
:
QQQConfig
,
"qqq"
:
QQQConfig
,
...
...
vllm/model_executor/layers/quantization/awq_triton.py
View file @
4851c202
...
@@ -22,7 +22,7 @@ def awq_dequantize_kernel(
...
@@ -22,7 +22,7 @@ def awq_dequantize_kernel(
# Compute offsets and masks for qweight_ptr.
# Compute offsets and masks for qweight_ptr.
offsets_y
=
pid_y
*
BLOCK_SIZE_Y
+
tl
.
arange
(
0
,
BLOCK_SIZE_Y
)
offsets_y
=
pid_y
*
BLOCK_SIZE_Y
+
tl
.
arange
(
0
,
BLOCK_SIZE_Y
)
offsets_x
=
pid_x
*
BLOCK_SIZE_X
+
tl
.
arange
(
0
,
BLOCK_SIZE_X
*
8
)
//
8
offsets_x
=
pid_x
*
BLOCK_SIZE_X
+
tl
.
arange
(
0
,
BLOCK_SIZE_X
)
offsets
=
num_cols
*
offsets_y
[:,
None
]
+
offsets_x
[
None
,
:]
offsets
=
num_cols
*
offsets_y
[:,
None
]
+
offsets_x
[
None
,
:]
masks_y
=
offsets_y
<
num_rows
masks_y
=
offsets_y
<
num_rows
...
@@ -43,6 +43,9 @@ def awq_dequantize_kernel(
...
@@ -43,6 +43,9 @@ def awq_dequantize_kernel(
# Load the weights.
# Load the weights.
iweights
=
tl
.
load
(
qweight_ptr
+
offsets
,
masks
)
iweights
=
tl
.
load
(
qweight_ptr
+
offsets
,
masks
)
iweights
=
tl
.
interleave
(
iweights
,
iweights
)
iweights
=
tl
.
interleave
(
iweights
,
iweights
)
iweights
=
tl
.
interleave
(
iweights
,
iweights
)
# Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
# Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
# that will map given indices to the correct order.
# that will map given indices to the correct order.
...
@@ -59,9 +62,8 @@ def awq_dequantize_kernel(
...
@@ -59,9 +62,8 @@ def awq_dequantize_kernel(
iweights
=
(
iweights
>>
shifts
)
&
0xF
iweights
=
(
iweights
>>
shifts
)
&
0xF
# Compute zero offsets and masks.
# Compute zero offsets and masks.
zero_offsets_y
=
(
pid_y
*
BLOCK_SIZE_Y
//
group_size
+
zero_offsets_y
=
pid_y
*
BLOCK_SIZE_Y
//
group_size
+
tl
.
arange
(
0
,
1
)
tl
.
arange
(
0
,
BLOCK_SIZE_Y
)
//
group_size
)
zero_offsets_x
=
pid_x
*
BLOCK_SIZE_X
+
tl
.
arange
(
0
,
BLOCK_SIZE_X
)
zero_offsets_x
=
pid_x
*
BLOCK_SIZE_X
+
tl
.
arange
(
0
,
BLOCK_SIZE_X
*
8
)
//
8
zero_offsets
=
num_cols
*
zero_offsets_y
[:,
None
]
+
zero_offsets_x
[
None
,
:]
zero_offsets
=
num_cols
*
zero_offsets_y
[:,
None
]
+
zero_offsets_x
[
None
,
:]
zero_masks_y
=
zero_offsets_y
<
num_rows
//
group_size
zero_masks_y
=
zero_offsets_y
<
num_rows
//
group_size
...
@@ -70,13 +72,16 @@ def awq_dequantize_kernel(
...
@@ -70,13 +72,16 @@ def awq_dequantize_kernel(
# Load the zeros.
# Load the zeros.
zeros
=
tl
.
load
(
zeros_ptr
+
zero_offsets
,
zero_masks
)
zeros
=
tl
.
load
(
zeros_ptr
+
zero_offsets
,
zero_masks
)
zeros
=
tl
.
interleave
(
zeros
,
zeros
)
zeros
=
tl
.
interleave
(
zeros
,
zeros
)
zeros
=
tl
.
interleave
(
zeros
,
zeros
)
zeros
=
tl
.
broadcast_to
(
zeros
,
(
BLOCK_SIZE_Y
,
BLOCK_SIZE_X
*
8
))
# Unpack and reorder: shift out the correct 4-bit value and mask.
# Unpack and reorder: shift out the correct 4-bit value and mask.
zeros
=
(
zeros
>>
shifts
)
&
0xF
zeros
=
(
zeros
>>
shifts
)
&
0xF
# Compute scale offsets and masks.
# Compute scale offsets and masks.
scale_offsets_y
=
(
pid_y
*
BLOCK_SIZE_Y
//
group_size
+
scale_offsets_y
=
pid_y
*
BLOCK_SIZE_Y
//
group_size
+
tl
.
arange
(
0
,
1
)
tl
.
arange
(
0
,
BLOCK_SIZE_Y
)
//
group_size
)
scale_offsets_x
=
(
pid_x
*
BLOCK_SIZE_X
*
8
+
scale_offsets_x
=
(
pid_x
*
BLOCK_SIZE_X
*
8
+
tl
.
arange
(
0
,
BLOCK_SIZE_X
*
8
))
tl
.
arange
(
0
,
BLOCK_SIZE_X
*
8
))
scale_offsets
=
(
num_cols
*
8
*
scale_offsets_y
[:,
None
]
+
scale_offsets
=
(
num_cols
*
8
*
scale_offsets_y
[:,
None
]
+
...
@@ -87,6 +92,7 @@ def awq_dequantize_kernel(
...
@@ -87,6 +92,7 @@ def awq_dequantize_kernel(
# Load the scales.
# Load the scales.
scales
=
tl
.
load
(
scales_ptr
+
scale_offsets
,
scale_masks
)
scales
=
tl
.
load
(
scales_ptr
+
scale_offsets
,
scale_masks
)
scales
=
tl
.
broadcast_to
(
scales
,
(
BLOCK_SIZE_Y
,
BLOCK_SIZE_X
*
8
))
# Dequantize.
# Dequantize.
iweights
=
(
iweights
-
zeros
)
*
scales
iweights
=
(
iweights
-
zeros
)
*
scales
...
@@ -137,12 +143,10 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
...
@@ -137,12 +143,10 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
offsets_am
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offsets_am
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
masks_am
=
offsets_am
<
M
masks_am
=
offsets_am
<
M
offsets_bn
=
(
pid_n
*
(
BLOCK_SIZE_N
//
8
)
+
offsets_bn
=
pid_n
*
(
BLOCK_SIZE_N
//
8
)
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
//
8
)
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
//
8
)
masks_bn
=
offsets_bn
<
N
//
8
masks_bn
=
offsets_bn
<
N
//
8
offsets_zn
=
(
pid_n
*
(
BLOCK_SIZE_N
//
8
)
+
offsets_zn
=
pid_n
*
(
BLOCK_SIZE_N
//
8
)
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
//
8
)
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
//
8
)
masks_zn
=
offsets_zn
<
N
//
8
masks_zn
=
offsets_zn
<
N
//
8
offsets_sn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
offsets_sn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
...
@@ -165,22 +169,30 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
...
@@ -165,22 +169,30 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
masks_b
=
masks_k
[:,
None
]
&
masks_bn
[
None
,
:]
masks_b
=
masks_k
[:,
None
]
&
masks_bn
[
None
,
:]
b
=
tl
.
load
(
b_ptrs
,
mask
=
masks_b
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
masks_b
)
b
=
tl
.
interleave
(
b
,
b
)
b
=
tl
.
interleave
(
b
,
b
)
b
=
tl
.
interleave
(
b
,
b
)
# Dequantize b.
# Dequantize b.
offsets_szk
=
(
offsets_szk
=
(
(
BLOCK_SIZE_K
*
SPLIT_K
*
k
+
pid_z
*
BLOCK_SIZE_K
)
//
group_size
+
(
BLOCK_SIZE_K
*
SPLIT_K
*
k
+
pid_z
*
BLOCK_SIZE_K
)
//
group_size
+
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
//
group_size
)
tl
.
arange
(
0
,
1
)
)
offsets_z
=
(
N
//
8
)
*
offsets_szk
[:,
None
]
+
offsets_zn
[
None
,
:]
offsets_z
=
(
N
//
8
)
*
offsets_szk
[:,
None
]
+
offsets_zn
[
None
,
:]
masks_zk
=
offsets_szk
<
K
//
group_size
masks_zk
=
offsets_szk
<
K
//
group_size
masks_z
=
masks_zk
[:,
None
]
&
masks_zn
[
None
,
:]
masks_z
=
masks_zk
[:,
None
]
&
masks_zn
[
None
,
:]
zeros_ptrs
=
zeros_ptr
+
offsets_z
zeros_ptrs
=
zeros_ptr
+
offsets_z
zeros
=
tl
.
load
(
zeros_ptrs
,
mask
=
masks_z
)
zeros
=
tl
.
load
(
zeros_ptrs
,
mask
=
masks_z
)
zeros
=
tl
.
interleave
(
zeros
,
zeros
)
zeros
=
tl
.
interleave
(
zeros
,
zeros
)
zeros
=
tl
.
interleave
(
zeros
,
zeros
)
zeros
=
tl
.
broadcast_to
(
zeros
,
(
BLOCK_SIZE_K
,
BLOCK_SIZE_N
))
offsets_s
=
N
*
offsets_szk
[:,
None
]
+
offsets_sn
[
None
,
:]
offsets_s
=
N
*
offsets_szk
[:,
None
]
+
offsets_sn
[
None
,
:]
masks_sk
=
offsets_szk
<
K
//
group_size
masks_sk
=
offsets_szk
<
K
//
group_size
masks_s
=
masks_sk
[:,
None
]
&
masks_sn
[
None
,
:]
masks_s
=
masks_sk
[:,
None
]
&
masks_sn
[
None
,
:]
scales_ptrs
=
scales_ptr
+
offsets_s
scales_ptrs
=
scales_ptr
+
offsets_s
scales
=
tl
.
load
(
scales_ptrs
,
mask
=
masks_s
)
scales
=
tl
.
load
(
scales_ptrs
,
mask
=
masks_s
)
scales
=
tl
.
broadcast_to
(
scales
,
(
BLOCK_SIZE_K
,
BLOCK_SIZE_N
))
b
=
(
b
>>
shifts
)
&
0xF
b
=
(
b
>>
shifts
)
&
0xF
zeros
=
(
zeros
>>
shifts
)
&
0xF
zeros
=
(
zeros
>>
shifts
)
&
0xF
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
4851c202
...
@@ -116,15 +116,19 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -116,15 +116,19 @@ class CompressedTensorsConfig(QuantizationConfig):
def
_check_scheme_supported
(
self
,
def
_check_scheme_supported
(
self
,
min_capability
:
int
,
min_capability
:
int
,
error
:
bool
=
True
)
->
bool
:
error
:
bool
=
True
)
->
bool
:
capability
=
current_platform
.
get_device_capability
()
capability
=
current_platform
.
get_device_capability
()
# type: ignore
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
supported
=
capability
>=
min_capability
if
capability
is
not
None
:
if
error
and
not
supported
:
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
raise
RuntimeError
(
supported
=
capability
>=
min_capability
"Quantization scheme is not supported for "
,
if
error
and
not
supported
:
f
"the current GPU. Min capability:
{
min_capability
}
. "
,
raise
RuntimeError
(
f
"Current capability:
{
capability
}
."
)
"Quantization scheme is not supported for "
,
return
supported
f
"the current GPU. Min capability:
{
min_capability
}
. "
,
f
"Current capability:
{
capability
}
."
)
return
supported
else
:
return
False
def
_is_static_tensor_w8a8
(
self
,
weight_quant
:
BaseModel
,
def
_is_static_tensor_w8a8
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
input_quant
:
BaseModel
)
->
bool
:
...
@@ -232,7 +236,8 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -232,7 +236,8 @@ class CompressedTensorsConfig(QuantizationConfig):
return
CompressedTensorsWNA16
(
return
CompressedTensorsWNA16
(
num_bits
=
weight_quant
.
num_bits
,
num_bits
=
weight_quant
.
num_bits
,
strategy
=
weight_quant
.
strategy
,
strategy
=
weight_quant
.
strategy
,
group_size
=
weight_quant
.
group_size
)
group_size
=
weight_quant
.
group_size
,
actorder
=
weight_quant
.
actorder
)
# Detect If Activation Quantization.
# Detect If Activation Quantization.
# TODO @dsikka: clean-up conditions
# TODO @dsikka: clean-up conditions
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
4851c202
...
@@ -5,9 +5,7 @@ from typing import Callable, List, Optional
...
@@ -5,9 +5,7 @@ from typing import Callable, List, Optional
import
torch
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe
import
FusedMoEMethodBase
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
,
FusedMoEMethodBase
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
WNA16_SUPPORTED_BITS
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
CompressionFormat
)
CompressionFormat
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
...
@@ -40,11 +38,10 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
...
@@ -40,11 +38,10 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
if
not
(
self
.
quant_config
.
quant_format
if
not
(
self
.
quant_config
.
quant_format
==
CompressionFormat
.
pack_quantized
.
value
==
CompressionFormat
.
pack_quantized
.
value
and
self
.
num_bits
in
WNA16_SUPPORTED_BITS
):
and
self
.
num_bits
==
4
):
raise
ValueError
(
"For Fused MoE layers, only "
,
raise
ValueError
(
"For Fused MoE layers, only "
,
f
"
{
CompressionFormat
.
pack_quantized
.
value
}
"
,
f
"
{
CompressionFormat
.
pack_quantized
.
value
}
"
,
"is supported for the following bits: "
,
"is supported for 4 bits"
)
f
"
{
WNA16_SUPPORTED_BITS
}
"
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
...
@@ -269,19 +266,30 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
...
@@ -269,19 +266,30 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
custom_routing_function
:
Optional
[
Callable
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.fused_
marlin_
moe
import
(
fused_marlin_moe
)
fused_marlin_moe
)
return
fused_marlin_moe
(
x
,
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
layer
.
w13_weight_packed
,
hidden_states
=
x
,
layer
.
w2_weight_packed
,
router_logits
=
router_logits
,
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
layer
.
w13_g_idx
,
top_k
=
top_k
,
layer
.
w2_g_idx
,
renormalize
=
renormalize
,
layer
.
w13_g_idx_sort_indices
,
topk_group
=
topk_group
,
layer
.
w2_g_idx_sort_indices
,
num_expert_group
=
num_expert_group
,
top_k
,
custom_routing_function
=
custom_routing_function
)
custom_routing_function
,
renormalize
=
renormalize
,
return
fused_marlin_moe
(
w1_scale
=
layer
.
w13_weight_scale
,
x
,
w2_scale
=
layer
.
w2_weight_scale
)
layer
.
w13_weight_packed
,
layer
.
w2_weight_packed
,
router_logits
,
layer
.
w13_g_idx
,
layer
.
w2_g_idx
,
layer
.
w13_g_idx_sort_indices
,
layer
.
w2_g_idx_sort_indices
,
topk_weights
,
topk_ids
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
View file @
4851c202
...
@@ -5,20 +5,24 @@ import torch
...
@@ -5,20 +5,24 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
ActivationOrdering
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
apply_gptq_marlin_linear
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
apply_gptq_marlin_linear
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
replace_tensor
,
verify_marlin_supported
,
marlin_permute_scales
,
marlin_repeat_scales_on_all_ranks
,
marlin_sort_g_idx
,
replace_tensor
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
verify_marlin_supports_shape
)
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
ChannelQuantScaleParameter
,
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
GroupQuantScaleParameter
,
PackedvLLMParameter
)
PackedvLLMParameter
,
RowvLLMParameter
)
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
scalar_types
__all__
=
[
"CompressedTensorsWNA16"
]
__all__
=
[
"CompressedTensorsWNA16"
]
WNA16_SUPPORTED_TYPES_MAP
=
{
WNA16_SUPPORTED_TYPES_MAP
=
{
4
:
scalar_types
.
uint4b8
,
4
:
scalar_types
.
uint4b8
,
8
:
scalar_types
.
uint8b128
,
8
:
scalar_types
.
uint8b128
}
}
WNA16_SUPPORTED_BITS
=
list
(
WNA16_SUPPORTED_TYPES_MAP
.
keys
())
WNA16_SUPPORTED_BITS
=
list
(
WNA16_SUPPORTED_TYPES_MAP
.
keys
())
...
@@ -28,11 +32,13 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
...
@@ -28,11 +32,13 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
def
__init__
(
self
,
def
__init__
(
self
,
strategy
:
str
,
strategy
:
str
,
num_bits
:
int
,
num_bits
:
int
,
group_size
:
Optional
[
int
]
=
None
):
group_size
:
Optional
[
int
]
=
None
,
actorder
:
Optional
[
ActivationOrdering
]
=
None
):
self
.
pack_factor
=
32
//
num_bits
self
.
pack_factor
=
32
//
num_bits
self
.
strategy
=
strategy
self
.
strategy
=
strategy
self
.
group_size
=
-
1
if
group_size
is
None
else
group_size
self
.
group_size
=
-
1
if
group_size
is
None
else
group_size
self
.
has_g_idx
=
actorder
==
ActivationOrdering
.
GROUP
if
self
.
group_size
==
-
1
and
self
.
strategy
!=
"channel"
:
if
self
.
group_size
==
-
1
and
self
.
strategy
!=
"channel"
:
raise
ValueError
(
"Marlin kernels require group quantization or "
raise
ValueError
(
"Marlin kernels require group quantization or "
...
@@ -64,12 +70,10 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
...
@@ -64,12 +70,10 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
output_size_per_partition
=
sum
(
output_partition_sizes
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
# If group_size is -1, we are in channelwise case.
# If group_size is -1, we are in channelwise case.
channelwise
=
(
self
.
group_size
==
-
1
)
group_size
=
self
.
group_size
if
self
.
group_size
!=
-
1
else
input_size
group_size
=
self
.
group_size
if
self
.
group_size
!=
-
1
else
input_size
row_parallel
=
(
input_size
!=
input_size_per_partition
)
row_parallel
=
(
input_size
!=
input_size_per_partition
)
# In the case of channelwise quantization, we need to replicate the
partition_scales
=
not
marlin_repeat_scales_on_all_ranks
(
# scales across all gpus.
self
.
has_g_idx
,
self
.
group_size
,
row_parallel
)
partition_scales
=
(
row_parallel
and
not
channelwise
)
verify_marlin_supports_shape
(
verify_marlin_supports_shape
(
output_size_per_partition
=
output_size_per_partition
,
output_size_per_partition
=
output_size_per_partition
,
...
@@ -123,6 +127,16 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
...
@@ -123,6 +127,16 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_shape"
,
weight_shape
)
layer
.
register_parameter
(
"weight_shape"
,
weight_shape
)
# group index (for activation reordering)
if
self
.
has_g_idx
:
weight_g_idx
=
RowvLLMParameter
(
data
=
torch
.
empty
(
input_size_per_partition
,
dtype
=
torch
.
int32
,
),
input_dim
=
0
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight_g_idx"
,
weight_g_idx
)
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
input_size
=
input_size
layer
.
input_size
=
input_size
...
@@ -137,9 +151,14 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
...
@@ -137,9 +151,14 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
layer
.
workspace
=
marlin_make_workspace
(
layer
.
workspace
=
marlin_make_workspace
(
layer
.
output_size_per_partition
,
device
)
layer
.
output_size_per_partition
,
device
)
# Act-order not supported in compressed-tensors yet, so set to empty.
# Handle sorting for activation reordering if needed.
layer
.
g_idx
=
marlin_make_empty_g_idx
(
device
)
if
self
.
has_g_idx
:
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
device
)
g_idx
,
g_idx_sort_indices
=
marlin_sort_g_idx
(
layer
.
weight_g_idx
)
layer
.
g_idx_sort_indices
=
g_idx_sort_indices
replace_tensor
(
layer
,
"weight_g_idx"
,
g_idx
)
else
:
layer
.
weight_g_idx
=
marlin_make_empty_g_idx
(
device
)
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
device
)
# No zero-point
# No zero-point
layer
.
weight_zp
=
marlin_make_empty_g_idx
(
device
)
layer
.
weight_zp
=
marlin_make_empty_g_idx
(
device
)
...
@@ -159,9 +178,11 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
...
@@ -159,9 +178,11 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
replace_tensor
(
layer
,
"weight_packed"
,
marlin_qweight
)
replace_tensor
(
layer
,
"weight_packed"
,
marlin_qweight
)
# Permute scales from compressed-tensors format to marlin format.
# Permute scales from compressed-tensors format to marlin format.
# scale is required on all partitions if activation reordering
marlin_scales
=
marlin_permute_scales
(
marlin_scales
=
marlin_permute_scales
(
layer
.
weight_scale
,
layer
.
weight_scale
,
size_k
=
layer
.
input_size_per_partition
,
size_k
=
(
layer
.
input_size
if
self
.
has_g_idx
else
layer
.
input_size_per_partition
),
size_n
=
layer
.
output_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
group_size
=
layer
.
group_size
)
group_size
=
layer
.
group_size
)
replace_tensor
(
layer
,
"weight_scale"
,
marlin_scales
)
replace_tensor
(
layer
,
"weight_scale"
,
marlin_scales
)
...
@@ -174,7 +195,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
...
@@ -174,7 +195,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
weight
=
layer
.
weight_packed
,
weight
=
layer
.
weight_packed
,
weight_scale
=
layer
.
weight_scale
,
weight_scale
=
layer
.
weight_scale
,
weight_zp
=
layer
.
weight_zp
,
weight_zp
=
layer
.
weight_zp
,
g_idx
=
layer
.
g_idx
,
g_idx
=
layer
.
weight_
g_idx
,
g_idx_sort_indices
=
layer
.
g_idx_sort_indices
,
g_idx_sort_indices
=
layer
.
g_idx_sort_indices
,
workspace
=
layer
.
workspace
,
workspace
=
layer
.
workspace
,
wtype
=
self
.
quant_type
,
wtype
=
self
.
quant_type
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
View file @
4851c202
import
re
import
re
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Union
from
pydantic
import
BaseModel
,
Field
from
pydantic
import
BaseModel
,
Field
,
field_validator
from
torch.nn
import
Module
from
torch.nn
import
Module
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
...
@@ -40,6 +40,19 @@ class QuantizationStrategy(str, Enum):
...
@@ -40,6 +40,19 @@ class QuantizationStrategy(str, Enum):
TOKEN
=
"token"
TOKEN
=
"token"
class
ActivationOrdering
(
str
,
Enum
):
"""
Enum storing strategies for activation ordering
Group: reorder groups and weight
\n
Weight: only reorder weight, not groups. Slightly lower latency and
accuracy compared to group actorder
\n
"""
GROUP
=
"group"
WEIGHT
=
"weight"
class
QuantizationArgs
(
BaseModel
):
class
QuantizationArgs
(
BaseModel
):
"""
"""
User facing arguments used to define a quantization config
User facing arguments used to define a quantization config
...
@@ -58,6 +71,8 @@ class QuantizationArgs(BaseModel):
...
@@ -58,6 +71,8 @@ class QuantizationArgs(BaseModel):
observed with every sample. Defaults to False for static
observed with every sample. Defaults to False for static
quantization. Note that enabling dynamic quantization
quantization. Note that enabling dynamic quantization
will change the default observer to a memoryless one
will change the default observer to a memoryless one
:param actorder: whether to apply group quantization in decreasing order of
activation. Defaults to None for arbitrary ordering
"""
"""
num_bits
:
int
=
8
num_bits
:
int
=
8
...
@@ -67,6 +82,7 @@ class QuantizationArgs(BaseModel):
...
@@ -67,6 +82,7 @@ class QuantizationArgs(BaseModel):
strategy
:
Optional
[
QuantizationStrategy
]
=
None
strategy
:
Optional
[
QuantizationStrategy
]
=
None
block_structure
:
Optional
[
str
]
=
None
block_structure
:
Optional
[
str
]
=
None
dynamic
:
bool
=
False
dynamic
:
bool
=
False
actorder
:
Union
[
ActivationOrdering
,
bool
,
None
]
=
None
observer
:
str
=
Field
(
observer
:
str
=
Field
(
default
=
"minmax"
,
default
=
"minmax"
,
description
=
(
"The class to use to compute the quantization param - "
description
=
(
"The class to use to compute the quantization param - "
...
@@ -79,6 +95,16 @@ class QuantizationArgs(BaseModel):
...
@@ -79,6 +95,16 @@ class QuantizationArgs(BaseModel):
"Observers constructor excluding quantization range or symmetry"
),
"Observers constructor excluding quantization range or symmetry"
),
)
)
@
field_validator
(
"actorder"
,
mode
=
"before"
)
def
validate_actorder
(
cls
,
value
)
->
Optional
[
ActivationOrdering
]:
if
isinstance
(
value
,
bool
):
return
ActivationOrdering
.
GROUP
if
value
else
None
if
isinstance
(
value
,
str
):
return
ActivationOrdering
(
value
.
lower
())
return
value
def
is_activation_quantization_format
(
format
:
str
)
->
bool
:
def
is_activation_quantization_format
(
format
:
str
)
->
bool
:
_ACTIVATION_QUANTIZATION_FORMATS
=
[
_ACTIVATION_QUANTIZATION_FORMATS
=
[
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
4851c202
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
import
torch
import
torch
from
torch.nn
import
Parameter
from
torch.nn
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.fused_moe.layer
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
apply_gptq_marlin_linear
,
check_marlin_supported
,
marlin_is_k_full
,
apply_gptq_marlin_linear
,
check_marlin_supported
,
marlin_is_k_full
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_moe_permute_scales
,
marlin_repeat_scales_on_all_ranks
,
marlin_sort_g_idx
,
replace_tensor
,
marlin_permute_scales
,
marlin_repeat_scales_on_all_ranks
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
marlin_sort_g_idx
,
replace_tensor
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
GroupQuantScaleParameter
,
...
@@ -33,8 +37,14 @@ class GPTQMarlinConfig(QuantizationConfig):
...
@@ -33,8 +37,14 @@ class GPTQMarlinConfig(QuantizationConfig):
(
8
,
True
):
scalar_types
.
uint8b128
,
(
8
,
True
):
scalar_types
.
uint8b128
,
}
}
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
desc_act
:
bool
,
def
__init__
(
is_sym
:
bool
,
lm_head_quantized
:
bool
)
->
None
:
self
,
weight_bits
:
int
,
group_size
:
int
,
desc_act
:
bool
,
is_sym
:
bool
,
lm_head_quantized
:
bool
,
)
->
None
:
if
desc_act
and
group_size
==
-
1
:
if
desc_act
and
group_size
==
-
1
:
# In this case, act_order == True is the same as act_order == False
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
# (since we have only one group per output channel)
...
@@ -51,10 +61,6 @@ class GPTQMarlinConfig(QuantizationConfig):
...
@@ -51,10 +61,6 @@ class GPTQMarlinConfig(QuantizationConfig):
self
.
quant_type
=
self
.
TYPE_MAP
[(
weight_bits
,
is_sym
)]
self
.
quant_type
=
self
.
TYPE_MAP
[(
weight_bits
,
is_sym
)]
# Verify supported on platform.
verify_marlin_supported
(
quant_type
=
self
.
quant_type
,
group_size
=
self
.
group_size
)
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"GPTQMarlinConfig(quant_type=
{
self
.
quant_type
}
, "
return
(
f
"GPTQMarlinConfig(quant_type=
{
self
.
quant_type
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"group_size=
{
self
.
group_size
}
, "
...
@@ -109,11 +115,14 @@ class GPTQMarlinConfig(QuantizationConfig):
...
@@ -109,11 +115,14 @@ class GPTQMarlinConfig(QuantizationConfig):
" faster inference"
)
" faster inference"
)
return
None
return
None
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
def
get_quant_method
(
prefix
:
str
)
->
Optional
[
"GPTQMarlinLinearMethod"
]:
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
if
(
isinstance
(
layer
,
LinearBase
)
or
)
->
Optional
[
Union
[
"GPTQMarlinLinearMethod"
,
"GPTQMarlinMoEMethod"
]]:
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
)):
if
isinstance
(
layer
,
LinearBase
)
or
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
):
return
GPTQMarlinLinearMethod
(
self
)
return
GPTQMarlinLinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
return
GPTQMarlinMoEMethod
(
self
)
return
None
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
...
@@ -153,6 +162,10 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -153,6 +162,10 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
def
__init__
(
self
,
quant_config
:
GPTQMarlinConfig
)
->
None
:
def
__init__
(
self
,
quant_config
:
GPTQMarlinConfig
)
->
None
:
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
# Verify supported on platform.
verify_marlin_supported
(
quant_type
=
self
.
quant_config
.
quant_type
,
group_size
=
self
.
quant_config
.
group_size
)
def
create_weights
(
def
create_weights
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
@@ -179,7 +192,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -179,7 +192,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
output_size_per_partition
=
output_size_per_partition
,
output_size_per_partition
=
output_size_per_partition
,
input_size_per_partition
=
input_size_per_partition
,
input_size_per_partition
=
input_size_per_partition
,
input_size
=
input_size
,
input_size
=
input_size
,
group_size
=
group_size
)
group_size
=
group_size
,
)
# Determine sharding
# Determine sharding
if
marlin_repeat_scales_on_all_ranks
(
self
.
quant_config
.
desc_act
,
if
marlin_repeat_scales_on_all_ranks
(
self
.
quant_config
.
desc_act
,
...
@@ -299,7 +313,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -299,7 +313,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
perm
=
layer
.
g_idx_sort_indices
,
perm
=
layer
.
g_idx_sort_indices
,
size_k
=
layer
.
input_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
num_bits
=
self
.
quant_config
.
quant_type
.
size_bits
)
num_bits
=
self
.
quant_config
.
quant_type
.
size_bits
,
)
replace_tensor
(
layer
,
"qweight"
,
marlin_qweight
)
replace_tensor
(
layer
,
"qweight"
,
marlin_qweight
)
# Permute scales from autogptq format to marlin format.
# Permute scales from autogptq format to marlin format.
...
@@ -308,7 +323,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -308,7 +323,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
size_k
=
(
layer
.
input_size
if
self
.
quant_config
.
desc_act
else
size_k
=
(
layer
.
input_size
if
self
.
quant_config
.
desc_act
else
layer
.
input_size_per_partition
),
layer
.
input_size_per_partition
),
size_n
=
layer
.
output_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
group_size
=
self
.
quant_config
.
group_size
)
group_size
=
self
.
quant_config
.
group_size
,
)
replace_tensor
(
layer
,
"scales"
,
marlin_scales
)
replace_tensor
(
layer
,
"scales"
,
marlin_scales
)
def
apply
(
def
apply
(
...
@@ -329,4 +345,270 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -329,4 +345,270 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
output_size_per_partition
=
layer
.
output_size_per_partition
,
output_size_per_partition
=
layer
.
output_size_per_partition
,
input_size_per_partition
=
layer
.
input_size_per_partition
,
input_size_per_partition
=
layer
.
input_size_per_partition
,
is_k_full
=
layer
.
is_k_full
,
is_k_full
=
layer
.
is_k_full
,
bias
=
bias
)
bias
=
bias
,
)
class
GPTQMarlinMoEMethod
(
FusedMoEMethodBase
):
"""MoE Marlin method with quantization."""
def
__init__
(
self
,
quant_config
:
GPTQMarlinConfig
)
->
None
:
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
# Currently assuming is_k_full is always True
# (input size per partition is the same as full input size)
# Supports only sym for now (no zp)
if
self
.
quant_config
.
group_size
!=
-
1
:
scales_size13
=
hidden_size
//
self
.
quant_config
.
group_size
scales_size2
=
intermediate_size
//
self
.
quant_config
.
group_size
strategy
=
FusedMoeWeightScaleSupported
.
GROUP
.
value
else
:
scales_size13
=
1
scales_size2
=
1
strategy
=
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
extra_weight_attrs
.
update
({
"quant_method"
:
strategy
,
"is_transposed"
:
True
})
# Fused gate_up_proj (column parallel)
w13_qweight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
//
self
.
quant_config
.
pack_factor
,
2
*
intermediate_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_qweight"
,
w13_qweight
)
set_weight_attrs
(
w13_qweight
,
extra_weight_attrs
)
# down_proj (row parallel)
w2_qweight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
intermediate_size
//
self
.
quant_config
.
pack_factor
,
hidden_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_qweight"
,
w2_qweight
)
set_weight_attrs
(
w2_qweight
,
extra_weight_attrs
)
# up_proj scales
w13_scales
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
scales_size13
,
2
*
intermediate_size
,
dtype
=
torch
.
half
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_scales"
,
w13_scales
)
set_weight_attrs
(
w13_scales
,
extra_weight_attrs
)
# down_proj scales
w2_scales
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
scales_size2
,
hidden_size
,
dtype
=
torch
.
half
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_scales"
,
w2_scales
)
set_weight_attrs
(
w2_scales
,
extra_weight_attrs
)
# up_proj scales
w13_qzeros
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
scales_size13
,
2
*
intermediate_size
//
self
.
quant_config
.
pack_factor
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_qzeros"
,
w13_qzeros
)
set_weight_attrs
(
w13_qzeros
,
extra_weight_attrs
)
# down_proj scales
w2_qzeros
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
scales_size2
,
hidden_size
//
self
.
quant_config
.
pack_factor
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_qzeros"
,
w2_qzeros
)
set_weight_attrs
(
w2_qzeros
,
extra_weight_attrs
)
w13_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_g_idx"
,
w13_g_idx
)
set_weight_attrs
(
w13_g_idx
,
extra_weight_attrs
)
w2_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
intermediate_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_g_idx"
,
w2_g_idx
)
set_weight_attrs
(
w2_g_idx
,
extra_weight_attrs
)
w13_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_g_idx_sort_indices"
,
w13_g_idx_sort_indices
)
set_weight_attrs
(
w13_g_idx_sort_indices
,
extra_weight_attrs
)
w2_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
intermediate_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_g_idx_sort_indices"
,
w2_g_idx_sort_indices
)
set_weight_attrs
(
w2_g_idx_sort_indices
,
extra_weight_attrs
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# Process act_order
if
self
.
quant_config
.
desc_act
:
# Get sorting based on g_idx
num_experts
=
layer
.
w13_g_idx
.
shape
[
0
]
w13_g_idx_sort_indices
=
torch
.
empty_like
(
layer
.
w13_g_idx
)
w2_g_idx_sort_indices
=
torch
.
empty_like
(
layer
.
w2_g_idx
)
w13_sorted_g_idx
=
torch
.
empty_like
(
layer
.
w13_g_idx
)
w2_sorted_g_idx
=
torch
.
empty_like
(
layer
.
w2_g_idx
)
for
e
in
range
(
num_experts
):
w13_g_idx_sort_indices
[
e
]
=
torch
.
argsort
(
layer
.
w13_g_idx
[
e
]).
to
(
torch
.
int32
)
w2_g_idx_sort_indices
[
e
]
=
torch
.
argsort
(
layer
.
w2_g_idx
[
e
]).
to
(
torch
.
int32
)
w13_sorted_g_idx
[
e
]
=
layer
.
w13_g_idx
[
e
][
w13_g_idx_sort_indices
[
e
]]
w2_sorted_g_idx
[
e
]
=
layer
.
w2_g_idx
[
e
][
w2_g_idx_sort_indices
[
e
]]
replace_tensor
(
layer
,
"w13_g_idx"
,
w13_sorted_g_idx
)
replace_tensor
(
layer
,
"w2_g_idx"
,
w2_sorted_g_idx
)
replace_tensor
(
layer
,
"w13_g_idx_sort_indices"
,
w13_g_idx_sort_indices
)
replace_tensor
(
layer
,
"w2_g_idx_sort_indices"
,
w2_g_idx_sort_indices
)
else
:
# Reset g_idx related tensors
num_experts
=
layer
.
w13_g_idx
.
shape
[
0
]
device
=
layer
.
w13_g_idx
.
device
layer
.
w13_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
layer
.
w2_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
layer
.
w13_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
layer
.
w2_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
# Repack weights
marlin_w13_qweight
=
ops
.
gptq_marlin_moe_repack
(
layer
.
w13_qweight
,
layer
.
w13_g_idx_sort_indices
,
layer
.
w13_qweight
.
shape
[
1
]
*
self
.
quant_config
.
pack_factor
,
layer
.
w13_qweight
.
shape
[
2
],
self
.
quant_config
.
quant_type
.
size_bits
,
)
replace_tensor
(
layer
,
"w13_qweight"
,
marlin_w13_qweight
)
marlin_w2_qweight
=
ops
.
gptq_marlin_moe_repack
(
layer
.
w2_qweight
,
layer
.
w2_g_idx_sort_indices
,
layer
.
w2_qweight
.
shape
[
1
]
*
self
.
quant_config
.
pack_factor
,
layer
.
w2_qweight
.
shape
[
2
],
self
.
quant_config
.
quant_type
.
size_bits
,
)
replace_tensor
(
layer
,
"w2_qweight"
,
marlin_w2_qweight
)
# Repack scales
marlin_w13_scales
=
marlin_moe_permute_scales
(
s
=
layer
.
w13_scales
,
size_k
=
layer
.
intermediate_size_per_partition
,
size_n
=
layer
.
w13_scales
.
shape
[
2
],
group_size
=
self
.
quant_config
.
group_size
,
)
replace_tensor
(
layer
,
"w13_scales"
,
marlin_w13_scales
)
marlin_w2_scales
=
marlin_moe_permute_scales
(
s
=
layer
.
w2_scales
,
size_k
=
layer
.
w2_scales
.
shape
[
1
]
*
self
.
quant_config
.
pack_factor
,
size_n
=
layer
.
w2_scales
.
shape
[
2
],
group_size
=
self
.
quant_config
.
group_size
,
)
replace_tensor
(
layer
,
"w2_scales"
,
marlin_w2_scales
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.fused_marlin_moe
import
(
fused_marlin_moe
)
# The input must currently be float16
orig_dtype
=
x
.
dtype
x
=
x
.
half
()
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
None
)
return
fused_marlin_moe
(
x
,
layer
.
w13_qweight
,
layer
.
w2_qweight
,
router_logits
,
layer
.
w13_g_idx
,
layer
.
w2_g_idx
,
layer
.
w13_g_idx_sort_indices
,
layer
.
w2_g_idx_sort_indices
,
topk_weights
,
topk_ids
,
w1_scale
=
layer
.
w13_scales
,
w2_scale
=
layer
.
w2_scales
,
).
to
(
orig_dtype
)
vllm/model_executor/layers/quantization/modelopt.py
0 → 100644
View file @
4851c202
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_fp8_linear
,
cutlass_fp8_supported
,
requantize_with_max_scale
)
from
vllm.model_executor.parameter
import
(
ModelWeightParameter
,
PerTensorScaleParameter
)
logger
=
init_logger
(
__name__
)
ACTIVATION_SCHEMES
=
[
"static"
]
class
ModelOptFp8Config
(
QuantizationConfig
):
"""Config class for ModelOpt FP8."""
def
__init__
(
self
,
is_checkpoint_fp8_serialized
:
bool
=
False
,
)
->
None
:
self
.
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
if
is_checkpoint_fp8_serialized
:
logger
.
warning
(
"Detected ModelOpt fp8 checkpoint. Please note that"
" the format is experimental and could change."
)
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"modelopt"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
bfloat16
,
torch
.
half
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
89
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"hf_quant_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"ModelOptFp8Config"
:
quant_config
=
cls
.
get_from_keys
(
config
,
[
"quantization"
])
quant_method
=
quant_config
[
"quant_algo"
]
is_checkpoint_fp8_serialized
=
(
"FP8"
in
quant_method
)
if
not
is_checkpoint_fp8_serialized
:
raise
ValueError
(
"ModelOpt currently only supports static FP8"
"quantization in vLLM. Please check the "
"`hf_quant_config.json` file for your model's "
"quant configuration."
)
return
cls
(
is_checkpoint_fp8_serialized
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
from
vllm.attention.layer
import
Attention
# Avoid circular import
if
isinstance
(
layer
,
LinearBase
):
return
ModelOptFp8LinearMethod
(
self
)
elif
isinstance
(
layer
,
Attention
):
return
ModelOptFp8KVCacheMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
ModelOptFp8KVCacheMethod
(
BaseKVCacheMethod
):
"""
Supports loading kv-cache scaling factors from FP8 checkpoints.
"""
def
__init__
(
self
,
quant_config
:
ModelOptFp8Config
):
super
().
__init__
(
quant_config
)
class
ModelOptFp8LinearMethod
(
LinearMethodBase
):
"""Linear method for Model Optimizer static quantization.
Supports loading FP8 checkpoints with static weight scale and
activation scale. Future support might be added for dynamic
scales.
Limitations:
1. Only support per-tensor quantization due to torch._scaled_mm support.
2. Only support float8_e4m3fn datatype
Args: quant_config: The ModelOpt quantization config.
"""
def
__init__
(
self
,
quant_config
:
ModelOptFp8Config
):
self
.
quant_config
=
quant_config
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
del
input_size
,
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
weight_dtype
=
(
torch
.
float8_e4m3fn
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
else
params_dtype
)
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
weight_dtype
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight"
,
weight
)
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# WEIGHT SCALE
weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
weight_scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
# INPUT SCALE
scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
layer
.
register_parameter
(
"input_scale"
,
scale
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
max_w_scale
,
weight
=
requantize_with_max_scale
(
layer
.
weight
,
layer
.
weight_scale
,
layer
.
logical_widths
)
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
max_w_scale
,
requires_grad
=
False
)
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
apply_fp8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
,
cutlass_fp8_supported
=
self
.
cutlass_fp8_supported
)
vllm/model_executor/layers/quantization/squeezellm.py
deleted
100644 → 0
View file @
9b902f9e
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils
import
is_hip
class
SqueezeLLMConfig
(
QuantizationConfig
):
"""Config class for SqueezeLLM.
Reference: https://arxiv.org/pdf/2306.07629
"""
def
__init__
(
self
,
weight_bits
:
int
,
)
->
None
:
self
.
weight_bits
=
weight_bits
if
self
.
weight_bits
!=
4
:
raise
ValueError
(
"Currently, only 4-bit weight quantization is supported for "
f
"SqueezeLLM, but got
{
self
.
weight_bits
}
bits."
)
self
.
pack_factor
=
32
//
self
.
weight_bits
def
__repr__
(
self
)
->
str
:
return
f
"SqueezeLLMConfig(weight_bits=
{
self
.
weight_bits
}
)"
def
get_name
(
self
)
->
str
:
return
"squeezellm"
def
get_supported_act_dtypes
(
self
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
70
@
staticmethod
def
get_config_filenames
()
->
List
[
str
]:
return
[
"quant_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"SqueezeLLMConfig"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"wbits"
])
return
cls
(
weight_bits
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
QuantizeMethodBase
]:
if
isinstance
(
layer
,
LinearBase
):
return
SqueezeLLMLinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
SqueezeLLMLinearMethod
(
QuantizeMethodBase
):
"""Linear method for SqueezeLLM.
Args:
quant_config: The SqueezeLLM quantization config.
"""
def
__init__
(
self
,
quant_config
:
SqueezeLLMConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
if
input_size_per_partition
%
self
.
quant_config
.
pack_factor
!=
0
:
raise
ValueError
(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size."
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
qweight
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
pack_factor
,
output_size_per_partition
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
0
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
lookup_table
=
Parameter
(
torch
.
empty
(
output_size
,
self
.
quant_config
.
weight_bits
**
2
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
lookup_table
,
{
"output_dim"
:
0
,
})
layer
.
register_parameter
(
"qweight"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"lookup_table"
,
lookup_table
)
set_weight_attrs
(
lookup_table
,
extra_weight_attrs
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
qweight
=
layer
.
qweight
lookup_table
=
layer
.
lookup_table
out_shape
=
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
-
1
],
)
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
if
is_hip
():
out_f
=
torch
.
zeros
(
out_shape
,
dtype
=
torch
.
float
)
ops
.
squeezellm_gemm
(
reshaped_x
,
qweight
,
out_f
,
lookup_table
)
out
=
out_f
.
to
(
dtype
=
torch
.
float16
)
else
:
# NOTE: The output tensor should be zero-initialized.
out
=
torch
.
zeros
(
out_shape
,
dtype
=
torch
.
float16
)
ops
.
squeezellm_gemm
(
reshaped_x
,
qweight
,
out
,
lookup_table
)
if
bias
is
not
None
:
out
.
add_
(
bias
)
return
out
.
reshape
(
out_shape
)
vllm/model_executor/layers/quantization/utils/marlin_utils.py
View file @
4851c202
...
@@ -176,6 +176,23 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
...
@@ -176,6 +176,23 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
return
s
return
s
def
marlin_moe_permute_scales
(
s
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
group_size
:
int
,
):
num_experts
=
s
.
shape
[
0
]
output
=
torch
.
empty
(
(
num_experts
,
s
.
shape
[
1
],
s
.
shape
[
2
]),
device
=
s
.
device
,
dtype
=
s
.
dtype
,
)
for
e
in
range
(
num_experts
):
output
[
e
]
=
marlin_permute_scales
(
s
[
e
],
size_k
,
size_n
,
group_size
)
return
output
def
marlin_zero_points
(
zp
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
def
marlin_zero_points
(
zp
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
num_bits
:
int
)
->
torch
.
Tensor
:
num_bits
:
int
)
->
torch
.
Tensor
:
# Permute zero-points in a similar way to scales, but do not use the
# Permute zero-points in a similar way to scales, but do not use the
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils_test.py
View file @
4851c202
"""Utility functions used for tests and benchmarks"""
"""Utility functions used for tests and benchmarks"""
from
typing
import
List
from
typing
import
List
,
Optional
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -92,8 +92,11 @@ def get_weight_perm(num_bits: int):
...
@@ -92,8 +92,11 @@ def get_weight_perm(num_bits: int):
return
perm
return
perm
def
marlin_quantize
(
w
:
torch
.
Tensor
,
quant_type
:
ScalarType
,
group_size
:
int
,
def
marlin_quantize
(
w
:
torch
.
Tensor
,
act_order
:
bool
):
quant_type
:
ScalarType
,
group_size
:
int
,
act_order
:
bool
,
test_perm
:
Optional
[
torch
.
Tensor
]
=
None
):
size_k
,
size_n
=
w
.
shape
size_k
,
size_n
=
w
.
shape
num_bits
=
quant_type
.
size_bits
num_bits
=
quant_type
.
size_bits
...
@@ -104,7 +107,7 @@ def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int,
...
@@ -104,7 +107,7 @@ def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int,
# Quantize (and apply act_order if provided)
# Quantize (and apply act_order if provided)
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
=
gptq_quantize_weights
(
w_ref
,
q_w
,
s
,
g_idx
,
rand_perm
=
gptq_quantize_weights
(
w
,
quant_type
,
group_size
,
act_order
)
w
,
quant_type
,
group_size
,
act_order
,
test_perm
)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing
# increasing
...
...
vllm/model_executor/layers/quantization/utils/quant_utils.py
View file @
4851c202
"""This file is used for /tests and /benchmarks"""
"""This file is used for /tests and /benchmarks"""
from
typing
import
List
from
typing
import
List
,
Optional
import
numpy
import
numpy
import
torch
import
torch
...
@@ -53,7 +53,10 @@ def get_pack_factor(num_bits):
...
@@ -53,7 +53,10 @@ def get_pack_factor(num_bits):
return
32
//
num_bits
return
32
//
num_bits
def
permute_rows
(
q_w
:
torch
.
Tensor
,
w_ref
:
torch
.
Tensor
,
group_size
:
int
):
def
permute_rows
(
q_w
:
torch
.
Tensor
,
w_ref
:
torch
.
Tensor
,
group_size
:
int
,
test_perm
:
Optional
[
torch
.
Tensor
]
=
None
):
assert
q_w
.
shape
==
w_ref
.
shape
assert
q_w
.
shape
==
w_ref
.
shape
orig_device
=
q_w
.
device
orig_device
=
q_w
.
device
...
@@ -64,7 +67,7 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
...
@@ -64,7 +67,7 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
g_idx
[
i
]
=
i
//
group_size
g_idx
[
i
]
=
i
//
group_size
# Simulate act_order by doing a random permutation on K
# Simulate act_order by doing a random permutation on K
rand_perm
=
torch
.
randperm
(
k_size
)
rand_perm
=
test_perm
if
test_perm
is
not
None
else
torch
.
randperm
(
k_size
)
g_idx
=
g_idx
[
rand_perm
].
contiguous
()
g_idx
=
g_idx
[
rand_perm
].
contiguous
()
q_w
=
q_w
[
rand_perm
,
:].
contiguous
()
q_w
=
q_w
[
rand_perm
,
:].
contiguous
()
...
@@ -164,8 +167,11 @@ def quantize_weights(w: torch.Tensor,
...
@@ -164,8 +167,11 @@ def quantize_weights(w: torch.Tensor,
)
)
def
gptq_quantize_weights
(
w
:
torch
.
Tensor
,
quant_type
:
ScalarType
,
def
gptq_quantize_weights
(
w
:
torch
.
Tensor
,
group_size
:
int
,
act_order
:
bool
):
quant_type
:
ScalarType
,
group_size
:
int
,
act_order
:
bool
,
test_perm
:
Optional
[
torch
.
Tensor
]
=
None
):
size_k
,
_
=
w
.
shape
size_k
,
_
=
w
.
shape
assert
w
.
is_floating_point
(),
"w must be float"
assert
w
.
is_floating_point
(),
"w must be float"
...
@@ -186,7 +192,8 @@ def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType,
...
@@ -186,7 +192,8 @@ def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType,
),
"For act_order, groupsize = {} must be less than size_k = {}"
.
format
(
),
"For act_order, groupsize = {} must be less than size_k = {}"
.
format
(
group_size
,
size_k
)
group_size
,
size_k
)
w_ref
,
w_q
,
g_idx
,
rand_perm
=
permute_rows
(
w_q
,
w_ref
,
group_size
)
w_ref
,
w_q
,
g_idx
,
rand_perm
=
permute_rows
(
w_q
,
w_ref
,
group_size
,
test_perm
)
return
w_ref
,
w_q
,
w_s
,
g_idx
,
rand_perm
return
w_ref
,
w_q
,
w_s
,
g_idx
,
rand_perm
...
...
vllm/model_executor/layers/resampler.py
0 → 100644
View file @
4851c202
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
#
# Copyright 2023 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Shared resampler perceiver network used in multimodal models and
related helpers for sincos positional embeddings.
Example models: Qwen (Qwen-VL), Minicpmv2.0
"""
import
math
from
functools
import
partial
from
typing
import
Callable
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch.nn.init
import
trunc_normal_
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
DEFAULT_LN
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
def
get_abs_pos
(
abs_pos
:
torch
.
Tensor
,
tgt_size
:
Union
[
torch
.
Tensor
,
int
])
->
torch
.
Tensor
:
# abs_pos: L, C
# tgt_size: (H, W)
# return: M, C
src_size
=
int
(
math
.
sqrt
(
abs_pos
.
size
(
0
)))
dtype
=
abs_pos
.
dtype
if
isinstance
(
tgt_size
,
int
):
tgt_size
=
(
tgt_size
,
tgt_size
)
if
(
src_size
==
tgt_size
[
0
]
and
src_size
==
tgt_size
[
1
]):
return
abs_pos
return
(
F
.
interpolate
(
abs_pos
.
float
().
reshape
(
1
,
src_size
,
src_size
,
-
1
).
permute
(
0
,
3
,
1
,
2
),
size
=
(
tgt_size
[
0
],
tgt_size
[
1
]),
mode
=
"bicubic"
,
align_corners
=
False
,
).
permute
(
0
,
2
,
3
,
1
).
flatten
(
0
,
2
).
to
(
dtype
=
dtype
))
# sin/cos positional embedding helpers are adapted from:
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def
get_1d_sincos_pos_embed_from_grid
(
embed_dim
:
int
,
pos
:
np
.
ndarray
,
version
:
Tuple
[
int
,
int
]
=
(
2
,
0
))
->
torch
.
Tensor
:
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,) / (H, W)
out: (M, D) / (H, W, D)
"""
assert
embed_dim
%
2
==
0
omega
=
np
.
arange
(
embed_dim
//
2
,
dtype
=
np
.
float32
)
omega
/=
embed_dim
/
2.0
omega
=
1.0
/
10000
**
omega
# (D/2,)
if
version
==
(
2
,
0
):
pos
=
pos
.
reshape
(
-
1
)
# (M,)
out
=
np
.
einsum
(
"m,d->md"
,
pos
,
omega
)
# (M, D/2), outer product
emb_sin
=
np
.
sin
(
out
)
# (M, D/2)
emb_cos
=
np
.
cos
(
out
)
# (M, D/2)
emb
=
np
.
concatenate
([
emb_sin
,
emb_cos
],
axis
=
1
)
# (M, D)
else
:
out
=
np
.
einsum
(
"hw,d->hwd"
,
pos
,
omega
)
# (H, W, D/2), outer product
emb_sin
=
np
.
sin
(
out
)
# (H, W, D/2)
emb_cos
=
np
.
cos
(
out
)
# (H, W, D/2)
emb
=
np
.
concatenate
([
emb_sin
,
emb_cos
],
axis
=-
1
)
# (H, W, D)
return
emb
def
get_2d_sincos_pos_embed_from_grid
(
embed_dim
:
int
,
grid
:
np
.
ndarray
,
version
:
Tuple
[
int
,
int
]
=
(
2
,
0
))
->
torch
.
Tensor
:
assert
embed_dim
%
2
==
0
# use half of dimensions to encode grid_h
emb_h
=
get_1d_sincos_pos_embed_from_grid
(
embed_dim
//
2
,
grid
[
0
],
version
)
# (H*W, D/2) or (H, W, D/2)
emb_w
=
get_1d_sincos_pos_embed_from_grid
(
embed_dim
//
2
,
grid
[
1
],
version
)
# (H*W, D/2) or (H, W, D/2)
if
version
==
(
2
,
0
):
emb
=
np
.
concatenate
([
emb_h
,
emb_w
],
axis
=
1
)
# (H*W, D)
else
:
emb
=
np
.
concatenate
([
emb_h
,
emb_w
],
axis
=-
1
)
# (H, W, D)
return
emb
def
get_2d_sincos_pos_embed
(
embed_dim
:
int
,
grid_size
:
Union
[
int
,
Tuple
[
int
,
int
]],
cls_token
:
bool
=
False
,
version
:
Tuple
[
int
,
int
]
=
(
2
,
0
),
)
->
torch
.
Tensor
:
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if
isinstance
(
grid_size
,
int
):
grid_h_size
,
grid_w_size
=
grid_size
,
grid_size
else
:
grid_h_size
,
grid_w_size
=
grid_size
[
0
],
grid_size
[
1
]
grid_h
=
np
.
arange
(
grid_h_size
,
dtype
=
np
.
float32
)
grid_w
=
np
.
arange
(
grid_w_size
,
dtype
=
np
.
float32
)
grid
=
np
.
meshgrid
(
grid_w
,
grid_h
)
# here w goes first
grid
=
np
.
stack
(
grid
,
axis
=
0
)
assert
isinstance
(
grid
,
np
.
ndarray
)
and
\
grid
.
shape
==
(
2
,
grid_h_size
,
grid_w_size
)
if
version
==
(
2
,
0
):
grid
=
grid
.
reshape
([
2
,
1
,
grid_h_size
,
grid_w_size
])
pos_embed
=
get_2d_sincos_pos_embed_from_grid
(
embed_dim
,
grid
,
version
)
if
cls_token
:
pos_embed
=
np
.
concatenate
([
np
.
zeros
([
1
,
embed_dim
]),
pos_embed
],
axis
=
0
)
else
:
pos_embed
=
get_2d_sincos_pos_embed_from_grid
(
embed_dim
,
grid
,
version
)
return
pos_embed
class
BaseResampler
(
nn
.
Module
):
"""
A 2D perceiver-resampler network with one cross attention layers by
(grid_size**2) learnable queries and 2d sincos pos_emb.
Outputs:
A tensor with the shape of (grid_size**2, embed_dim)
"""
def
__init__
(
self
,
num_queries
:
int
,
embed_dim
:
int
,
num_heads
:
int
,
kv_dim
:
Optional
[
int
]
=
None
,
norm_layer
:
Callable
[[
int
],
nn
.
LayerNorm
]
=
DEFAULT_LN
,
do_post_projection
:
bool
=
True
,
)
->
None
:
super
().
__init__
()
self
.
num_queries
=
num_queries
self
.
embed_dim
=
embed_dim
self
.
num_heads
=
num_heads
self
.
query
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_queries
,
embed_dim
))
trunc_normal_
(
self
.
query
,
std
=
0.02
)
if
kv_dim
is
not
None
and
kv_dim
!=
embed_dim
:
self
.
kv_proj
=
ReplicatedLinear
(
kv_dim
,
embed_dim
,
bias
=
False
)
else
:
# Maintain the same return value with ReplicatedLinear.forward
self
.
kv_proj
=
lambda
*
args
,
**
kwargs
:
(
# type: ignore # noqa
nn
.
Identity
()(
*
args
,
**
kwargs
),
None
,
)
self
.
attn
=
nn
.
MultiheadAttention
(
embed_dim
,
num_heads
)
self
.
ln_q
=
norm_layer
(
embed_dim
)
self
.
ln_kv
=
norm_layer
(
embed_dim
)
self
.
do_post_projection
=
do_post_projection
self
.
ln_post
=
norm_layer
(
embed_dim
)
if
do_post_projection
else
None
self
.
proj
=
nn
.
Parameter
(
(
embed_dim
**-
0.5
)
*
torch
.
randn
(
embed_dim
,
embed_dim
))
if
do_post_projection
else
None
def
_init_weights
(
self
,
m
:
nn
.
Module
)
->
None
:
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
0.02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
_repeat
(
self
,
query
,
N
:
int
):
return
query
.
unsqueeze
(
1
).
repeat
(
1
,
N
,
1
)
class
Resampler2
(
BaseResampler
):
"""Resampler-perceiver network to be used for a variety of model types,
e.g., Qwen-vl / Minicpmv 2.0. The main difference is the addition of the
do_post_projection arg, which indicates whether or not there should be
a post layer normalization and projector after the attention. This is
present in minicpmv2.0, but not qwen-vl.
"""
def
__init__
(
self
,
grid_size
:
int
,
embed_dim
:
int
,
num_heads
:
int
,
kv_dim
:
Optional
[
int
]
=
None
,
norm_layer
:
Callable
[[
int
],
nn
.
LayerNorm
]
=
DEFAULT_LN
,
adaptive
:
bool
=
False
,
do_post_projection
:
bool
=
True
,
)
->
None
:
super
().
__init__
(
grid_size
**
2
,
embed_dim
,
num_heads
,
kv_dim
,
norm_layer
,
do_post_projection
=
do_post_projection
)
self
.
adaptive
=
adaptive
pos_embed_arr
=
get_2d_sincos_pos_embed
(
embed_dim
,
grid_size
,
version
=
(
2
,
0
))
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
from_numpy
(
pos_embed_arr
).
requires_grad_
(
False
))
self
.
apply
(
self
.
_init_weights
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
tgt_sizes
:
Optional
[
torch
.
Tensor
]
=
None
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
tgt_sizes
is
None
:
tgt_sizes
=
int
(
math
.
sqrt
(
x
.
size
(
1
)))
if
self
.
adaptive
:
pos_embed_arr
=
get_2d_sincos_pos_embed
(
self
.
embed_dim
,
tgt_sizes
,
version
=
(
2
,
0
))
pos_embed
=
torch
.
from_numpy
(
pos_embed_arr
).
to
(
device
=
x
.
device
,
dtype
=
x
.
dtype
)
else
:
pos_embed
=
get_abs_pos
(
self
.
pos_embed
,
tgt_sizes
).
to
(
device
=
x
.
device
,
dtype
=
x
.
dtype
)
x
,
_
=
self
.
kv_proj
(
x
)
x
=
self
.
ln_kv
(
x
).
permute
(
1
,
0
,
2
)
N
=
x
.
shape
[
1
]
q
=
self
.
ln_q
(
self
.
query
)
out
=
self
.
attn
(
self
.
_repeat
(
q
,
N
)
+
self
.
pos_embed
.
unsqueeze
(
1
),
x
+
pos_embed
.
unsqueeze
(
1
),
x
,
attn_mask
=
attn_mask
,
)[
0
]
x
=
out
.
permute
(
1
,
0
,
2
)
if
self
.
do_post_projection
:
x
=
self
.
ln_post
(
x
)
x
=
x
@
self
.
proj
return
x
vllm/model_executor/layers/rotary_embedding.py
View file @
4851c202
...
@@ -28,7 +28,6 @@ import torch
...
@@ -28,7 +28,6 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.platforms
import
current_platform
def
_rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -48,21 +47,29 @@ def _apply_rotary_emb(
...
@@ -48,21 +47,29 @@ def _apply_rotary_emb(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
is_neox_style
:
bool
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
x: [num_tokens, num_heads, head_size]
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
"""
orig_dtype
=
x
.
dtype
cos
=
cos
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
x
=
x
.
float
()
sin
=
sin
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
x1
,
x2
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
if
is_neox_style
:
cos
=
cos
.
unsqueeze
(
-
2
)
x1
,
x2
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
sin
=
sin
.
unsqueeze
(
-
2
)
else
:
x1
=
x
[...,
::
2
]
x2
=
x
[...,
1
::
2
]
o1
=
x1
*
cos
-
x2
*
sin
o1
=
x1
*
cos
-
x2
*
sin
o2
=
x2
*
cos
+
x1
*
sin
o2
=
x2
*
cos
+
x1
*
sin
return
torch
.
cat
((
o1
,
o2
),
dim
=-
1
).
to
(
orig_dtype
)
if
is_neox_style
:
return
torch
.
cat
((
o1
,
o2
),
dim
=-
1
)
else
:
return
torch
.
stack
((
o1
,
o2
),
dim
=-
1
).
flatten
(
-
2
)
class
RotaryEmbedding
(
CustomOp
):
class
RotaryEmbedding
(
CustomOp
):
...
@@ -87,10 +94,9 @@ class RotaryEmbedding(CustomOp):
...
@@ -87,10 +94,9 @@ class RotaryEmbedding(CustomOp):
cache
=
self
.
_compute_cos_sin_cache
()
cache
=
self
.
_compute_cos_sin_cache
()
cache
=
cache
.
to
(
dtype
)
cache
=
cache
.
to
(
dtype
)
self
.
cos_sin_cache
:
torch
.
Tensor
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
self
.
use_native2
=
current_platform
.
is_tpu
()
and
is_neox_style
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
"""Compute the inverse frequency."""
"""Compute the inverse frequency."""
# NOTE(woosuk): To exactly match the HF implementation, we need to
# NOTE(woosuk): To exactly match the HF implementation, we need to
...
@@ -119,59 +125,7 @@ class RotaryEmbedding(CustomOp):
...
@@ -119,59 +125,7 @@ class RotaryEmbedding(CustomOp):
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""A PyTorch-native implementation equivalent to forward().
"""A PyTorch-native implementation of forward()."""
This method mimics the implementation of the custom CUDA kernel
used in `forward_cuda()`.
"""
query
=
query
.
view
(
*
query
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
key
=
key
.
view
(
*
key
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
if
self
.
rotary_dim
<
self
.
head_size
:
query_pass
=
query
[...,
self
.
rotary_dim
:]
key_pass
=
key
[...,
self
.
rotary_dim
:]
self
.
cos_sin_cache
:
torch
.
Tensor
=
self
.
cos_sin_cache
.
to
(
positions
.
device
,
dtype
=
query
.
dtype
)
cos_sin
=
self
.
cos_sin_cache
[
torch
.
add
(
positions
,
offsets
)
if
offsets
is
not
None
else
positions
]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
if
self
.
is_neox_style
:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos
=
cos
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
else
:
cos
=
cos
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
rotate_fn
=
_rotate_neox
if
self
.
is_neox_style
else
_rotate_gptj
query_rot
=
query_rot
*
cos
+
rotate_fn
(
query_rot
)
*
sin
key_rot
=
key_rot
*
cos
+
rotate_fn
(
key_rot
)
*
sin
if
self
.
rotary_dim
<
self
.
head_size
:
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
)
else
:
query
=
query_rot
key
=
key_rot
query
=
query
.
flatten
(
-
2
)
key
=
key
.
flatten
(
-
2
)
return
query
,
key
def
forward_native2
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Another PyTorch-native implementation of forward().
This method might perform better than `forward_native()` when compiled.
"""
if
offsets
is
not
None
:
if
offsets
is
not
None
:
positions
=
positions
+
offsets
positions
=
positions
+
offsets
positions
=
positions
.
flatten
()
positions
=
positions
.
flatten
()
...
@@ -183,14 +137,14 @@ class RotaryEmbedding(CustomOp):
...
@@ -183,14 +137,14 @@ class RotaryEmbedding(CustomOp):
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
_apply_rotary_emb
(
query_rot
,
cos
,
sin
)
query_rot
=
_apply_rotary_emb
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
key_shape
=
key
.
shape
key_shape
=
key
.
shape
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
_apply_rotary_emb
(
key_rot
,
cos
,
sin
)
key_rot
=
_apply_rotary_emb
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
return
query
,
key
...
@@ -203,7 +157,7 @@ class RotaryEmbedding(CustomOp):
...
@@ -203,7 +157,7 @@ class RotaryEmbedding(CustomOp):
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
positions
.
device
,
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
query
.
device
,
dtype
=
query
.
dtype
)
dtype
=
query
.
dtype
)
# ops.rotary_embedding()/batched_rotary_embedding()
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
# are in-place operations that update the query and key tensors.
...
@@ -240,17 +194,6 @@ class RotaryEmbedding(CustomOp):
...
@@ -240,17 +194,6 @@ class RotaryEmbedding(CustomOp):
self
.
cos_sin_cache
,
self
.
is_neox_style
)
self
.
cos_sin_cache
,
self
.
is_neox_style
)
return
query
,
key
return
query
,
key
def
forward_tpu
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
forward_fn
=
(
self
.
forward_native2
if
self
.
use_native2
else
self
.
forward_native
)
return
forward_fn
(
positions
,
query
,
key
,
offsets
)
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
s
=
f
"head_size=
{
self
.
head_size
}
, rotary_dim=
{
self
.
rotary_dim
}
"
s
=
f
"head_size=
{
self
.
head_size
}
, rotary_dim=
{
self
.
rotary_dim
}
"
s
+=
f
", max_position_embeddings=
{
self
.
max_position_embeddings
}
"
s
+=
f
", max_position_embeddings=
{
self
.
max_position_embeddings
}
"
...
@@ -769,6 +712,179 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
...
@@ -769,6 +712,179 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
return
new_freqs
return
new_freqs
class
MRotaryEmbedding
(
RotaryEmbedding
):
"""Rotary Embedding with Multimodal Sections."""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
mrope_section
:
Optional
[
List
[
int
]]
=
None
,
)
->
None
:
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
self
.
mrope_section
=
mrope_section
if
self
.
mrope_section
:
assert
sum
(
self
.
mrope_section
)
==
rotary_dim
//
2
def
forward
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""PyTorch-native implementation equivalent to forward().
Args:
positions:
[num_tokens,] (text only) or
[3, num_tokens] (T/H/W positions with multimodal inputs)
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
assert
positions
.
ndim
==
1
or
positions
.
ndim
==
2
num_tokens
=
positions
.
shape
[
-
1
]
cos_sin
=
self
.
cos_sin_cache
[
positions
]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
if
positions
.
ndim
==
2
:
assert
self
.
mrope_section
cos
=
torch
.
cat
([
m
[
i
]
for
i
,
m
in
enumerate
(
cos
.
split
(
self
.
mrope_section
,
dim
=-
1
))
],
dim
=-
1
)
sin
=
torch
.
cat
([
m
[
i
]
for
i
,
m
in
enumerate
(
sin
.
split
(
self
.
mrope_section
,
dim
=-
1
))
],
dim
=-
1
)
query_shape
=
query
.
shape
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
_apply_rotary_emb
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
key_shape
=
key
.
shape
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
_apply_rotary_emb
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
@
staticmethod
def
get_input_positions
(
input_tokens
:
List
[
int
],
image_grid_thw
:
Union
[
List
[
List
[
int
]],
torch
.
Tensor
],
video_grid_thw
:
Union
[
List
[
List
[
int
]],
torch
.
Tensor
],
image_token_id
:
int
,
video_token_id
:
int
,
vision_start_token_id
:
int
,
vision_end_token_id
:
int
,
spatial_merge_size
:
int
,
context_len
:
int
=
0
,
)
->
Tuple
[
List
[
List
[
int
]],
int
]:
"""Get mrope input positions and delta value."""
if
isinstance
(
image_grid_thw
,
torch
.
Tensor
):
image_grid_thw
=
image_grid_thw
.
tolist
()
if
isinstance
(
video_grid_thw
,
torch
.
Tensor
):
video_grid_thw
=
video_grid_thw
.
tolist
()
input_tokens_tensor
=
torch
.
tensor
(
input_tokens
)
vision_start_indices
=
torch
.
argwhere
(
input_tokens_tensor
==
vision_start_token_id
).
squeeze
(
1
)
vision_tokens
=
input_tokens_tensor
[
vision_start_indices
+
1
]
image_nums
=
(
vision_tokens
==
image_token_id
).
sum
()
video_nums
=
(
vision_tokens
==
video_token_id
).
sum
()
llm_pos_ids_list
:
list
=
[]
st
=
0
remain_images
,
remain_videos
=
image_nums
,
video_nums
image_index
,
video_index
=
0
,
0
for
_
in
range
(
image_nums
+
video_nums
):
if
image_token_id
in
input_tokens
and
remain_images
>
0
:
ed_image
=
input_tokens
.
index
(
image_token_id
,
st
)
else
:
ed_image
=
len
(
input_tokens
)
+
1
if
video_token_id
in
input_tokens
and
remain_videos
>
0
:
ed_video
=
input_tokens
.
index
(
video_token_id
,
st
)
else
:
ed_video
=
len
(
input_tokens
)
+
1
if
ed_image
<
ed_video
:
t
,
h
,
w
=
(
image_grid_thw
[
image_index
][
0
],
image_grid_thw
[
image_index
][
1
],
image_grid_thw
[
image_index
][
2
],
)
image_index
+=
1
remain_images
-=
1
ed
=
ed_image
else
:
t
,
h
,
w
=
(
video_grid_thw
[
video_index
][
0
],
video_grid_thw
[
video_index
][
1
],
video_grid_thw
[
video_index
][
2
],
)
video_index
+=
1
remain_videos
-=
1
ed
=
ed_video
llm_grid_t
,
llm_grid_h
,
llm_grid_w
=
\
t
,
h
//
spatial_merge_size
,
w
//
spatial_merge_size
text_len
=
ed
-
st
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
t_index
=
torch
.
arange
(
llm_grid_t
).
view
(
-
1
,
1
).
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
).
flatten
()
h_index
=
torch
.
arange
(
llm_grid_h
).
view
(
1
,
-
1
,
1
).
expand
(
llm_grid_t
,
-
1
,
llm_grid_w
).
flatten
()
w_index
=
torch
.
arange
(
llm_grid_w
).
view
(
1
,
1
,
-
1
).
expand
(
llm_grid_t
,
llm_grid_h
,
-
1
).
flatten
()
llm_pos_ids_list
.
append
(
torch
.
stack
([
t_index
,
h_index
,
w_index
])
+
text_len
+
st_idx
)
st
=
ed
+
llm_grid_t
*
llm_grid_h
*
llm_grid_w
if
st
<
len
(
input_tokens
):
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
text_len
=
len
(
input_tokens
)
-
st
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
llm_positions
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
).
reshape
(
3
,
-
1
)
llm_positions
=
llm_positions
[:,
context_len
:]
mrope_position_delta
=
(
llm_positions
.
max
()
+
1
-
len
(
input_tokens
)).
item
()
return
llm_positions
.
tolist
(),
mrope_position_delta
@
staticmethod
def
get_next_input_positions
(
mrope_position_delta
:
int
,
context_len
:
int
,
seq_len
:
int
,
)
->
List
[
List
[
int
]]:
return
[
list
(
range
(
context_len
+
mrope_position_delta
,
seq_len
+
mrope_position_delta
))
for
_
in
range
(
3
)
]
_ROPE_DICT
:
Dict
[
Tuple
,
RotaryEmbedding
]
=
{}
_ROPE_DICT
:
Dict
[
Tuple
,
RotaryEmbedding
]
=
{}
...
@@ -809,7 +925,7 @@ def get_rope(
...
@@ -809,7 +925,7 @@ def get_rope(
# The correct one should be "longrope" but keep "su" here
# The correct one should be "longrope" but keep "su" here
# for backward compatible
# for backward compatible
if
scaling_type
not
in
{
"su"
,
"longrope"
}:
if
scaling_type
not
in
{
"su"
,
"longrope"
}:
scaling_factor
=
rope_scaling
[
"factor"
]
scaling_factor
=
rope_scaling
.
get
(
"factor"
,
1.0
)
if
scaling_type
==
"llama3"
:
if
scaling_type
==
"llama3"
:
low_freq_factor
=
rope_scaling
[
"low_freq_factor"
]
low_freq_factor
=
rope_scaling
[
"low_freq_factor"
]
high_freq_factor
=
rope_scaling
[
"high_freq_factor"
]
high_freq_factor
=
rope_scaling
[
"high_freq_factor"
]
...
@@ -873,6 +989,16 @@ def get_rope(
...
@@ -873,6 +989,16 @@ def get_rope(
head_size
,
rotary_dim
,
max_position
,
original_max_position
,
head_size
,
rotary_dim
,
max_position
,
original_max_position
,
base
,
is_neox_style
,
dtype
,
short_factor
,
long_factor
,
base
,
is_neox_style
,
dtype
,
short_factor
,
long_factor
,
**
extra_kwargs
)
**
extra_kwargs
)
elif
scaling_type
==
"mrope"
:
return
MRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
,
mrope_section
=
rope_scaling
[
"mrope_section"
],
)
else
:
else
:
raise
ValueError
(
f
"Unknown RoPE scaling type
{
scaling_type
}
"
)
raise
ValueError
(
f
"Unknown RoPE scaling type
{
scaling_type
}
"
)
_ROPE_DICT
[
key
]
=
rotary_emb
_ROPE_DICT
[
key
]
=
rotary_emb
...
...
vllm/model_executor/model_loader/loader.py
View file @
4851c202
...
@@ -17,6 +17,7 @@ import torch
...
@@ -17,6 +17,7 @@ import torch
from
huggingface_hub
import
HfApi
,
hf_hub_download
from
huggingface_hub
import
HfApi
,
hf_hub_download
from
torch
import
nn
from
torch
import
nn
from
transformers
import
AutoModelForCausalLM
,
PretrainedConfig
from
transformers
import
AutoModelForCausalLM
,
PretrainedConfig
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoadFormat
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoadFormat
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
...
@@ -94,8 +95,9 @@ def _get_quantization_config(
...
@@ -94,8 +95,9 @@ def _get_quantization_config(
"""Get the quantization config."""
"""Get the quantization config."""
if
model_config
.
quantization
is
not
None
:
if
model_config
.
quantization
is
not
None
:
quant_config
=
get_quant_config
(
model_config
,
load_config
)
quant_config
=
get_quant_config
(
model_config
,
load_config
)
if
not
current_platform
.
is_tpu
():
capability
=
current_platform
.
get_device_capability
()
# type: ignore
capability
=
current_platform
.
get_device_capability
()
if
capability
is
not
None
:
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
if
capability
<
quant_config
.
get_min_capability
():
if
capability
<
quant_config
.
get_min_capability
():
raise
ValueError
(
raise
ValueError
(
...
@@ -187,6 +189,11 @@ class BaseModelLoader(ABC):
...
@@ -187,6 +189,11 @@ class BaseModelLoader(ABC):
def
__init__
(
self
,
load_config
:
LoadConfig
):
def
__init__
(
self
,
load_config
:
LoadConfig
):
self
.
load_config
=
load_config
self
.
load_config
=
load_config
@
abstractmethod
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
"""Download a model so that it can be immediately loaded."""
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
device_config
:
DeviceConfig
,
...
@@ -195,7 +202,7 @@ class BaseModelLoader(ABC):
...
@@ -195,7 +202,7 @@ class BaseModelLoader(ABC):
scheduler_config
:
SchedulerConfig
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
)
->
nn
.
Module
:
cache_config
:
CacheConfig
)
->
nn
.
Module
:
"""Load a model with the given configurations."""
"""Load a model with the given configurations."""
...
raise
NotImplementedError
class
DefaultModelLoader
(
BaseModelLoader
):
class
DefaultModelLoader
(
BaseModelLoader
):
...
@@ -244,12 +251,17 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -244,12 +251,17 @@ class DefaultModelLoader(BaseModelLoader):
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
load_format
=
self
.
load_config
.
load_format
load_format
=
self
.
load_config
.
load_format
use_safetensors
=
False
use_safetensors
=
False
index_file
=
SAFE_WEIGHTS_INDEX_NAME
# Some quantized models use .pt files for storing the weights.
# Some quantized models use .pt files for storing the weights.
if
load_format
==
LoadFormat
.
AUTO
:
if
load_format
==
LoadFormat
.
AUTO
:
allow_patterns
=
[
"*.safetensors"
,
"*.bin"
]
allow_patterns
=
[
"*.safetensors"
,
"*.bin"
]
elif
load_format
==
LoadFormat
.
SAFETENSORS
:
elif
load_format
==
LoadFormat
.
SAFETENSORS
:
use_safetensors
=
True
use_safetensors
=
True
allow_patterns
=
[
"*.safetensors"
]
allow_patterns
=
[
"*.safetensors"
]
elif
load_format
==
LoadFormat
.
MISTRAL
:
use_safetensors
=
True
allow_patterns
=
[
"consolidated*.safetensors"
]
index_file
=
"consolidated.safetensors.index.json"
elif
load_format
==
LoadFormat
.
PT
:
elif
load_format
==
LoadFormat
.
PT
:
allow_patterns
=
[
"*.pt"
]
allow_patterns
=
[
"*.pt"
]
elif
load_format
==
LoadFormat
.
NPCACHE
:
elif
load_format
==
LoadFormat
.
NPCACHE
:
...
@@ -287,10 +299,10 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -287,10 +299,10 @@ class DefaultModelLoader(BaseModelLoader):
# any files not found in the index.
# any files not found in the index.
if
not
is_local
:
if
not
is_local
:
download_safetensors_index_file_from_hf
(
download_safetensors_index_file_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
model_name_or_path
,
index_file
,
revision
)
self
.
load_config
.
download_dir
,
revision
)
hf_weights_files
=
filter_duplicate_safetensors_files
(
hf_weights_files
=
filter_duplicate_safetensors_files
(
hf_weights_files
,
hf_folder
)
hf_weights_files
,
hf_folder
,
index_file
)
else
:
else
:
hf_weights_files
=
filter_files_not_needed_for_inference
(
hf_weights_files
=
filter_files_not_needed_for_inference
(
hf_weights_files
)
hf_weights_files
)
...
@@ -332,6 +344,11 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -332,6 +344,11 @@ class DefaultModelLoader(BaseModelLoader):
weights_iterator
=
_xla_weights_iterator
(
weights_iterator
)
weights_iterator
=
_xla_weights_iterator
(
weights_iterator
)
return
weights_iterator
return
weights_iterator
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
,
fall_back_to_pt
=
True
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
lora_config
:
Optional
[
LoRAConfig
],
...
@@ -374,6 +391,9 @@ class DummyModelLoader(BaseModelLoader):
...
@@ -374,6 +391,9 @@ class DummyModelLoader(BaseModelLoader):
raise
ValueError
(
f
"Model loader extra config is not supported for "
raise
ValueError
(
f
"Model loader extra config is not supported for "
f
"load format
{
load_config
.
load_format
}
"
)
f
"load format
{
load_config
.
load_format
}
"
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
pass
# Nothing to download
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
lora_config
:
Optional
[
LoRAConfig
],
...
@@ -464,6 +484,12 @@ class TensorizerLoader(BaseModelLoader):
...
@@ -464,6 +484,12 @@ class TensorizerLoader(BaseModelLoader):
model
=
load_with_tensorizer
(
tensorizer_config
,
**
extra_kwargs
)
model
=
load_with_tensorizer
(
tensorizer_config
,
**
extra_kwargs
)
return
model
.
eval
()
return
model
.
eval
()
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
tensorizer_config
.
verify_with_model_config
(
model_config
)
with
self
.
tensorizer_config
.
open_stream
():
pass
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
lora_config
:
Optional
[
LoRAConfig
],
...
@@ -565,6 +591,9 @@ class ShardedStateLoader(BaseModelLoader):
...
@@ -565,6 +591,9 @@ class ShardedStateLoader(BaseModelLoader):
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
)
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
lora_config
:
Optional
[
LoRAConfig
],
...
@@ -992,6 +1021,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -992,6 +1021,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
set_weight_attrs
(
set_weight_attrs
(
param
,
{
"matmul_state"
:
[
None
]
*
len
(
quant_states
)})
param
,
{
"matmul_state"
:
[
None
]
*
len
(
quant_states
)})
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
lora_config
:
Optional
[
LoRAConfig
],
...
@@ -1067,6 +1099,9 @@ class GGUFModelLoader(BaseModelLoader):
...
@@ -1067,6 +1099,9 @@ class GGUFModelLoader(BaseModelLoader):
return
gguf_quant_weights_iterator
(
model_name_or_path
,
return
gguf_quant_weights_iterator
(
model_name_or_path
,
gguf_to_hf_name_map
)
gguf_to_hf_name_map
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
lora_config
:
Optional
[
LoRAConfig
],
...
...
vllm/model_executor/model_loader/tensorizer.py
View file @
4851c202
...
@@ -99,6 +99,13 @@ class TensorizerConfig:
...
@@ -99,6 +99,13 @@ class TensorizerConfig:
"Loading a model using Tensorizer with quantization on vLLM"
"Loading a model using Tensorizer with quantization on vLLM"
" is unstable and may lead to errors."
)
" is unstable and may lead to errors."
)
def
open_stream
(
self
,
tensorizer_args
:
Optional
[
"TensorizerArgs"
]
=
None
):
if
tensorizer_args
is
None
:
tensorizer_args
=
self
.
_construct_tensorizer_args
()
return
open_stream
(
self
.
tensorizer_uri
,
**
tensorizer_args
.
stream_params
)
def
load_with_tensorizer
(
tensorizer_config
:
TensorizerConfig
,
def
load_with_tensorizer
(
tensorizer_config
:
TensorizerConfig
,
**
extra_kwargs
)
->
nn
.
Module
:
**
extra_kwargs
)
->
nn
.
Module
:
...
...
vllm/model_executor/model_loader/utils.py
View file @
4851c202
...
@@ -43,10 +43,18 @@ def get_model_architecture(
...
@@ -43,10 +43,18 @@ def get_model_architecture(
# Special handling for quantized Mixtral.
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
# FIXME(woosuk): This is a temporary hack.
mixtral_supported
=
[
"fp8"
,
"compressed-tensors"
]
mixtral_supported
=
[
"fp8"
,
"compressed-tensors"
]
# for gptq_marlin, only run fused MoE for int4
if
model_config
.
quantization
==
"gptq_marlin"
:
hf_quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
None
)
if
hf_quant_config
and
hf_quant_config
.
get
(
"bits"
)
==
4
:
mixtral_supported
.
append
(
"gptq_marlin"
)
if
(
model_config
.
quantization
is
not
None
if
(
model_config
.
quantization
is
not
None
and
model_config
.
quantization
not
in
mixtral_supported
and
model_config
.
quantization
not
in
mixtral_supported
and
"MixtralForCausalLM"
in
architectures
):
and
"MixtralForCausalLM"
in
architectures
):
architectures
=
[
"QuantMixtralForCausalLM"
]
architectures
=
[
"QuantMixtralForCausalLM"
]
return
ModelRegistry
.
resolve_model_cls
(
architectures
)
return
ModelRegistry
.
resolve_model_cls
(
architectures
)
...
...
vllm/model_executor/model_loader/weight_utils.py
View file @
4851c202
...
@@ -16,7 +16,6 @@ import torch
...
@@ -16,7 +16,6 @@ import torch
from
huggingface_hub
import
HfFileSystem
,
hf_hub_download
,
snapshot_download
from
huggingface_hub
import
HfFileSystem
,
hf_hub_download
,
snapshot_download
from
safetensors.torch
import
load_file
,
safe_open
,
save_file
from
safetensors.torch
import
load_file
,
safe_open
,
save_file
from
tqdm.auto
import
tqdm
from
tqdm.auto
import
tqdm
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
vllm.config
import
LoadConfig
,
ModelConfig
from
vllm.config
import
LoadConfig
,
ModelConfig
from
vllm.distributed
import
get_tensor_model_parallel_rank
from
vllm.distributed
import
get_tensor_model_parallel_rank
...
@@ -193,6 +192,13 @@ def get_quant_config(model_config: ModelConfig,
...
@@ -193,6 +192,13 @@ def get_quant_config(model_config: ModelConfig,
if
model_config
.
quantization
==
"bitsandbytes"
:
if
model_config
.
quantization
==
"bitsandbytes"
:
config
[
"adapter_name_or_path"
]
=
model_name_or_path
config
[
"adapter_name_or_path"
]
=
model_name_or_path
elif
model_config
.
quantization
==
"modelopt"
:
if
config
[
"producer"
][
"name"
]
==
"modelopt"
:
return
quant_cls
.
from_config
(
config
)
else
:
raise
ValueError
(
f
"Unsupported quantization config"
f
" found for
{
model_config
.
quantization
}
in
{
f
}
."
)
return
quant_cls
.
from_config
(
config
)
return
quant_cls
.
from_config
(
config
)
...
@@ -251,6 +257,7 @@ def download_weights_from_hf(
...
@@ -251,6 +257,7 @@ def download_weights_from_hf(
def
download_safetensors_index_file_from_hf
(
def
download_safetensors_index_file_from_hf
(
model_name_or_path
:
str
,
model_name_or_path
:
str
,
index_file
:
str
,
cache_dir
:
Optional
[
str
],
cache_dir
:
Optional
[
str
],
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
)
->
None
:
)
->
None
:
...
@@ -269,36 +276,37 @@ def download_safetensors_index_file_from_hf(
...
@@ -269,36 +276,37 @@ def download_safetensors_index_file_from_hf(
# Download the safetensors index file.
# Download the safetensors index file.
hf_hub_download
(
hf_hub_download
(
repo_id
=
model_name_or_path
,
repo_id
=
model_name_or_path
,
filename
=
SAFE_WEIGHTS_INDEX_NAME
,
filename
=
index_file
,
cache_dir
=
cache_dir
,
cache_dir
=
cache_dir
,
revision
=
revision
,
revision
=
revision
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
)
)
# If file not found on remote or locally, we should not fail since
# If file not found on remote or locally, we should not fail since
# only some models will have
SAFE_WEIGHTS_INDEX_NAME
.
# only some models will have
index_file
.
except
huggingface_hub
.
utils
.
EntryNotFoundError
:
except
huggingface_hub
.
utils
.
EntryNotFoundError
:
logger
.
info
(
"No %s found in remote."
,
SAFE_WEIGHTS_INDEX_NAME
)
logger
.
info
(
"No %s found in remote."
,
index_file
)
except
huggingface_hub
.
utils
.
LocalEntryNotFoundError
:
except
huggingface_hub
.
utils
.
LocalEntryNotFoundError
:
logger
.
info
(
"No %s found in local cache."
,
SAFE_WEIGHTS_INDEX_NAME
)
logger
.
info
(
"No %s found in local cache."
,
index_file
)
# For models like Mistral-7B-v0.3, there are both sharded
# For models like Mistral-7B-v0.3, there are both sharded
# safetensors files and a consolidated safetensors file.
# safetensors files and a consolidated safetensors file.
# Passing both of these to the weight loader functionality breaks.
# Passing both of these to the weight loader functionality breaks.
# So, we use the
SAFE_WEIGHTS_INDEX_NAME
to
# So, we use the
index_file
to
# look up which safetensors files should be used.
# look up which safetensors files should be used.
def
filter_duplicate_safetensors_files
(
hf_weights_files
:
List
[
str
],
def
filter_duplicate_safetensors_files
(
hf_weights_files
:
List
[
str
],
hf_folder
:
str
)
->
List
[
str
]:
hf_folder
:
str
,
index_file
:
str
)
->
List
[
str
]:
# model.safetensors.index.json is a mapping from keys in the
# model.safetensors.index.json is a mapping from keys in the
# torch state_dict to safetensors file holding that weight.
# torch state_dict to safetensors file holding that weight.
index_file_name
=
os
.
path
.
join
(
hf_folder
,
SAFE_WEIGHTS_INDEX_NAME
)
index_file_name
=
os
.
path
.
join
(
hf_folder
,
index_file
)
if
not
os
.
path
.
isfile
(
index_file_name
):
if
not
os
.
path
.
isfile
(
index_file_name
):
return
hf_weights_files
return
hf_weights_files
# Iterate through the weight_map (weight_name: safetensors files)
# Iterate through the weight_map (weight_name: safetensors files)
# to identify weights that we should use.
# to identify weights that we should use.
with
open
(
index_file_name
)
as
index_file
:
with
open
(
index_file_name
,
"r"
)
as
f
:
weight_map
=
json
.
load
(
index_file
)[
"weight_map"
]
weight_map
=
json
.
load
(
f
)[
"weight_map"
]
weight_files_in_index
=
set
()
weight_files_in_index
=
set
()
for
weight_name
in
weight_map
:
for
weight_name
in
weight_map
:
weight_files_in_index
.
add
(
weight_files_in_index
.
add
(
...
...
vllm/model_executor/models/__init__.py
View file @
4851c202
...
@@ -51,9 +51,10 @@ _GENERATION_MODELS = {
...
@@ -51,9 +51,10 @@ _GENERATION_MODELS = {
"PhiForCausalLM"
:
(
"phi"
,
"PhiForCausalLM"
),
"PhiForCausalLM"
:
(
"phi"
,
"PhiForCausalLM"
),
"Phi3ForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"Phi3ForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"PhiMoEForCausalLM"
:
(
"phimoe"
,
"PhiMoEForCausalLM"
),
"PhiMoEForCausalLM"
:
(
"phimoe"
,
"PhiMoEForCausalLM"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
"Qwen2ForCausalLM"
:
(
"qwen2"
,
"Qwen2ForCausalLM"
),
"Qwen2ForCausalLM"
:
(
"qwen2"
,
"Qwen2ForCausalLM"
),
"Qwen2MoeForCausalLM"
:
(
"qwen2_moe"
,
"Qwen2MoeForCausalLM"
),
"Qwen2MoeForCausalLM"
:
(
"qwen2_moe"
,
"Qwen2MoeForCausalLM"
),
"Qwen2VLForConditionalGeneration"
:
(
"qwen2_vl"
,
"Qwen2VLForConditionalGeneration"
),
"RWForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"RWForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"StableLMEpochForCausalLM"
:
(
"stablelm"
,
"StablelmForCausalLM"
),
"StableLMEpochForCausalLM"
:
(
"stablelm"
,
"StablelmForCausalLM"
),
"StableLmForCausalLM"
:
(
"stablelm"
,
"StablelmForCausalLM"
),
"StableLmForCausalLM"
:
(
"stablelm"
,
"StablelmForCausalLM"
),
...
@@ -81,13 +82,20 @@ _MULTIMODAL_MODELS = {
...
@@ -81,13 +82,20 @@ _MULTIMODAL_MODELS = {
"InternVLChatModel"
:
(
"internvl"
,
"InternVLChatModel"
),
"InternVLChatModel"
:
(
"internvl"
,
"InternVLChatModel"
),
"LlavaForConditionalGeneration"
:
"LlavaForConditionalGeneration"
:
(
"llava"
,
"LlavaForConditionalGeneration"
),
(
"llava"
,
"LlavaForConditionalGeneration"
),
"LlavaNextForConditionalGeneration"
:
"LlavaNextForConditionalGeneration"
:
(
"llava_next"
,
(
"llava_next"
,
"LlavaNextForConditionalGeneration"
),
"LlavaNextForConditionalGeneration"
),
"LlavaNextVideoForConditionalGeneration"
:
(
"llava_next_video"
,
"LlavaNextVideoForConditionalGeneration"
),
"MiniCPMV"
:
(
"minicpmv"
,
"MiniCPMV"
),
"MiniCPMV"
:
(
"minicpmv"
,
"MiniCPMV"
),
"PaliGemmaForConditionalGeneration"
:
(
"paligemma"
,
"PaliGemmaForConditionalGeneration"
:
(
"paligemma"
,
"PaliGemmaForConditionalGeneration"
),
"PaliGemmaForConditionalGeneration"
),
"Phi3VForCausalLM"
:
(
"phi3v"
,
"Phi3VForCausalLM"
),
"Phi3VForCausalLM"
:
(
"phi3v"
,
"Phi3VForCausalLM"
),
"UltravoxModel"
:
(
"ultravox"
,
"UltravoxModel"
),
"UltravoxModel"
:
(
"ultravox"
,
"UltravoxModel"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
"PixtralForConditionalGeneration"
:
(
"pixtral"
,
"PixtralForConditionalGeneration"
),
"Qwen2VLForConditionalGeneration"
:
(
"qwen2_vl"
,
"Qwen2VLForConditionalGeneration"
),
}
}
_CONDITIONAL_GENERATION_MODELS
=
{
_CONDITIONAL_GENERATION_MODELS
=
{
"BartModel"
:
(
"bart"
,
"BartForConditionalGeneration"
),
"BartModel"
:
(
"bart"
,
"BartForConditionalGeneration"
),
...
...
Prev
1
…
4
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment