Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
3da42d24
Commit
3da42d24
authored
Apr 20, 2023
by
Tri Dao
Browse files
[GPT] Add option to only return the logit for the last token
parent
311d6606
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
8 deletions
+12
-8
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+6
-2
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+6
-6
No files found.
flash_attn/models/gpt.py
View file @
3da42d24
...
...
@@ -426,20 +426,24 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
if
self
.
process_group
is
not
None
:
sync_shared_params
(
self
,
self
.
process_group
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
inference_params
=
None
):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
inference_params
=
None
,
last_token_only
=
False
):
"""
inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
last_token_only: whether to return the logit for the last token only,
of shape (batch_size, vocab_size)
"""
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
)
if
last_token_only
:
hidden_states
=
hidden_states
[:,
-
1
]
if
self
.
project_out
is
not
None
:
hidden_states
=
self
.
project_out
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
# During inference, we want the full logit for sampling
if
isinstance
(
self
.
lm_head
,
ColumnParallelLinear
)
and
inference_params
is
not
None
:
lm_logits
,
_
=
all_gather_raw
(
lm_logits
,
self
.
lm_head
.
process_group
)
lm_logits
=
rearrange
(
lm_logits
,
'(n b)
s
d -> b
s
(n d)'
,
b
=
hidden_states
.
shape
[
0
])
lm_logits
=
rearrange
(
lm_logits
,
'(n b)
...
d -> b
...
(n d)'
,
b
=
hidden_states
.
shape
[
0
])
CausalLMOutput
=
namedtuple
(
'CausalLMOutput'
,
[
'logits'
])
return
CausalLMOutput
(
logits
=
lm_logits
)
...
...
flash_attn/utils/generation.py
View file @
3da42d24
...
...
@@ -112,7 +112,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
torch
.
distributed
.
barrier
()
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
logits
=
model
(
input_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
logits
=
model
(
input_ids
,
inference_params
=
inference_params
,
last_token_only
=
True
).
logits
if
vocab_size
is
not
None
:
logits
=
logits
[...,
:
vocab_size
]
scores
.
append
(
logits
if
not
cg
else
logits
.
clone
())
...
...
@@ -127,7 +127,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
if
not
cg
:
logits
=
model
(
rearrange
(
next_token
,
'b -> b 1'
),
position_ids
=
position_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
inference_params
=
inference_params
,
last_token_only
=
True
).
logits
else
:
logits
=
model
.
_decoding_cache
.
run
(
rearrange
(
next_token
,
'b -> b 1'
),
position_ids
,
inference_params
.
sequence_len_offset
)
...
...
@@ -269,8 +269,8 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
s
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
s
):
for
_
in
range
(
n_warmups
):
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
,
last_token_only
=
True
).
logits
s
.
synchronize
()
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
# which requires that graph launch and non-captured launch to not overlap (I think,
...
...
@@ -282,8 +282,8 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
# To allow capture, automatically sets a side stream as the current stream in the context
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
,
pool
=
mempool
):
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
,
last_token_only
=
True
).
logits
def
run
(
new_input_ids
,
new_position_ids
,
seqlen
):
inference_params
.
lengths_per_sample
[:]
=
seqlen
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment