Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
360bd67c
Unverified
Commit
360bd67c
authored
Aug 06, 2024
by
Isotr0py
Committed by
GitHub
Aug 05, 2024
Browse files
[Core] Support loading GGUF model (#5191)
Co-authored-by:
Michael Goin
<
michael@neuralmagic.com
>
parent
ef527be0
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
431 additions
and
18 deletions
+431
-18
vllm/model_executor/layers/quantization/base_config.py
vllm/model_executor/layers/quantization/base_config.py
+25
-1
vllm/model_executor/layers/quantization/gguf.py
vllm/model_executor/layers/quantization/gguf.py
+165
-0
vllm/model_executor/layers/vocab_parallel_embedding.py
vllm/model_executor/layers/vocab_parallel_embedding.py
+54
-5
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+92
-2
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+46
-1
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+7
-0
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+1
-0
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+32
-8
vllm/transformers_utils/tokenizer.py
vllm/transformers_utils/tokenizer.py
+9
-1
No files found.
vllm/model_executor/layers/quantization/base_config.py
View file @
360bd67c
import
inspect
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Type
import
torch
from
torch
import
nn
...
...
@@ -23,6 +24,14 @@ class QuantizeMethodBase(ABC):
Expects create_weights to have been called before on the layer."""
raise
NotImplementedError
# Not required functions
def
embedding
(
self
,
layer
:
torch
.
nn
.
Module
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
"""Gather embeddings in the layer based on indices in the input tensor.
Expects create_weights to have been called before on the layer."""
raise
NotImplementedError
def
process_weights_after_loading
(
self
,
layer
:
nn
.
Module
)
->
None
:
"""Process the weight after loading.
...
...
@@ -31,6 +40,21 @@ class QuantizeMethodBase(ABC):
return
def
method_has_implemented_embedding
(
method_class
:
Type
[
QuantizeMethodBase
])
->
bool
:
"""
Not all quant methods have embedding implemented, so we need to check that
it exists for our given method. We check this by making sure the function
has been changed from the base implementation.
"""
base_embedding
=
inspect
.
getattr_static
(
QuantizeMethodBase
,
"embedding"
,
None
)
class_embedding
=
inspect
.
getattr_static
(
method_class
,
"embedding"
,
None
)
return
(
class_embedding
is
not
None
and
class_embedding
is
not
base_embedding
)
class
QuantizationConfig
(
ABC
):
"""Base class for quantization configs."""
...
...
vllm/model_executor/layers/quantization/gguf.py
0 → 100644
View file @
360bd67c
from
typing
import
Any
,
Dict
,
List
,
Optional
import
gguf
import
torch
from
torch.nn.parameter
import
Parameter
,
UninitializedParameter
from
vllm
import
_custom_ops
as
ops
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.utils
import
set_weight_attrs
class
GGUFConfig
(
QuantizationConfig
):
"""Config class for GGUF."""
def
__init__
(
self
,
)
->
None
:
pass
def
__repr__
(
self
)
->
str
:
return
(
"GGUFConfig()"
)
def
get_name
(
self
)
->
str
:
return
"gguf"
def
get_supported_act_dtypes
(
self
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
60
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[]
# no extra configs.
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"GGUFConfig"
:
if
get_tensor_model_parallel_world_size
()
>
1
:
raise
ValueError
(
"GGUF quantization hasn't supported tensor parallelism yet."
)
return
cls
()
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
if
isinstance
(
layer
,
LinearBase
):
return
GGUFLinearMethod
(
self
)
elif
isinstance
(
layer
,
VocabParallelEmbedding
):
return
GGUFEmbeddingMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
def
_fuse_mul_mat
(
x
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
qweight_type
:
int
)
->
torch
.
Tensor
:
# use dequantize mulmat for IQmatrix, mmq for k-quants
if
qweight_type
>=
16
:
block_size
,
type_size
=
gguf
.
GGML_QUANT_SIZES
[
qweight_type
]
shape
=
(
qweight
.
shape
[
0
],
qweight
.
shape
[
1
]
//
type_size
*
block_size
)
weight
=
ops
.
ggml_dequantize
(
qweight
,
qweight_type
,
*
shape
)
y
=
x
@
weight
.
T
else
:
y
=
ops
.
ggml_mul_mat_a8
(
qweight
,
x
,
qweight_type
,
qweight
.
shape
[
0
])
return
y
class
GGUFLinearMethod
(
LinearMethodBase
):
"""Linear method for GGUF.
Args:
quant_config: The GGUF quantization config.
"""
def
__init__
(
self
,
quant_config
:
GGUFConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
output_size_per_partition
=
sum
(
output_partition_sizes
)
tensor_shape
=
(
output_size_per_partition
,
input_size_per_partition
)
qweight
=
UninitializedParameter
(
requires_grad
=
False
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
,
"tensor_shape"
:
tensor_shape
,
"is_gguf_weight"
:
True
,
"shard_size"
:
{},
"shard_id"
:
[],
})
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
qweight_type
=
Parameter
(
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
uint8
),
requires_grad
=
False
)
set_weight_attrs
(
qweight_type
,
{
"is_gguf_weight_type"
:
True
,
"weight_type"
:
0
,
"shard_weight_type"
:
{},
"ignore_warning"
:
True
})
set_weight_attrs
(
qweight_type
,
extra_weight_attrs
)
layer
.
register_parameter
(
"qweight_type"
,
qweight_type
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
shard_size
=
getattr
(
layer
.
qweight
,
"shard_size"
,
None
)
shard_id
=
getattr
(
layer
.
qweight
,
"shard_id"
,
None
)
if
shard_id
and
shard_size
:
result
=
[]
offset
=
0
# dequantize shard weights respectively
shard_id
=
[
"q"
,
"k"
,
"v"
]
if
"q"
in
shard_id
else
shard_id
for
id
in
shard_id
:
shard_weight
=
layer
.
qweight
[
offset
:
offset
+
shard_size
[
id
][
0
],
:
shard_size
[
id
][
1
]].
contiguous
()
qweight_type
=
layer
.
qweight_type
.
shard_weight_type
[
id
]
result
.
append
(
_fuse_mul_mat
(
x
,
shard_weight
,
qweight_type
))
offset
+=
shard_size
[
id
][
0
]
out
=
torch
.
cat
(
result
,
axis
=
1
)
else
:
qweight
=
layer
.
qweight
qweight_type
=
layer
.
qweight_type
.
weight_type
out
=
_fuse_mul_mat
(
x
,
qweight
,
qweight_type
)
if
bias
is
not
None
:
out
.
add_
(
bias
)
return
out
class
GGUFEmbeddingMethod
(
GGUFLinearMethod
):
"""Embedding method for GGUF.
Args:
quant_config: The GGUF quantization config.
"""
def
embedding
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
qweight
=
layer
.
qweight
qweight_type
=
layer
.
qweight_type
.
weight_type
block_size
,
type_size
=
gguf
.
GGML_QUANT_SIZES
[
qweight_type
]
hidden_size
=
qweight
.
shape
[
1
]
//
type_size
*
block_size
if
qweight_type
<
2
:
return
torch
.
embedding
(
qweight
,
x
)
x_flat
=
x
.
flatten
()
quant
=
torch
.
index_select
(
qweight
,
dim
=
0
,
index
=
x_flat
)
dequant
=
ops
.
ggml_dequantize
(
quant
,
qweight_type
,
hidden_size
,
x_flat
.
shape
[
0
])
return
dequant
.
view
(
*
x
.
shape
,
hidden_size
)
vllm/model_executor/layers/vocab_parallel_embedding.py
View file @
360bd67c
...
...
@@ -3,19 +3,46 @@ from typing import List, Optional, Sequence, Tuple
import
torch
import
torch.nn.functional
as
F
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
,
UninitializedParameter
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.linear
import
UnquantizedLinearMethod
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
,
method_has_implemented_embedding
)
from
vllm.model_executor.utils
import
set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE
=
64
class
UnquantizedEmbeddingMethod
(
QuantizeMethodBase
):
"""Unquantized method for embeddings."""
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
"""Create weights for embedding layer."""
weight
=
Parameter
(
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
params_dtype
),
requires_grad
=
False
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
})
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
extra_weight_attrs
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
F
.
linear
(
x
,
layer
.
weight
,
bias
)
def
embedding
(
self
,
layer
:
torch
.
nn
.
Module
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
F
.
embedding
(
input_
,
layer
.
weight
)
def
pad_vocab_size
(
vocab_size
:
int
,
pad_to
:
int
=
DEFAULT_VOCAB_PADDING_SIZE
)
->
int
:
"""Pad the vocab size to the given value."""
...
...
@@ -199,7 +226,19 @@ class VocabParallelEmbedding(torch.nn.Module):
if
quant_config
is
not
None
:
linear_method
=
quant_config
.
get_quant_method
(
self
,
prefix
=
prefix
)
if
linear_method
is
None
:
linear_method
=
UnquantizedLinearMethod
()
linear_method
=
UnquantizedEmbeddingMethod
()
# If we are making an embedding layer, then our quantization linear
# method must implement the embedding operation. If we are another
# layer type like ParallelLMHead, this is not important.
is_embedding_layer
=
type
(
self
.
__class__
)
is
VocabParallelEmbedding
linear_method_implements_embedding
=
method_has_implemented_embedding
(
type
(
linear_method
))
if
is_embedding_layer
and
not
linear_method_implements_embedding
:
raise
NotImplementedError
(
f
"The class
{
type
(
linear_method
).
__name__
}
must implement "
"the 'embedding' method, see UnquantizedEmbeddingMethod."
)
self
.
linear_method
:
QuantizeMethodBase
=
linear_method
if
params_dtype
is
None
:
...
...
@@ -306,6 +345,14 @@ class VocabParallelEmbedding(torch.nn.Module):
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
# If the parameter is a gguf weight, then load it directly.
if
getattr
(
param
,
"is_gguf_weight_type"
,
None
):
param
.
data
.
copy_
(
loaded_weight
)
param
.
weight_type
=
loaded_weight
.
item
()
return
elif
isinstance
(
param
,
UninitializedParameter
):
param
.
materialize
(
loaded_weight
.
shape
,
dtype
=
loaded_weight
.
dtype
)
# If parameter does not have output dim, then it should
# be copied onto all gpus (e.g. g_idx for act_order gptq).
if
output_dim
is
None
:
...
...
@@ -344,7 +391,8 @@ class VocabParallelEmbedding(torch.nn.Module):
else
:
masked_input
=
input_
# Get the embeddings.
output_parallel
=
F
.
embedding
(
masked_input
.
long
(),
self
.
weight
)
output_parallel
=
self
.
linear_method
.
embedding
(
self
,
masked_input
.
long
())
# Mask the output embedding.
if
self
.
tp_size
>
1
:
output_parallel
.
masked_fill_
(
input_mask
.
unsqueeze
(
-
1
),
0
)
...
...
@@ -389,6 +437,7 @@ class ParallelLMHead(VocabParallelEmbedding):
super
().
__init__
(
num_embeddings
,
embedding_dim
,
params_dtype
,
org_num_embeddings
,
padding_size
,
quant_config
,
prefix
)
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
num_embeddings_per_partition
,
...
...
vllm/model_executor/model_loader/loader.py
View file @
360bd67c
...
...
@@ -10,11 +10,13 @@ from abc import ABC, abstractmethod
from
contextlib
import
contextmanager
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
,
Type
import
gguf
import
huggingface_hub
import
numpy
as
np
import
torch
from
huggingface_hub
import
HfApi
,
hf_hub_download
from
torch
import
nn
from
transformers
import
AutoModelForCausalLM
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoadFormat
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
...
...
@@ -31,8 +33,9 @@ from vllm.model_executor.model_loader.utils import (get_model_architecture,
from
vllm.model_executor.model_loader.weight_utils
import
(
download_safetensors_index_file_from_hf
,
download_weights_from_hf
,
filter_duplicate_safetensors_files
,
filter_files_not_needed_for_inference
,
get_quant_config
,
initialize_dummy_weights
,
np_cache_weights_iterator
,
pt_weights_iterator
,
safetensors_weights_iterator
)
get_gguf_extra_tensor_names
,
get_quant_config
,
gguf_quant_weights_iterator
,
initialize_dummy_weights
,
np_cache_weights_iterator
,
pt_weights_iterator
,
safetensors_weights_iterator
)
from
vllm.model_executor.models.interfaces
import
(
has_inner_state
,
supports_lora
,
supports_vision
)
...
...
@@ -948,6 +951,90 @@ class BitsAndBytesModelLoader(BaseModelLoader):
return
model
.
eval
()
class
GGUFModelLoader
(
BaseModelLoader
):
"""
Model loader that can load GGUF files. This is useful for loading models
that are quantized with GGUF and saved in the GGUF format. This loader
supports loading both full models and sharded models.
"""
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
if
load_config
.
model_loader_extra_config
:
raise
ValueError
(
f
"Model loader extra config is not supported for "
f
"load format
{
load_config
.
load_format
}
"
)
def
_prepare_weights
(
self
,
model_name_or_path
:
str
):
if
os
.
path
.
isfile
(
model_name_or_path
):
return
model_name_or_path
else
:
raise
ValueError
(
f
"
{
model_name_or_path
}
is not a file."
)
def
_get_gguf_weights_map
(
self
,
model_config
:
ModelConfig
):
"""
GGUF uses this naming convention for their tensors from HF checkpoint:
`blk.N.BB.weight` and `blk.N.BB.bias`
where N signifies the block number of a layer, and BB signifies the
attention/mlp layer components.
See "Standardized tensor names" in
https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
"""
config
=
model_config
.
hf_config
model_type
=
config
.
model_type
# hack: ggufs have a different name than transformers
if
model_type
==
"cohere"
:
model_type
=
"command-r"
arch
=
None
for
key
,
value
in
gguf
.
MODEL_ARCH_NAMES
.
items
():
if
value
==
model_type
:
arch
=
key
break
if
arch
is
None
:
raise
RuntimeError
(
f
"Unknown gguf model_type:
{
model_type
}
"
)
num_layers
=
config
.
num_hidden_layers
name_map
=
gguf
.
get_tensor_name_map
(
arch
,
num_layers
)
with
torch
.
device
(
"meta"
):
dummy_model
=
AutoModelForCausalLM
.
from_config
(
config
)
state_dict
=
dummy_model
.
state_dict
()
gguf_to_hf_name_map
=
{}
for
hf_name
in
state_dict
:
name
,
suffix
=
hf_name
.
rsplit
(
"."
,
1
)
gguf_name
=
name_map
.
get_name
(
name
)
gguf_to_hf_name_map
[
f
"
{
gguf_name
}
.
{
suffix
}
"
]
=
hf_name
return
gguf_to_hf_name_map
def
_get_weights_iterator
(
self
,
model_name_or_path
:
str
,
gguf_to_hf_name_map
:
Dict
[
str
,
str
]
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
return
gguf_quant_weights_iterator
(
model_name_or_path
,
gguf_to_hf_name_map
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
multimodal_config
:
Optional
[
MultiModalConfig
],
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
)
->
nn
.
Module
:
local_model_path
=
self
.
_prepare_weights
(
model_config
.
model
)
gguf_weights_map
=
self
.
_get_gguf_weights_map
(
model_config
)
# we can only know if tie word embeddings after mapping weights
if
"lm_head.weight"
in
get_gguf_extra_tensor_names
(
local_model_path
,
gguf_weights_map
):
model_config
.
hf_config
.
update
({
"tie_word_embeddings"
:
True
})
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
lora_config
,
multimodal_config
,
cache_config
)
model
.
load_weights
(
self
.
_get_weights_iterator
(
local_model_path
,
gguf_weights_map
))
return
model
def
get_model_loader
(
load_config
:
LoadConfig
)
->
BaseModelLoader
:
"""Get a model loader based on the load format."""
...
...
@@ -966,4 +1053,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if
load_config
.
load_format
==
LoadFormat
.
BITSANDBYTES
:
return
BitsAndBytesModelLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
GGUF
:
return
GGUFModelLoader
(
load_config
)
return
DefaultModelLoader
(
load_config
)
vllm/model_executor/model_loader/weight_utils.py
View file @
360bd67c
...
...
@@ -6,9 +6,10 @@ import json
import
os
import
tempfile
from
collections
import
defaultdict
from
typing
import
Any
,
Generator
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Generator
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
filelock
import
gguf
import
huggingface_hub.constants
import
numpy
as
np
import
torch
...
...
@@ -121,6 +122,11 @@ def get_quant_config(model_config: ModelConfig,
load_config
:
LoadConfig
)
->
QuantizationConfig
:
quant_cls
=
get_quantization_config
(
model_config
.
quantization
)
# GGUF doesn't have config file
if
model_config
.
quantization
==
"gguf"
:
return
quant_cls
.
from_config
({})
# Read the quantization config from the HF model config, if available.
hf_quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
None
)
...
...
@@ -409,6 +415,45 @@ def pt_weights_iterator(
torch
.
cuda
.
empty_cache
()
def
get_gguf_extra_tensor_names
(
gguf_file
:
str
,
gguf_to_hf_name_map
:
Dict
[
str
,
str
])
->
List
[
str
]:
reader
=
gguf
.
GGUFReader
(
gguf_file
)
expected_gguf_keys
=
set
(
gguf_to_hf_name_map
.
keys
())
exact_gguf_keys
=
set
([
tensor
.
name
for
tensor
in
reader
.
tensors
])
extra_keys
=
expected_gguf_keys
-
exact_gguf_keys
return
[
gguf_to_hf_name_map
[
key
]
for
key
in
extra_keys
]
def
gguf_quant_weights_iterator
(
gguf_file
:
str
,
gguf_to_hf_name_map
:
Dict
[
str
,
str
]
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""
Iterate over the quant weights in the model gguf files and convert
them to torch tensors
"""
reader
=
gguf
.
GGUFReader
(
gguf_file
)
for
tensor
in
reader
.
tensors
:
weight_type
=
tensor
.
tensor_type
name
=
gguf_to_hf_name_map
[
tensor
.
name
]
if
weight_type
.
name
!=
"F32"
:
weight_type_name
=
name
.
replace
(
"weight"
,
"qweight_type"
)
weight_type
=
torch
.
tensor
(
weight_type
)
yield
weight_type_name
,
weight_type
for
tensor
in
reader
.
tensors
:
weight
=
tensor
.
data
weight_type
=
tensor
.
tensor_type
name
=
gguf_to_hf_name_map
[
tensor
.
name
]
if
weight_type
.
name
!=
"F32"
:
name
=
name
.
replace
(
"weight"
,
"qweight"
)
param
=
torch
.
tensor
(
weight
)
yield
name
,
param
def
kv_cache_scales_loader
(
filename
:
str
,
tp_rank
:
int
,
tp_size
:
int
,
num_hidden_layers
:
int
,
model_type
:
Optional
[
str
])
->
Iterable
[
Tuple
[
int
,
float
]]:
...
...
vllm/model_executor/models/llama.py
View file @
360bd67c
...
...
@@ -140,6 +140,7 @@ class LlamaAttention(nn.Module):
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
input_size
=
self
.
total_num_heads
*
self
.
head_dim
,
output_size
=
hidden_size
,
...
...
@@ -148,12 +149,17 @@ class LlamaAttention(nn.Module):
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
is_neox_style
=
True
if
quant_config
is
not
None
and
quant_config
.
get_name
()
==
"gguf"
:
is_neox_style
=
False
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
is_neox_style
=
is_neox_style
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
...
...
@@ -279,6 +285,7 @@ class LlamaModel(nn.Module):
self
.
vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
quant_config
=
quant_config
,
)
else
:
self
.
embed_tokens
=
PPMissingLayer
()
...
...
vllm/model_executor/models/qwen2.py
View file @
360bd67c
...
...
@@ -238,6 +238,7 @@ class Qwen2Model(nn.Module):
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
)
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
...
...
vllm/transformers_utils/config.py
View file @
360bd67c
import
contextlib
from
typing
import
Dict
,
Optional
,
Type
from
pathlib
import
Path
from
typing
import
Dict
,
Optional
,
Type
,
Union
from
transformers
import
GenerationConfig
,
PretrainedConfig
from
transformers.models.auto.modeling_auto
import
(
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
)
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.logger
import
init_logger
...
...
@@ -36,18 +39,29 @@ for name, cls in _CONFIG_REGISTRY.items():
AutoConfig
.
register
(
name
,
cls
)
def
get_config
(
model
:
str
,
trust_remote_code
:
bool
,
revision
:
Optional
[
str
]
=
None
,
code_revision
:
Optional
[
str
]
=
None
,
rope_scaling
:
Optional
[
dict
]
=
None
,
rope_theta
:
Optional
[
float
]
=
None
)
->
PretrainedConfig
:
def
get_config
(
model
:
Union
[
str
,
Path
],
trust_remote_code
:
bool
,
revision
:
Optional
[
str
]
=
None
,
code_revision
:
Optional
[
str
]
=
None
,
rope_scaling
:
Optional
[
dict
]
=
None
,
rope_theta
:
Optional
[
float
]
=
None
,
**
kwargs
,
)
->
PretrainedConfig
:
# Separate model folder from file path for GGUF models
is_gguf
=
Path
(
model
).
is_file
()
and
Path
(
model
).
suffix
==
".gguf"
if
is_gguf
:
kwargs
[
"gguf_file"
]
=
Path
(
model
).
name
model
=
Path
(
model
).
parent
try
:
config
=
AutoConfig
.
from_pretrained
(
model
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
code_revision
=
code_revision
)
code_revision
=
code_revision
,
**
kwargs
)
except
ValueError
as
e
:
if
(
not
trust_remote_code
and
"requires you to execute the configuration file"
in
str
(
e
)):
...
...
@@ -64,12 +78,22 @@ def get_config(model: str,
config
=
config_class
.
from_pretrained
(
model
,
revision
=
revision
,
code_revision
=
code_revision
)
# Special architecture mapping check for GGUF models
if
is_gguf
:
if
config
.
model_type
not
in
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
:
raise
RuntimeError
(
f
"Can't get gguf config for
{
config
.
model_type
}
."
)
model_type
=
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
[
config
.
model_type
]
config
.
update
({
"architectures"
:
[
model_type
]})
for
key
,
value
in
[(
"rope_scaling"
,
rope_scaling
),
(
"rope_theta"
,
rope_theta
)]:
if
value
is
not
None
:
logger
.
info
(
"Updating %s from %r to %r"
,
key
,
getattr
(
config
,
key
,
None
),
value
)
config
.
update
({
key
:
value
})
return
config
...
...
vllm/transformers_utils/tokenizer.py
View file @
360bd67c
import
os
from
pathlib
import
Path
from
typing
import
Optional
,
Union
import
huggingface_hub
...
...
@@ -55,7 +56,7 @@ def get_cached_tokenizer(
def
get_tokenizer
(
tokenizer_name
:
str
,
tokenizer_name
:
Union
[
str
,
Path
]
,
*
args
,
tokenizer_mode
:
str
=
"auto"
,
trust_remote_code
:
bool
=
False
,
...
...
@@ -91,6 +92,13 @@ def get_tokenizer(
if
"truncation_side"
not
in
kwargs
:
kwargs
[
"truncation_side"
]
=
"left"
# Separate model folder from file path for GGUF models
is_gguf
=
Path
(
tokenizer_name
).
is_file
()
and
Path
(
tokenizer_name
).
suffix
==
".gguf"
if
is_gguf
:
kwargs
[
"gguf_file"
]
=
Path
(
tokenizer_name
).
name
tokenizer_name
=
Path
(
tokenizer_name
).
parent
try
:
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer_name
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment