Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
ba2fe7f3
Commit
ba2fe7f3
authored
Apr 20, 2023
by
Tri Dao
Browse files
[Gen] Move allocate_inference_cache to within the model
parent
3da42d24
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
37 additions
and
7 deletions
+37
-7
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+8
-0
flash_attn/modules/block.py
flash_attn/modules/block.py
+3
-0
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+16
-0
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+10
-7
No files found.
flash_attn/models/gpt.py
View file @
ba2fe7f3
...
@@ -335,6 +335,10 @@ class GPTModel(GPTPreTrainedModel):
...
@@ -335,6 +335,10 @@ class GPTModel(GPTPreTrainedModel):
if
self
.
process_group
is
not
None
:
if
self
.
process_group
is
not
None
:
sync_shared_params
(
self
,
self
.
process_group
)
sync_shared_params
(
self
,
self
.
process_group
)
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
**
kwargs
):
return
{
i
:
layer
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
=
dtype
,
**
kwargs
)
for
i
,
layer
in
enumerate
(
self
.
layers
)}
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
inference_params
=
None
):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
inference_params
=
None
):
# If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
# If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
# dimensions so that we can split on it easily, in case of small batch size.
# dimensions so that we can split on it easily, in case of small batch size.
...
@@ -426,6 +430,10 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
...
@@ -426,6 +430,10 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
if
self
.
process_group
is
not
None
:
if
self
.
process_group
is
not
None
:
sync_shared_params
(
self
,
self
.
process_group
)
sync_shared_params
(
self
,
self
.
process_group
)
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
**
kwargs
):
return
self
.
transformer
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
=
dtype
,
**
kwargs
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
inference_params
=
None
,
last_token_only
=
False
):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
inference_params
=
None
,
last_token_only
=
False
):
"""
"""
inference_params: for generation. Adapted from Megatron-LM (and Apex)
inference_params: for generation. Adapted from Megatron-LM (and Apex)
...
...
flash_attn/modules/block.py
View file @
ba2fe7f3
...
@@ -105,6 +105,9 @@ class Block(nn.Module):
...
@@ -105,6 +105,9 @@ class Block(nn.Module):
for
p
in
self
.
norm2
.
parameters
():
for
p
in
self
.
norm2
.
parameters
():
p
.
_shared_params
=
True
p
.
_shared_params
=
True
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
**
kwargs
):
return
self
.
mixer
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
=
dtype
,
**
kwargs
)
def
forward
(
self
,
hidden_states
:
Tensor
,
residual
:
Optional
[
Tensor
]
=
None
,
def
forward
(
self
,
hidden_states
:
Tensor
,
residual
:
Optional
[
Tensor
]
=
None
,
mixer_subset
=
None
,
mixer_kwargs
=
None
):
mixer_subset
=
None
,
mixer_kwargs
=
None
):
r
"""Pass the input through the encoder layer.
r
"""Pass the input through the encoder layer.
...
...
flash_attn/modules/mha.py
View file @
ba2fe7f3
...
@@ -416,6 +416,22 @@ class MHA(nn.Module):
...
@@ -416,6 +416,22 @@ class MHA(nn.Module):
attention_dropout
=
dropout
)
attention_dropout
=
dropout
)
self
.
out_proj
=
linear_cls
(
embed_dim
,
embed_dim
,
bias
=
out_proj_bias
,
**
factory_kwargs
)
self
.
out_proj
=
linear_cls
(
embed_dim
,
embed_dim
,
bias
=
out_proj_bias
,
**
factory_kwargs
)
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
fused_ft_kernel
=
True
):
dtype
=
self
.
out_proj
.
weight
.
dtype
if
dtype
is
None
else
dtype
device
=
self
.
out_proj
.
weight
.
device
if
not
fused_ft_kernel
:
return
torch
.
empty
(
batch_size
,
max_seqlen
,
2
,
self
.
num_heads
,
self
.
head_dim
,
dtype
=
dtype
,
device
=
device
)
else
:
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
packsize
=
4
if
dtype
==
torch
.
float32
else
8
assert
self
.
head_dim
%
packsize
==
0
k_cache
=
torch
.
empty
(
batch_size
,
self
.
num_heads
,
self
.
head_dim
//
packsize
,
max_seqlen
,
packsize
,
dtype
=
dtype
,
device
=
device
)
v_cache
=
torch
.
empty
(
batch_size
,
self
.
num_heads
,
max_seqlen
,
self
.
head_dim
,
dtype
=
dtype
,
device
=
device
)
return
k_cache
,
v_cache
def
_update_kv_cache
(
self
,
kv
,
inference_params
):
def
_update_kv_cache
(
self
,
kv
,
inference_params
):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
"""
"""
...
...
flash_attn/utils/generation.py
View file @
ba2fe7f3
...
@@ -167,8 +167,8 @@ class GenerationMixin:
...
@@ -167,8 +167,8 @@ class GenerationMixin:
return
output
if
return_dict_in_generate
else
output
.
sequences
return
output
if
return_dict_in_generate
else
output
.
sequences
def
allocate_
kv
_cache
(
max_batch_size
,
max_seqlen
,
nheads
,
headdim
,
layers
:
Union
[
int
,
Sequence
],
def
allocate_
inference
_cache
(
max_batch_size
,
max_seqlen
,
nheads
,
headdim
,
layers
:
Union
[
int
,
Sequence
],
device
,
dtype
=
torch
.
float16
):
device
,
dtype
=
torch
.
float16
):
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
packsize
=
4
if
dtype
==
torch
.
float32
else
8
packsize
=
4
if
dtype
==
torch
.
float32
else
8
assert
headdim
%
packsize
==
0
assert
headdim
%
packsize
==
0
...
@@ -226,14 +226,17 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p
...
@@ -226,14 +226,17 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p
cache
.
max_batch_size
,
cache
.
max_seqlen
=
batch_size
,
max_seqlen
cache
.
max_batch_size
,
cache
.
max_seqlen
=
batch_size
,
max_seqlen
headdim
=
getattr
(
model
.
config
,
'head_dim'
,
headdim
=
getattr
(
model
.
config
,
'head_dim'
,
model
.
config
.
hidden_size
//
model
.
config
.
num_attention_heads
)
model
.
config
.
hidden_size
//
model
.
config
.
num_attention_heads
)
kv_cache
=
allocate_kv_cache
(
if
hasattr
(
model
,
'allocate_inference_cache'
):
batch_size
,
max_seqlen
,
model
.
config
.
num_attention_heads
//
tensor_parallel
,
headdim
,
inf_cache
=
model
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
)
model
.
config
.
num_hidden_layers
,
device
,
dtype
else
:
)
inf_cache
=
allocate_inference_cache
(
batch_size
,
max_seqlen
,
model
.
config
.
num_attention_heads
//
tensor_parallel
,
headdim
,
model
.
config
.
num_hidden_layers
,
device
,
dtype
)
lengths_per_sample
=
torch
.
full
((
batch_size
,),
seqlen_og
,
dtype
=
torch
.
int32
,
device
=
device
)
lengths_per_sample
=
torch
.
full
((
batch_size
,),
seqlen_og
,
dtype
=
torch
.
int32
,
device
=
device
)
cache
.
inference_params
=
InferenceParams
(
cache
.
inference_params
=
InferenceParams
(
max_sequence_len
=
max_seqlen
,
max_batch_size
=
batch_size
,
max_sequence_len
=
max_seqlen
,
max_batch_size
=
batch_size
,
sequence_len_offset
=
seqlen_og
,
key_value_memory_dict
=
kv
_cache
,
fused_ft_kernel
=
True
,
sequence_len_offset
=
seqlen_og
,
key_value_memory_dict
=
inf
_cache
,
fused_ft_kernel
=
True
,
lengths_per_sample
=
lengths_per_sample
lengths_per_sample
=
lengths_per_sample
)
)
cache
.
mempool
=
torch
.
cuda
.
graphs
.
graph_pool_handle
()
cache
.
mempool
=
torch
.
cuda
.
graphs
.
graph_pool_handle
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment