Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
311d6606
Commit
311d6606
authored
Apr 20, 2023
by
Tri Dao
Browse files
[Gen] Fix FT kernel smem size, CG when batch size changed
parent
96d10f65
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
106 additions
and
16 deletions
+106
-16
csrc/ft_attention/decoder_masked_multihead_attention.cu
csrc/ft_attention/decoder_masked_multihead_attention.cu
+7
-8
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+15
-5
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+4
-3
tests/models/test_gpt_generation_cg.py
tests/models/test_gpt_generation_cg.py
+80
-0
No files found.
csrc/ft_attention/decoder_masked_multihead_attention.cu
View file @
311d6606
...
...
@@ -29,14 +29,13 @@
#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \
size_t smem_sz = mmha::smem_size_in_bytes<T, DO_CROSS_ATTENTION>(params, THDS_PER_VALUE, THDS_PER_BLOCK); \
dim3 grid(params.num_heads, params.batch_size); \
mmha::masked_multihead_attention_kernel<T, \
Dh, \
Dh_MAX, \
THDS_PER_KEY, \
THDS_PER_VALUE, \
THDS_PER_BLOCK, \
DO_CROSS_ATTENTION><<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
auto kernel = mmha::masked_multihead_attention_kernel<T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, \
THDS_PER_BLOCK, DO_CROSS_ATTENTION>; \
if (smem_sz >= 48 * 1024) { \
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \
} \
dim3 grid(params.num_heads, params.batch_size); \
kernel<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
flash_attn/modules/mha.py
View file @
311d6606
...
...
@@ -490,10 +490,15 @@ class MHA(nn.Module):
else
:
assert
inference_params
.
fused_ft_kernel
assert
ft_attention
is
not
None
batch_start
=
inference_params
.
batch_size_offset
batch_end
=
batch_start
+
qkv
.
shape
[
0
]
k_cache
,
v_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
lengths_per_sample
=
(
inference_params
.
lengths_per_sample
[
batch_start
:
batch_end
]
if
inference_params
.
lengths_per_sample
is
not
None
else
None
)
context
=
ft_attention
.
single_query_attention
(
*
rearrange
(
qkv
,
'b 1 three h d -> b three h d'
).
unbind
(
dim
=
1
),
*
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
],
inference_params
.
lengths_per_sample
,
inference_params
.
sequence_len_offset
,
k_cache
[
batch_start
:
batch_end
],
v_cache
[
batch_start
:
batch_end
],
lengths_per_sample
,
inference_params
.
sequence_len_offset
,
self
.
rotary_emb_dim
,
# neox_rotary_style
(
not
self
.
rotary_emb
.
interleaved
)
if
self
.
rotary_emb_dim
>
0
else
True
...
...
@@ -605,11 +610,16 @@ class ParallelMHA(nn.Module):
else
:
assert
inference_params
.
fused_ft_kernel
assert
ft_attention
is
not
None
batch_start
=
inference_params
.
batch_size_offset
batch_end
=
batch_start
+
qkv
.
shape
[
0
]
k_cache
,
v_cache
=
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
]
lengths_per_sample
=
(
inference_params
.
lengths_per_sample
[
batch_start
:
batch_end
]
if
inference_params
.
lengths_per_sample
is
not
None
else
None
)
context
=
ft_attention
.
single_query_attention
(
*
rearrange
(
qkv
,
'b 1 three h d -> b three h d'
).
unbind
(
dim
=
1
),
*
inference_params
.
key_value_memory_dict
[
self
.
layer_idx
],
inference_params
.
lengths_per_sample
,
inference_params
.
sequence_len_offset
,
self
.
rotary_emb_dim
,
k_cache
[
batch_start
:
batch_end
],
v_cache
[
batch_start
:
batch_end
],
lengths_per_sample
,
inference_params
.
sequence_len_offset
,
self
.
rotary_emb_dim
,
inference_params
.
sequence_len_offset
,
# neox_rotary_style
(
not
self
.
rotary_emb
.
interleaved
)
if
self
.
rotary_emb_dim
>
0
else
True
)
...
...
flash_attn/utils/generation.py
View file @
311d6606
...
...
@@ -238,15 +238,16 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p
)
cache
.
mempool
=
torch
.
cuda
.
graphs
.
graph_pool_handle
()
for
s_type
in
range
(
seqlen_to_seqlen_type
(
seqlen_og
),
seqlen_to_seqlen_type
(
max_seqlen
)
+
1
):
if
s_type
not
in
cache
.
callables
:
if
(
batch_size
,
s_type
)
not
in
cache
.
callables
:
max_seqlen_
=
min
(
max
(
seqlen_og
,
seqlen_type_to_max_seqlen
(
s_type
)),
max_seqlen
)
cache
.
callables
[
s_type
]
=
capture_graph
(
cache
.
callables
[
batch_size
,
s_type
]
=
capture_graph
(
model
,
cache
.
inference_params
,
batch_size
,
max_seqlen_
,
mempool
=
cache
.
mempool
,
n_warmups
=
n_warmups
)
def
dispatch
(
input_ids
,
position_ids
,
seqlen
):
return
cache
.
callables
[
seqlen_to_seqlen_type
(
seqlen
)](
input_ids
,
position_ids
,
seqlen
)
batch_size
=
input_ids
.
shape
[
0
]
return
cache
.
callables
[
batch_size
,
seqlen_to_seqlen_type
(
seqlen
)](
input_ids
,
position_ids
,
seqlen
)
cache
.
run
=
dispatch
cache
.
inference_params
.
sequence_len_offset
=
0
# Reset so it's not confusing
...
...
tests/models/test_gpt_generation_cg.py
0 → 100644
View file @
311d6606
import
os
import
re
import
time
import
torch
import
pytest
from
einops
import
rearrange
from
transformers
import
GPT2Config
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.utils.generation
import
update_graph_cache
def
get_logits
(
model
,
input_ids
,
max_length
,
teacher_outputs
=
None
,
**
kwargs
):
out
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
fused_ft_kernel
=
True
,
teacher_outputs
=
teacher_outputs
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
,
**
kwargs
)
return
torch
.
stack
(
out
.
scores
,
dim
=
1
)
@
pytest
.
mark
.
parametrize
(
'seqlen,maxlen'
,
[(
10
,
20
),
(
30
,
150
),
(
3000
,
3400
),
(
14000
,
15000
)])
# @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)])
@
pytest
.
mark
.
parametrize
(
'rotary'
,
[
None
,
"interleaved"
,
"block"
])
# @pytest.mark.parametrize('rotary', [None])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
])
def
test_greedy_decode_gpt2_cg
(
model_name
,
rotary
,
seqlen
,
maxlen
):
"""Check that decoding with CUDA graph is the same as decoding without CUDA graph.
"""
dtype
=
torch
.
float16
device
=
'cuda'
rtol
,
atol
=
3e-3
,
3e-1
config
=
GPT2Config
.
from_pretrained
(
model_name
)
config
.
n_positions
=
16
*
1024
assert
seqlen
<=
maxlen
<=
config
.
n_positions
if
rotary
is
not
None
:
config
.
n_positions
=
0
config
.
rotary_emb_dim
=
32
config
.
rotary_emb_interleaved
=
rotary
==
"interleaved"
config
.
residual_in_fp32
=
True
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
True
config
.
fused_dropout_add_ln
=
True
model
=
GPTLMHeadModel
(
config
,
device
=
device
,
dtype
=
dtype
)
model
.
eval
()
torch
.
manual_seed
(
0
)
batch_size
=
1
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
seqlen
),
dtype
=
torch
.
long
,
device
=
device
)
teacher_outputs
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
maxlen
),
dtype
=
torch
.
long
,
device
=
device
)
logits
=
get_logits
(
model
,
input_ids
,
maxlen
,
teacher_outputs
=
teacher_outputs
)
logits_cg
=
get_logits
(
model
,
input_ids
,
maxlen
,
teacher_outputs
=
teacher_outputs
,
cg
=
True
)
assert
torch
.
equal
(
logits
,
logits_cg
)
# Try increasing batch size and seqlen, then decrease them to see if it's still correct
batch_size
=
3
maxlen
+=
30
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
seqlen
),
dtype
=
torch
.
long
,
device
=
device
)
teacher_outputs
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
maxlen
),
dtype
=
torch
.
long
,
device
=
device
)
logits
=
get_logits
(
model
,
input_ids
,
maxlen
,
teacher_outputs
=
teacher_outputs
)
logits_cg
=
get_logits
(
model
,
input_ids
,
maxlen
,
teacher_outputs
=
teacher_outputs
,
cg
=
True
)
assert
torch
.
equal
(
logits
,
logits_cg
)
batch_size
=
2
maxlen
-=
35
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
seqlen
),
dtype
=
torch
.
long
,
device
=
device
)
teacher_outputs
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
maxlen
),
dtype
=
torch
.
long
,
device
=
device
)
logits
=
get_logits
(
model
,
input_ids
,
maxlen
,
teacher_outputs
=
teacher_outputs
)
logits_cg
=
get_logits
(
model
,
input_ids
,
maxlen
,
teacher_outputs
=
teacher_outputs
,
cg
=
True
)
assert
torch
.
equal
(
logits
,
logits_cg
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment