Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
d4bc1a4d
Commit
d4bc1a4d
authored
Feb 23, 2023
by
Woosuk Kwon
Browse files
Add unoptimized OPT Attention
parent
b56b6ca0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
177 additions
and
14 deletions
+177
-14
cacheflow/models/attention.py
cacheflow/models/attention.py
+118
-0
cacheflow/models/opt.py
cacheflow/models/opt.py
+59
-14
No files found.
cacheflow/models/attention.py
0 → 100644
View file @
d4bc1a4d
from
typing
import
Optional
,
Tuple
import
torch
import
torch.nn
as
nn
import
xformers.ops
as
xops
from
cacheflow
import
ops
from
cacheflow.models
import
InputMetadata
class
OPTCacheFlowAttention
(
nn
.
Module
):
def
__init__
(
self
,
scale
:
float
)
->
None
:
super
().
__init__
()
self
.
scale
=
scale
# Shape-agnostic attention mask.
self
.
attention_mask
=
xops
.
LowerTriangularMask
()
def
multi_query_kv_attention
(
self
,
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
)
->
None
:
out
=
xops
.
memory_efficient_attention
(
query
,
key
,
value
,
attn_bias
=
self
.
attention_mask
,
scale
=
self
.
scale
)
# FIXME(woosuk): Directly write the attention output.
output
.
copy_
(
out
,
non_blocking
=
True
)
def
single_query_cached_kv_attention
(
self
,
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
None
:
num_heads
=
value_cache
.
shape
[
1
]
head_size
=
value_cache
.
shape
[
3
]
block_size
=
value_cache
.
shape
[
2
]
block_tables
=
input_metadata
.
block_tables
# FIXME(woosuk): Replace the following with a custom op.
for
i
in
range
(
input_metadata
.
num_generation_tokens
):
q
=
query
[
i
]
block_table
=
block_tables
[
i
]
context_len
=
int
(
input_metadata
.
context_lens
[
i
])
keys
=
[]
for
j
in
range
(
context_len
):
block_number
=
block_table
[
j
//
block_size
]
block_offset
=
j
%
block_size
k
=
key_cache
[
block_number
,
:,
:,
block_offset
,
:]
k
=
k
.
view
(
num_heads
,
head_size
)
keys
.
append
(
k
)
keys
=
torch
.
stack
(
keys
,
dim
=-
1
)
logits
=
q
@
keys
attention_weights
=
torch
.
softmax
(
logits
,
dim
=-
1
)
values
=
[]
for
j
in
range
(
context_len
):
block_number
=
block_table
[
j
//
block_size
]
block_offset
=
j
%
block_size
v
=
value_cache
[
block_number
,
:,
block_offset
,
:]
values
.
append
(
v
)
values
=
torch
.
stack
(
values
,
dim
=-
1
)
out
=
attention_weights
@
values
output
[
i
].
copy_
(
out
,
non_blocking
=
True
)
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
# Reshape the input tensors.
num_heads
=
value_cache
.
shape
[
1
]
head_size
=
value_cache
.
shape
[
3
]
query
=
query
.
view
(
-
1
,
num_heads
,
head_size
)
key
=
key
.
view
(
-
1
,
num_heads
,
head_size
)
value
=
value
.
view
(
-
1
,
num_heads
,
head_size
)
# Compute the attention op for prompts.
output
=
torch
.
empty_like
(
query
)
start_idx
=
0
for
i
in
range
(
input_metadata
.
num_prompts
):
prompt_len
=
input_metadata
.
prompt_lens
[
i
]
out
=
output
[
start_idx
:
start_idx
+
prompt_len
]
q
=
query
[
start_idx
:
start_idx
+
prompt_len
]
k
=
key
[
start_idx
:
start_idx
+
prompt_len
]
v
=
value
[
start_idx
:
start_idx
+
prompt_len
]
self
.
multi_query_kv_attention
(
out
,
q
,
k
,
v
)
start_idx
+=
prompt_len
# Wait until the cache op is done.
if
cache_event
is
not
None
:
cache_event
.
wait
()
# Reshape the keys and values and store them in the cache.
ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
input_metadata
.
slot_mapping
)
if
input_metadata
.
num_generation_tokens
>
0
:
# Compute the attention op for generation tokens.
self
.
single_query_cached_kv_attention
(
output
[
start_idx
:],
query
[
start_idx
:],
key_cache
,
value_cache
,
input_metadata
)
# Reshape the output tensor.
return
output
.
view
(
-
1
,
num_heads
*
head_size
)
cacheflow/models/opt.py
View file @
d4bc1a4d
"""1D OPT model compatible with HuggingFace weights."""
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
OPTConfig
from
transformers
import
PreTrainedModel
from
cacheflow.models
import
InputMetadata
from
cacheflow.models.attention
import
OPTCacheFlowAttention
from
cacheflow.models.sample
import
Sampler
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
OPTLearnedPositionalEmbedding
(
nn
.
Embedding
):
...
...
@@ -31,17 +39,27 @@ class OPTAttention(nn.Module):
self
.
head_dim
=
embed_dim
//
num_heads
self
.
scaling
=
self
.
head_dim
**-
0.5
# TODO(woosuk): Fuse the three linear layers into one QKV linear layer.
self
.
k_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
)
self
.
v_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
)
self
.
q_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
)
self
.
out_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
q
=
self
.
q_proj
(
hidden_states
)
*
self
.
scaling
self
.
attn
=
OPTCacheFlowAttention
(
scale
=
self
.
scaling
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
q
=
self
.
q_proj
(
hidden_states
)
k
=
self
.
k_proj
(
hidden_states
)
v
=
self
.
v_proj
(
hidden_states
)
# TODO
attn_output
=
None
key_cache
,
value_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
key_cache
,
value_cache
,
input_metadata
,
cache_event
)
output
=
self
.
out_proj
(
attn_output
)
return
output
...
...
@@ -66,13 +84,23 @@ class OPTDecoderLayer(nn.Module):
self
.
fc2
=
nn
.
Linear
(
config
.
ffn_dim
,
self
.
embed_dim
,
bias
=
config
.
enable_bias
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
,
elementwise_affine
=
config
.
layer_norm_elementwise_affine
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
# Self Attention
residual
=
hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if
self
.
do_layer_norm_before
:
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
hidden_states
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input_metadata
=
input_metadata
,
cache_event
=
cache_event
)
hidden_states
=
residual
+
hidden_states
# 350m applies layer norm AFTER attention
if
not
self
.
do_layer_norm_before
:
...
...
@@ -145,6 +173,9 @@ class OPTDecoder(OPTPreTrainedModel):
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
LongTensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
pos_embeds
=
self
.
embed_positions
(
positions
)
...
...
@@ -153,8 +184,14 @@ class OPTDecoder(OPTPreTrainedModel):
inputs_embeds
=
self
.
project_in
(
inputs_embeds
)
hidden_states
=
inputs_embeds
+
pos_embeds
for
layer
in
self
.
layers
:
hidden_states
=
layer
(
hidden_states
)
for
i
in
range
(
len
(
self
.
layers
)):
if
cache_events
is
None
:
cache_event
=
None
else
:
cache_event
=
cache_events
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
input_metadata
,
cache_event
)
if
self
.
final_layer_norm
is
not
None
:
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
...
...
@@ -175,8 +212,12 @@ class OPTModel(OPTPreTrainedModel):
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
LongTensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
torch
.
Tensor
:
return
self
.
decoder
(
input_ids
,
positions
)
return
self
.
decoder
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
class
OPTForCausalLM
(
OPTPreTrainedModel
):
...
...
@@ -185,9 +226,9 @@ class OPTForCausalLM(OPTPreTrainedModel):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
model
=
OPTModel
(
config
)
# the lm_head weight is automatically tied to the embed tokens weight
self
.
lm_head
=
nn
.
Linear
(
config
.
word_embed_proj_dim
,
config
.
vocab_size
,
bias
=
False
)
self
.
sampler
=
Sampler
(
embedding
=
self
.
lm_head
.
weight
)
# Initialize weights and apply final processing
self
.
post_init
()
...
...
@@ -196,7 +237,11 @@ class OPTForCausalLM(OPTPreTrainedModel):
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
LongTensor
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
.
decoder
(
input_ids
,
positions
)
logits
=
self
.
lm_head
(
hidden_states
).
contiguous
()
return
logits
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
Tuple
[
int
,
int
]]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
hidden_states
,
input_metadata
)
return
next_tokens
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment