Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
c8948361
Unverified
Commit
c8948361
authored
Jul 08, 2023
by
Andre Slavescu
Committed by
GitHub
Jul 08, 2023
Browse files
[Model] Add support for GPT-J (#226)
Co-authored-by:
woWoosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
75beba29
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
269 additions
and
7 deletions
+269
-7
README.md
README.md
+1
-0
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+4
-4
docs/source/models/supported_models.rst
docs/source/models/supported_models.rst
+3
-0
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+2
-2
vllm/model_executor/layers/attention.py
vllm/model_executor/layers/attention.py
+1
-1
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+3
-0
vllm/model_executor/model_loader.py
vllm/model_executor/model_loader.py
+1
-0
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+2
-0
vllm/model_executor/models/gpt_j.py
vllm/model_executor/models/gpt_j.py
+251
-0
vllm/model_executor/models/mpt.py
vllm/model_executor/models/mpt.py
+1
-0
No files found.
README.md
View file @
c8948361
...
...
@@ -44,6 +44,7 @@ vLLM seamlessly supports many Huggingface models, including the following archit
-
BLOOM (
`bigscience/bloom`
,
`bigscience/bloomz`
, etc.)
-
GPT-2 (
`gpt2`
,
`gpt2-xl`
, etc.)
-
GPT BigCode (
`bigcode/starcoder`
,
`bigcode/gpt_bigcode-santacoder`
, etc.)
-
GPT-J (
`EleutherAI/gpt-j-6b`
,
`nomic-ai/gpt4all-j`
, etc.)
-
GPT-NeoX (
`EleutherAI/gpt-neox-20b`
,
`databricks/dolly-v2-12b`
,
`stabilityai/stablelm-tuned-alpha-7b`
, etc.)
-
LLaMA (
`lmsys/vicuna-13b-v1.3`
,
`young-geng/koala`
,
`openlm-research/open_llama_13b`
, etc.)
-
MPT (
`mosaicml/mpt-7b`
,
`mosaicml/mpt-30b`
, etc.)
...
...
csrc/attention/attention_kernels.cu
View file @
c8948361
...
...
@@ -382,7 +382,7 @@ void single_query_cached_kv_attention_launcher(
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
switch
(
head_size
)
{
// NOTE(woosuk): To reduce the compilation time, we omitted head sizes
// 32, 160, 192
, 256
.
// 32, 160, 192.
// case 32:
// LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
// break;
...
...
@@ -407,9 +407,9 @@ void single_query_cached_kv_attention_launcher(
// case 192:
// LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
// break;
//
case 256:
//
LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
//
break;
case
256
:
LAUNCH_ATTENTION_KERNEL
(
T
,
256
,
BLOCK_SIZE
,
NUM_THREADS
);
break
;
default:
TORCH_CHECK
(
false
,
"Unsupported head size: "
,
head_size
);
break
;
...
...
docs/source/models/supported_models.rst
View file @
c8948361
...
...
@@ -23,6 +23,9 @@ Alongside each architecture, we include some popular models that use it.
* - :code:`GPTBigCodeForCausalLM`
- StarCoder, SantaCoder, WizardCoder
- :code:`bigcode/starcoder`, :code:`bigcode/gpt_bigcode-santacoder`, :code:`WizardLM/WizardCoder-15B-V1.0`, etc.
* - :code:`GPTJForCausalLM`
- GPT-J
- :code:`EleutherAI/gpt-j-6b`, :code:`nomic-ai/gpt4all-j`, etc.
* - :code:`GPTNeoXForCausalLM`
- GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM
- :code:`EleutherAI/gpt-neox-20b`, :code:`EleutherAI/pythia-12b`, :code:`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, :code:`databricks/dolly-v2-12b`, :code:`stabilityai/stablelm-tuned-alpha-7b`, etc.
...
...
tests/kernels/test_attention.py
View file @
c8948361
...
...
@@ -286,7 +286,7 @@ def test_single_query_cached_kv_attention() -> None:
torch
.
cuda
.
manual_seed
(
TEST_SEED
)
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
block_size
in
[
8
,
16
,
32
]:
for
head_size
in
[
64
,
80
,
96
,
1
28
]:
for
head_size
in
[
64
,
80
,
96
,
1
12
,
128
,
256
]:
print
(
f
'Testing single_query_cached_kv_attention with '
f
'dtype=
{
dtype
}
, block_size=
{
block_size
}
, '
f
'head_size=
{
head_size
}
'
)
...
...
@@ -304,7 +304,7 @@ def test_multi_query_kv_attention() -> None:
torch
.
random
.
manual_seed
(
TEST_SEED
)
torch
.
cuda
.
manual_seed
(
TEST_SEED
)
for
dtype
in
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]:
for
head_size
in
[
64
,
80
,
96
,
1
28
]:
for
head_size
in
[
64
,
80
,
96
,
1
12
,
128
,
256
]:
print
(
f
'Testing multi_query_kv_attention with dtype=
{
dtype
}
, '
f
'head_size=
{
head_size
}
'
)
run_multi_query_kv_attention
(
...
...
vllm/model_executor/layers/attention.py
View file @
c8948361
...
...
@@ -12,7 +12,7 @@ from vllm import cache_ops
from
vllm
import
pos_encoding_ops
from
vllm.model_executor.input_metadata
import
InputMetadata
_SUPPORTED_HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
]
_SUPPORTED_HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
class
PagedAttention
(
nn
.
Module
):
...
...
vllm/model_executor/layers/sampler.py
View file @
c8948361
...
...
@@ -38,12 +38,15 @@ class Sampler(nn.Module):
embedding
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
embedding_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Dict
[
int
,
SequenceOutputs
]:
# Get the hidden states that we use for sampling.
hidden_states
=
_prune_hidden_states
(
hidden_states
,
input_metadata
)
# Get the logits for the next tokens.
logits
=
torch
.
matmul
(
hidden_states
,
embedding
.
t
())
if
embedding_bias
is
not
None
:
logits
+=
embedding_bias
logits
=
gather_from_tensor_model_parallel_region
(
logits
)
# Remove paddings in vocab (if any).
logits
=
logits
[:,
:
self
.
vocab_size
]
...
...
vllm/model_executor/model_loader.py
View file @
c8948361
...
...
@@ -14,6 +14,7 @@ _MODEL_REGISTRY = {
"BloomForCausalLM"
:
BloomForCausalLM
,
"GPT2LMHeadModel"
:
GPT2LMHeadModel
,
"GPTBigCodeForCausalLM"
:
GPTBigCodeForCausalLM
,
"GPTJForCausalLM"
:
GPTJForCausalLM
,
"GPTNeoXForCausalLM"
:
GPTNeoXForCausalLM
,
"LlamaForCausalLM"
:
LlamaForCausalLM
,
"LLaMAForCausalLM"
:
LlamaForCausalLM
,
# For decapoda-research/llama-*
...
...
vllm/model_executor/models/__init__.py
View file @
c8948361
from
vllm.model_executor.models.bloom
import
BloomForCausalLM
from
vllm.model_executor.models.gpt2
import
GPT2LMHeadModel
from
vllm.model_executor.models.gpt_bigcode
import
GPTBigCodeForCausalLM
from
vllm.model_executor.models.gpt_j
import
GPTJForCausalLM
from
vllm.model_executor.models.gpt_neox
import
GPTNeoXForCausalLM
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.model_executor.models.mpt
import
MPTForCausalLM
...
...
@@ -10,6 +11,7 @@ __all__ = [
"BloomForCausalLM"
,
"GPT2LMHeadModel"
,
"GPTBigCodeForCausalLM"
,
"GPTJForCausalLM"
,
"GPTNeoXForCausalLM"
,
"LlamaForCausalLM"
,
"MPTForCausalLM"
,
...
...
vllm/model_executor/models/gpt_j.py
0 → 100644
View file @
c8948361
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
# Copyright 2023 The vLLM team.
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-J model compatible with HuggingFace weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
GPTJConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
PagedAttentionWithRoPE
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.weight_utils
import
(
hf_model_weights_iterator
,
load_tensor_parallel_weights
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.sequence
import
SequenceOutputs
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
GPTJAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GPTJConfig
):
super
().
__init__
()
self
.
total_num_heads
=
config
.
num_attention_heads
self
.
hidden_size
=
config
.
hidden_size
self
.
head_size
=
self
.
hidden_size
//
self
.
total_num_heads
self
.
qkv_proj
=
ColumnParallelLinear
(
config
.
hidden_size
,
3
*
config
.
hidden_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
out_proj
=
RowParallelLinear
(
config
.
hidden_size
,
config
.
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
tp_world_size
=
get_tensor_model_parallel_world_size
()
assert
self
.
total_num_heads
%
tp_world_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_world_size
scaling
=
self
.
head_size
**-
0.5
assert
config
.
rotary
assert
config
.
rotary_dim
%
2
==
0
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
head_size
,
scaling
,
config
.
rotary_dim
)
self
.
warmup
=
False
def
forward
(
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
position_ids
,
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
attn_output
,
_
=
self
.
out_proj
(
attn_output
)
return
attn_output
class
GPTJMLP
(
nn
.
Module
):
def
__init__
(
self
,
intermediate_size
:
int
,
config
:
GPTJConfig
):
super
().
__init__
()
hidden_size
=
config
.
n_embd
self
.
fc_in
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
fc_out
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
self
.
act
=
get_act_fn
(
config
.
activation_function
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
fc_in
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
,
_
=
self
.
fc_out
(
hidden_states
)
return
hidden_states
class
GPTJBlock
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GPTJConfig
):
super
().
__init__
()
if
config
.
n_inner
is
None
:
inner_dim
=
4
*
config
.
n_embd
else
:
inner_dim
=
config
.
n_inner
self
.
ln_1
=
nn
.
LayerNorm
(
config
.
n_embd
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
GPTJAttention
(
config
)
self
.
mlp
=
GPTJMLP
(
inner_dim
,
config
)
def
forward
(
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
ln_1
(
hidden_states
)
attn_output
=
self
.
attn
(
position_ids
=
position_ids
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input_metadata
=
input_metadata
,
cache_event
=
cache_event
,
)
mlp_output
=
self
.
mlp
(
hidden_states
)
hidden_states
=
attn_output
+
mlp_output
+
residual
return
hidden_states
class
GPTJModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GPTJConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
n_embd
self
.
wte
=
VocabParallelEmbedding
(
config
.
vocab_size
,
self
.
embed_dim
,
perform_initialization
=
False
)
self
.
h
=
nn
.
ModuleList
(
[
GPTJBlock
(
config
)
for
_
in
range
(
config
.
n_layer
)])
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
torch
.
Tensor
:
hidden_states
=
self
.
wte
(
input_ids
)
for
i
in
range
(
len
(
self
.
h
)):
if
cache_events
is
None
:
cache_event
=
None
else
:
cache_event
=
cache_events
[
i
]
layer
=
self
.
h
[
i
]
hidden_states
=
layer
(
position_ids
,
hidden_states
,
kv_caches
[
i
],
input_metadata
,
cache_event
,
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
class
GPTJForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
GPTJConfig
):
super
().
__init__
()
self
.
config
=
config
assert
not
config
.
tie_word_embeddings
self
.
transformer
=
GPTJModel
(
config
)
self
.
lm_head
=
ColumnParallelLinear
(
config
.
n_embd
,
config
.
vocab_size
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
SequenceOutputs
]:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
input_metadata
,
self
.
lm_head
.
bias
)
return
next_tokens
_column_parallel_weights
=
[
"wte.weight"
,
"fc_in.weight"
,
"fc_in.bias"
,
"lm_head.weight"
,
"lm_head.bias"
]
_row_parallel_weights
=
[
"out_proj.weight"
,
"fc_out.weight"
]
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
tp_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
if
"attn.bias"
in
name
or
"attn.masked_bias"
in
name
:
continue
is_attention_weight
=
False
for
stride_id
,
att_weight_name
in
enumerate
(
[
"q_proj"
,
"k_proj"
,
"v_proj"
]):
if
att_weight_name
not
in
name
:
continue
param
=
state_dict
[
name
.
replace
(
att_weight_name
,
"qkv_proj"
)]
shard_size
=
param
.
shape
[
1
]
loaded_weight
=
loaded_weight
[
shard_size
*
tp_rank
:
shard_size
*
(
tp_rank
+
1
)]
param_slice
=
param
.
data
[
shard_size
*
stride_id
:
shard_size
*
(
stride_id
+
1
)]
assert
param_slice
.
shape
==
loaded_weight
.
shape
param_slice
.
copy_
(
loaded_weight
)
is_attention_weight
=
True
break
if
is_attention_weight
:
continue
param
=
state_dict
[
name
]
load_tensor_parallel_weights
(
param
,
loaded_weight
,
name
,
self
.
_column_parallel_weights
,
self
.
_row_parallel_weights
,
tp_rank
)
vllm/model_executor/models/mpt.py
View file @
c8948361
# coding=utf-8
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import
math
from
typing
import
Dict
,
List
,
Optional
,
Tuple
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment