Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e4a97f82
Unverified
Commit
e4a97f82
authored
Apr 24, 2023
by
Joao Gante
Committed by
GitHub
Apr 24, 2023
Browse files
Generate: assisted generation with sample (take 2) (#22949)
* temperature controls speed
parent
7701716e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
149 additions
and
54 deletions
+149
-54
docs/source/en/generation_strategies.mdx
docs/source/en/generation_strategies.mdx
+29
-6
src/transformers/generation/configuration_utils.py
src/transformers/generation/configuration_utils.py
+3
-1
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+71
-41
tests/generation/test_utils.py
tests/generation/test_utils.py
+46
-6
No files found.
docs/source/en/generation_strategies.mdx
View file @
e4a97f82
...
...
@@ -333,15 +333,16 @@ This guide illustrates the main parameters that enable various decoding strategi
[`generate`] method, which gives you even further control over the [`generate`] method's behavior.
For the complete list of the available parameters, refer to the [API documentation](./main_classes/text_generation.mdx).
### Assisted
Generation
### Assisted
Decoding
Assisted generation is a modification of the decoding strategies above that uses an assistant model with the same
tokenizer (ideally a much smaller model) to speed up the decoding process. Currently only assisted greedy search is
supported, and doesn't support batched inputs.
Assisted decoding is a modification of the decoding strategies above that uses an assistant model with the same
tokenizer (ideally a much smaller model) to greedily generate a few candidate tokens. The main model then validates
the candidate tokens in a single forward pass, which speeds up the decoding process. Currently, only greedy search
and sampling are supported with assisted decoding, and doesn't support batched inputs.
<!-- TODO: add link to the blog post about assisted
generation
when it exists -->
<!-- TODO: add link to the blog post about assisted
decoding
when it exists -->
To enable assisted
generation
, set the `assistant_model` argument with a model.
To enable assisted
decoding
, set the `assistant_model` argument with a model.
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
...
...
@@ -359,3 +360,25 @@ To enable assisted generation, set the `assistant_model` argument with a model.
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
```
When using assisted decoding with sampling methods, you can use the `temperarure` argument to control the randomness
just like in multinomial sampling. However, in assisted decoding, reducing the temperature will help improving latency.
<!-- TODO: link the blog post again to explain why the tradeoff exists -->
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> prompt = "Alice and Bob"
>>> checkpoint = "EleutherAI/pythia-1.4b-deduped"
>>> assistant_checkpoint = "EleutherAI/pythia-160m-deduped"
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)
>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint)
>>> outputs = model.generate(**inputs, assistant_model=assistant_model, do_sample=True, temperature=0.5)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
["Alice and Bob are sitting on the sofa. Alice says, 'I'm going to my room"]
```
src/transformers/generation/configuration_utils.py
View file @
e4a97f82
...
...
@@ -54,8 +54,10 @@ class GenerationConfig(PushToHubMixin):
`num_beams>1` and `num_beam_groups>1`
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if
`constraints!=None` or `force_words_ids!=None`
- *assisted decoding* by calling [`~generation.GenerationMixin.assisted_decoding`], if
`assistant_model` is passed to `.generate()`
You do not need to call any of the above methods directly. Pass custom parameter values to 'generate'. To learn
You do not need to call any of the above methods directly. Pass custom parameter values to '
.
generate
()
'. To learn
more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
Arg:
...
...
src/transformers/generation/utils.py
View file @
e4a97f82
...
...
@@ -492,7 +492,7 @@ class GenerationMixin:
def
prepare_inputs_for_generation
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
"A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`."
"A model class needs to define a `prepare_inputs_for_generation` method in order to use `
.
generate
()
`."
)
def
_prepare_model_inputs
(
...
...
@@ -962,10 +962,10 @@ class GenerationMixin:
object_type
=
"stopping criteria"
if
isinstance
(
custom
,
StoppingCriteria
)
else
"logits processor"
raise
ValueError
(
f
"A custom
{
object_type
}
of type
{
type
(
custom
)
}
with values
{
custom
}
has been passed to"
f
" `generate`, but it has already been created with the values
{
default
}
.
{
default
}
has been"
f
" `
.
generate
()
`, but it has already been created with the values
{
default
}
.
{
default
}
has been"
" created by passing the corresponding arguments to generate or by the model's config default"
f
" values. If you just want to change the default values of
{
object_type
}
consider passing"
f
" them as arguments to `generate` instead of using a custom
{
object_type
}
."
f
" them as arguments to `
.
generate
()
` instead of using a custom
{
object_type
}
."
)
default_list
.
extend
(
custom_list
)
return
default_list
...
...
@@ -1418,14 +1418,14 @@ class GenerationMixin:
and
not
is_constraint_gen_mode
and
not
is_contrastive_search_gen_mode
)
is_assisted_
greedy_
gen_mode
=
False
is_assisted_gen_mode
=
False
if
assistant_model
is
not
None
:
if
not
is_greedy_gen_mode
:
if
not
(
is_greedy_gen_mode
or
is_sample_gen_mode
)
:
raise
ValueError
(
"You've set `assistant_model`, which triggers assisted generat
ion
. Currently, assisted generat
ion
"
"is only supported with Greedy Search."
"You've set `assistant_model`, which triggers assisted generat
e
. Currently, assisted generat
e
"
"is only supported with Greedy Search
and Sample
."
)
is_assisted_
greedy_
gen_mode
=
True
is_assisted_gen_mode
=
True
if
generation_config
.
num_beam_groups
>
generation_config
.
num_beams
:
raise
ValueError
(
"`num_beam_groups` has to be smaller or equal to `num_beams`"
)
...
...
@@ -1464,16 +1464,16 @@ class GenerationMixin:
generation_config
=
generation_config
,
stopping_criteria
=
stopping_criteria
)
# 10. go into different generation modes
if
is_assisted_
greedy_
gen_mode
:
if
is_assisted_gen_mode
:
if
generation_config
.
num_return_sequences
>
1
:
raise
ValueError
(
"num_return_sequences has to be 1 when doing assisted g
reedy search
, "
"num_return_sequences has to be 1 when doing assisted g
enerate
, "
f
"but is
{
generation_config
.
num_return_sequences
}
."
)
if
batch_size
>
1
:
raise
ValueError
(
"
A
ssisted generat
ion
is only supported for batch_size = 1"
)
raise
ValueError
(
"
a
ssisted generat
e
is only supported for batch_size = 1"
)
if
not
model_kwargs
[
"use_cache"
]:
raise
ValueError
(
"
A
ssisted generat
ion
requires `use_cache=True`"
)
raise
ValueError
(
"
a
ssisted generat
e
requires `use_cache=True`"
)
# 11. If the assistant model is an encoder-decoder, prepare its encoder outputs
if
assistant_model
.
config
.
is_encoder_decoder
:
...
...
@@ -1486,11 +1486,13 @@ class GenerationMixin:
)
model_kwargs
[
"assistant_encoder_outputs"
]
=
assistant_model_kwargs
[
"encoder_outputs"
]
# 12. run assisted g
reedy search
return
self
.
assisted_
greedy_search
(
# 12. run assisted g
enerate
return
self
.
assisted_
decoding
(
input_ids
,
assistant_model
=
assistant_model
,
do_sample
=
generation_config
.
do_sample
,
logits_processor
=
logits_processor
,
logits_warper
=
self
.
_get_logits_warper
(
generation_config
)
if
generation_config
.
do_sample
else
None
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
...
...
@@ -4059,11 +4061,13 @@ class GenerationMixin:
else
:
return
sequence_outputs
[
"sequences"
]
def
assisted_
greedy_search
(
def
assisted_
decoding
(
self
,
input_ids
:
torch
.
LongTensor
,
assistant_model
:
"PreTrainedModel"
,
do_sample
:
bool
=
False
,
logits_processor
:
Optional
[
LogitsProcessorList
]
=
None
,
logits_warper
:
Optional
[
LogitsProcessorList
]
=
None
,
stopping_criteria
:
Optional
[
StoppingCriteriaList
]
=
None
,
pad_token_id
:
Optional
[
int
]
=
None
,
eos_token_id
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
None
,
...
...
@@ -4076,12 +4080,13 @@ class GenerationMixin:
**
model_kwargs
,
):
r
"""
Generates sequences of token ids for models with a language modeling head using **greedy decoding**, assisted
by a smaller model. Can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Generates sequences of token ids for models with a language modeling head using **greedy decoding** or
**sample** (depending on `do_sample`), assisted by a smaller model. Can be used for text-decoder, text-to-text,
speech-to-text, and vision-to-text models.
<Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.assisted_
greedy_search
`] directly. Use
In most cases, you do not need to call [`~generation.GenerationMixin.assisted_
decoding
`] directly. Use
generate() instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies).
...
...
@@ -4095,9 +4100,15 @@ class GenerationMixin:
same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model
is much faster than running generation with the model you're calling generate from. As such, the
assistant model should be much smaller.
do_sample (`bool`, *optional*, defaults to `False`):
Whether or not to use sampling ; use greedy decoding otherwise.
logits_processor (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
logits_warper (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
to warp the prediction score distribution of the language modeling head applied before multinomial
sampling at each generation step.
stopping_criteria (`StoppingCriteriaList`, *optional*):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
...
...
@@ -4157,7 +4168,7 @@ class GenerationMixin:
... ]
... )
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
>>> outputs = model.assisted_
greedy_search
(
>>> outputs = model.assisted_
decoding
(
... input_ids,
... assistant_model=assistant_model,
... logits_processor=logits_processor,
...
...
@@ -4166,13 +4177,14 @@ class GenerationMixin:
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
["It might be possible to get a better understanding of the nature of the problem, but it's not"]
```"""
# NOTE: the code here is copy/paste from greedy search, except when clearly stated in the comments
# NOTE: the code here is copy/paste from greedy search
/sample
, except when clearly stated in the comments
# Assistant: initialize assistant-related variables
if
not
hasattr
(
assistant_model
,
"max_assistant_tokens"
):
assistant_model
.
max_assistant_tokens
=
5
# this value, which will be updated, persists across calls
# init values
logits_processor
=
logits_processor
if
logits_processor
is
not
None
else
LogitsProcessorList
()
logits_warper
=
logits_warper
if
logits_warper
is
not
None
else
LogitsProcessorList
()
stopping_criteria
=
stopping_criteria
if
stopping_criteria
is
not
None
else
StoppingCriteriaList
()
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
generation_config
.
pad_token_id
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
generation_config
.
eos_token_id
...
...
@@ -4285,6 +4297,8 @@ class GenerationMixin:
# 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
# `candidate_length + 1` relevant logits from this process (see step 7 on why the +1)
# 2.1. Run a forward pass on the candidate sequence
if
"past_key_values"
in
model_kwargs
:
og_model_attn
=
torch
.
ones_like
(
candidate_input_ids
)
og_model_input_ids
=
candidate_input_ids
[:,
-
candidate_length
-
1
:]
...
...
@@ -4320,17 +4334,28 @@ class GenerationMixin:
output_hidden_states
=
output_hidden_states
,
)
#
3. Obtain the argmax from the original model
logits
.
#
2.2. Process the new
logits
new_logits
=
outputs
.
logits
[:,
-
candidate_length
-
1
:]
# excludes the input prompt if present
if
len
(
logits_processor
)
>
0
:
for
i
in
range
(
candidate_length
):
new_logits
[:,
i
,
:]
=
logits_processor
(
candidate_input_ids
[:,
:
cur_len
+
i
],
new_logits
[:,
i
,
:])
max_logits
=
new_logits
.
argmax
(
dim
=-
1
)[:,
-
candidate_length
-
1
:
-
1
]
if
len
(
logits_warper
)
>
0
:
for
i
in
range
(
candidate_length
):
new_logits
[:,
i
,
:]
=
logits_warper
(
candidate_input_ids
[:,
:
cur_len
+
i
],
new_logits
[:,
i
,
:])
# 3. Obtain the next tokens from the original model logits. If `do_sample` is True, use multinomial
# sampling, otherwise use argmax.
if
do_sample
:
probs
=
new_logits
[:,
-
candidate_length
-
1
:,
:].
softmax
(
dim
=-
1
)
sampled_tokens
=
torch
.
multinomial
(
probs
[
0
,
:,
:],
num_samples
=
1
).
squeeze
(
1
)[
None
,
:]
next_tokens
=
sampled_tokens
[:,
:
-
1
]
else
:
next_tokens
=
new_logits
[:,
-
candidate_length
-
1
:
-
1
,
:].
argmax
(
dim
=-
1
)
# 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep
# the assistant forecasted tokens until the first mismatch, or until the max length is reached.
candidate_new_tokens
=
candidate_input_ids
[:,
-
candidate_length
:]
n_matches
=
((
~
(
candidate_new_tokens
==
max_logit
s
)).
cumsum
(
dim
=-
1
)
<
1
).
sum
()
n_matches
=
((
~
(
candidate_new_tokens
==
next_token
s
)).
cumsum
(
dim
=-
1
)
<
1
).
sum
()
# 5. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
...
...
@@ -4360,12 +4385,17 @@ class GenerationMixin:
next_token_scores
=
new_logits
[:,
n_matches
,
:]
# 7. Use the set of logits after the last matching assistant token to obtain the next token. Note that,
# because of this step, assisted greedy search reduces to a normal greedy search if there is no match.
next_tokens
=
torch
.
argmax
(
next_token_scores
,
dim
=-
1
)
# because of this step, assisted generation search reduces to a normal greedy search/sample if there is no
# match.
if
do_sample
:
probs
=
probs
[:,
n_matches
,
:]
next_tokens
=
sampled_tokens
[:,
n_matches
]
else
:
next_tokens
=
torch
.
argmax
(
next_token_scores
,
dim
=-
1
)
# Assistant: main logic end; Compared to greedy search, the following (redundant) blocks were
removed
# below: (1) model input preparation; (2) model forward pass; (3) score preparation; (4) model
cache
# update.
# Assistant: main logic end; Compared to greedy search
/sample
, the following (redundant) blocks were
#
removed
below: (1) model input preparation; (2) model forward pass; (3) score preparation; (4) model
#
cache
update.
if
synced_gpus
and
this_peer_finished
:
continue
# don't waste resources running the code we don't need
...
...
@@ -4378,20 +4408,18 @@ class GenerationMixin:
if
"past_key_values"
not
in
model_kwargs
:
last_matching_idx
=
new_cur_len
-
1
prompt_length
=
cur_len
else
:
last_matching_idx
=
n_matches
prompt_length
=
0
if
output_attentions
:
if
self
.
config
.
is_encoder_decoder
:
cross_attentions
=
_split_model_outputs
(
cross_attentions
,
outputs
.
cross_attentions
,
prompt
_len
gth
,
last_matching_idx
cross_attentions
,
outputs
.
cross_attentions
,
cur
_len
,
last_matching_idx
)
decoder_attentions
=
_split_model_outputs
(
decoder_attentions
,
outputs
.
decoder_attentions
,
prompt
_len
gth
,
cur
_len
,
last_matching_idx
,
is_decoder_attention
=
True
,
)
...
...
@@ -4399,18 +4427,18 @@ class GenerationMixin:
decoder_attentions
=
_split_model_outputs
(
decoder_attentions
,
outputs
.
attentions
,
prompt
_len
gth
,
cur
_len
,
last_matching_idx
,
is_decoder_attention
=
True
,
)
if
output_hidden_states
:
if
self
.
config
.
is_encoder_decoder
:
decoder_hidden_states
=
_split_model_outputs
(
decoder_hidden_states
,
outputs
.
decoder_hidden_states
,
prompt
_len
gth
,
last_matching_idx
decoder_hidden_states
,
outputs
.
decoder_hidden_states
,
cur
_len
,
last_matching_idx
)
else
:
decoder_hidden_states
=
_split_model_outputs
(
decoder_hidden_states
,
outputs
.
hidden_states
,
prompt
_len
gth
,
last_matching_idx
decoder_hidden_states
,
outputs
.
hidden_states
,
cur
_len
,
last_matching_idx
)
# finished sentences should have their next token be a padding token
...
...
@@ -4503,24 +4531,26 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
return
past_key_values
def
_split_model_outputs
(
outputs
,
new_outputs
,
pr
ompt
_len
gth
,
last_matching_idx
,
is_decoder_attention
=
False
):
def
_split_model_outputs
(
outputs
,
new_outputs
,
pr
evious_cur
_len
,
last_matching_idx
,
is_decoder_attention
=
False
):
"""
Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple
where each member corresponds to a single generated token.
"""
# Retrocompatibility: in our generation functions, the first iteration includes the attention/hidden states for the
# prompt.
if
prompt_length
>
0
:
if
len
(
outputs
)
==
0
:
new_tuple
=
()
for
layer
in
new_outputs
:
last_dim_size
=
pr
ompt
_len
gth
if
is_decoder_attention
else
layer
.
shape
[
-
1
]
new_tuple
+=
(
layer
[...,
:
pr
ompt
_len
gth
,
:
last_dim_size
],)
last_dim_size
=
pr
evious_cur
_len
if
is_decoder_attention
else
layer
.
shape
[
-
1
]
new_tuple
+=
(
layer
[...,
:
pr
evious_cur
_len
,
:
last_dim_size
],)
outputs
+=
(
new_tuple
,)
last_matching_idx
-=
previous_cur_len
previous_cur_len
+=
1
for
i
in
range
(
prompt_length
,
last_matching_idx
+
1
):
for
i
in
range
(
last_matching_idx
+
1
):
new_tuple
=
()
for
layer
in
new_outputs
:
last_dim_size
=
i
+
1
if
is_decoder_attention
else
layer
.
shape
[
-
1
]
last_dim_size
=
previous_cur_len
+
i
if
is_decoder_attention
else
layer
.
shape
[
-
1
]
new_tuple
+=
(
layer
[...,
i
:
i
+
1
,
:
last_dim_size
],)
outputs
+=
(
new_tuple
,)
return
outputs
...
...
tests/generation/test_utils.py
View file @
e4a97f82
...
...
@@ -1457,22 +1457,22 @@ class GenerationTesterMixin:
for
output
in
(
output_contrastive
,
output_generate
):
self
.
_check_outputs
(
output
,
input_ids
,
model
.
config
,
use_cache
=
True
)
def
test_assisted_
greedy_search
_matches_greedy_search
(
self
):
def
test_assisted_
decoding
_matches_greedy_search
(
self
):
# This test ensures that the assisted generation does not introduce output changes over greedy search.
# It breaks the pattern in the tests above, for multiple reasons:
# - assisted_
greedy_search
, contrarily to the other methods, can't be called on its own (e.g. needs to
# - assisted_
decoding
, contrarily to the other methods, can't be called on its own (e.g. needs to
# prepare the assistant encoder outputs in the main generate body);
# - assisted_
greedy_search
does not support `use_cache = False`
# - assisted_
greedy_search
does not support `batch_size > 1`
# - assisted_
decoding
does not support `use_cache = False`
# - assisted_
decoding
does not support `batch_size > 1`
for
model_class
in
self
.
all_generative_model_classes
:
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
return
# may fix in the future: the following models fail
to pass this test
, and need model-specific fixes
# may fix in the future: the following models fail
with assisted decoding
, and need model-specific fixes
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"bigbirdpegasus"
,
"gptbigcode"
,
"led"
,
"mega"
,
"speech2text"
]
for
model_name
in
[
"bigbirdpegasus"
,
"gptbigcode"
,
"led"
,
"mega"
,
"speech2text"
,
"git"
,
"prophetnet"
]
):
return
...
...
@@ -1517,6 +1517,46 @@ class GenerationTesterMixin:
for
output
in
(
output_greedy
,
output_assisted
):
self
.
_check_outputs
(
output
,
input_ids
,
model
.
config
,
use_cache
=
True
)
def
test_assisted_decoding_sample
(
self
):
# Seeded assisted decoding will not match sample for the same seed, as there are >1 sampling steps per output
# token. As such, this test only checks that the output format is correct.
for
model_class
in
self
.
all_generative_model_classes
:
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
return
# may fix in the future: the following models fail with assisted decoding, and need model-specific fixes
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"bigbirdpegasus"
,
"gptbigcode"
,
"led"
,
"mega"
,
"speech2text"
,
"git"
,
"prophetnet"
]
):
return
# enable cache
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
(
batch_size
=
1
)
# NOTE: assisted generation only works with cache on at the moment.
if
not
hasattr
(
config
,
"use_cache"
):
return
config
.
use_cache
=
True
config
.
is_decoder
=
True
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
output_assisted
=
model
.
generate
(
input_ids
,
attention_mask
=
attention_mask
,
max_length
=
max_length
,
num_beams
=
1
,
do_sample
=
True
,
assistant_model
=
model
,
# triggers assisted decoding
output_scores
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
return_dict_in_generate
=
True
,
)
self
.
_check_outputs
(
output_assisted
,
input_ids
,
model
.
config
,
use_cache
=
True
)
def
test_generate_with_head_masking
(
self
):
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
attention_names
=
[
"encoder_attentions"
,
"decoder_attentions"
,
"cross_attentions"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment