Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9e5c28c5
Unverified
Commit
9e5c28c5
authored
Dec 14, 2023
by
Joao Gante
Committed by
GitHub
Dec 14, 2023
Browse files
Generate: assisted decoding now uses `generate` for the assistant (#28030)
generate refactor
parent
dde6c427
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
67 additions
and
71 deletions
+67
-71
src/transformers/generation/candidate_generator.py
src/transformers/generation/candidate_generator.py
+27
-51
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+1
-1
tests/generation/test_utils.py
tests/generation/test_utils.py
+39
-19
No files found.
src/transformers/generation/candidate_generator.py
View file @
9e5c28c5
...
...
@@ -15,7 +15,7 @@
import
copy
import
warnings
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
...
...
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
class
CandidateGenerator
:
"""Abstract base class for all candidate generators that can be applied during assisted generation."""
def
get_candidates
(
self
,
input_ids
:
torch
.
LongTensor
)
->
torch
.
LongTensor
:
def
get_candidates
(
self
,
input_ids
:
torch
.
LongTensor
)
->
Tuple
[
torch
.
LongTensor
,
Optional
[
torch
.
FloatTensor
]]
:
"""
Fetches the candidates to be tried for the current input.
...
...
@@ -37,8 +37,9 @@ class CandidateGenerator:
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
Return:
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be assessed by
the model.
`torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
assessed by the model and, optionally, a `torch.FloatTensor` of shape `(batch_size, candidate_length,
vocabulary_size)` containing the logits associated to each candidate.
"""
raise
NotImplementedError
(
f
"
{
self
.
__class__
}
is an abstract class. Only classes inheriting this class can call `get_candidates`."
...
...
@@ -152,7 +153,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
)
self
.
logits_processor
=
logits_processor
def
get_candidates
(
self
,
input_ids
:
torch
.
LongTensor
)
->
torch
.
LongTensor
:
def
get_candidates
(
self
,
input_ids
:
torch
.
LongTensor
)
->
Tuple
[
torch
.
LongTensor
,
Optional
[
torch
.
FloatTensor
]]
:
"""
Fetches the candidates to be tried for the current input.
...
...
@@ -161,7 +162,9 @@ class AssistedCandidateGenerator(CandidateGenerator):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
Return:
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried.
`torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
vocabulary_size)` containing the logits associated to each candidate.
"""
# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
# (which implicitly contains the number of accepted candidates from the previous round)
...
...
@@ -179,51 +182,24 @@ class AssistedCandidateGenerator(CandidateGenerator):
)
self
.
assistant_kwargs
=
_prepare_token_type_ids
(
self
.
assistant_kwargs
,
new_cur_len
)
# 2. Forecast next N tokens using the assistant model. This `for` block can be replaced with a `.generate()`
# call if we decide to add `past_key_values` as a possible output of generate, as we need access to the
# assistant cache to secure strong speedups.
candidate_input_ids
=
input_ids
for
_
in
range
(
int
(
self
.
num_assistant_tokens
)):
# 2.1 prepare assistant model inputs
assistant_inputs
=
self
.
assistant_model
.
prepare_inputs_for_generation
(
candidate_input_ids
,
**
self
.
assistant_kwargs
,
)
# 2.2. check if the input ids length is correct
has_past_key_values
=
assistant_inputs
.
get
(
"past_key_values"
,
None
)
is
not
None
if
has_past_key_values
and
assistant_inputs
[
self
.
input_ids_key
].
shape
[
-
1
]
not
in
(
1
,
2
):
raise
ValueError
(
"The length of the input ids in assistant inputs should be 1 or 2"
)
# 2.3. use the assistant model to obtain the next candidate logits
assistant_model_outputs
=
self
.
assistant_model
(
**
assistant_inputs
)
# 2.4. greedily select the next candidate token
if
len
(
self
.
logits_processor
)
>
0
:
assistant_model_outputs
.
logits
[:,
-
1
,
:]
=
self
.
logits_processor
(
candidate_input_ids
,
assistant_model_outputs
.
logits
[:,
-
1
,
:]
)
new_token
=
assistant_model_outputs
.
logits
[:,
-
1
,
:].
argmax
(
dim
=-
1
)
candidate_input_ids
=
torch
.
cat
((
candidate_input_ids
,
new_token
[:,
None
]),
dim
=-
1
)
# 2.5. update assistant model inputs
if
self
.
assistant_kwargs
.
get
(
self
.
attention_key
,
None
)
is
not
None
:
mask
=
self
.
assistant_kwargs
[
self
.
attention_key
]
self
.
assistant_kwargs
[
self
.
attention_key
]
=
torch
.
cat
(
[
mask
,
mask
.
new_ones
((
mask
.
shape
[
0
],
1
))],
dim
=-
1
)
self
.
assistant_kwargs
[
"past_key_values"
]
=
assistant_model_outputs
.
past_key_values
# 2.6. stop assistant generation on EOS
if
self
.
eos_token_id_tensor
is
not
None
:
last_assistant_token_is_eos
=
new_token
.
tile
(
self
.
eos_token_id_tensor
.
shape
[
0
],
1
)
last_assistant_token_is_eos
=
(
~
last_assistant_token_is_eos
.
ne
(
self
.
eos_token_id_tensor
.
unsqueeze
(
1
)).
prod
(
dim
=
0
).
bool
()
)
if
last_assistant_token_is_eos
:
break
return
candidate_input_ids
# 2. Forecast next N tokens using the assistant model.
assistant_generation_kwargs
=
{
self
.
input_ids_key
:
input_ids
,
"do_sample"
:
False
,
"num_beams"
:
1
,
"max_new_tokens"
:
int
(
self
.
num_assistant_tokens
),
"return_dict_in_generate"
:
True
,
"output_scores"
:
True
,
}
assistant_output
=
self
.
assistant_model
.
generate
(
**
assistant_generation_kwargs
,
**
self
.
assistant_kwargs
)
# 3. Update variables for the next round of candidate generation
self
.
assistant_kwargs
[
"past_key_values"
]
=
assistant_output
.
past_key_values
# 4. Prepare variables for output
candidate_logits
=
torch
.
stack
(
assistant_output
.
scores
,
dim
=
1
)
candidate_ids
=
assistant_output
.
sequences
return
candidate_ids
,
candidate_logits
def
update_candidate_strategy
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
num_matches
:
int
):
"""
...
...
src/transformers/generation/utils.py
View file @
9e5c28c5
...
...
@@ -4585,7 +4585,7 @@ class GenerationMixin:
cur_len
=
input_ids
.
shape
[
-
1
]
# 1. Fetch candidate sequences from a `CandidateGenerator`
candidate_input_ids
=
candidate_generator
.
get_candidates
(
input_ids
)
candidate_input_ids
,
candidate_logits
=
candidate_generator
.
get_candidates
(
input_ids
)
candidate_length
=
candidate_input_ids
.
shape
[
1
]
-
input_ids
.
shape
[
1
]
last_assistant_token_is_eos
=
(
~
candidate_input_ids
[:,
-
1
]
...
...
tests/generation/test_utils.py
View file @
9e5c28c5
...
...
@@ -3128,21 +3128,26 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self
.
assertListEqual
(
outputs_assisted
.
tolist
(),
outputs_tti
.
tolist
())
def
test_model_kwarg_assisted_decoding_encoder_decoder
(
self
):
"""
Tests that the following scenario is compatible with assisted generation:
1. encoder-decoder main model
2. encoder-decoder assistant model
3. both have a custom input
(e.g. Whisper)
"""
# PT-only test: TF doesn't support assisted decoding yet.
# Bart subclass with a kwarg that distorts the output
class
FakeBart
(
BartForConditionalGeneration
):
def
forward
(
self
,
input_ids
,
foo
=
False
,
**
kwargs
):
outs
=
super
().
forward
(
input_ids
,
**
kwargs
)
def
forward
(
self
,
input_ids
,
past_key_values
,
foo
=
False
,
**
kwargs
):
outs
=
super
().
forward
(
input_ids
,
past_key_values
=
past_key_values
,
**
kwargs
)
if
foo
:
outs
[
"logits"
][:,
:,
:]
=
0.0
return
outs
def
prepare_inputs_for_generation
(
self
,
*
args
,
foo
=
False
,
encoder_outputs
=
None
,
**
kwargs
):
kwargs
[
"encoder_outputs"
]
=
encoder_outputs
inputs
=
super
().
prepare_inputs_for_generation
(
*
args
,
**
kwargs
)
inputs
[
"foo"
]
=
foo
return
inputs
...
...
@@ -3160,17 +3165,14 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self
.
assertEqual
(
outputs_normal
.
shape
,
(
1
,
20
))
# Should be different with foo
outputs_foo
=
model
.
generate
(
input_ids
,
foo
=
True
,
)
outputs_foo
=
model
.
generate
(
input_ids
,
foo
=
True
)
with
self
.
assertRaises
(
AssertionError
):
self
.
assertListEqual
(
outputs_foo
.
tolist
(),
outputs_normal
.
tolist
())
# Assistant model
assistant
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
)
.
to
(
torch_device
)
assistant
=
FakeBart
.
from_pretrained
(
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
).
to
(
torch_device
)
# If assisted generation passes model_kwargs correctly, should be same as previous
outputs_assisted
=
model
.
generate
(
...
...
@@ -3192,25 +3194,43 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self
.
assertListEqual
(
outputs_assisted
.
tolist
(),
outputs_foo
.
tolist
())
def
test_assisted_decoding_encoder_decoder_shared_encoder
(
self
):
"""
Tests that the following scenario is compatible with assisted generation:
1. encoder-decoder main model
2. decoder-only assistant model
3. both have a custom input
(e.g. DistilWhisper)
"""
# PT-only test: TF doesn't support assisted decoding yet.
# Bart subclass with a kwarg called foo that distorts the output
class
FakeBart
(
BartForConditionalGeneration
):
class
FakeBart
Seq2Seq
(
BartForConditionalGeneration
):
def
forward
(
self
,
input_ids
,
foo
=
False
,
**
kwargs
):
outs
=
super
().
forward
(
input_ids
,
**
kwargs
)
if
foo
:
outs
[
"logits"
][:,
:,
:]
=
0.0
return
outs
def
prepare_inputs_for_generation
(
self
,
*
args
,
foo
=
False
,
encoder_outputs
=
None
,
**
kwargs
):
kwargs
[
"encoder_outputs"
]
=
encoder_outputs
inputs
=
super
().
prepare_inputs_for_generation
(
*
args
,
**
kwargs
)
inputs
[
"foo"
]
=
foo
return
inputs
class
FakeBartCausalLM
(
BartForCausalLM
):
def
forward
(
self
,
input_ids
,
attention_mask
,
past_key_values
,
foo
=
False
,
**
kwargs
):
outs
=
super
().
forward
(
input_ids
,
attention_mask
,
past_key_values
=
past_key_values
,
**
kwargs
)
if
foo
:
outs
[
"logits"
][:,
:,
:]
=
0.0
return
outs
def
prepare_inputs_for_generation
(
self
,
*
args
,
foo
=
False
,
encoder_outputs
=
None
,
**
kwargs
):
kwargs
[
"encoder_outputs"
]
=
encoder_outputs
inputs
=
super
().
prepare_inputs_for_generation
(
*
args
,
**
kwargs
)
inputs
[
"foo"
]
=
foo
return
inputs
model
=
FakeBart
.
from_pretrained
(
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
).
to
(
model
=
FakeBart
Seq2Seq
.
from_pretrained
(
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
).
to
(
torch_device
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
)
...
...
@@ -3229,9 +3249,9 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self
.
assertListEqual
(
outputs_foo
.
tolist
(),
outputs_normal
.
tolist
())
# Assistant model
assistant
=
Bart
For
CausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
).
to
(
torch_device
)
assistant
=
Fake
BartCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
)
.
to
(
torch_device
)
# If assisted generation passes model_kwargs correctly, should be same as previous
outputs_assisted
=
model
.
generate
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment