Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
b4859900
Commit
b4859900
authored
Jan 07, 2023
by
Tri Dao
Browse files
[Gen] Add timing option
parent
0938298e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
112 additions
and
7 deletions
+112
-7
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+103
-4
tests/models/test_gpt_generation.py
tests/models/test_gpt_generation.py
+9
-3
No files found.
flash_attn/utils/generation.py
View file @
b4859900
# Copyright (c) 2022, Tri Dao.
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
from
typing
import
Optional
from
typing
import
Optional
import
time
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
collections
import
namedtuple
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
from
torch.profiler
import
profile
,
record_function
,
ProfilerActivity
from
einops
import
rearrange
from
einops
import
rearrange
...
@@ -65,7 +69,8 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
...
@@ -65,7 +69,8 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
return
torch
.
multinomial
(
torch
.
softmax
(
logits_top
,
dim
=-
1
),
num_samples
=
1
).
squeeze
(
dim
=-
1
)
return
torch
.
multinomial
(
torch
.
softmax
(
logits_top
,
dim
=-
1
),
num_samples
=
1
).
squeeze
(
dim
=-
1
)
def
decode
(
input_ids
,
model
,
max_length
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
,
fused_ft_kernel
=
True
):
def
decode
(
input_ids
,
model
,
max_length
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
,
fused_ft_kernel
=
False
,
cg
=
False
,
timing
=
False
):
"""Decoding, either greedy or with top-k or top-p sampling.
"""Decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
...
@@ -89,17 +94,31 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, fu
...
@@ -89,17 +94,31 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, fu
next_token
=
sample
(
logits
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
next_token
=
sample
(
logits
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
sequences
=
[
next_token
]
sequences
=
[
next_token
]
inference_params
.
sequence_len_offset
=
seqlen_og
inference_params
.
sequence_len_offset
=
seqlen_og
if
cg
:
assert
fused_ft_kernel
run
,
cg_cache
=
capture_cg
(
model
,
inference_params
,
batch_size
,
seqlen_og
,
max_length
)
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
if
timing
:
start
=
time
.
time
()
while
True
:
while
True
:
position_ids
=
torch
.
full
((
batch_size
,
1
),
inference_params
.
sequence_len_offset
,
position_ids
=
torch
.
full
((
batch_size
,
1
),
inference_params
.
sequence_len_offset
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
logits
=
model
(
rearrange
(
next_token
,
'b -> b 1'
),
position_ids
=
position_ids
,
if
not
cg
:
inference_params
=
inference_params
).
logits
[:,
-
1
]
logits
=
model
(
rearrange
(
next_token
,
'b -> b 1'
),
position_ids
=
position_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
else
:
logits
=
run
(
rearrange
(
next_token
,
'b -> b 1'
),
position_ids
,
inference_params
.
sequence_len_offset
)
scores
.
append
(
logits
)
scores
.
append
(
logits
)
next_token
=
sample
(
logits
,
top_k
=
top_k
,
temperature
=
temperature
)
next_token
=
sample
(
logits
,
top_k
=
top_k
,
temperature
=
temperature
)
sequences
.
append
(
next_token
)
sequences
.
append
(
next_token
)
inference_params
.
sequence_len_offset
+=
1
inference_params
.
sequence_len_offset
+=
1
if
inference_params
.
sequence_len_offset
>=
max_length
-
1
:
if
inference_params
.
sequence_len_offset
>=
max_length
-
1
:
break
break
if
timing
:
print
(
f
'Decoding time:
{
time
.
time
()
-
start
}
'
)
# print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=50))
# prof.export_chrome_trace("gpt2s_generation.json")
output_cls
=
GreedySearchDecoderOnlyOutput
if
top_k
==
1
else
SampleDecoderOnlyOutput
output_cls
=
GreedySearchDecoderOnlyOutput
if
top_k
==
1
else
SampleDecoderOnlyOutput
return
output_cls
(
return
output_cls
(
sequences
=
torch
.
cat
([
input_ids
,
torch
.
stack
(
sequences
,
dim
=
1
)],
dim
=
1
),
sequences
=
torch
.
cat
([
input_ids
,
torch
.
stack
(
sequences
,
dim
=
1
)],
dim
=
1
),
...
@@ -116,3 +135,83 @@ class GenerationMixin:
...
@@ -116,3 +135,83 @@ class GenerationMixin:
if
not
output_scores
:
if
not
output_scores
:
output
.
scores
=
None
output
.
scores
=
None
return
output
if
return_dict_in_generate
else
output
.
sequences
return
output
if
return_dict_in_generate
else
output
.
sequences
CgKey
=
namedtuple
(
'CgKey'
,
[
'batch_size'
,
'seqlen_type'
,
'max_length'
])
CgVal
=
namedtuple
(
'CgVal'
,
[
'graph'
,
'input_ids'
,
'position_ids'
,
'lengths'
,
'logits'
])
def
seqlen_to_seqlen_type
(
seqlen
:
int
)
->
int
:
"""Convert sequence length to a seqlen_type.
This is used to determine which cuda graph to use.
Arguments:
seqlen: int
"""
return
0
if
seqlen
<
32
else
(
1
if
seqlen
<
2048
else
2
)
def
seqlen_type_to_seqlen
(
seqlen_type
:
int
)
->
int
:
assert
seqlen_type
in
[
0
,
1
,
2
]
return
1
if
seqlen_type
==
0
else
(
32
if
seqlen_type
==
1
else
2048
)
def
capture_cg
(
model
,
inference_params
,
batch_size
,
seqlen_og
,
max_length
,
copy_output
=
False
):
"""Build a cache of cuda graphs for decoding.
Arguments:
model: a GPTLMHeadModel
batch_size: int
seqlen_og: int. Length of the prompt.
max_length: int
TODO: how do we deal with the k_cache and v_cache memory? I think the CUDA graph also
has to own the k_cache and v_cache?
Here we assume that the model already has inference_params from the prompt processing.
"""
assert
max_length
>
seqlen_og
cg_cache
:
dict
[
CgKey
,
CgVal
]
=
{}
device
=
next
(
iter
(
model
.
parameters
())).
device
sequence_length_offset_og
=
inference_params
.
sequence_len_offset
input_ids
=
torch
.
full
((
batch_size
,
1
),
0
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
torch
.
full
((
batch_size
,
1
),
0
,
dtype
=
torch
.
long
,
device
=
device
)
inference_params
.
lengths_per_sample
=
torch
.
full
((
batch_size
,),
seqlen_og
,
dtype
=
torch
.
int32
,
device
=
device
)
memory_pool
=
None
for
s_type
in
range
(
seqlen_to_seqlen_type
(
seqlen_og
),
seqlen_to_seqlen_type
(
max_length
)
+
1
):
seqlen
=
max
(
seqlen_og
,
seqlen_type_to_seqlen
(
s_type
))
input_ids
=
torch
.
full
((
batch_size
,
1
),
0
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
torch
.
full
((
batch_size
,
1
),
0
,
dtype
=
torch
.
long
,
device
=
device
)
inference_params
.
lengths_per_sample
[:]
=
seqlen
inference_params
.
sequence_len_offset
=
seqlen
g
=
torch
.
cuda
.
CUDAGraph
()
# Warmup before capture
s
=
torch
.
cuda
.
Stream
()
s
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
s
):
for
_
in
range
(
2
):
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
torch
.
cuda
.
current_stream
().
wait_stream
(
s
)
# Captures the graph
# To allow capture, automatically sets a side stream as the current stream in the context
with
torch
.
cuda
.
graph
(
g
,
pool
=
memory_pool
):
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
if
memory_pool
is
None
:
memory_pool
=
g
.
pool
()
cg_cache
[
CgKey
(
batch_size
,
s_type
,
max_length
)]
=
CgVal
(
g
,
input_ids
,
position_ids
,
inference_params
.
lengths_per_sample
,
logits
)
def
run
(
new_input_ids
,
new_position_ids
,
seqlen
):
cg_val
=
cg_cache
[
CgKey
(
batch_size
,
seqlen_to_seqlen_type
(
seqlen
),
max_length
)]
inference_params
.
lengths_per_sample
=
cg_val
.
lengths
inference_params
.
lengths_per_sample
[:]
=
seqlen
cg_val
.
input_ids
.
copy_
(
new_input_ids
)
cg_val
.
position_ids
.
copy_
(
new_position_ids
)
cg_val
.
graph
.
replay
()
output
=
cg_val
.
logits
return
output
.
clone
()
if
copy_output
else
output
inference_params
.
sequence_len_offset
=
sequence_length_offset_og
return
run
,
cg_cache
tests/models/test_gpt_generation.py
View file @
b4859900
...
@@ -54,8 +54,8 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
...
@@ -54,8 +54,8 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
input_ids
=
tokenizer
(
"Hello, my dog is cute and "
,
return_tensors
=
"pt"
).
input_ids
.
cuda
()
input_ids
=
tokenizer
(
"Hello, my dog is cute and "
,
return_tensors
=
"pt"
).
input_ids
.
cuda
()
max_length
=
30
max_length
=
30
# input_ids = torch.randint(0, 100, (1,
512
), dtype=torch.long, device='cuda')
# input_ids = torch.randint(0, 100, (1,
10
), dtype=torch.long, device='cuda')
# max_length =
512
+
5
0
# max_length =
input_ids.shape[1]
+
4
0
# Slow generation for reference
# Slow generation for reference
sequences
=
[]
sequences
=
[]
...
@@ -73,7 +73,13 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
...
@@ -73,7 +73,13 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
out
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
out
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
fused_ft_kernel
=
fused_ft_kernel
,
fused_ft_kernel
=
fused_ft_kernel
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
)
print
(
out
.
sequences
)
if
fused_ft_kernel
:
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
)
print
(
out_cg
.
sequences
)
if
not
rotary
:
if
not
rotary
:
out_hf
=
model_hf
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
out_hf
=
model_hf
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment