Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
f95c2fc1
"...git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "483201d484f6f0842580a2ea04a1a3ab96e1c6f5"
Commit
f95c2fc1
authored
Jan 07, 2023
by
Tri Dao
Browse files
[Gen] Remove commented code
parent
b4859900
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
0 additions
and
3 deletions
+0
-3
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+0
-3
No files found.
flash_attn/utils/generation.py
View file @
f95c2fc1
...
@@ -97,7 +97,6 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
...
@@ -97,7 +97,6 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
if
cg
:
if
cg
:
assert
fused_ft_kernel
assert
fused_ft_kernel
run
,
cg_cache
=
capture_cg
(
model
,
inference_params
,
batch_size
,
seqlen_og
,
max_length
)
run
,
cg_cache
=
capture_cg
(
model
,
inference_params
,
batch_size
,
seqlen_og
,
max_length
)
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
if
timing
:
if
timing
:
start
=
time
.
time
()
start
=
time
.
time
()
while
True
:
while
True
:
...
@@ -117,8 +116,6 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
...
@@ -117,8 +116,6 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
break
break
if
timing
:
if
timing
:
print
(
f
'Decoding time:
{
time
.
time
()
-
start
}
'
)
print
(
f
'Decoding time:
{
time
.
time
()
-
start
}
'
)
# print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=50))
# prof.export_chrome_trace("gpt2s_generation.json")
output_cls
=
GreedySearchDecoderOnlyOutput
if
top_k
==
1
else
SampleDecoderOnlyOutput
output_cls
=
GreedySearchDecoderOnlyOutput
if
top_k
==
1
else
SampleDecoderOnlyOutput
return
output_cls
(
return
output_cls
(
sequences
=
torch
.
cat
([
input_ids
,
torch
.
stack
(
sequences
,
dim
=
1
)],
dim
=
1
),
sequences
=
torch
.
cat
([
input_ids
,
torch
.
stack
(
sequences
,
dim
=
1
)],
dim
=
1
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment