Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
e41f0670
Unverified
Commit
e41f0670
authored
Jul 03, 2023
by
Woosuk Kwon
Committed by
GitHub
Jul 03, 2023
Browse files
Add support for BLOOM (#331)
parent
d6fa1be3
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
479 additions
and
18 deletions
+479
-18
README.md
README.md
+1
-0
csrc/attention.cpp
csrc/attention.cpp
+3
-1
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+19
-6
docs/source/models/supported_models.rst
docs/source/models/supported_models.rst
+3
-0
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+1
-0
vllm/model_executor/input_metadata.py
vllm/model_executor/input_metadata.py
+4
-2
vllm/model_executor/layers/attention.py
vllm/model_executor/layers/attention.py
+128
-5
vllm/model_executor/model_loader.py
vllm/model_executor/model_loader.py
+2
-3
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+2
-0
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+316
-0
vllm/model_executor/models/gpt_neox.py
vllm/model_executor/models/gpt_neox.py
+0
-1
No files found.
README.md
View file @
e41f0670
...
@@ -41,6 +41,7 @@ vLLM is flexible and easy to use with:
...
@@ -41,6 +41,7 @@ vLLM is flexible and easy to use with:
vLLM seamlessly supports many Huggingface models, including the following architectures:
vLLM seamlessly supports many Huggingface models, including the following architectures:
-
BLOOM (
`bigscience/bloom`
,
`bigscience/bloomz`
, etc.)
-
GPT-2 (
`gpt2`
,
`gpt2-xl`
, etc.)
-
GPT-2 (
`gpt2`
,
`gpt2-xl`
, etc.)
-
GPT BigCode (
`bigcode/starcoder`
,
`bigcode/gpt_bigcode-santacoder`
, etc.)
-
GPT BigCode (
`bigcode/starcoder`
,
`bigcode/gpt_bigcode-santacoder`
, etc.)
-
GPT-NeoX (
`EleutherAI/gpt-neox-20b`
,
`databricks/dolly-v2-12b`
,
`stabilityai/stablelm-tuned-alpha-7b`
, etc.)
-
GPT-NeoX (
`EleutherAI/gpt-neox-20b`
,
`databricks/dolly-v2-12b`
,
`stabilityai/stablelm-tuned-alpha-7b`
, etc.)
...
...
csrc/attention.cpp
View file @
e41f0670
#include <torch/extension.h>
#include <torch/extension.h>
#include <c10/util/Optional.h>
void
single_query_cached_kv_attention
(
void
single_query_cached_kv_attention
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
out
,
...
@@ -9,7 +10,8 @@ void single_query_cached_kv_attention(
...
@@ -9,7 +10,8 @@ void single_query_cached_kv_attention(
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context_lens
,
torch
::
Tensor
&
context_lens
,
int
block_size
,
int
block_size
,
int
max_context_len
);
int
max_context_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
m
.
def
(
...
...
csrc/attention/attention_kernels.cu
View file @
e41f0670
...
@@ -80,6 +80,7 @@ __global__ void single_query_cached_kv_attention_kernel(
...
@@ -80,6 +80,7 @@ __global__ void single_query_cached_kv_attention_kernel(
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
)
{
const
int
q_stride
)
{
constexpr
int
THREAD_GROUP_SIZE
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
constexpr
int
THREAD_GROUP_SIZE
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
constexpr
int
NUM_TOKENS_PER_THREAD_GROUP
=
(
BLOCK_SIZE
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
constexpr
int
NUM_TOKENS_PER_THREAD_GROUP
=
(
BLOCK_SIZE
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
...
@@ -91,6 +92,7 @@ __global__ void single_query_cached_kv_attention_kernel(
...
@@ -91,6 +92,7 @@ __global__ void single_query_cached_kv_attention_kernel(
const
int
head_idx
=
blockIdx
.
x
;
const
int
head_idx
=
blockIdx
.
x
;
const
int
num_heads
=
gridDim
.
x
;
const
int
num_heads
=
gridDim
.
x
;
const
int
seq_idx
=
blockIdx
.
y
;
const
int
seq_idx
=
blockIdx
.
y
;
const
float
alibi_slope
=
alibi_slopes
==
nullptr
?
0.
f
:
alibi_slopes
[
head_idx
];
// A vector type to store a part of a key or a query.
// A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread group
// The vector size is configured in such a way that the threads in a thread group
...
@@ -167,12 +169,14 @@ __global__ void single_query_cached_kv_attention_kernel(
...
@@ -167,12 +169,14 @@ __global__ void single_query_cached_kv_attention_kernel(
// Compute dot product.
// Compute dot product.
// This includes a reduction across the threads in the same thread group.
// This includes a reduction across the threads in the same thread group.
const
float
qk
=
scale
*
Qk_dot
<
scalar_t
,
THREAD_GROUP_SIZE
>::
dot
(
q_vecs
,
k_vecs
);
float
qk
=
scale
*
Qk_dot
<
scalar_t
,
THREAD_GROUP_SIZE
>::
dot
(
q_vecs
,
k_vecs
);
const
bool
mask
=
token_idx
>=
context_len
;
// Add the ALiBi bias if slopes are given.
qk
+=
(
alibi_slope
!=
0
)
?
alibi_slope
*
(
token_idx
-
context_len
)
:
0
;
if
(
thread_group_offset
==
0
)
{
if
(
thread_group_offset
==
0
)
{
// Store the partial reductions to shared memory.
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
// NOTE(woosuk): It is required to zero out the masked logits.
const
bool
mask
=
token_idx
>=
context_len
;
logits
[
token_idx
]
=
mask
?
0.
f
:
qk
;
logits
[
token_idx
]
=
mask
?
0.
f
:
qk
;
// Update the max value.
// Update the max value.
qk_max
=
mask
?
qk_max
:
fmaxf
(
qk_max
,
qk
);
qk_max
=
mask
?
qk_max
:
fmaxf
(
qk_max
,
qk
);
...
@@ -328,6 +332,7 @@ __global__ void single_query_cached_kv_attention_kernel(
...
@@ -328,6 +332,7 @@ __global__ void single_query_cached_kv_attention_kernel(
block_tables_ptr, \
block_tables_ptr, \
context_lens_ptr, \
context_lens_ptr, \
max_num_blocks_per_seq, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
query_stride);
query_stride);
// TODO(woosuk): Tune NUM_THREADS.
// TODO(woosuk): Tune NUM_THREADS.
...
@@ -343,7 +348,8 @@ void single_query_cached_kv_attention_launcher(
...
@@ -343,7 +348,8 @@ void single_query_cached_kv_attention_launcher(
float
scale
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context_lens
,
torch
::
Tensor
&
context_lens
,
int
max_context_len
)
{
int
max_context_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
)
{
int
num_seqs
=
query
.
size
(
0
);
int
num_seqs
=
query
.
size
(
0
);
int
num_heads
=
query
.
size
(
1
);
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
int
head_size
=
query
.
size
(
2
);
...
@@ -353,6 +359,11 @@ void single_query_cached_kv_attention_launcher(
...
@@ -353,6 +359,11 @@ void single_query_cached_kv_attention_launcher(
int
thread_group_size
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
int
thread_group_size
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
assert
(
head_size
%
thread_group_size
==
0
);
assert
(
head_size
%
thread_group_size
==
0
);
// NOTE: alibi_slopes is optional.
const
float
*
alibi_slopes_ptr
=
alibi_slopes
?
reinterpret_cast
<
const
float
*>
(
alibi_slopes
.
value
().
data_ptr
())
:
nullptr
;
T
*
out_ptr
=
reinterpret_cast
<
T
*>
(
out
.
data_ptr
());
T
*
out_ptr
=
reinterpret_cast
<
T
*>
(
out
.
data_ptr
());
T
*
query_ptr
=
reinterpret_cast
<
T
*>
(
query
.
data_ptr
());
T
*
query_ptr
=
reinterpret_cast
<
T
*>
(
query
.
data_ptr
());
T
*
key_cache_ptr
=
reinterpret_cast
<
T
*>
(
key_cache
.
data_ptr
());
T
*
key_cache_ptr
=
reinterpret_cast
<
T
*>
(
key_cache
.
data_ptr
());
...
@@ -411,7 +422,8 @@ void single_query_cached_kv_attention_launcher(
...
@@ -411,7 +422,8 @@ void single_query_cached_kv_attention_launcher(
scale, \
scale, \
block_tables, \
block_tables, \
context_lens, \
context_lens, \
max_context_len);
max_context_len, \
alibi_slopes);
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
// 1, 2, 4, 64, 128, 256.
...
@@ -458,7 +470,8 @@ void single_query_cached_kv_attention(
...
@@ -458,7 +470,8 @@ void single_query_cached_kv_attention(
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
context_lens
,
// [num_seqs]
torch
::
Tensor
&
context_lens
,
// [num_seqs]
int
block_size
,
int
block_size
,
int
max_context_len
)
{
int
max_context_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
)
{
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Float
)
{
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Float
)
{
CALL_KERNEL_LAUNCHER_BLOCK_SIZE
(
float
);
CALL_KERNEL_LAUNCHER_BLOCK_SIZE
(
float
);
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
}
else
if
(
query
.
dtype
()
==
at
::
ScalarType
::
Half
)
{
...
...
docs/source/models/supported_models.rst
View file @
e41f0670
...
@@ -14,6 +14,9 @@ Alongside each architecture, we include some popular models that use it.
...
@@ -14,6 +14,9 @@ Alongside each architecture, we include some popular models that use it.
* - Architecture
* - Architecture
- Models
- Models
- Example HuggingFace Models
- Example HuggingFace Models
* - :code:`BloomForCausalLM`
- BLOOM, BLOOMZ, BLOOMChat
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
* - :code:`GPT2LMHeadModel`
* - :code:`GPT2LMHeadModel`
- GPT-2
- GPT-2
- :code:`gpt2`, :code:`gpt2-xl`, etc.
- :code:`gpt2`, :code:`gpt2-xl`, etc.
...
...
tests/kernels/test_attention.py
View file @
e41f0670
...
@@ -216,6 +216,7 @@ def run_single_query_cached_kv_attention(
...
@@ -216,6 +216,7 @@ def run_single_query_cached_kv_attention(
context_lens
,
context_lens
,
block_size
,
block_size
,
max_context_len
,
max_context_len
,
None
,
# ALiBi slopes.
)
)
ref_output
=
torch
.
empty_like
(
query
)
ref_output
=
torch
.
empty_like
(
query
)
...
...
vllm/model_executor/input_metadata.py
View file @
e41f0670
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
import
torch
import
torch
from
xformers.ops
.fmha.attn_bias
import
BlockDiagonalCausalM
as
k
from
xformers.ops
import
AttentionBi
as
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
SequenceData
from
vllm.sequence
import
SequenceData
...
@@ -38,7 +38,6 @@ class InputMetadata:
...
@@ -38,7 +38,6 @@ class InputMetadata:
self
.
max_context_len
=
max_context_len
self
.
max_context_len
=
max_context_len
self
.
block_tables
=
block_tables
self
.
block_tables
=
block_tables
self
.
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
prompt_lens
)
self
.
num_prompts
=
len
(
prompt_lens
)
self
.
num_prompts
=
len
(
prompt_lens
)
self
.
num_prompt_tokens
=
sum
(
prompt_lens
)
self
.
num_prompt_tokens
=
sum
(
prompt_lens
)
self
.
num_generation_tokens
=
context_lens
.
shape
[
0
]
self
.
num_generation_tokens
=
context_lens
.
shape
[
0
]
...
@@ -50,6 +49,9 @@ class InputMetadata:
...
@@ -50,6 +49,9 @@ class InputMetadata:
assert
block_tables
.
shape
[
0
]
==
self
.
num_generation_tokens
assert
block_tables
.
shape
[
0
]
==
self
.
num_generation_tokens
assert
context_lens
.
shape
[
0
]
==
self
.
num_generation_tokens
assert
context_lens
.
shape
[
0
]
==
self
.
num_generation_tokens
# Set during the execution of the first attention op.
self
.
attn_bias
:
List
[
AttentionBias
]
=
[]
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
# Print only useful metadata.
# Print only useful metadata.
return
(
f
'InputMetadata('
return
(
f
'InputMetadata('
...
...
vllm/model_executor/layers/attention.py
View file @
e41f0670
"""Multi-head attention."""
"""Multi-head attention."""
from
typing
import
Optional
from
typing
import
List
,
Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
xformers
import
ops
as
xops
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
(
BlockDiagonalCausalMask
,
LowerTriangularMaskWithTensorBias
)
from
vllm
import
attention_ops
from
vllm
import
attention_ops
from
vllm
import
cache_ops
from
vllm
import
cache_ops
...
@@ -53,13 +55,21 @@ class PagedAttention(nn.Module):
...
@@ -53,13 +55,21 @@ class PagedAttention(nn.Module):
raise
ValueError
(
f
"head_size (
{
self
.
head_size
}
) is not supported. "
raise
ValueError
(
f
"head_size (
{
self
.
head_size
}
) is not supported. "
f
"Supported head sizes:
{
_SUPPORTED_HEAD_SIZES
}
."
)
f
"Supported head sizes:
{
_SUPPORTED_HEAD_SIZES
}
."
)
def
set_attn_bias
(
self
,
input_metadata
:
InputMetadata
)
->
None
:
if
input_metadata
.
attn_bias
:
# Already set by a previous layer.
return
prompt_lens
=
input_metadata
.
prompt_lens
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
prompt_lens
)
input_metadata
.
attn_bias
.
append
(
attn_bias
)
def
multi_query_kv_attention
(
def
multi_query_kv_attention
(
self
,
self
,
output
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
attn_bias
:
xops
.
AttentionBias
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Normal attention for the prompt tokens.
"""Normal attention for the prompt tokens.
...
@@ -68,13 +78,14 @@ class PagedAttention(nn.Module):
...
@@ -68,13 +78,14 @@ class PagedAttention(nn.Module):
query: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_heads, head_size]
value: shape = [num_prompt_tokens, num_heads, head_size]
value: shape = [num_prompt_tokens, num_heads, head_size]
input_metadata: metadata for paged attention.
"""
"""
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
out
=
xops
.
memory_efficient_attention_forward
(
out
=
xops
.
memory_efficient_attention_forward
(
query
.
unsqueeze
(
0
),
query
.
unsqueeze
(
0
),
key
.
unsqueeze
(
0
),
key
.
unsqueeze
(
0
),
value
.
unsqueeze
(
0
),
value
.
unsqueeze
(
0
),
attn_bias
=
attn_bias
,
attn_bias
=
input_metadata
.
attn_bias
[
0
]
,
p
=
0.0
,
p
=
0.0
,
scale
=
self
.
scale
,
scale
=
self
.
scale
,
op
=
self
.
attn_op
,
op
=
self
.
attn_op
,
...
@@ -112,6 +123,7 @@ class PagedAttention(nn.Module):
...
@@ -112,6 +123,7 @@ class PagedAttention(nn.Module):
input_metadata
.
context_lens
,
input_metadata
.
context_lens
,
block_size
,
block_size
,
input_metadata
.
max_context_len
,
input_metadata
.
max_context_len
,
None
,
# alibi_slopes
)
)
def
forward
(
def
forward
(
...
@@ -154,12 +166,13 @@ class PagedAttention(nn.Module):
...
@@ -154,12 +166,13 @@ class PagedAttention(nn.Module):
# Compute the attention op for prompts.
# Compute the attention op for prompts.
num_prompt_tokens
=
input_metadata
.
num_prompt_tokens
num_prompt_tokens
=
input_metadata
.
num_prompt_tokens
if
num_prompt_tokens
>
0
:
if
num_prompt_tokens
>
0
:
self
.
set_attn_bias
(
input_metadata
)
self
.
multi_query_kv_attention
(
self
.
multi_query_kv_attention
(
output
[:
num_prompt_tokens
],
output
[:
num_prompt_tokens
],
query
[:
num_prompt_tokens
],
query
[:
num_prompt_tokens
],
key
[:
num_prompt_tokens
],
key
[:
num_prompt_tokens
],
value
[:
num_prompt_tokens
],
value
[:
num_prompt_tokens
],
input_metadata
.
attn_bias
,
input_metadata
,
)
)
# Wait until the cache op is done.
# Wait until the cache op is done.
...
@@ -219,7 +232,8 @@ class PagedAttentionWithRoPE(PagedAttention):
...
@@ -219,7 +232,8 @@ class PagedAttentionWithRoPE(PagedAttention):
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
# FIXME(woosuk): This assumes that we configure the default dtype when
# FIXME(woosuk): This assumes that we configure the default dtype when
# initializing the model. Make it more robust.
# initializing the model.
# TODO(woosuk): Make it more robust.
torch_dtype
=
torch
.
get_default_dtype
()
torch_dtype
=
torch
.
get_default_dtype
()
cache
=
cache
.
to
(
torch_dtype
)
cache
=
cache
.
to
(
torch_dtype
)
# Embedding size: [max_position, rotary_dim]
# Embedding size: [max_position, rotary_dim]
...
@@ -271,3 +285,112 @@ class PagedAttentionWithRoPE(PagedAttention):
...
@@ -271,3 +285,112 @@ class PagedAttentionWithRoPE(PagedAttention):
input_metadata
,
input_metadata
,
cache_event
,
cache_event
,
)
)
class
PagedAttentionWithALiBi
(
PagedAttention
):
"""PagedAttention with ALiBi attention bias."""
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
slopes
:
List
[
float
],
)
->
None
:
super
().
__init__
(
num_heads
,
head_size
,
scale
)
assert
len
(
slopes
)
==
num_heads
slopes
=
torch
.
tensor
(
slopes
,
dtype
=
torch
.
float32
)
self
.
register_buffer
(
"alibi_slopes"
,
slopes
,
persistent
=
False
)
def
set_attn_bias
(
self
,
input_metadata
:
InputMetadata
)
->
None
:
if
input_metadata
.
attn_bias
:
# Already set by a previous layer.
return
# Generates ALiBi mask for each prompt.
for
prompt_len
in
input_metadata
.
prompt_lens
:
bias
=
torch
.
arange
(
prompt_len
)
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
bias
=
bias
.
to
(
self
.
alibi_slopes
.
device
)
# When using custom attention bias, xformers requires the bias to
# be sliced from a tensor whose length is a multiple of 8.
padded_len
=
(
prompt_len
+
7
)
//
8
*
8
bias
=
torch
.
empty
(
self
.
num_heads
,
padded_len
,
padded_len
,
device
=
self
.
alibi_slopes
.
device
,
)[:,
:
prompt_len
,
:
prompt_len
].
copy_
(
bias
)
bias
.
mul_
(
self
.
alibi_slopes
[:,
None
,
None
])
attn_bias
=
LowerTriangularMaskWithTensorBias
(
bias
)
input_metadata
.
attn_bias
.
append
(
attn_bias
)
def
multi_query_kv_attention
(
self
,
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
"""Attention with ALiBi bias for the prompt tokens.
Args:
output: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_heads, head_size]
value: shape = [num_prompt_tokens, num_heads, head_size]
input_metadata: metadata for paged attention.
"""
# FIXME(woosuk): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
start
=
0
for
i
,
prompt_len
in
enumerate
(
input_metadata
.
prompt_lens
):
end
=
start
+
prompt_len
out
=
xops
.
memory_efficient_attention_forward
(
query
[
None
,
start
:
end
],
key
[
None
,
start
:
end
],
value
[
None
,
start
:
end
],
attn_bias
=
input_metadata
.
attn_bias
[
i
],
p
=
0.0
,
scale
=
self
.
scale
,
op
=
self
.
attn_op
,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output
[
start
:
end
].
copy_
(
out
.
squeeze
(
0
))
start
+=
prompt_len
return
output
def
single_query_cached_kv_attention
(
self
,
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
None
:
"""PagedAttention with ALiBi bias for the generation tokens.
Args:
output: shape = [num_generation_tokens, num_heads, head_size]
query: shape = [num_generation_tokens, num_heads, head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
input_metadata: metadata for paged attention.
"""
block_size
=
value_cache
.
shape
[
3
]
attention_ops
.
single_query_cached_kv_attention
(
output
,
query
,
key_cache
,
value_cache
,
self
.
scale
,
input_metadata
.
block_tables
,
input_metadata
.
context_lens
,
block_size
,
input_metadata
.
max_context_len
,
self
.
alibi_slopes
,
)
vllm/model_executor/model_loader.py
View file @
e41f0670
...
@@ -6,13 +6,12 @@ import torch.nn as nn
...
@@ -6,13 +6,12 @@ import torch.nn as nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.model_executor.models
import
(
GPT2LMHeadModel
,
GPTBigCodeForCausalLM
,
from
vllm.model_executor.models
import
*
# pylint: disable=wildcard-import
GPTNeoXForCausalLM
,
LlamaForCausalLM
,
OPTForCausalLM
)
from
vllm.model_executor.weight_utils
import
initialize_dummy_weights
from
vllm.model_executor.weight_utils
import
initialize_dummy_weights
# TODO(woosuk): Lazy-load the model classes.
# TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY
=
{
_MODEL_REGISTRY
=
{
"BloomForCausalLM"
:
BloomForCausalLM
,
"GPT2LMHeadModel"
:
GPT2LMHeadModel
,
"GPT2LMHeadModel"
:
GPT2LMHeadModel
,
"GPTBigCodeForCausalLM"
:
GPTBigCodeForCausalLM
,
"GPTBigCodeForCausalLM"
:
GPTBigCodeForCausalLM
,
"GPTNeoXForCausalLM"
:
GPTNeoXForCausalLM
,
"GPTNeoXForCausalLM"
:
GPTNeoXForCausalLM
,
...
...
vllm/model_executor/models/__init__.py
View file @
e41f0670
from
vllm.model_executor.models.bloom
import
BloomForCausalLM
from
vllm.model_executor.models.gpt2
import
GPT2LMHeadModel
from
vllm.model_executor.models.gpt2
import
GPT2LMHeadModel
from
vllm.model_executor.models.gpt_bigcode
import
GPTBigCodeForCausalLM
from
vllm.model_executor.models.gpt_bigcode
import
GPTBigCodeForCausalLM
from
vllm.model_executor.models.gpt_neox
import
GPTNeoXForCausalLM
from
vllm.model_executor.models.gpt_neox
import
GPTNeoXForCausalLM
...
@@ -5,6 +6,7 @@ from vllm.model_executor.models.llama import LlamaForCausalLM
...
@@ -5,6 +6,7 @@ from vllm.model_executor.models.llama import LlamaForCausalLM
from
vllm.model_executor.models.opt
import
OPTForCausalLM
from
vllm.model_executor.models.opt
import
OPTForCausalLM
__all__
=
[
__all__
=
[
"BloomForCausalLM"
,
"GPT2LMHeadModel"
,
"GPT2LMHeadModel"
,
"GPTBigCodeForCausalLM"
,
"GPTBigCodeForCausalLM"
,
"GPTNeoXForCausalLM"
,
"GPTNeoXForCausalLM"
,
...
...
vllm/model_executor/models/bloom.py
0 → 100644
View file @
e41f0670
# coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
# Copyright 2023 The CacheFlow team.
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only BLOOM model compatible with HuggingFace weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
import
math
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
BloomConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
PagedAttentionWithALiBi
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.weight_utils
import
(
hf_model_weights_iterator
,
load_tensor_parallel_weights
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.sequence
import
SequenceOutputs
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
def
_get_alibi_slopes
(
total_num_heads
:
int
)
->
torch
.
Tensor
:
closest_power_of_2
=
2
**
math
.
floor
(
math
.
log2
(
total_num_heads
))
base
=
torch
.
tensor
(
2
**
(
-
(
2
**-
(
math
.
log2
(
closest_power_of_2
)
-
3
))),
dtype
=
torch
.
float32
,
)
powers
=
torch
.
arange
(
1
,
1
+
closest_power_of_2
,
dtype
=
torch
.
int32
)
slopes
=
torch
.
pow
(
base
,
powers
)
if
closest_power_of_2
!=
total_num_heads
:
extra_base
=
torch
.
tensor
(
2
**
(
-
(
2
**-
(
math
.
log2
(
2
*
closest_power_of_2
)
-
3
))),
dtype
=
torch
.
float32
,
)
num_remaining_heads
=
min
(
closest_power_of_2
,
total_num_heads
-
closest_power_of_2
)
extra_powers
=
torch
.
arange
(
start
=
1
,
end
=
1
+
2
*
num_remaining_heads
,
step
=
2
,
dtype
=
torch
.
int32
)
slopes
=
torch
.
cat
(
[
slopes
,
torch
.
pow
(
extra_base
,
extra_powers
)],
dim
=
0
)
return
slopes
class
BloomAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
BloomConfig
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
total_num_heads
=
config
.
n_head
self
.
head_dim
=
self
.
hidden_size
//
self
.
total_num_heads
assert
self
.
head_dim
*
self
.
total_num_heads
==
self
.
hidden_size
tp_world_size
=
get_tensor_model_parallel_world_size
()
assert
self
.
total_num_heads
%
tp_world_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_world_size
self
.
query_key_value
=
ColumnParallelLinear
(
self
.
hidden_size
,
3
*
self
.
hidden_size
,
bias
=
True
,
gather_output
=
False
,
perform_initialization
=
False
,
)
self
.
dense
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
input_is_parallel
=
True
,
perform_initialization
=
False
,
)
# Create the alibi slopes and slice them.
tp_rank
=
get_tensor_model_parallel_rank
()
head_start
=
tp_rank
*
self
.
num_heads
head_end
=
(
tp_rank
+
1
)
*
self
.
num_heads
alibi_slopes
=
_get_alibi_slopes
(
self
.
total_num_heads
)
alibi_slopes
=
alibi_slopes
[
head_start
:
head_end
].
tolist
()
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
PagedAttentionWithALiBi
(
self
.
num_heads
,
self
.
head_dim
,
scaling
,
alibi_slopes
)
def
forward
(
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
del
position_ids
# Unused.
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
output
,
_
=
self
.
dense
(
attn_output
)
return
output
class
BloomMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
:
BloomConfig
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
self
.
dense_h_to_4h
=
ColumnParallelLinear
(
hidden_size
,
4
*
hidden_size
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
act
=
get_act_fn
(
"gelu"
)
self
.
dense_4h_to_h
=
RowParallelLinear
(
4
*
hidden_size
,
hidden_size
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
,
_
=
self
.
dense_h_to_4h
(
x
)
x
=
self
.
act
(
x
)
x
,
_
=
self
.
dense_4h_to_h
(
x
)
return
x
class
BloomBlock
(
nn
.
Module
):
def
__init__
(
self
,
config
:
BloomConfig
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
self
.
input_layernorm
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
self_attention
=
BloomAttention
(
config
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
BloomMLP
(
config
)
self
.
apply_residual_connection_post_layernorm
=
(
config
.
apply_residual_connection_post_layernorm
)
def
forward
(
self
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
# Layer norm at the beginning of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Layer norm post the self attention.
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
hidden_states
# Self attention.
attention_output
=
self
.
self_attention
(
position_ids
=
position_ids
,
hidden_states
=
layernorm_output
,
kv_cache
=
kv_cache
,
input_metadata
=
input_metadata
,
cache_event
=
cache_event
,
)
attention_output
=
attention_output
+
residual
layernorm_output
=
self
.
post_attention_layernorm
(
attention_output
)
# Get residual
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
else
:
residual
=
attention_output
# MLP.
output
=
self
.
mlp
(
layernorm_output
)
+
residual
return
output
class
BloomModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
BloomConfig
):
super
().
__init__
()
self
.
embed_dim
=
config
.
hidden_size
# Embedding + LN Embedding
self
.
word_embeddings
=
VocabParallelEmbedding
(
config
.
vocab_size
,
self
.
embed_dim
,
perform_initialization
=
False
)
self
.
word_embeddings_layernorm
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
# Transformer blocks
self
.
h
=
nn
.
ModuleList
(
[
BloomBlock
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
# Final Layer Norm
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
torch
.
Tensor
:
hidden_states
=
self
.
word_embeddings
(
input_ids
)
hidden_states
=
self
.
word_embeddings_layernorm
(
hidden_states
)
for
i
in
range
(
len
(
self
.
h
)):
if
cache_events
is
None
:
cache_event
=
None
else
:
cache_event
=
cache_events
[
i
]
layer
=
self
.
h
[
i
]
hidden_states
=
layer
(
position_ids
,
hidden_states
,
kv_caches
[
i
],
input_metadata
,
cache_event
,
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
class
BloomForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
BloomConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
transformer
=
BloomModel
(
config
)
# TODO(zhuohan): create a new weight after implementing pipeline
# parallelism
self
.
lm_head_weight
=
self
.
transformer
.
word_embeddings
.
weight
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
SequenceOutputs
]:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
input_metadata
)
return
next_tokens
_column_parallel_weights
=
[
"word_embeddings.weight"
,
"dense_h_to_4h.weight"
,
"dense_h_to_4h.bias"
]
_row_parallel_weights
=
[
"dense.weight"
,
"dense_4h_to_h.weight"
]
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
tp_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
if
not
name
.
startswith
(
"transformer."
):
name
=
"transformer."
+
name
param
=
state_dict
[
name
]
if
"query_key_value"
in
name
:
# NOTE(woosuk): BLOOM's fused QKV has the shape of
# [num_heads * 3 * head_size, hidden_size], while the
# required shape is [3 * num_heads * head_size, hidden_size].
# Thus, we need weight conversion.
shard_size
=
param
.
shape
[
0
]
start
=
shard_size
*
tp_rank
end
=
shard_size
*
(
tp_rank
+
1
)
loaded_weight
=
loaded_weight
[
start
:
end
]
num_heads
=
self
.
config
.
num_attention_heads
hidden_size
=
self
.
config
.
hidden_size
head_size
=
hidden_size
//
num_heads
if
"query_key_value.weight"
in
name
:
loaded_weight
=
loaded_weight
.
view
(
-
1
,
3
,
head_size
,
hidden_size
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
,
hidden_size
)
elif
"query_key_value.bias"
in
name
:
loaded_weight
=
loaded_weight
.
view
(
-
1
,
3
,
head_size
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
)
else
:
raise
ValueError
(
f
"Unexpected weight name:
{
name
}
"
)
load_tensor_parallel_weights
(
param
,
loaded_weight
,
name
,
self
.
_column_parallel_weights
,
self
.
_row_parallel_weights
,
tp_rank
)
vllm/model_executor/models/gpt_neox.py
View file @
e41f0670
...
@@ -80,7 +80,6 @@ class GPTNeoXAttention(nn.Module):
...
@@ -80,7 +80,6 @@ class GPTNeoXAttention(nn.Module):
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
qkv
,
_
=
self
.
query_key_value
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
k_cache
,
v_cache
=
kv_cache
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
position_ids
,
q
,
k
,
v
,
k_cache
,
v_cache
,
attn_output
=
self
.
attn
(
position_ids
,
q
,
k
,
v
,
k_cache
,
v_cache
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment