Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9e5c28c5
Unverified
Commit
9e5c28c5
authored
Dec 14, 2023
by
Joao Gante
Committed by
GitHub
Dec 14, 2023
Browse files
Generate: assisted decoding now uses `generate` for the assistant (#28030)
generate refactor
parent
dde6c427
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
67 additions
and
71 deletions
+67
-71
src/transformers/generation/candidate_generator.py
src/transformers/generation/candidate_generator.py
+27
-51
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+1
-1
tests/generation/test_utils.py
tests/generation/test_utils.py
+39
-19
No files found.
src/transformers/generation/candidate_generator.py
View file @
9e5c28c5
...
...
@@ -15,7 +15,7 @@
import
copy
import
warnings
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
...
...
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
class
CandidateGenerator
:
"""Abstract base class for all candidate generators that can be applied during assisted generation."""
def
get_candidates
(
self
,
input_ids
:
torch
.
LongTensor
)
->
torch
.
LongTensor
:
def
get_candidates
(
self
,
input_ids
:
torch
.
LongTensor
)
->
Tuple
[
torch
.
LongTensor
,
Optional
[
torch
.
FloatTensor
]]
:
"""
Fetches the candidates to be tried for the current input.
...
...
@@ -37,8 +37,9 @@ class CandidateGenerator:
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
Return:
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be assessed by
the model.
`torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
assessed by the model and, optionally, a `torch.FloatTensor` of shape `(batch_size, candidate_length,
vocabulary_size)` containing the logits associated to each candidate.
"""
raise
NotImplementedError
(
f
"
{
self
.
__class__
}
is an abstract class. Only classes inheriting this class can call `get_candidates`."
...
...
@@ -152,7 +153,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
)
self
.
logits_processor
=
logits_processor
def
get_candidates
(
self
,
input_ids
:
torch
.
LongTensor
)
->
torch
.
LongTensor
:
def
get_candidates
(
self
,
input_ids
:
torch
.
LongTensor
)
->
Tuple
[
torch
.
LongTensor
,
Optional
[
torch
.
FloatTensor
]]
:
"""
Fetches the candidates to be tried for the current input.
...
...
@@ -161,7 +162,9 @@ class AssistedCandidateGenerator(CandidateGenerator):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
Return:
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried.
`torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
vocabulary_size)` containing the logits associated to each candidate.
"""
# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
# (which implicitly contains the number of accepted candidates from the previous round)
...
...
@@ -179,51 +182,24 @@ class AssistedCandidateGenerator(CandidateGenerator):
)
self
.
assistant_kwargs
=
_prepare_token_type_ids
(
self
.
assistant_kwargs
,
new_cur_len
)
# 2. Forecast next N tokens using the assistant model. This `for` block can be replaced with a `.generate()`
# call if we decide to add `past_key_values` as a possible output of generate, as we need access to the
# assistant cache to secure strong speedups.
candidate_input_ids
=
input_ids
for
_
in
range
(
int
(
self
.
num_assistant_tokens
)):
# 2.1 prepare assistant model inputs
assistant_inputs
=
self
.
assistant_model
.
prepare_inputs_for_generation
(
candidate_input_ids
,
**
self
.
assistant_kwargs
,
)
# 2.2. check if the input ids length is correct
has_past_key_values
=
assistant_inputs
.
get
(
"past_key_values"
,
None
)
is
not
None
if
has_past_key_values
and
assistant_inputs
[
self
.
input_ids_key
].
shape
[
-
1
]
not
in
(
1
,
2
):
raise
ValueError
(
"The length of the input ids in assistant inputs should be 1 or 2"
)
# 2.3. use the assistant model to obtain the next candidate logits
assistant_model_outputs
=
self
.
assistant_model
(
**
assistant_inputs
)
# 2.4. greedily select the next candidate token
if
len
(
self
.
logits_processor
)
>
0
:
assistant_model_outputs
.
logits
[:,
-
1
,
:]
=
self
.
logits_processor
(
candidate_input_ids
,
assistant_model_outputs
.
logits
[:,
-
1
,
:]
)
new_token
=
assistant_model_outputs
.
logits
[:,
-
1
,
:].
argmax
(
dim
=-
1
)
candidate_input_ids
=
torch
.
cat
((
candidate_input_ids
,
new_token
[:,
None
]),
dim
=-
1
)
# 2.5. update assistant model inputs
if
self
.
assistant_kwargs
.
get
(
self
.
attention_key
,
None
)
is
not
None
:
mask
=
self
.
assistant_kwargs
[
self
.
attention_key
]
self
.
assistant_kwargs
[
self
.
attention_key
]
=
torch
.
cat
(
[
mask
,
mask
.
new_ones
((
mask
.
shape
[
0
],
1
))],
dim
=-
1
)
self
.
assistant_kwargs
[
"past_key_values"
]
=
assistant_model_outputs
.
past_key_values
# 2.6. stop assistant generation on EOS
if
self
.
eos_token_id_tensor
is
not
None
:
last_assistant_token_is_eos
=
new_token
.
tile
(
self
.
eos_token_id_tensor
.
shape
[
0
],
1
)
last_assistant_token_is_eos
=
(
~
last_assistant_token_is_eos
.
ne
(
self
.
eos_token_id_tensor
.
unsqueeze
(
1
)).
prod
(
dim
=
0
).
bool
()
)
if
last_assistant_token_is_eos
:
break
return
candidate_input_ids
# 2. Forecast next N tokens using the assistant model.
assistant_generation_kwargs
=
{
self
.
input_ids_key
:
input_ids
,
"do_sample"
:
False
,
"num_beams"
:
1
,
"max_new_tokens"
:
int
(
self
.
num_assistant_tokens
),
"return_dict_in_generate"
:
True
,
"output_scores"
:
True
,
}
assistant_output
=
self
.
assistant_model
.
generate
(
**
assistant_generation_kwargs
,
**
self
.
assistant_kwargs
)
# 3. Update variables for the next round of candidate generation
self
.
assistant_kwargs
[
"past_key_values"
]
=
assistant_output
.
past_key_values
# 4. Prepare variables for output
candidate_logits
=
torch
.
stack
(
assistant_output
.
scores
,
dim
=
1
)
candidate_ids
=
assistant_output
.
sequences
return
candidate_ids
,
candidate_logits
def
update_candidate_strategy
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
num_matches
:
int
):
"""
...
...
src/transformers/generation/utils.py
View file @
9e5c28c5
...
...
@@ -4585,7 +4585,7 @@ class GenerationMixin:
cur_len
=
input_ids
.
shape
[
-
1
]
# 1. Fetch candidate sequences from a `CandidateGenerator`
candidate_input_ids
=
candidate_generator
.
get_candidates
(
input_ids
)
candidate_input_ids
,
candidate_logits
=
candidate_generator
.
get_candidates
(
input_ids
)
candidate_length
=
candidate_input_ids
.
shape
[
1
]
-
input_ids
.
shape
[
1
]
last_assistant_token_is_eos
=
(
~
candidate_input_ids
[:,
-
1
]
...
...
tests/generation/test_utils.py
View file @
9e5c28c5
...
...
@@ -3128,21 +3128,26 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self
.
assertListEqual
(
outputs_assisted
.
tolist
(),
outputs_tti
.
tolist
())
def
test_model_kwarg_assisted_decoding_encoder_decoder
(
self
):
"""
Tests that the following scenario is compatible with assisted generation:
1. encoder-decoder main model
2. encoder-decoder assistant model
3. both have a custom input
(e.g. Whisper)
"""
# PT-only test: TF doesn't support assisted decoding yet.
# Bart subclass with a kwarg that distorts the output
class
FakeBart
(
BartForConditionalGeneration
):
def
forward
(
self
,
input_ids
,
foo
=
False
,
**
kwargs
):
outs
=
super
().
forward
(
input_ids
,
**
kwargs
)
def
forward
(
self
,
input_ids
,
past_key_values
,
foo
=
False
,
**
kwargs
):
outs
=
super
().
forward
(
input_ids
,
past_key_values
=
past_key_values
,
**
kwargs
)
if
foo
:
outs
[
"logits"
][:,
:,
:]
=
0.0
return
outs
def
prepare_inputs_for_generation
(
self
,
*
args
,
foo
=
False
,
encoder_outputs
=
None
,
**
kwargs
):
kwargs
[
"encoder_outputs"
]
=
encoder_outputs
inputs
=
super
().
prepare_inputs_for_generation
(
*
args
,
**
kwargs
)
inputs
[
"foo"
]
=
foo
return
inputs
...
...
@@ -3160,17 +3165,14 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self
.
assertEqual
(
outputs_normal
.
shape
,
(
1
,
20
))
# Should be different with foo
outputs_foo
=
model
.
generate
(
input_ids
,
foo
=
True
,
)
outputs_foo
=
model
.
generate
(
input_ids
,
foo
=
True
)
with
self
.
assertRaises
(
AssertionError
):
self
.
assertListEqual
(
outputs_foo
.
tolist
(),
outputs_normal
.
tolist
())
# Assistant model
assistant
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
)
.
to
(
torch_device
)
assistant
=
FakeBart
.
from_pretrained
(
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
).
to
(
torch_device
)
# If assisted generation passes model_kwargs correctly, should be same as previous
outputs_assisted
=
model
.
generate
(
...
...
@@ -3192,25 +3194,43 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self
.
assertListEqual
(
outputs_assisted
.
tolist
(),
outputs_foo
.
tolist
())
def
test_assisted_decoding_encoder_decoder_shared_encoder
(
self
):
"""
Tests that the following scenario is compatible with assisted generation:
1. encoder-decoder main model
2. decoder-only assistant model
3. both have a custom input
(e.g. DistilWhisper)
"""
# PT-only test: TF doesn't support assisted decoding yet.
# Bart subclass with a kwarg called foo that distorts the output
class
FakeBart
(
BartForConditionalGeneration
):
class
FakeBart
Seq2Seq
(
BartForConditionalGeneration
):
def
forward
(
self
,
input_ids
,
foo
=
False
,
**
kwargs
):
outs
=
super
().
forward
(
input_ids
,
**
kwargs
)
if
foo
:
outs
[
"logits"
][:,
:,
:]
=
0.0
return
outs
def
prepare_inputs_for_generation
(
self
,
*
args
,
foo
=
False
,
encoder_outputs
=
None
,
**
kwargs
):
kwargs
[
"encoder_outputs"
]
=
encoder_outputs
inputs
=
super
().
prepare_inputs_for_generation
(
*
args
,
**
kwargs
)
inputs
[
"foo"
]
=
foo
return
inputs
class
FakeBartCausalLM
(
BartForCausalLM
):
def
forward
(
self
,
input_ids
,
attention_mask
,
past_key_values
,
foo
=
False
,
**
kwargs
):
outs
=
super
().
forward
(
input_ids
,
attention_mask
,
past_key_values
=
past_key_values
,
**
kwargs
)
if
foo
:
outs
[
"logits"
][:,
:,
:]
=
0.0
return
outs
def
prepare_inputs_for_generation
(
self
,
*
args
,
foo
=
False
,
encoder_outputs
=
None
,
**
kwargs
):
kwargs
[
"encoder_outputs"
]
=
encoder_outputs
inputs
=
super
().
prepare_inputs_for_generation
(
*
args
,
**
kwargs
)
inputs
[
"foo"
]
=
foo
return
inputs
model
=
FakeBart
.
from_pretrained
(
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
).
to
(
model
=
FakeBart
Seq2Seq
.
from_pretrained
(
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
).
to
(
torch_device
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
)
...
...
@@ -3229,9 +3249,9 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self
.
assertListEqual
(
outputs_foo
.
tolist
(),
outputs_normal
.
tolist
())
# Assistant model
assistant
=
Bart
For
CausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
).
to
(
torch_device
)
assistant
=
Fake
BartCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
)
.
to
(
torch_device
)
# If assisted generation passes model_kwargs correctly, should be same as previous
outputs_assisted
=
model
.
generate
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment