Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
913922ca
Commit
913922ca
authored
Sep 04, 2023
by
Tri Dao
Browse files
[Gen] Refactor decoding function
parent
3557e0bb
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
77 additions
and
82 deletions
+77
-82
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+54
-59
tests/models/test_baichuan.py
tests/models/test_baichuan.py
+4
-4
tests/models/test_falcon.py
tests/models/test_falcon.py
+4
-4
tests/models/test_gpt.py
tests/models/test_gpt.py
+5
-5
tests/models/test_gpt_generation_parallel.py
tests/models/test_gpt_generation_parallel.py
+2
-2
tests/models/test_gptj.py
tests/models/test_gptj.py
+2
-2
tests/models/test_llama.py
tests/models/test_llama.py
+4
-4
tests/models/test_opt.py
tests/models/test_opt.py
+2
-2
No files found.
flash_attn/utils/generation.py
View file @
913922ca
...
...
@@ -84,6 +84,7 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
)
@
torch
.
inference_mode
()
def
decode
(
input_ids
,
model
,
...
...
@@ -97,7 +98,7 @@ def decode(
tensor_parallel
=
1
,
fused_ft_kernel
=
False
,
cg
=
False
,
timing
=
False
,
enable_
timing
=
False
,
):
"""Decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
...
...
@@ -137,73 +138,67 @@ def decode(
max_sequence_len
=
max_length
,
max_batch_size
=
batch_size
,
fused_ft_kernel
=
fused_ft_kernel
)
def
logits_forward_fn
(
input_ids
,
position_ids
,
inference_params
):
if
not
cg
:
return
model
(
def
get_logits
(
input_ids
,
inference_params
):
decoding
=
inference_params
.
sequence_len_offset
>
0
if
decoding
:
position_ids
=
torch
.
full
(
(
batch_size
,
1
),
inference_params
.
sequence_len_offset
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
,
)
else
:
position_ids
=
None
if
not
cg
or
not
decoding
:
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
,
num_last_tokens
=
1
,
).
logits
.
squeeze
(
dim
=
1
)
else
:
return
model
.
_decoding_cache
.
run
(
logits
=
model
.
_decoding_cache
.
run
(
input_ids
,
position_ids
,
inference_params
.
sequence_len_offset
).
clone
()
return
logits
[...,
:
vocab_size
]
if
vocab_size
is
not
None
else
logits
logits_postprocess_fn
=
(
lambda
logits
:
logits
[...,
:
vocab_size
]
if
vocab_size
is
not
None
else
logits
)
scores
=
[]
with
torch
.
inference_mode
():
if
timing
:
if
tensor_parallel
>
1
:
torch
.
distributed
.
barrier
()
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
logits
=
model
(
input_ids
,
inference_params
=
inference_params
,
num_last_tokens
=
1
).
logits
.
squeeze
(
dim
=
1
)
logits
=
logits_postprocess_fn
(
logits
)
scores
.
append
(
logits
if
not
cg
else
logits
.
clone
())
if
teacher_outputs
is
None
or
teacher_output_len
<=
seqlen_og
:
next_token
=
sample
(
logits
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
def
sample_tokens
(
logits
,
inference_params
):
if
teacher_outputs
is
None
or
teacher_output_len
<=
inference_params
.
sequence_len_offset
:
token
=
sample
(
logits
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
else
:
next_token
=
teacher_outputs
[:,
seqlen_og
]
sequences
=
[
next_token
]
inference_params
.
sequence_len_offset
=
seqlen_og
while
True
:
position_ids
=
torch
.
full
(
(
batch_size
,
1
),
inference_params
.
sequence_len_offset
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
,
)
logits
=
logits_postprocess_fn
(
logits_forward_fn
(
rearrange
(
next_token
,
"b -> b 1"
),
position_ids
,
inference_params
)
)
scores
.
append
(
logits
)
if
(
teacher_outputs
is
None
or
teacher_output_len
<=
inference_params
.
sequence_len_offset
+
1
):
next_token
=
sample
(
logits
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
else
:
next_token
=
teacher_outputs
[:,
inference_params
.
sequence_len_offset
+
1
]
sequences
.
append
(
next_token
)
inference_params
.
sequence_len_offset
+=
1
if
eos_token_id
is
not
None
and
(
next_token
==
eos_token_id
).
all
():
break
if
inference_params
.
sequence_len_offset
>=
max_length
-
1
:
break
if
timing
:
if
tensor_parallel
>
1
:
torch
.
distributed
.
barrier
()
torch
.
cuda
.
synchronize
()
print
(
f
"Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms"
)
token
=
teacher_outputs
[:,
inference_params
.
sequence_len_offset
]
return
rearrange
(
token
,
"b -> b 1"
)
def
should_stop
(
current_token
,
inference_params
):
if
inference_params
.
sequence_len_offset
==
0
:
return
False
if
eos_token_id
is
not
None
and
(
current_token
==
eos_token_id
).
all
():
return
True
if
inference_params
.
sequence_len_offset
>=
max_length
-
1
:
return
True
return
False
start
=
torch
.
cuda
.
Event
(
enable_timing
=
enable_timing
)
end
=
torch
.
cuda
.
Event
(
enable_timing
=
enable_timing
)
if
enable_timing
:
if
tensor_parallel
>
1
:
torch
.
distributed
.
barrier
()
start
.
record
()
scores
,
sequences
=
[],
[
input_ids
]
while
not
should_stop
(
sequences
[
-
1
],
inference_params
):
scores
.
append
(
get_logits
(
sequences
[
-
1
],
inference_params
))
inference_params
.
sequence_len_offset
+=
sequences
[
-
1
].
shape
[
1
]
sequences
.
append
(
sample_tokens
(
scores
[
-
1
],
inference_params
))
if
enable_timing
:
end
.
record
()
if
tensor_parallel
>
1
:
torch
.
distributed
.
barrier
()
torch
.
cuda
.
synchronize
()
print
(
f
"Prompt processing + decoding time:
{
(
start
.
elapsed_time
(
end
)):.
0
f
}
ms"
)
output_cls
=
GreedySearchDecoderOnlyOutput
if
top_k
==
1
else
SampleDecoderOnlyOutput
return
output_cls
(
sequences
=
torch
.
cat
(
[
input_ids
,
torch
.
stack
(
sequences
,
dim
=
1
)]
,
dim
=
1
),
scores
=
tuple
(
scores
)
sequences
=
torch
.
cat
(
sequences
,
dim
=
1
),
scores
=
tuple
(
scores
)
)
...
...
@@ -280,7 +275,7 @@ def decode_speculative(
tensor_parallel
=
1
,
fused_ft_kernel
=
False
,
cg
=
False
,
timing
=
False
,
enable_
timing
=
False
,
debug
=
False
,
):
"""
...
...
@@ -446,7 +441,7 @@ def decode_speculative(
sequences
=
[
input_ids
]
scores
=
[]
with
torch
.
inference_mode
():
if
timing
:
if
enable_
timing
:
if
tensor_parallel
>
1
:
torch
.
distributed
.
barrier
()
torch
.
cuda
.
synchronize
()
...
...
@@ -566,7 +561,7 @@ def decode_speculative(
).
logits
print
((
scores
[
-
1
]
-
scores_ref
[:,
:
-
1
]).
abs
().
max
())
if
timing
:
if
enable_
timing
:
if
tensor_parallel
>
1
:
torch
.
distributed
.
barrier
()
torch
.
cuda
.
synchronize
()
...
...
tests/models/test_baichuan.py
View file @
913922ca
...
...
@@ -289,7 +289,7 @@ def test_baichuan_generation(model_name):
fused_ft_kernel
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
,
enable_
timing
=
True
,
teacher_outputs
=
out_hf
.
sequences
,
)
torch
.
cuda
.
synchronize
()
...
...
@@ -310,7 +310,7 @@ def test_baichuan_generation(model_name):
cg
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
,
enable_
timing
=
True
,
teacher_outputs
=
out_hf
.
sequences
,
)
torch
.
cuda
.
synchronize
()
...
...
@@ -400,7 +400,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
# teacher_outputs=out_hf.sequences,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
,
enable_
timing
=
True
,
)
# Capture graph outside the timing loop
...
...
@@ -419,7 +419,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
# teacher_outputs=out_hf.sequences,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
,
enable_
timing
=
True
,
)
del
model
parallel_state
.
destroy_model_parallel
()
...
...
tests/models/test_falcon.py
View file @
913922ca
...
...
@@ -245,7 +245,7 @@ def test_falcon_generation(model_name):
fused_ft_kernel
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
,
enable_
timing
=
True
,
teacher_outputs
=
out_hf
.
sequences
,
)
torch
.
cuda
.
synchronize
()
...
...
@@ -264,7 +264,7 @@ def test_falcon_generation(model_name):
cg
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
,
enable_
timing
=
True
,
teacher_outputs
=
out_hf
.
sequences
,
)
torch
.
cuda
.
synchronize
()
...
...
@@ -351,7 +351,7 @@ def test_falcon_parallel_generation(model_name, world_size):
# teacher_outputs=out_hf.sequences,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
,
enable_
timing
=
True
,
)
# Capture graph outside the timing loop
...
...
@@ -368,7 +368,7 @@ def test_falcon_parallel_generation(model_name, world_size):
# teacher_outputs=out_hf.sequences,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
,
enable_
timing
=
True
,
)
del
model
parallel_state
.
destroy_model_parallel
()
...
...
tests/models/test_gpt.py
View file @
913922ca
...
...
@@ -200,7 +200,7 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel):
fused_ft_kernel
=
fused_ft_kernel
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
,
enable_
timing
=
True
,
)
print
(
out
.
sequences
)
print
(
tokenizer
.
batch_decode
(
out
.
sequences
.
tolist
()))
...
...
@@ -212,7 +212,7 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel):
cg
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
,
enable_
timing
=
True
,
)
print
(
out_cg
.
sequences
)
...
...
@@ -267,7 +267,7 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
teacher_outputs
=
teacher_outputs
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
,
enable_
timing
=
True
,
**
kwargs
,
)
return
torch
.
stack
(
out
.
scores
,
dim
=
1
)
...
...
@@ -431,7 +431,7 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg):
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
cg
,
speculative_lookahead
=
4
,
timing
=
True
,
enable_
timing
=
True
,
)
print
(
tokenizer
.
batch_decode
(
out
.
sequences
))
out_og
=
model
.
generate
(
...
...
@@ -440,7 +440,7 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg):
top_k
=
5
,
fused_ft_kernel
=
fused_ft_kernel
,
cg
=
False
,
timing
=
True
,
enable_
timing
=
True
,
return_dict_in_generate
=
True
,
)
print
(
tokenizer
.
batch_decode
(
out_og
.
sequences
))
...
...
tests/models/test_gpt_generation_parallel.py
View file @
913922ca
...
...
@@ -114,7 +114,7 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
fused_ft_kernel
=
fused_ft_kernel
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
,
enable_
timing
=
True
,
)
print
(
out
.
sequences
)
if
fused_ft_kernel
:
...
...
@@ -127,7 +127,7 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
cg
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
,
enable_
timing
=
True
,
)
print
(
out_cg
.
sequences
)
...
...
tests/models/test_gptj.py
View file @
913922ca
...
...
@@ -144,7 +144,7 @@ def test_gptj_generation(model_name):
# eos_token_id=eos_token_id, fused_ft_kernel=False,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
,
enable_
timing
=
True
,
teacher_outputs
=
out_hf
.
sequences
,
)
torch
.
cuda
.
synchronize
()
...
...
@@ -163,7 +163,7 @@ def test_gptj_generation(model_name):
cg
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
,
enable_
timing
=
True
,
teacher_outputs
=
out_hf
.
sequences
,
)
torch
.
cuda
.
synchronize
()
...
...
tests/models/test_llama.py
View file @
913922ca
...
...
@@ -295,7 +295,7 @@ def test_llama_generation(model_name, checkpoint_format):
fused_ft_kernel
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
,
enable_
timing
=
True
,
teacher_outputs
=
out_hf
.
sequences
,
)
torch
.
cuda
.
synchronize
()
...
...
@@ -314,7 +314,7 @@ def test_llama_generation(model_name, checkpoint_format):
cg
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
,
enable_
timing
=
True
,
teacher_outputs
=
out_hf
.
sequences
,
)
torch
.
cuda
.
synchronize
()
...
...
@@ -403,7 +403,7 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
# teacher_outputs=out_hf.sequences,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
,
enable_
timing
=
True
,
)
# Capture graph outside the timing loop
...
...
@@ -420,7 +420,7 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
# teacher_outputs=out_hf.sequences,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
,
enable_
timing
=
True
,
)
del
model
parallel_state
.
destroy_model_parallel
()
...
...
tests/models/test_opt.py
View file @
913922ca
...
...
@@ -158,7 +158,7 @@ def test_opt_generation(model_name):
fused_ft_kernel
=
fused_ft_kernel
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
,
enable_
timing
=
True
,
)
torch
.
cuda
.
synchronize
()
print
(
f
"Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms"
)
...
...
@@ -179,7 +179,7 @@ def test_opt_generation(model_name):
cg
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
,
enable_
timing
=
True
,
)
torch
.
cuda
.
synchronize
()
print
(
f
"Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment