Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9e5c28c5
Unverified
Commit
9e5c28c5
authored
Dec 14, 2023
by
Joao Gante
Committed by
GitHub
Dec 14, 2023
Browse files
Generate: assisted decoding now uses `generate` for the assistant (#28030)
generate refactor
parent
dde6c427
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
67 additions
and
71 deletions
+67
-71
src/transformers/generation/candidate_generator.py
src/transformers/generation/candidate_generator.py
+27
-51
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+1
-1
tests/generation/test_utils.py
tests/generation/test_utils.py
+39
-19
No files found.
src/transformers/generation/candidate_generator.py
View file @
9e5c28c5
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
import
copy
import
copy
import
warnings
import
warnings
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
...
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
...
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
class
CandidateGenerator
:
class
CandidateGenerator
:
"""Abstract base class for all candidate generators that can be applied during assisted generation."""
"""Abstract base class for all candidate generators that can be applied during assisted generation."""
def
get_candidates
(
self
,
input_ids
:
torch
.
LongTensor
)
->
torch
.
LongTensor
:
def
get_candidates
(
self
,
input_ids
:
torch
.
LongTensor
)
->
Tuple
[
torch
.
LongTensor
,
Optional
[
torch
.
FloatTensor
]]
:
"""
"""
Fetches the candidates to be tried for the current input.
Fetches the candidates to be tried for the current input.
...
@@ -37,8 +37,9 @@ class CandidateGenerator:
...
@@ -37,8 +37,9 @@ class CandidateGenerator:
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
Return:
Return:
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be assessed by
`torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
the model.
assessed by the model and, optionally, a `torch.FloatTensor` of shape `(batch_size, candidate_length,
vocabulary_size)` containing the logits associated to each candidate.
"""
"""
raise
NotImplementedError
(
raise
NotImplementedError
(
f
"
{
self
.
__class__
}
is an abstract class. Only classes inheriting this class can call `get_candidates`."
f
"
{
self
.
__class__
}
is an abstract class. Only classes inheriting this class can call `get_candidates`."
...
@@ -152,7 +153,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
...
@@ -152,7 +153,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
)
)
self
.
logits_processor
=
logits_processor
self
.
logits_processor
=
logits_processor
def
get_candidates
(
self
,
input_ids
:
torch
.
LongTensor
)
->
torch
.
LongTensor
:
def
get_candidates
(
self
,
input_ids
:
torch
.
LongTensor
)
->
Tuple
[
torch
.
LongTensor
,
Optional
[
torch
.
FloatTensor
]]
:
"""
"""
Fetches the candidates to be tried for the current input.
Fetches the candidates to be tried for the current input.
...
@@ -161,7 +162,9 @@ class AssistedCandidateGenerator(CandidateGenerator):
...
@@ -161,7 +162,9 @@ class AssistedCandidateGenerator(CandidateGenerator):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
Return:
Return:
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried.
`torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
vocabulary_size)` containing the logits associated to each candidate.
"""
"""
# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
# (which implicitly contains the number of accepted candidates from the previous round)
# (which implicitly contains the number of accepted candidates from the previous round)
...
@@ -179,51 +182,24 @@ class AssistedCandidateGenerator(CandidateGenerator):
...
@@ -179,51 +182,24 @@ class AssistedCandidateGenerator(CandidateGenerator):
)
)
self
.
assistant_kwargs
=
_prepare_token_type_ids
(
self
.
assistant_kwargs
,
new_cur_len
)
self
.
assistant_kwargs
=
_prepare_token_type_ids
(
self
.
assistant_kwargs
,
new_cur_len
)
# 2. Forecast next N tokens using the assistant model. This `for` block can be replaced with a `.generate()`
# 2. Forecast next N tokens using the assistant model.
# call if we decide to add `past_key_values` as a possible output of generate, as we need access to the
assistant_generation_kwargs
=
{
# assistant cache to secure strong speedups.
self
.
input_ids_key
:
input_ids
,
candidate_input_ids
=
input_ids
"do_sample"
:
False
,
for
_
in
range
(
int
(
self
.
num_assistant_tokens
)):
"num_beams"
:
1
,
# 2.1 prepare assistant model inputs
"max_new_tokens"
:
int
(
self
.
num_assistant_tokens
),
assistant_inputs
=
self
.
assistant_model
.
prepare_inputs_for_generation
(
"return_dict_in_generate"
:
True
,
candidate_input_ids
,
"output_scores"
:
True
,
**
self
.
assistant_kwargs
,
}
)
assistant_output
=
self
.
assistant_model
.
generate
(
**
assistant_generation_kwargs
,
**
self
.
assistant_kwargs
)
# 2.2. check if the input ids length is correct
# 3. Update variables for the next round of candidate generation
has_past_key_values
=
assistant_inputs
.
get
(
"past_key_values"
,
None
)
is
not
None
self
.
assistant_kwargs
[
"past_key_values"
]
=
assistant_output
.
past_key_values
if
has_past_key_values
and
assistant_inputs
[
self
.
input_ids_key
].
shape
[
-
1
]
not
in
(
1
,
2
):
raise
ValueError
(
"The length of the input ids in assistant inputs should be 1 or 2"
)
# 4. Prepare variables for output
candidate_logits
=
torch
.
stack
(
assistant_output
.
scores
,
dim
=
1
)
# 2.3. use the assistant model to obtain the next candidate logits
candidate_ids
=
assistant_output
.
sequences
assistant_model_outputs
=
self
.
assistant_model
(
**
assistant_inputs
)
return
candidate_ids
,
candidate_logits
# 2.4. greedily select the next candidate token
if
len
(
self
.
logits_processor
)
>
0
:
assistant_model_outputs
.
logits
[:,
-
1
,
:]
=
self
.
logits_processor
(
candidate_input_ids
,
assistant_model_outputs
.
logits
[:,
-
1
,
:]
)
new_token
=
assistant_model_outputs
.
logits
[:,
-
1
,
:].
argmax
(
dim
=-
1
)
candidate_input_ids
=
torch
.
cat
((
candidate_input_ids
,
new_token
[:,
None
]),
dim
=-
1
)
# 2.5. update assistant model inputs
if
self
.
assistant_kwargs
.
get
(
self
.
attention_key
,
None
)
is
not
None
:
mask
=
self
.
assistant_kwargs
[
self
.
attention_key
]
self
.
assistant_kwargs
[
self
.
attention_key
]
=
torch
.
cat
(
[
mask
,
mask
.
new_ones
((
mask
.
shape
[
0
],
1
))],
dim
=-
1
)
self
.
assistant_kwargs
[
"past_key_values"
]
=
assistant_model_outputs
.
past_key_values
# 2.6. stop assistant generation on EOS
if
self
.
eos_token_id_tensor
is
not
None
:
last_assistant_token_is_eos
=
new_token
.
tile
(
self
.
eos_token_id_tensor
.
shape
[
0
],
1
)
last_assistant_token_is_eos
=
(
~
last_assistant_token_is_eos
.
ne
(
self
.
eos_token_id_tensor
.
unsqueeze
(
1
)).
prod
(
dim
=
0
).
bool
()
)
if
last_assistant_token_is_eos
:
break
return
candidate_input_ids
def
update_candidate_strategy
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
num_matches
:
int
):
def
update_candidate_strategy
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
num_matches
:
int
):
"""
"""
...
...
src/transformers/generation/utils.py
View file @
9e5c28c5
...
@@ -4585,7 +4585,7 @@ class GenerationMixin:
...
@@ -4585,7 +4585,7 @@ class GenerationMixin:
cur_len
=
input_ids
.
shape
[
-
1
]
cur_len
=
input_ids
.
shape
[
-
1
]
# 1. Fetch candidate sequences from a `CandidateGenerator`
# 1. Fetch candidate sequences from a `CandidateGenerator`
candidate_input_ids
=
candidate_generator
.
get_candidates
(
input_ids
)
candidate_input_ids
,
candidate_logits
=
candidate_generator
.
get_candidates
(
input_ids
)
candidate_length
=
candidate_input_ids
.
shape
[
1
]
-
input_ids
.
shape
[
1
]
candidate_length
=
candidate_input_ids
.
shape
[
1
]
-
input_ids
.
shape
[
1
]
last_assistant_token_is_eos
=
(
last_assistant_token_is_eos
=
(
~
candidate_input_ids
[:,
-
1
]
~
candidate_input_ids
[:,
-
1
]
...
...
tests/generation/test_utils.py
View file @
9e5c28c5
...
@@ -3128,21 +3128,26 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
...
@@ -3128,21 +3128,26 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self
.
assertListEqual
(
outputs_assisted
.
tolist
(),
outputs_tti
.
tolist
())
self
.
assertListEqual
(
outputs_assisted
.
tolist
(),
outputs_tti
.
tolist
())
def
test_model_kwarg_assisted_decoding_encoder_decoder
(
self
):
def
test_model_kwarg_assisted_decoding_encoder_decoder
(
self
):
"""
Tests that the following scenario is compatible with assisted generation:
1. encoder-decoder main model
2. encoder-decoder assistant model
3. both have a custom input
(e.g. Whisper)
"""
# PT-only test: TF doesn't support assisted decoding yet.
# PT-only test: TF doesn't support assisted decoding yet.
# Bart subclass with a kwarg that distorts the output
# Bart subclass with a kwarg that distorts the output
class
FakeBart
(
BartForConditionalGeneration
):
class
FakeBart
(
BartForConditionalGeneration
):
def
forward
(
self
,
input_ids
,
foo
=
False
,
**
kwargs
):
def
forward
(
self
,
input_ids
,
past_key_values
,
foo
=
False
,
**
kwargs
):
outs
=
super
().
forward
(
input_ids
,
**
kwargs
)
outs
=
super
().
forward
(
input_ids
,
past_key_values
=
past_key_values
,
**
kwargs
)
if
foo
:
if
foo
:
outs
[
"logits"
][:,
:,
:]
=
0.0
outs
[
"logits"
][:,
:,
:]
=
0.0
return
outs
return
outs
def
prepare_inputs_for_generation
(
self
,
*
args
,
foo
=
False
,
encoder_outputs
=
None
,
**
kwargs
):
def
prepare_inputs_for_generation
(
self
,
*
args
,
foo
=
False
,
encoder_outputs
=
None
,
**
kwargs
):
kwargs
[
"encoder_outputs"
]
=
encoder_outputs
kwargs
[
"encoder_outputs"
]
=
encoder_outputs
inputs
=
super
().
prepare_inputs_for_generation
(
*
args
,
**
kwargs
)
inputs
=
super
().
prepare_inputs_for_generation
(
*
args
,
**
kwargs
)
inputs
[
"foo"
]
=
foo
inputs
[
"foo"
]
=
foo
return
inputs
return
inputs
...
@@ -3160,17 +3165,14 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
...
@@ -3160,17 +3165,14 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self
.
assertEqual
(
outputs_normal
.
shape
,
(
1
,
20
))
self
.
assertEqual
(
outputs_normal
.
shape
,
(
1
,
20
))
# Should be different with foo
# Should be different with foo
outputs_foo
=
model
.
generate
(
outputs_foo
=
model
.
generate
(
input_ids
,
foo
=
True
)
input_ids
,
foo
=
True
,
)
with
self
.
assertRaises
(
AssertionError
):
with
self
.
assertRaises
(
AssertionError
):
self
.
assertListEqual
(
outputs_foo
.
tolist
(),
outputs_normal
.
tolist
())
self
.
assertListEqual
(
outputs_foo
.
tolist
(),
outputs_normal
.
tolist
())
# Assistant model
# Assistant model
assistant
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
assistant
=
FakeBart
.
from_pretrained
(
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
).
to
(
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
torch_device
)
.
to
(
torch_device
)
)
# If assisted generation passes model_kwargs correctly, should be same as previous
# If assisted generation passes model_kwargs correctly, should be same as previous
outputs_assisted
=
model
.
generate
(
outputs_assisted
=
model
.
generate
(
...
@@ -3192,25 +3194,43 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
...
@@ -3192,25 +3194,43 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self
.
assertListEqual
(
outputs_assisted
.
tolist
(),
outputs_foo
.
tolist
())
self
.
assertListEqual
(
outputs_assisted
.
tolist
(),
outputs_foo
.
tolist
())
def
test_assisted_decoding_encoder_decoder_shared_encoder
(
self
):
def
test_assisted_decoding_encoder_decoder_shared_encoder
(
self
):
"""
Tests that the following scenario is compatible with assisted generation:
1. encoder-decoder main model
2. decoder-only assistant model
3. both have a custom input
(e.g. DistilWhisper)
"""
# PT-only test: TF doesn't support assisted decoding yet.
# PT-only test: TF doesn't support assisted decoding yet.
# Bart subclass with a kwarg called foo that distorts the output
# Bart subclass with a kwarg called foo that distorts the output
class
FakeBart
(
BartForConditionalGeneration
):
class
FakeBart
Seq2Seq
(
BartForConditionalGeneration
):
def
forward
(
self
,
input_ids
,
foo
=
False
,
**
kwargs
):
def
forward
(
self
,
input_ids
,
foo
=
False
,
**
kwargs
):
outs
=
super
().
forward
(
input_ids
,
**
kwargs
)
outs
=
super
().
forward
(
input_ids
,
**
kwargs
)
if
foo
:
if
foo
:
outs
[
"logits"
][:,
:,
:]
=
0.0
outs
[
"logits"
][:,
:,
:]
=
0.0
return
outs
return
outs
def
prepare_inputs_for_generation
(
self
,
*
args
,
foo
=
False
,
encoder_outputs
=
None
,
**
kwargs
):
def
prepare_inputs_for_generation
(
self
,
*
args
,
foo
=
False
,
encoder_outputs
=
None
,
**
kwargs
):
kwargs
[
"encoder_outputs"
]
=
encoder_outputs
kwargs
[
"encoder_outputs"
]
=
encoder_outputs
inputs
=
super
().
prepare_inputs_for_generation
(
*
args
,
**
kwargs
)
inputs
=
super
().
prepare_inputs_for_generation
(
*
args
,
**
kwargs
)
inputs
[
"foo"
]
=
foo
return
inputs
class
FakeBartCausalLM
(
BartForCausalLM
):
def
forward
(
self
,
input_ids
,
attention_mask
,
past_key_values
,
foo
=
False
,
**
kwargs
):
outs
=
super
().
forward
(
input_ids
,
attention_mask
,
past_key_values
=
past_key_values
,
**
kwargs
)
if
foo
:
outs
[
"logits"
][:,
:,
:]
=
0.0
return
outs
def
prepare_inputs_for_generation
(
self
,
*
args
,
foo
=
False
,
encoder_outputs
=
None
,
**
kwargs
):
kwargs
[
"encoder_outputs"
]
=
encoder_outputs
inputs
=
super
().
prepare_inputs_for_generation
(
*
args
,
**
kwargs
)
inputs
[
"foo"
]
=
foo
inputs
[
"foo"
]
=
foo
return
inputs
return
inputs
model
=
FakeBart
.
from_pretrained
(
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
).
to
(
model
=
FakeBart
Seq2Seq
.
from_pretrained
(
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
).
to
(
torch_device
torch_device
)
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
)
...
@@ -3229,9 +3249,9 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
...
@@ -3229,9 +3249,9 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self
.
assertListEqual
(
outputs_foo
.
tolist
(),
outputs_normal
.
tolist
())
self
.
assertListEqual
(
outputs_foo
.
tolist
(),
outputs_normal
.
tolist
())
# Assistant model
# Assistant model
assistant
=
Bart
For
CausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
).
to
(
assistant
=
Fake
BartCausalLM
.
from_pretrained
(
torch_device
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
)
)
.
to
(
torch_device
)
# If assisted generation passes model_kwargs correctly, should be same as previous
# If assisted generation passes model_kwargs correctly, should be same as previous
outputs_assisted
=
model
.
generate
(
outputs_assisted
=
model
.
generate
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment