Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
de60a3fb
Unverified
Commit
de60a3fb
authored
Dec 19, 2023
by
avideci
Committed by
GitHub
Dec 19, 2023
Browse files
Added DeciLM-7b and DeciLM-7b-instruct (#2062)
parent
21d5daa4
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
129 additions
and
0 deletions
+129
-0
README.md
README.md
+1
-0
docs/source/models/supported_models.rst
docs/source/models/supported_models.rst
+3
-0
tests/models/test_models.py
tests/models/test_models.py
+1
-0
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+1
-0
vllm/model_executor/models/decilm.py
vllm/model_executor/models/decilm.py
+123
-0
No files found.
README.md
View file @
de60a3fb
...
...
@@ -54,6 +54,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
-
Baichuan & Baichuan2 (
`baichuan-inc/Baichuan2-13B-Chat`
,
`baichuan-inc/Baichuan-7B`
, etc.)
-
BLOOM (
`bigscience/bloom`
,
`bigscience/bloomz`
, etc.)
-
ChatGLM (
`THUDM/chatglm2-6b`
,
`THUDM/chatglm3-6b`
, etc.)
-
DeciLM (
`Deci/DeciLM-7B`
,
`Deci/DeciLM-7B-instruct`
, etc.)
-
Falcon (
`tiiuae/falcon-7b`
,
`tiiuae/falcon-40b`
,
`tiiuae/falcon-rw-7b`
, etc.)
-
GPT-2 (
`gpt2`
,
`gpt2-xl`
, etc.)
-
GPT BigCode (
`bigcode/starcoder`
,
`bigcode/gpt_bigcode-santacoder`
, etc.)
...
...
docs/source/models/supported_models.rst
View file @
de60a3fb
...
...
@@ -23,6 +23,9 @@ Alongside each architecture, we include some popular models that use it.
* - :code:`ChatGLMModel`
- ChatGLM
- :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc.
* - :code:`DeciLMForCausalLM`
- DeciLM
- :code:`Deci/DeciLM-7B`, :code:`Deci/DeciLM-7B-instruct`, etc.
* - :code:`BloomForCausalLM`
- BLOOM, BLOOMZ, BLOOMChat
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
...
...
tests/models/test_models.py
View file @
de60a3fb
...
...
@@ -8,6 +8,7 @@ MODELS = [
"facebook/opt-125m"
,
"meta-llama/Llama-2-7b-hf"
,
"mistralai/Mistral-7B-v0.1"
,
"Deci/DeciLM-7b"
,
"tiiuae/falcon-7b"
,
"gpt2"
,
"bigcode/tiny_starcoder_py"
,
...
...
vllm/model_executor/models/__init__.py
View file @
de60a3fb
...
...
@@ -17,6 +17,7 @@ _MODELS = {
"BloomForCausalLM"
:
(
"bloom"
,
"BloomForCausalLM"
),
"ChatGLMModel"
:
(
"chatglm"
,
"ChatGLMForCausalLM"
),
"ChatGLMForConditionalGeneration"
:
(
"chatglm"
,
"ChatGLMForCausalLM"
),
"DeciLMForCausalLM"
:
(
"decilm"
,
"DeciLMForCausalLM"
),
"FalconForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"GPT2LMHeadModel"
:
(
"gpt2"
,
"GPT2LMHeadModel"
),
"GPTBigCodeForCausalLM"
:
(
"gpt_bigcode"
,
"GPTBigCodeForCausalLM"
),
...
...
vllm/model_executor/models/decilm.py
0 → 100644
View file @
de60a3fb
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 DeciAI Research Team. All rights reserved.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on MistralAI GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only DeciLM model compatible with HuggingFace weights."""
from
typing
import
Optional
import
torch
from
transformers
import
PretrainedConfig
from
vllm.model_executor.layers.linear
import
LinearMethodBase
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
class
DeciLMForCausalLM
(
LlamaForCausalLM
):
"""
Implementation for https://huggingface.co/Deci/DeciLM-7b-instruct.
Based on the llama executor.
The main difference is that DeciLM uses Variable Grouped Query Attention.
The constant number of GQA heads in the decoder is overriden with a value
per layer.
Usually, in the HuggingFace implementation, instead of
"config.num_key_value_heads", we use
"config.num_key_value_heads_per_layer[i]" which varies.
Currently, PagedAttention does not work well with variable GQA, so we
normalize the weights upon loading, and use uniform GQA with the max value
instead.
"""
def
__init__
(
self
,
config
:
Optional
[
PretrainedConfig
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
config
.
num_key_value_heads
=
max
(
config
.
num_key_value_heads_per_layer
)
delattr
(
config
,
"num_key_value_heads_per_layer"
)
super
().
__init__
(
config
=
config
,
linear_method
=
linear_method
)
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
"rotary_emb.inv_freq"
in
name
:
continue
if
"k_proj"
in
name
or
"v_proj"
in
name
:
loaded_weight
=
self
.
_degroup_weight
(
loaded_weight
)
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
def
_degroup_weight
(
self
,
loaded_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_size
=
self
.
config
.
hidden_size
head_size
=
self
.
config
.
hidden_size
//
self
.
config
.
num_attention_heads
target_num_kv_heads
=
self
.
config
.
num_key_value_heads
num_kv_heads
=
loaded_weight
.
shape
[
0
]
//
head_size
n_repeats
=
target_num_kv_heads
/
num_kv_heads
assert
n_repeats
==
int
(
n_repeats
)
n_repeats
=
int
(
n_repeats
)
loaded_weight
=
loaded_weight
.
view
(
num_kv_heads
,
head_size
,
hidden_size
)
loaded_weight
=
torch
.
repeat_interleave
(
loaded_weight
,
repeats
=
n_repeats
,
dim
=
0
)
loaded_weight
=
loaded_weight
.
reshape
(
target_num_kv_heads
*
head_size
,
hidden_size
)
return
loaded_weight
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment