Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
e02fd588
Commit
e02fd588
authored
Jan 07, 2023
by
Tri Dao
Browse files
[Gen] Implement top-k and top-p sampling
parent
11be742a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
57 additions
and
12 deletions
+57
-12
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+57
-10
tests/models/test_gpt_generation.py
tests/models/test_gpt_generation.py
+0
-2
No files found.
flash_attn/utils/generation.py
View file @
e02fd588
...
@@ -8,7 +8,7 @@ from torch import Tensor
...
@@ -8,7 +8,7 @@ from torch import Tensor
from
einops
import
rearrange
from
einops
import
rearrange
from
transformers.generation
import
GreedySearchDecoderOnlyOutput
from
transformers.generation
import
GreedySearchDecoderOnlyOutput
,
SampleDecoderOnlyOutput
@
dataclass
@
dataclass
...
@@ -24,13 +24,58 @@ class InferenceParams:
...
@@ -24,13 +24,58 @@ class InferenceParams:
lengths_per_sample
:
Optional
[
Tensor
]
=
None
lengths_per_sample
:
Optional
[
Tensor
]
=
None
def
greedy_decode
(
input_ids
,
model
,
max_length
,
fused_ft_kernel
=
True
):
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
"""Greedy decoding. This is a very simple implementation.
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
def
modify_logits_for_top_p_filtering
(
logits
,
top_p
):
"""Set the logits for none top-p values to -inf."""
if
top_p
<=
0.0
:
return
# First sort and calculate cumulative sum of probabilities.
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
False
)
cumulative_probs
=
sorted_logits
.
softmax
(
dim
=-
1
).
cumsum
(
dim
=-
1
)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove
=
cumulative_probs
<=
(
1
-
top_p
)
# scatter sorted tensors to original indexing
indices_to_remove
=
sorted_indices_to_remove
.
scatter
(
1
,
sorted_indices
,
sorted_indices_to_remove
)
logits
=
logits
.
masked_fill
(
indices_to_remove
,
float
(
'-inf'
))
def
sample
(
logits
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
):
"""Sample from top-k logits.
Arguments:
logits: Tensor of shape (batch_size, vocab_size)
"""
if
top_k
==
1
:
# Short-circuit for greedy decoding
return
logits
.
argmax
(
dim
=-
1
)
else
:
if
top_p
>
0.0
:
assert
top_p
<=
1.0
,
'top-p should be in (0, 1].'
if
top_k
>
0
:
top_k
=
min
(
top_k
,
logits
.
size
(
-
1
))
# Safety check
logits_top
,
indices
=
torch
.
topk
(
logits
,
top_k
,
dim
=-
1
)
logits_top
/=
temperature
modify_logits_for_top_p_filtering
(
logits_top
,
top_p
)
return
indices
[
torch
.
arange
(
indices
.
shape
[
0
],
device
=
indices
.
device
),
torch
.
multinomial
(
torch
.
softmax
(
logits_top
,
dim
=-
1
),
num_samples
=
1
).
squeeze
(
dim
=-
1
)
]
else
:
logits_top
=
logits
/
temperature
modify_logits_for_top_p_filtering
(
logits_top
,
top_p
)
return
torch
.
multinomial
(
torch
.
softmax
(
logits_top
,
dim
=-
1
),
num_samples
=
1
).
squeeze
(
dim
=-
1
)
def
decode
(
input_ids
,
model
,
max_length
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
,
fused_ft_kernel
=
True
):
"""Decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
then top-p.
We assume that all sequences in the same batch have the same length.
We assume that all sequences in the same batch have the same length.
Arguments:
Arguments:
input_ids: (batch, seq_len)
input_ids: (batch, seq_len)
max_length: int
max_length: int
Returns: GreedySearchDecoderOnlyOutput, with the following fields:
Returns: GreedySearchDecoderOnlyOutput
or SampleDecoderOnlyOutput
, with the following fields:
sequences: (batch, max_length)
sequences: (batch, max_length)
scores: tuples of (batch, vocab_size)
scores: tuples of (batch, vocab_size)
"""
"""
...
@@ -41,7 +86,7 @@ def greedy_decode(input_ids, model, max_length, fused_ft_kernel=True):
...
@@ -41,7 +86,7 @@ def greedy_decode(input_ids, model, max_length, fused_ft_kernel=True):
with
torch
.
inference_mode
():
with
torch
.
inference_mode
():
logits
=
model
(
input_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
logits
=
model
(
input_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
scores
.
append
(
logits
)
scores
.
append
(
logits
)
next_token
=
logits
.
argmax
(
dim
=-
1
)
next_token
=
sample
(
logits
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
sequences
=
[
next_token
]
sequences
=
[
next_token
]
inference_params
.
sequence_len_offset
=
seqlen_og
inference_params
.
sequence_len_offset
=
seqlen_og
while
True
:
while
True
:
...
@@ -50,12 +95,13 @@ def greedy_decode(input_ids, model, max_length, fused_ft_kernel=True):
...
@@ -50,12 +95,13 @@ def greedy_decode(input_ids, model, max_length, fused_ft_kernel=True):
logits
=
model
(
rearrange
(
next_token
,
'b -> b 1'
),
position_ids
=
position_ids
,
logits
=
model
(
rearrange
(
next_token
,
'b -> b 1'
),
position_ids
=
position_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
inference_params
=
inference_params
).
logits
[:,
-
1
]
scores
.
append
(
logits
)
scores
.
append
(
logits
)
next_token
=
logits
.
argmax
(
dim
=-
1
)
next_token
=
sample
(
logits
,
top_k
=
top_k
,
temperature
=
temperature
)
sequences
.
append
(
next_token
)
sequences
.
append
(
next_token
)
inference_params
.
sequence_len_offset
+=
1
inference_params
.
sequence_len_offset
+=
1
if
inference_params
.
sequence_len_offset
>=
max_length
-
1
:
if
inference_params
.
sequence_len_offset
>=
max_length
-
1
:
break
break
return
GreedySearchDecoderOnlyOutput
(
output_cls
=
GreedySearchDecoderOnlyOutput
if
top_k
==
1
else
SampleDecoderOnlyOutput
return
output_cls
(
sequences
=
torch
.
cat
([
input_ids
,
torch
.
stack
(
sequences
,
dim
=
1
)],
dim
=
1
),
sequences
=
torch
.
cat
([
input_ids
,
torch
.
stack
(
sequences
,
dim
=
1
)],
dim
=
1
),
scores
=
tuple
(
scores
)
scores
=
tuple
(
scores
)
)
)
...
@@ -63,9 +109,10 @@ def greedy_decode(input_ids, model, max_length, fused_ft_kernel=True):
...
@@ -63,9 +109,10 @@ def greedy_decode(input_ids, model, max_length, fused_ft_kernel=True):
class
GenerationMixin
:
class
GenerationMixin
:
def
generate
(
self
,
input_ids
,
max_length
,
return_dict_in_generate
=
False
,
output_scores
=
False
,
def
generate
(
self
,
input_ids
,
max_length
,
top_k
=
1
,
top_p
=
0.0
,
temperature
=
1.0
,
**
kwargs
):
return_dict_in_generate
=
False
,
output_scores
=
False
,
**
kwargs
):
output
=
greedy_decode
(
input_ids
,
self
,
max_length
,
**
kwargs
)
output
=
decode
(
input_ids
,
self
,
max_length
,
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
,
**
kwargs
)
if
not
output_scores
:
if
not
output_scores
:
output
.
scores
=
None
output
.
scores
=
None
return
output
if
return_dict_in_generate
else
output
.
sequences
return
output
if
return_dict_in_generate
else
output
.
sequences
tests/models/test_gpt_generation.py
View file @
e02fd588
...
@@ -11,14 +11,12 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHead
...
@@ -11,14 +11,12 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHead
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.gpt
import
remap_state_dict_gpt2
from
flash_attn.models.gpt
import
remap_state_dict_gpt2
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.generation
import
greedy_decode
@
pytest
.
mark
.
parametrize
(
'fused_ft_kernel'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'fused_ft_kernel'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'optimized'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'optimized'
,
[
False
,
True
])
# @pytest.mark.parametrize('fused_ft_kernel', [False])
# @pytest.mark.parametrize('fused_ft_kernel', [False])
# @pytest.mark.parametrize('optimized', [True])
# @pytest.mark.parametrize('optimized', [True])
# @pytest.mark.parametrize('optimized', [True])
@
pytest
.
mark
.
parametrize
(
'rotary'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'rotary'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
])
def
test_greedy_decode
(
model_name
,
rotary
,
optimized
,
fused_ft_kernel
):
def
test_greedy_decode
(
model_name
,
rotary
,
optimized
,
fused_ft_kernel
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment