Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dbe55885
Unverified
Commit
dbe55885
authored
Jul 18, 2024
by
Robert Shaw
Committed by
GitHub
Jul 18, 2024
Browse files
[ Misc ] non-uniform quantization via `compressed-tensors` for `Llama` (#6515)
parent
d4201e06
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
300 additions
and
90 deletions
+300
-90
.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
...ta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
+11
-0
.buildkite/lm-eval-harness/configs/models-small.txt
.buildkite/lm-eval-harness/configs/models-small.txt
+1
-0
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+1
-0
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+32
-12
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+58
-34
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py
...pressed_tensors/schemes/compressed_tensors_unquantized.py
+0
-1
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
..._executor/layers/quantization/compressed_tensors/utils.py
+118
-21
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+23
-4
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+19
-6
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+22
-8
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+15
-4
No files found.
.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
0 → 100644
View file @
dbe55885
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test -b auto -l 1000 -f 5 -t 1
model_name
:
"
nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test"
tasks
:
-
name
:
"
gsm8k"
metrics
:
-
name
:
"
exact_match,strict-match"
value
:
0.758
-
name
:
"
exact_match,flexible-extract"
value
:
0.759
limit
:
1000
num_fewshot
:
5
.buildkite/lm-eval-harness/configs/models-small.txt
View file @
dbe55885
...
@@ -2,4 +2,5 @@ Meta-Llama-3-8B-Instruct.yaml
...
@@ -2,4 +2,5 @@ Meta-Llama-3-8B-Instruct.yaml
Meta-Llama-3-8B-Instruct-FP8.yaml
Meta-Llama-3-8B-Instruct-FP8.yaml
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
vllm/model_executor/layers/fused_moe/layer.py
View file @
dbe55885
...
@@ -158,6 +158,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -158,6 +158,7 @@ class FusedMoE(torch.nn.Module):
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
...
...
vllm/model_executor/layers/linear.py
View file @
dbe55885
...
@@ -171,6 +171,8 @@ class ReplicatedLinear(LinearBase):
...
@@ -171,6 +171,8 @@ class ReplicatedLinear(LinearBase):
skip_bias_add: If true, skip adding bias but instead return it.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -179,15 +181,19 @@ class ReplicatedLinear(LinearBase):
...
@@ -179,15 +181,19 @@ class ReplicatedLinear(LinearBase):
bias
:
bool
=
True
,
bias
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
Optional
[
str
]
=
None
):
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
)
quant_config
)
# All the linear layer supports quant method.
# All the linear layer supports quant method.
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
self
,
self
.
input_size
,
self
.
quant_method
.
create_weights
(
self
,
[
self
.
output_size
],
self
.
input_size
,
self
.
input_size
,
[
self
.
output_size
],
self
.
output_size
,
self
.
params_dtype
)
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
,
prefix
=
prefix
)
if
bias
:
if
bias
:
self
.
bias
=
Parameter
(
self
.
bias
=
Parameter
(
...
@@ -239,6 +245,8 @@ class ColumnParallelLinear(LinearBase):
...
@@ -239,6 +245,8 @@ class ColumnParallelLinear(LinearBase):
quant_config: Quantization configure.
quant_config: Quantization configure.
output_sizes: list of output sizes packed into one output, like for QKV
output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3.
the list would be size 3.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -249,7 +257,8 @@ class ColumnParallelLinear(LinearBase):
...
@@ -249,7 +257,8 @@ class ColumnParallelLinear(LinearBase):
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
output_sizes
:
Optional
[
List
[
int
]]
=
None
):
output_sizes
:
Optional
[
List
[
int
]]
=
None
,
prefix
:
Optional
[
str
]
=
None
):
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
)
quant_config
)
...
@@ -276,7 +285,8 @@ class ColumnParallelLinear(LinearBase):
...
@@ -276,7 +285,8 @@ class ColumnParallelLinear(LinearBase):
input_size
=
self
.
input_size
,
input_size
=
self
.
input_size
,
output_size
=
self
.
output_size
,
output_size
=
self
.
output_size
,
params_dtype
=
self
.
params_dtype
,
params_dtype
=
self
.
params_dtype
,
weight_loader
=
self
.
weight_loader
)
weight_loader
=
self
.
weight_loader
,
prefix
=
prefix
)
if
bias
:
if
bias
:
self
.
bias
=
Parameter
(
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
torch
.
empty
(
self
.
output_size_per_partition
,
...
@@ -348,6 +358,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -348,6 +358,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip adding bias but instead return it.
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -357,7 +369,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -357,7 +369,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
gather_output
:
bool
=
False
,
gather_output
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
Optional
[
str
]
=
None
):
self
.
output_sizes
=
output_sizes
self
.
output_sizes
=
output_sizes
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
...
@@ -367,7 +380,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -367,7 +380,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
gather_output
=
gather_output
,
gather_output
=
gather_output
,
skip_bias_add
=
skip_bias_add
,
skip_bias_add
=
skip_bias_add
,
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
prefix
)
def
weight_loader
(
self
,
def
weight_loader
(
self
,
param
:
Parameter
,
param
:
Parameter
,
...
@@ -487,6 +501,8 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -487,6 +501,8 @@ class QKVParallelLinear(ColumnParallelLinear):
skip adding bias but instead return it.
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -497,7 +513,8 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -497,7 +513,8 @@ class QKVParallelLinear(ColumnParallelLinear):
bias
:
bool
=
True
,
bias
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
Optional
[
str
]
=
None
):
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
total_num_heads
=
total_num_heads
self
.
total_num_heads
=
total_num_heads
...
@@ -529,7 +546,8 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -529,7 +546,8 @@ class QKVParallelLinear(ColumnParallelLinear):
gather_output
=
False
,
gather_output
=
False
,
skip_bias_add
=
skip_bias_add
,
skip_bias_add
=
skip_bias_add
,
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
prefix
)
def
weight_loader
(
self
,
def
weight_loader
(
self
,
param
:
Parameter
,
param
:
Parameter
,
...
@@ -688,7 +706,8 @@ class RowParallelLinear(LinearBase):
...
@@ -688,7 +706,8 @@ class RowParallelLinear(LinearBase):
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
reduce_results
:
bool
=
True
,
reduce_results
:
bool
=
True
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
Optional
[
str
]
=
None
):
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
)
quant_config
)
...
@@ -706,7 +725,8 @@ class RowParallelLinear(LinearBase):
...
@@ -706,7 +725,8 @@ class RowParallelLinear(LinearBase):
input_size
=
self
.
input_size
,
input_size
=
self
.
input_size
,
output_size
=
self
.
output_size
,
output_size
=
self
.
output_size
,
params_dtype
=
self
.
params_dtype
,
params_dtype
=
self
.
params_dtype
,
weight_loader
=
self
.
weight_loader
)
weight_loader
=
self
.
weight_loader
,
prefix
=
prefix
)
if
not
reduce_results
and
(
bias
and
not
skip_bias_add
):
if
not
reduce_results
and
(
bias
and
not
skip_bias_add
):
raise
ValueError
(
"When not reduce the results, adding bias to the "
raise
ValueError
(
"When not reduce the results, adding bias to the "
"results can lead to incorrect results"
)
"results can lead to incorrect results"
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
dbe55885
...
@@ -8,23 +8,25 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
...
@@ -8,23 +8,25 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
W4A16SPARSE24_SUPPORTED_BITS
,
WNA16_SUPPORTED_BITS
,
W4A16SPARSE24_SUPPORTED_BITS
,
WNA16_SUPPORTED_BITS
,
CompressedTensorsScheme
,
CompressedTensors
W4A16Sparse24
,
CompressedTensorsScheme
,
CompressedTensors
Unquantized
,
CompressedTensorsW
8A8Fp8
,
CompressedTensorsW8A8
Int
8
,
CompressedTensorsW
4A16Sparse24
,
CompressedTensorsW8A8
Fp
8
,
CompressedTensorsWNA16
)
CompressedTensorsW8A8Int8
,
CompressedTensorsWNA16
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
CompressionFormat
,
QuantizationArgs
,
QuantizationStrategy
,
CompressionFormat
,
QuantizationArgs
,
QuantizationStrategy
,
QuantizationType
,
find_
first_name_or_class_
mat
ch
,
QuantizationType
,
find_
matched_target
,
is_activation_quantization_for
mat
,
is_activation_quantization_format
)
should_ignore_layer
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
class
CompressedTensorsConfig
(
QuantizationConfig
):
class
CompressedTensorsConfig
(
QuantizationConfig
):
def
__init__
(
self
,
layer_quant_details
:
Dict
[
str
,
Any
],
ignore
:
List
[
str
],
def
__init__
(
self
,
target_scheme_map
:
Dict
[
str
,
Any
],
ignore
:
List
[
str
],
quant_format
:
str
):
quant_format
:
str
):
self
.
ignore
=
ignore
self
.
ignore
=
ignore
self
.
layer_quant_details
=
layer_quant_details
self
.
quant_format
=
quant_format
self
.
quant_format
=
quant_format
# Map from [target -> scheme]
self
.
target_scheme_map
=
target_scheme_map
def
get_linear_method
(
self
)
->
"CompressedTensorsLinearMethod"
:
def
get_linear_method
(
self
)
->
"CompressedTensorsLinearMethod"
:
return
CompressedTensorsLinearMethod
(
self
)
return
CompressedTensorsLinearMethod
(
self
)
...
@@ -51,7 +53,7 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -51,7 +53,7 @@ class CompressedTensorsConfig(QuantizationConfig):
@
classmethod
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"CompressedTensorsConfig"
:
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"CompressedTensorsConfig"
:
layer_quant_details
:
Dict
[
str
,
Any
]
=
dict
()
target_scheme_map
:
Dict
[
str
,
Any
]
=
dict
()
ignore
:
List
[
str
]
=
config
.
get
(
"ignore"
,
None
)
ignore
:
List
[
str
]
=
config
.
get
(
"ignore"
,
None
)
quant_format
:
str
=
config
.
get
(
"format"
,
None
)
quant_format
:
str
=
config
.
get
(
"format"
,
None
)
...
@@ -63,21 +65,21 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -63,21 +65,21 @@ class CompressedTensorsConfig(QuantizationConfig):
# details follow the structure defined by the QuantizationArgs
# details follow the structure defined by the QuantizationArgs
# pydantic model, which is used to verify the structure of the
# pydantic model, which is used to verify the structure of the
# quant_config and also store the details for later use.
# quant_config and also store the details for later use.
for
key
,
quant_config
in
config
[
"config_groups"
].
items
():
for
_
,
quant_config
in
config
[
"config_groups"
].
items
():
targets
=
quant_config
.
get
(
"targets"
)
targets
=
quant_config
.
get
(
"targets"
)
for
target
in
targets
:
for
target
in
targets
:
layer_quant_details
[
target
]
=
{}
target_scheme_map
[
target
]
=
{}
layer_quant_details
[
target
][
target_scheme_map
[
target
][
"weights"
]
=
QuantizationArgs
.
parse_obj
(
"weights"
]
=
QuantizationArgs
.
parse_obj
(
quant_config
.
get
(
"weights"
))
quant_config
.
get
(
"weights"
))
try
:
try
:
layer_quant_details
[
target
][
target_scheme_map
[
target
][
"input_activations"
]
=
QuantizationArgs
.
parse_obj
(
"input_activations"
]
=
QuantizationArgs
.
parse_obj
(
quant_config
.
get
(
"input_activations"
))
quant_config
.
get
(
"input_activations"
))
except
Exception
:
except
Exception
:
layer_quant_details
[
target
][
"input_activations"
]
=
None
target_scheme_map
[
target
][
"input_activations"
]
=
None
return
cls
(
layer_quant_details
=
layer_quant_details
,
return
cls
(
target_scheme_map
=
target_scheme_map
,
ignore
=
ignore
,
ignore
=
ignore
,
quant_format
=
quant_format
)
quant_format
=
quant_format
)
...
@@ -167,7 +169,8 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -167,7 +169,8 @@ class CompressedTensorsConfig(QuantizationConfig):
return
(
is_channel_group
and
input_quant_none
and
is_symmetric
return
(
is_channel_group
and
input_quant_none
and
is_symmetric
and
is_static
)
and
is_static
)
def
_get_schema
(
self
,
weight_quant
:
BaseModel
,
def
_get_scheme_from_parts
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
"CompressedTensorsScheme"
:
input_quant
:
BaseModel
)
->
"CompressedTensorsScheme"
:
# Detect If Mixed Precision
# Detect If Mixed Precision
...
@@ -205,26 +208,47 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -205,26 +208,47 @@ class CompressedTensorsConfig(QuantizationConfig):
raise
NotImplementedError
(
raise
NotImplementedError
(
"No compressed-tensors compatible scheme was found."
)
"No compressed-tensors compatible scheme was found."
)
def
get_scheme
(
self
,
layer
:
torch
.
nn
.
Module
)
->
"CompressedTensorsScheme"
:
def
get_scheme
(
self
,
layer
:
torch
.
nn
.
Module
,
layer_name
:
Optional
[
str
]
=
None
)
->
"CompressedTensorsScheme"
:
"""
compressed-tensors supports non uniform in the following way:
layer_type_name
=
find_first_name_or_class_match
(
ignore: List of layer_names or nn.Module names to be ignored.
name
=
""
,
targets of config_groups: There can be N config_groups which each
module
=
layer
,
have a quantization scheme. Each config_group has a list of targets
targets
=
self
.
layer_quant_details
.
keys
(),
which can be a full layer_name, a regex for a layer_name, or
check_contains
=
True
)
an nn.Module name.
if
layer_type_name
is
None
:
We first check whether a layer is in the ignore group and use
raise
ValueError
(
f
"Could not matching target for layer
{
layer
}
"
)
CompressedTensorsUnquantized (i.e. fp16/bf16) scheme for the
layer
layer_quant_details
:
Dict
[
str
,
Any
]
=
self
.
layer_quant_details
.
get
(
We then detect whether a layer_name is found in any target and
layer_type_name
,
None
)
use the quantization scheme corresponding to the matched target
if
layer_quant_details
is
None
:
to select the CompressedTensorsScheme used for infernece.
raise
ValueError
(
"""
f
"Could not find quantization details for
{
layer
}
."
)
scheme
=
self
.
_get_schema
(
# Check if the layer is skipped for quantization.
weight_quant
=
layer_quant_details
[
"weights"
],
# TODO (@robertgshaw2): support module names
input_quant
=
layer_quant_details
[
"input_activations"
])
if
should_ignore_layer
(
layer_name
,
ignore
=
self
.
ignore
):
return
CompressedTensorsUnquantized
()
# Find the "target" in the compressed-tensors config
# that our layer conforms to.
# TODO (@robertgshaw): add compressed-tensors as dep
# so we do not have to re-write these functions
matched_target
=
find_matched_target
(
layer_name
=
layer_name
,
module
=
layer
,
targets
=
self
.
target_scheme_map
.
keys
())
# Find the quant_scheme
scheme
=
self
.
target_scheme_map
[
matched_target
]
return
self
.
_get_scheme_from_parts
(
weight_quant
=
scheme
[
"weights"
],
input_quant
=
scheme
[
"input_activations"
])
# Raise error if device does not support the scheme
# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
# (e.g. fp8 needs ada lovelace)
...
@@ -250,11 +274,11 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
...
@@ -250,11 +274,11 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
Use the CompressedTensorsScheme associated with each layer to create
Use the CompressedTensorsScheme associated with each layer to create
the necessary parameters for the layer. See LinearMethodBase for param
the necessary parameters for the layer. See LinearMethodBase for param
details
details
"""
"""
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
layer_name
=
extra_weight_attrs
.
get
(
"prefix"
)
scheme
=
self
.
quantization_config
.
get_scheme
(
layer
=
layer
)
scheme
=
self
.
quantization_config
.
get_scheme
(
layer
,
layer
_name
)
scheme
.
create_weights
(
scheme
.
create_weights
(
layer
=
layer
,
layer
=
layer
,
input_size
=
input_size
,
input_size
=
input_size
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py
View file @
dbe55885
...
@@ -33,7 +33,6 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
...
@@ -33,7 +33,6 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
weight
=
Parameter
(
torch
.
empty
(
sum
(
output_partition_sizes
),
weight
=
Parameter
(
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
input_size_per_partition
,
device
=
"cuda"
,
dtype
=
params_dtype
),
dtype
=
params_dtype
),
requires_grad
=
False
)
requires_grad
=
False
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
View file @
dbe55885
...
@@ -86,25 +86,106 @@ def is_activation_quantization_format(format: str) -> bool:
...
@@ -86,25 +86,106 @@ def is_activation_quantization_format(format: str) -> bool:
return
format
in
_ACTIVATION_QUANTIZATION_FORMATS
return
format
in
_ACTIVATION_QUANTIZATION_FORMATS
def
find_first_name_or_class_match
(
# fused_name: List[shard_name]
name
:
str
,
_FUSED_LAYER_NAME_MAPPING
=
{
module
:
Module
,
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
targets
:
Iterable
[
str
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
]
check_contains
:
bool
=
False
)
->
Optional
[
str
]:
}
def
should_ignore_layer
(
layer_name
:
Optional
[
str
],
ignore
:
Iterable
[
str
])
->
bool
:
if
layer_name
is
None
:
return
False
# layer_name = model.layers.0.self_attn.qkv_proj
# proj_name = qkv_proj
proj_name
=
layer_name
.
split
(
"."
)[
-
1
]
# Fused layers like gate_up_proj or qkv_proj will not be fused
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if
proj_name
in
_FUSED_LAYER_NAME_MAPPING
:
shard_proj_names
=
_FUSED_LAYER_NAME_MAPPING
[
proj_name
]
# Convert fused_name --> [shard_names]
shard_names
=
[
layer_name
.
replace
(
proj_name
,
shard_proj_name
)
for
shard_proj_name
in
shard_proj_names
]
# Layer should be ignored if shards are ignored.
should_ignore_layer
=
None
for
shard_name
in
shard_names
:
should_ignore_shard
=
check_equal_or_regex_match
(
layer_name
=
shard_name
,
targets
=
ignore
)
# If shard_idx=0, set layer ignore to match shard.
if
should_ignore_layer
is
None
:
should_ignore_layer
=
should_ignore_shard
# If shard_idx=1+ confirm scheme matches prior shards.
elif
should_ignore_shard
!=
should_ignore_layer
:
raise
ValueError
(
f
"Found a different quantization schemes for "
f
"
{
shard_proj_names
}
in
{
layer_name
}
. vLLM "
"requires all to use the same scheme."
)
# Unfused layers like down_proj and o_proj will match
# the safetensors checkpoint already.
else
:
should_ignore_layer
=
check_equal_or_regex_match
(
layer_name
=
layer_name
,
targets
=
ignore
)
assert
should_ignore_layer
is
not
None
return
should_ignore_layer
def
check_equal_or_regex_match
(
layer_name
:
str
,
targets
:
Iterable
[
str
])
->
bool
:
"""
Checks whether a layer_name is exactly equal or a regex match for
if target starts with 're:' to any target in list.
"""
for
target
in
targets
:
if
_is_equal_or_regex_match
(
layer_name
,
target
):
return
True
return
False
def
find_matched_target
(
layer_name
:
Optional
[
str
],
module
:
Module
,
targets
:
Iterable
[
str
])
->
str
:
"""
"""
Helper function to map the quantization details listed in the config
Helper function to look up which "target" in the compressed-tensors
for a given list of targets against each model layer. First uses the
config that a layer corresponds to.
layer name to try and find a match. If no name match is found, uses
the layer class name. Returns None otherwise.
Recall that a compressed-tensors configs has a concept of
config_groups, where each layer can be quantized with with a different
scheme.
targets in each config_group will be a list of either layer names
(or regexes corresponding to layer names) or names of torch Modules.
:param name: layer name
First, we try to match the layer_name with a target
Second, we try to match the module's name with a target
:param layer_name: layer name
:param module: torch.nn.Module
:param module: torch.nn.Module
:param targets: list of targets to match the layer against
:param targets: list of targets to match the layer against
:param check_contains: whether or not to do a substring match
"""
"""
return
_find_first_match
(
name
,
targets
)
or
_find_first_match
(
if
layer_name
is
None
:
module
.
__class__
.
__name__
,
targets
,
check_contains
)
layer_name
=
""
matched_target
=
(
_find_first_match
(
layer_name
,
targets
)
or
_find_first_match
(
module
.
__class__
.
__name__
,
targets
,
True
))
if
matched_target
is
None
:
raise
ValueError
(
f
"Unable to find matching target for
{
module
}
in the "
"compressed-tensors config."
)
return
matched_target
def
_find_first_match
(
value
:
str
,
def
_find_first_match
(
value
:
str
,
...
@@ -121,13 +202,29 @@ def _find_first_match(value: str,
...
@@ -121,13 +202,29 @@ def _find_first_match(value: str,
"""
"""
for
target
in
targets
:
for
target
in
targets
:
if
_is_equal_or_regex_match
(
value
,
target
,
check_contains
=
check_contains
):
return
target
return
None
def
_is_equal_or_regex_match
(
value
:
str
,
target
:
str
,
check_contains
:
bool
=
False
)
->
bool
:
"""
Checks whether a value is exactly equal or a regex match for target
if target starts with 're:'. If check_contains is set to True,
additionally checks if the target string is contained within the value.
"""
if
target
.
startswith
(
"re:"
):
if
target
.
startswith
(
"re:"
):
pattern
=
target
[
3
:]
pattern
=
target
[
3
:]
if
re
.
match
(
pattern
,
value
):
if
re
.
match
(
pattern
,
value
):
return
target
return
True
elif
check_contains
:
elif
check_contains
:
if
target
.
lower
()
in
value
.
lower
():
if
target
.
lower
()
in
value
.
lower
():
return
target
return
True
elif
target
==
value
:
elif
target
==
value
:
return
target
return
True
return
Non
e
return
Fals
e
vllm/model_executor/models/gpt2.py
View file @
dbe55885
...
@@ -51,6 +51,7 @@ class GPT2Attention(nn.Module):
...
@@ -51,6 +51,7 @@ class GPT2Attention(nn.Module):
config
:
GPT2Config
,
config
:
GPT2Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
...
@@ -68,12 +69,14 @@ class GPT2Attention(nn.Module):
...
@@ -68,12 +69,14 @@ class GPT2Attention(nn.Module):
total_num_heads
,
total_num_heads
,
bias
=
True
,
bias
=
True
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.c_attn"
,
)
)
self
.
c_proj
=
RowParallelLinear
(
self
.
c_proj
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
bias
=
True
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.c_proj"
,
)
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
...
@@ -101,6 +104,7 @@ class GPT2MLP(nn.Module):
...
@@ -101,6 +104,7 @@ class GPT2MLP(nn.Module):
intermediate_size
:
int
,
intermediate_size
:
int
,
config
:
GPT2Config
,
config
:
GPT2Config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
hidden_size
=
config
.
hidden_size
hidden_size
=
config
.
hidden_size
...
@@ -109,12 +113,14 @@ class GPT2MLP(nn.Module):
...
@@ -109,12 +113,14 @@ class GPT2MLP(nn.Module):
intermediate_size
,
intermediate_size
,
bias
=
True
,
bias
=
True
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.c_fc"
,
)
)
self
.
c_proj
=
RowParallelLinear
(
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
intermediate_size
,
hidden_size
,
hidden_size
,
bias
=
True
,
bias
=
True
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.c_proj"
,
)
)
self
.
act
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
self
.
act
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
intermediate_size
)
intermediate_size
)
...
@@ -133,6 +139,7 @@ class GPT2Block(nn.Module):
...
@@ -133,6 +139,7 @@ class GPT2Block(nn.Module):
config
:
GPT2Config
,
config
:
GPT2Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
hidden_size
=
config
.
hidden_size
hidden_size
=
config
.
hidden_size
...
@@ -140,9 +147,15 @@ class GPT2Block(nn.Module):
...
@@ -140,9 +147,15 @@ class GPT2Block(nn.Module):
hidden_size
)
hidden_size
)
self
.
ln_1
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_1
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
GPT2Attention
(
config
,
cache_config
,
quant_config
)
self
.
attn
=
GPT2Attention
(
config
,
cache_config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
self
.
ln_2
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_2
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
GPT2MLP
(
inner_dim
,
config
,
quant_config
)
self
.
mlp
=
GPT2MLP
(
inner_dim
,
config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -175,6 +188,7 @@ class GPT2Model(nn.Module):
...
@@ -175,6 +188,7 @@ class GPT2Model(nn.Module):
config
:
GPT2Config
,
config
:
GPT2Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -186,7 +200,9 @@ class GPT2Model(nn.Module):
...
@@ -186,7 +200,9 @@ class GPT2Model(nn.Module):
self
.
wpe
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
self
.
embed_dim
)
self
.
wpe
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
self
.
embed_dim
)
self
.
start_layer
,
self
.
end_layer
,
self
.
h
=
make_layers
(
self
.
start_layer
,
self
.
end_layer
,
self
.
h
=
make_layers
(
config
.
num_hidden_layers
,
config
.
num_hidden_layers
,
lambda
:
GPT2Block
(
config
,
cache_config
,
quant_config
))
lambda
prefix
:
GPT2Block
(
config
,
cache_config
,
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.h"
)
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
def
forward
(
...
@@ -229,7 +245,10 @@ class GPT2LMHeadModel(nn.Module):
...
@@ -229,7 +245,10 @@ class GPT2LMHeadModel(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
transformer
=
GPT2Model
(
config
,
cache_config
,
quant_config
)
self
.
transformer
=
GPT2Model
(
config
,
cache_config
,
quant_config
,
prefix
=
"transformer"
)
self
.
lm_head
=
self
.
transformer
.
wte
self
.
lm_head
=
self
.
transformer
.
wte
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
...
...
vllm/model_executor/models/llama.py
View file @
dbe55885
...
@@ -62,17 +62,20 @@ class LlamaMLP(nn.Module):
...
@@ -62,17 +62,20 @@ class LlamaMLP(nn.Module):
hidden_act
:
str
,
hidden_act
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
bias
:
bool
=
False
,
bias
:
bool
=
False
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
input_size
=
hidden_size
,
input_size
=
hidden_size
,
output_sizes
=
[
intermediate_size
]
*
2
,
output_sizes
=
[
intermediate_size
]
*
2
,
bias
=
bias
,
bias
=
bias
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
)
self
.
down_proj
=
RowParallelLinear
(
input_size
=
intermediate_size
,
self
.
down_proj
=
RowParallelLinear
(
input_size
=
intermediate_size
,
output_size
=
hidden_size
,
output_size
=
hidden_size
,
bias
=
bias
,
bias
=
bias
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.down_proj"
)
if
hidden_act
!=
"silu"
:
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
"Only silu is supported for now."
)
...
@@ -99,6 +102,7 @@ class LlamaAttention(nn.Module):
...
@@ -99,6 +102,7 @@ class LlamaAttention(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
bias
:
bool
=
False
,
bias
:
bool
=
False
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -132,12 +136,14 @@ class LlamaAttention(nn.Module):
...
@@ -132,12 +136,14 @@ class LlamaAttention(nn.Module):
total_num_kv_heads
=
self
.
total_num_kv_heads
,
total_num_kv_heads
=
self
.
total_num_kv_heads
,
bias
=
bias
,
bias
=
bias
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
)
self
.
o_proj
=
RowParallelLinear
(
self
.
o_proj
=
RowParallelLinear
(
input_size
=
self
.
total_num_heads
*
self
.
head_dim
,
input_size
=
self
.
total_num_heads
*
self
.
head_dim
,
output_size
=
hidden_size
,
output_size
=
hidden_size
,
bias
=
bias
,
bias
=
bias
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
)
self
.
rotary_emb
=
get_rope
(
self
.
rotary_emb
=
get_rope
(
...
@@ -176,6 +182,7 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -176,6 +182,7 @@ class LlamaDecoderLayer(nn.Module):
config
:
LlamaConfig
,
config
:
LlamaConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
...
@@ -203,6 +210,7 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -203,6 +210,7 @@ class LlamaDecoderLayer(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
bias
=
attention_bias
,
bias
=
attention_bias
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
)
self
.
mlp
=
LlamaMLP
(
self
.
mlp
=
LlamaMLP
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
...
@@ -210,6 +218,7 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -210,6 +218,7 @@ class LlamaDecoderLayer(nn.Module):
hidden_act
=
config
.
hidden_act
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
bias
=
getattr
(
config
,
"mlp_bias"
,
False
),
bias
=
getattr
(
config
,
"mlp_bias"
,
False
),
prefix
=
f
"
{
prefix
}
.mlp"
,
)
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
...
@@ -253,6 +262,7 @@ class LlamaModel(nn.Module):
...
@@ -253,6 +262,7 @@ class LlamaModel(nn.Module):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -272,9 +282,11 @@ class LlamaModel(nn.Module):
...
@@ -272,9 +282,11 @@ class LlamaModel(nn.Module):
self
.
embed_tokens
=
PPMissingLayer
()
self
.
embed_tokens
=
PPMissingLayer
()
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
config
.
num_hidden_layers
,
lambda
:
LlamaDecoderLayer
(
config
=
config
,
lambda
prefix
:
LlamaDecoderLayer
(
config
=
config
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
))
quant_config
=
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
)
if
get_pp_group
().
is_last_rank
:
if
get_pp_group
().
is_last_rank
:
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
else
:
...
@@ -370,7 +382,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
...
@@ -370,7 +382,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
self
.
model
=
LlamaModel
(
config
,
self
.
model
=
LlamaModel
(
config
,
cache_config
,
cache_config
,
quant_config
,
quant_config
,
lora_config
=
lora_config
)
lora_config
=
lora_config
,
prefix
=
"model"
)
if
get_pp_group
().
is_last_rank
:
if
get_pp_group
().
is_last_rank
:
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
if
lora_config
:
...
...
vllm/model_executor/models/mixtral.py
View file @
dbe55885
...
@@ -67,7 +67,8 @@ class MixtralMoE(nn.Module):
...
@@ -67,7 +67,8 @@ class MixtralMoE(nn.Module):
intermediate_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
):
tp_size
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -76,7 +77,8 @@ class MixtralMoE(nn.Module):
...
@@ -76,7 +77,8 @@ class MixtralMoE(nn.Module):
num_experts
,
num_experts
,
bias
=
False
,
bias
=
False
,
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
quant_config
=
None
)
quant_config
=
None
,
prefix
=
f
"
{
prefix
}
.gate"
)
self
.
experts
=
FusedMoE
(
num_experts
=
num_experts
,
self
.
experts
=
FusedMoE
(
num_experts
=
num_experts
,
top_k
=
top_k
,
top_k
=
top_k
,
...
@@ -86,7 +88,8 @@ class MixtralMoE(nn.Module):
...
@@ -86,7 +88,8 @@ class MixtralMoE(nn.Module):
reduce_results
=
True
,
reduce_results
=
True
,
renormalize
=
True
,
renormalize
=
True
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
tp_size
=
tp_size
)
tp_size
=
tp_size
,
prefix
=
f
"
{
prefix
}
.experts"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# NOTE: hidden_states can have either 1D or 2D shape.
# NOTE: hidden_states can have either 1D or 2D shape.
...
@@ -109,6 +112,7 @@ class MixtralAttention(nn.Module):
...
@@ -109,6 +112,7 @@ class MixtralAttention(nn.Module):
rope_theta
:
float
=
10000
,
rope_theta
:
float
=
10000
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -139,12 +143,14 @@ class MixtralAttention(nn.Module):
...
@@ -139,12 +143,14 @@ class MixtralAttention(nn.Module):
self
.
total_num_kv_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
)
self
.
o_proj
=
RowParallelLinear
(
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
)
self
.
rotary_emb
=
get_rope
(
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
self
.
head_dim
,
...
@@ -182,6 +188,7 @@ class MixtralDecoderLayer(nn.Module):
...
@@ -182,6 +188,7 @@ class MixtralDecoderLayer(nn.Module):
config
:
MixtralConfig
,
config
:
MixtralConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
...
@@ -194,13 +201,15 @@ class MixtralDecoderLayer(nn.Module):
...
@@ -194,13 +201,15 @@ class MixtralDecoderLayer(nn.Module):
num_kv_heads
=
config
.
num_key_value_heads
,
num_kv_heads
=
config
.
num_key_value_heads
,
rope_theta
=
rope_theta
,
rope_theta
=
rope_theta
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
)
self
.
block_sparse_moe
=
MixtralMoE
(
self
.
block_sparse_moe
=
MixtralMoE
(
num_experts
=
config
.
num_local_experts
,
num_experts
=
config
.
num_local_experts
,
top_k
=
config
.
num_experts_per_tok
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_size
=
config
.
intermediate_size
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.block_sparse_moe"
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
...
@@ -243,6 +252,7 @@ class MixtralModel(nn.Module):
...
@@ -243,6 +252,7 @@ class MixtralModel(nn.Module):
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
...
@@ -258,8 +268,11 @@ class MixtralModel(nn.Module):
...
@@ -258,8 +268,11 @@ class MixtralModel(nn.Module):
)
)
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
:
MixtralDecoderLayer
(
config
.
num_hidden_layers
,
config
,
cache_config
,
quant_config
=
quant_config
))
lambda
prefix
:
MixtralDecoderLayer
(
config
,
cache_config
,
quant_config
=
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
@@ -331,7 +344,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
...
@@ -331,7 +344,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
self
.
model
=
MixtralModel
(
config
,
self
.
model
=
MixtralModel
(
config
,
cache_config
,
cache_config
,
quant_config
,
quant_config
,
lora_config
=
lora_config
)
lora_config
=
lora_config
,
prefix
=
"model"
)
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
...
...
vllm/model_executor/models/utils.py
View file @
dbe55885
from
typing
import
Callable
,
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Protocol
,
Tuple
import
torch
import
torch
from
torch.func
import
functional_call
from
torch.func
import
functional_call
...
@@ -45,6 +45,15 @@ def merge_vision_embeddings(input_ids: torch.Tensor,
...
@@ -45,6 +45,15 @@ def merge_vision_embeddings(input_ids: torch.Tensor,
return
inputs_embeds
return
inputs_embeds
class
LayerFn
(
Protocol
):
def
__call__
(
self
,
prefix
=
""
,
)
->
torch
.
nn
.
Module
:
...
class
PPMissingLayer
(
torch
.
nn
.
Identity
):
class
PPMissingLayer
(
torch
.
nn
.
Identity
):
"""
"""
A placeholder layer for missing layers in a pipeline parallel model.
A placeholder layer for missing layers in a pipeline parallel model.
...
@@ -119,7 +128,9 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
...
@@ -119,7 +128,9 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
def
make_layers
(
def
make_layers
(
num_hidden_layers
:
int
,
layer_fn
:
Callable
[[],
torch
.
nn
.
Module
]
num_hidden_layers
:
int
,
layer_fn
:
LayerFn
,
prefix
:
str
,
)
->
Tuple
[
int
,
int
,
torch
.
nn
.
ModuleList
]:
)
->
Tuple
[
int
,
int
,
torch
.
nn
.
ModuleList
]:
"""Make a list of layers with the given layer function, taking
"""Make a list of layers with the given layer function, taking
pipeline parallelism into account.
pipeline parallelism into account.
...
@@ -131,8 +142,8 @@ def make_layers(
...
@@ -131,8 +142,8 @@ def make_layers(
get_pp_group
().
world_size
)
get_pp_group
().
world_size
)
modules
=
torch
.
nn
.
ModuleList
(
modules
=
torch
.
nn
.
ModuleList
(
[
PPMissingLayer
()
for
_
in
range
(
start_layer
)]
+
[
[
PPMissingLayer
()
for
_
in
range
(
start_layer
)]
+
[
maybe_offload_to_cpu
(
layer_fn
())
maybe_offload_to_cpu
(
layer_fn
(
prefix
=
f
"
{
prefix
}
.
{
idx
}
"
))
for
_
in
range
(
start_layer
,
end_layer
)
for
idx
in
range
(
start_layer
,
end_layer
)
]
+
[
PPMissingLayer
()
for
_
in
range
(
end_layer
,
num_hidden_layers
)])
]
+
[
PPMissingLayer
()
for
_
in
range
(
end_layer
,
num_hidden_layers
)])
return
start_layer
,
end_layer
,
modules
return
start_layer
,
end_layer
,
modules
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment