Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
d4bc1a4d
Commit
d4bc1a4d
authored
Feb 23, 2023
by
Woosuk Kwon
Browse files
Add unoptimized OPT Attention
parent
b56b6ca0
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
177 additions
and
14 deletions
+177
-14
cacheflow/models/attention.py
cacheflow/models/attention.py
+118
-0
cacheflow/models/opt.py
cacheflow/models/opt.py
+59
-14
No files found.
cacheflow/models/attention.py
0 → 100644
View file @
d4bc1a4d
from
typing
import
Optional
,
Tuple
import
torch
import
torch.nn
as
nn
import
xformers.ops
as
xops
from
cacheflow
import
ops
from
cacheflow.models
import
InputMetadata
class
OPTCacheFlowAttention
(
nn
.
Module
):
def
__init__
(
self
,
scale
:
float
)
->
None
:
super
().
__init__
()
self
.
scale
=
scale
# Shape-agnostic attention mask.
self
.
attention_mask
=
xops
.
LowerTriangularMask
()
def
multi_query_kv_attention
(
self
,
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
)
->
None
:
out
=
xops
.
memory_efficient_attention
(
query
,
key
,
value
,
attn_bias
=
self
.
attention_mask
,
scale
=
self
.
scale
)
# FIXME(woosuk): Directly write the attention output.
output
.
copy_
(
out
,
non_blocking
=
True
)
def
single_query_cached_kv_attention
(
self
,
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
None
:
num_heads
=
value_cache
.
shape
[
1
]
head_size
=
value_cache
.
shape
[
3
]
block_size
=
value_cache
.
shape
[
2
]
block_tables
=
input_metadata
.
block_tables
# FIXME(woosuk): Replace the following with a custom op.
for
i
in
range
(
input_metadata
.
num_generation_tokens
):
q
=
query
[
i
]
block_table
=
block_tables
[
i
]
context_len
=
int
(
input_metadata
.
context_lens
[
i
])
keys
=
[]
for
j
in
range
(
context_len
):
block_number
=
block_table
[
j
//
block_size
]
block_offset
=
j
%
block_size
k
=
key_cache
[
block_number
,
:,
:,
block_offset
,
:]
k
=
k
.
view
(
num_heads
,
head_size
)
keys
.
append
(
k
)
keys
=
torch
.
stack
(
keys
,
dim
=-
1
)
logits
=
q
@
keys
attention_weights
=
torch
.
softmax
(
logits
,
dim
=-
1
)
values
=
[]
for
j
in
range
(
context_len
):
block_number
=
block_table
[
j
//
block_size
]
block_offset
=
j
%
block_size
v
=
value_cache
[
block_number
,
:,
block_offset
,
:]
values
.
append
(
v
)
values
=
torch
.
stack
(
values
,
dim
=-
1
)
out
=
attention_weights
@
values
output
[
i
].
copy_
(
out
,
non_blocking
=
True
)
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
# Reshape the input tensors.
num_heads
=
value_cache
.
shape
[
1
]
head_size
=
value_cache
.
shape
[
3
]
query
=
query
.
view
(
-
1
,
num_heads
,
head_size
)
key
=
key
.
view
(
-
1
,
num_heads
,
head_size
)
value
=
value
.
view
(
-
1
,
num_heads
,
head_size
)
# Compute the attention op for prompts.
output
=
torch
.
empty_like
(
query
)
start_idx
=
0
for
i
in
range
(
input_metadata
.
num_prompts
):
prompt_len
=
input_metadata
.
prompt_lens
[
i
]
out
=
output
[
start_idx
:
start_idx
+
prompt_len
]
q
=
query
[
start_idx
:
start_idx
+
prompt_len
]
k
=
key
[
start_idx
:
start_idx
+
prompt_len
]
v
=
value
[
start_idx
:
start_idx
+
prompt_len
]
self
.
multi_query_kv_attention
(
out
,
q
,
k
,
v
)
start_idx
+=
prompt_len
# Wait until the cache op is done.
if
cache_event
is
not
None
:
cache_event
.
wait
()
# Reshape the keys and values and store them in the cache.
ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
input_metadata
.
slot_mapping
)
if
input_metadata
.
num_generation_tokens
>
0
:
# Compute the attention op for generation tokens.
self
.
single_query_cached_kv_attention
(
output
[
start_idx
:],
query
[
start_idx
:],
key_cache
,
value_cache
,
input_metadata
)
# Reshape the output tensor.
return
output
.
view
(
-
1
,
num_heads
*
head_size
)
cacheflow/models/opt.py
View file @
d4bc1a4d
"""1D OPT model compatible with HuggingFace weights."""
"""1D OPT model compatible with HuggingFace weights."""
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
OPTConfig
from
transformers
import
OPTConfig
from
transformers
import
PreTrainedModel
from
transformers
import
PreTrainedModel
from
cacheflow.models
import
InputMetadata
from
cacheflow.models.attention
import
OPTCacheFlowAttention
from
cacheflow.models.sample
import
Sampler
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
OPTLearnedPositionalEmbedding
(
nn
.
Embedding
):
class
OPTLearnedPositionalEmbedding
(
nn
.
Embedding
):
...
@@ -31,17 +39,27 @@ class OPTAttention(nn.Module):
...
@@ -31,17 +39,27 @@ class OPTAttention(nn.Module):
self
.
head_dim
=
embed_dim
//
num_heads
self
.
head_dim
=
embed_dim
//
num_heads
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
scaling
=
self
.
head_dim
**-
0.5
# TODO(woosuk): Fuse the three linear layers into one QKV linear layer.
self
.
k_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
)
self
.
k_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
)
self
.
v_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
)
self
.
v_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
)
self
.
q_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
)
self
.
q_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
)
self
.
out_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
)
self
.
out_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
self
.
attn
=
OPTCacheFlowAttention
(
scale
=
self
.
scaling
)
q
=
self
.
q_proj
(
hidden_states
)
*
self
.
scaling
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
q
=
self
.
q_proj
(
hidden_states
)
k
=
self
.
k_proj
(
hidden_states
)
k
=
self
.
k_proj
(
hidden_states
)
v
=
self
.
v_proj
(
hidden_states
)
v
=
self
.
v_proj
(
hidden_states
)
# TODO
key_cache
,
value_cache
=
kv_cache
attn_output
=
None
attn_output
=
self
.
attn
(
q
,
k
,
v
,
key_cache
,
value_cache
,
input_metadata
,
cache_event
)
output
=
self
.
out_proj
(
attn_output
)
output
=
self
.
out_proj
(
attn_output
)
return
output
return
output
...
@@ -66,13 +84,23 @@ class OPTDecoderLayer(nn.Module):
...
@@ -66,13 +84,23 @@ class OPTDecoderLayer(nn.Module):
self
.
fc2
=
nn
.
Linear
(
config
.
ffn_dim
,
self
.
embed_dim
,
bias
=
config
.
enable_bias
)
self
.
fc2
=
nn
.
Linear
(
config
.
ffn_dim
,
self
.
embed_dim
,
bias
=
config
.
enable_bias
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
,
elementwise_affine
=
config
.
layer_norm_elementwise_affine
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
,
elementwise_affine
=
config
.
layer_norm_elementwise_affine
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
# Self Attention
# Self Attention
residual
=
hidden_states
residual
=
hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if
self
.
do_layer_norm_before
:
if
self
.
do_layer_norm_before
:
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
hidden_states
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input_metadata
=
input_metadata
,
cache_event
=
cache_event
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
# 350m applies layer norm AFTER attention
# 350m applies layer norm AFTER attention
if
not
self
.
do_layer_norm_before
:
if
not
self
.
do_layer_norm_before
:
...
@@ -145,6 +173,9 @@ class OPTDecoder(OPTPreTrainedModel):
...
@@ -145,6 +173,9 @@ class OPTDecoder(OPTPreTrainedModel):
self
,
self
,
input_ids
:
torch
.
LongTensor
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
LongTensor
,
positions
:
torch
.
LongTensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
pos_embeds
=
self
.
embed_positions
(
positions
)
pos_embeds
=
self
.
embed_positions
(
positions
)
...
@@ -153,8 +184,14 @@ class OPTDecoder(OPTPreTrainedModel):
...
@@ -153,8 +184,14 @@ class OPTDecoder(OPTPreTrainedModel):
inputs_embeds
=
self
.
project_in
(
inputs_embeds
)
inputs_embeds
=
self
.
project_in
(
inputs_embeds
)
hidden_states
=
inputs_embeds
+
pos_embeds
hidden_states
=
inputs_embeds
+
pos_embeds
for
layer
in
self
.
layers
:
for
i
in
range
(
len
(
self
.
layers
)):
hidden_states
=
layer
(
hidden_states
)
if
cache_events
is
None
:
cache_event
=
None
else
:
cache_event
=
cache_events
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
input_metadata
,
cache_event
)
if
self
.
final_layer_norm
is
not
None
:
if
self
.
final_layer_norm
is
not
None
:
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
...
@@ -175,8 +212,12 @@ class OPTModel(OPTPreTrainedModel):
...
@@ -175,8 +212,12 @@ class OPTModel(OPTPreTrainedModel):
self
,
self
,
input_ids
:
torch
.
LongTensor
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
LongTensor
,
positions
:
torch
.
LongTensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
self
.
decoder
(
input_ids
,
positions
)
return
self
.
decoder
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
class
OPTForCausalLM
(
OPTPreTrainedModel
):
class
OPTForCausalLM
(
OPTPreTrainedModel
):
...
@@ -185,9 +226,9 @@ class OPTForCausalLM(OPTPreTrainedModel):
...
@@ -185,9 +226,9 @@ class OPTForCausalLM(OPTPreTrainedModel):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
model
=
OPTModel
(
config
)
self
.
model
=
OPTModel
(
config
)
# the lm_head weight is automatically tied to the embed tokens weight
# the lm_head weight is automatically tied to the embed tokens weight
self
.
lm_head
=
nn
.
Linear
(
config
.
word_embed_proj_dim
,
config
.
vocab_size
,
bias
=
False
)
self
.
lm_head
=
nn
.
Linear
(
config
.
word_embed_proj_dim
,
config
.
vocab_size
,
bias
=
False
)
self
.
sampler
=
Sampler
(
embedding
=
self
.
lm_head
.
weight
)
# Initialize weights and apply final processing
# Initialize weights and apply final processing
self
.
post_init
()
self
.
post_init
()
...
@@ -196,7 +237,11 @@ class OPTForCausalLM(OPTPreTrainedModel):
...
@@ -196,7 +237,11 @@ class OPTForCausalLM(OPTPreTrainedModel):
self
,
self
,
input_ids
:
torch
.
LongTensor
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
LongTensor
,
positions
:
torch
.
LongTensor
,
)
->
torch
.
Tensor
:
kv_caches
:
List
[
KVCache
],
hidden_states
=
self
.
model
.
decoder
(
input_ids
,
positions
)
input_metadata
:
InputMetadata
,
logits
=
self
.
lm_head
(
hidden_states
).
contiguous
()
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
return
logits
)
->
Dict
[
int
,
Tuple
[
int
,
int
]]:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
hidden_states
,
input_metadata
)
return
next_tokens
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment