Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
f8aea6ea
"tools/vscode:/vscode.git/clone" did not exist on "237febb143de9e02b5789c0c5bc7d0f8bbc5b672"
Commit
f8aea6ea
authored
Aug 26, 2023
by
Tri Dao
Browse files
[GPT] Generalize last_token_only arg to num_last_tokens
parent
7a3bd55f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
19 deletions
+28
-19
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+4
-5
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+24
-12
tests/models/test_gpt.py
tests/models/test_gpt.py
+0
-2
No files found.
flash_attn/models/gpt.py
View file @
f8aea6ea
...
...
@@ -621,18 +621,17 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
batch_size
,
max_seqlen
,
dtype
=
dtype
,
**
kwargs
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
inference_params
=
None
,
last_token
_only
=
False
):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
inference_params
=
None
,
num_
last_token
s
=
0
):
"""
inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
last_token_only: whether to return the logit for the last token only,
of shape (batch_size, vocab_size)
num_last_tokens: if > 0, only return the logits for the last n tokens
"""
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
)
if
last_token
_only
:
hidden_states
=
hidden_states
[:,
-
1
]
if
num_
last_token
s
>
0
:
hidden_states
=
hidden_states
[:,
-
num_last_tokens
:
]
if
self
.
project_out
is
not
None
:
hidden_states
=
self
.
project_out
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
...
...
flash_attn/utils/generation.py
View file @
f8aea6ea
...
...
@@ -27,11 +27,19 @@ class InferenceParams:
lengths_per_sample
:
Optional
[
Tensor
]
=
None
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
def
modify_logits_for_top_k_filtering
(
logits
,
top_k
):
"""Set the logits for none top-k values to -inf."""
indices_to_remove
=
logits
<
torch
.
topk
(
logits
,
top_k
)[
0
][...,
-
1
,
None
]
logits
.
masked_fill_
(
indices_to_remove
,
float
(
"-Inf"
))
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
def
modify_logits_for_top_p_filtering
(
logits
,
top_p
):
"""Set the logits for none top-p values to -inf."""
if
top_p
<=
0.0
:
if
top_p
<=
0.0
or
top_p
>=
1.0
:
return
# First sort and calculate cumulative sum of probabilities.
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
False
)
...
...
@@ -58,14 +66,16 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
if
top_k
>
0
:
top_k
=
min
(
top_k
,
logits
.
size
(
-
1
))
# Safety check
logits_top
,
indices
=
torch
.
topk
(
logits
,
top_k
,
dim
=-
1
)
logits_top
/=
temperature
if
temperature
!=
1.0
:
logits_top
/=
temperature
modify_logits_for_top_p_filtering
(
logits_top
,
top_p
)
return
indices
[
torch
.
arange
(
indices
.
shape
[
0
],
device
=
indices
.
device
),
torch
.
multinomial
(
torch
.
softmax
(
logits_top
,
dim
=-
1
),
num_samples
=
1
).
squeeze
(
dim
=-
1
),
]
else
:
logits_top
=
logits
/
temperature
# Clone so that when we modify for top_p we don't change the original logits
logits_top
=
logits
/
temperature
if
temperature
!=
1.0
else
logits
.
clone
()
modify_logits_for_top_p_filtering
(
logits_top
,
top_p
)
return
torch
.
multinomial
(
torch
.
softmax
(
logits_top
,
dim
=-
1
),
num_samples
=
1
).
squeeze
(
dim
=-
1
...
...
@@ -131,8 +141,8 @@ def decode(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
,
last_token
_only
=
True
,
).
logits
num_
last_token
s
=
1
,
).
logits
.
squeeze
(
dim
=
1
)
else
:
return
model
.
_decoding_cache
.
run
(
input_ids
,
position_ids
,
inference_params
.
sequence_len_offset
...
...
@@ -149,7 +159,9 @@ def decode(
torch
.
distributed
.
barrier
()
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
logits
=
model
(
input_ids
,
inference_params
=
inference_params
,
last_token_only
=
True
).
logits
logits
=
model
(
input_ids
,
inference_params
=
inference_params
,
num_last_tokens
=
1
).
logits
.
squeeze
(
dim
=
1
)
logits
=
logits_postprocess_fn
(
logits
)
scores
.
append
(
logits
if
not
cg
else
logits
.
clone
())
if
teacher_outputs
is
None
or
teacher_output_len
<=
seqlen_og
:
...
...
@@ -165,9 +177,9 @@ def decode(
dtype
=
torch
.
long
,
device
=
input_ids
.
device
,
)
logits
=
logits_postprocess_fn
(
logits_forward_fn
(
rearrange
(
next_token
,
"b -> b 1"
),
position_ids
,
inference_params
)
)
logits
=
logits_postprocess_fn
(
logits_forward_fn
(
rearrange
(
next_token
,
"b -> b 1"
),
position_ids
,
inference_params
)
)
scores
.
append
(
logits
)
if
(
teacher_outputs
is
None
...
...
@@ -357,7 +369,7 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
,
last_token
_only
=
True
,
num_
last_token
s
=
1
,
).
logits
s
.
synchronize
()
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
...
...
@@ -374,8 +386,8 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
,
last_token
_only
=
True
,
).
logits
num_
last_token
s
=
1
,
).
logits
.
squeeze
(
dim
=
1
)
def
run
(
new_input_ids
,
new_position_ids
,
seqlen
):
inference_params
.
lengths_per_sample
[:]
=
seqlen
...
...
tests/models/test_gpt.py
View file @
f8aea6ea
...
...
@@ -355,8 +355,6 @@ def test_gpt2_multiple_token_generation(model_name, optimized):
config
.
fused_dropout_add_ln
=
True
# fused_ft_kernel currently doesn't work with multiple tokens at a time
# if not rotary, we load the weight from HF but ignore the position embeddings.
# The model would be nonsense but it doesn't matter for the test.
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
device
=
device
,
dtype
=
dtype
)
model
.
eval
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment