Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
1c9ef9b3
Commit
1c9ef9b3
authored
Apr 13, 2023
by
Tri Dao
Browse files
[Gen] Measure prompt processing + decoding time, not just decoding
parent
6f6e9a9a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
2 deletions
+6
-2
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+6
-2
No files found.
flash_attn/utils/generation.py
View file @
1c9ef9b3
...
@@ -107,10 +107,12 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
...
@@ -107,10 +107,12 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
fused_ft_kernel
=
fused_ft_kernel
)
fused_ft_kernel
=
fused_ft_kernel
)
scores
=
[]
scores
=
[]
with
torch
.
inference_mode
():
with
torch
.
inference_mode
():
logits
=
model
(
input_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
if
timing
:
if
timing
:
if
tensor_parallel
>
1
:
torch
.
distributed
.
barrier
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
start
=
time
.
time
()
logits
=
model
(
input_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
if
vocab_size
is
not
None
:
if
vocab_size
is
not
None
:
logits
=
logits
[...,
:
vocab_size
]
logits
=
logits
[...,
:
vocab_size
]
scores
.
append
(
logits
if
not
cg
else
logits
.
clone
())
scores
.
append
(
logits
if
not
cg
else
logits
.
clone
())
...
@@ -143,8 +145,10 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
...
@@ -143,8 +145,10 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
if
inference_params
.
sequence_len_offset
>=
max_length
-
1
:
if
inference_params
.
sequence_len_offset
>=
max_length
-
1
:
break
break
if
timing
:
if
timing
:
if
tensor_parallel
>
1
:
torch
.
distributed
.
barrier
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
f
'
D
ecoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms'
)
print
(
f
'
Prompt processing + d
ecoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms'
)
output_cls
=
GreedySearchDecoderOnlyOutput
if
top_k
==
1
else
SampleDecoderOnlyOutput
output_cls
=
GreedySearchDecoderOnlyOutput
if
top_k
==
1
else
SampleDecoderOnlyOutput
return
output_cls
(
return
output_cls
(
sequences
=
torch
.
cat
([
input_ids
,
torch
.
stack
(
sequences
,
dim
=
1
)],
dim
=
1
),
sequences
=
torch
.
cat
([
input_ids
,
torch
.
stack
(
sequences
,
dim
=
1
)],
dim
=
1
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment