Unverified Commit 1fc4b2a1 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: use the correct config with `(...)EncoderDecoder` models (#18097)

parent 49354097
...@@ -403,8 +403,13 @@ def unpack_inputs(func): ...@@ -403,8 +403,13 @@ def unpack_inputs(func):
# move any arg into kwargs, if they exist # move any arg into kwargs, if they exist
fn_args_and_kwargs.update(dict(zip(func.__code__.co_varnames[1:], args))) fn_args_and_kwargs.update(dict(zip(func.__code__.co_varnames[1:], args)))
# process the inputs and call the wrapped function # Encoder Decoder models delegate the application of the configuration options to their inner models.
unpacked_inputs = input_processing(func, self.config, **fn_args_and_kwargs) if "encoder_decoder" in str(self).lower():
config = None
else:
config = self.config
unpacked_inputs = input_processing(func, config, **fn_args_and_kwargs)
return func(self, **unpacked_inputs) return func(self, **unpacked_inputs)
# Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This # Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This
...@@ -559,18 +564,19 @@ def input_processing(func, config, **kwargs): ...@@ -559,18 +564,19 @@ def input_processing(func, config, **kwargs):
if "kwargs" in output: if "kwargs" in output:
del output["kwargs"] del output["kwargs"]
boolean_dict = { if config is not None:
k: v boolean_dict = {
for k, v in output.items() k: v
if k in ["return_dict", "output_attentions", "output_hidden_states", "use_cache"] for k, v in output.items()
} if k in ["return_dict", "output_attentions", "output_hidden_states", "use_cache"]
}
output.update( output.update(
booleans_processing( booleans_processing(
config=config, config=config,
**boolean_dict, **boolean_dict,
)
) )
)
return output return output
......
...@@ -630,13 +630,13 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -630,13 +630,13 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
warnings.warn(DEPRECATION_WARNING, FutureWarning) warnings.warn(DEPRECATION_WARNING, FutureWarning)
loss = self.hf_compute_loss(labels, logits) loss = self.hf_compute_loss(labels, logits)
past_key_values = None if not return_dict:
if decoder_inputs["use_cache"]: past_key_values = None
past_key_values = decoder_outputs[1] if use_cache:
# The starting index of the remaining elements in `decoder_outputs` past_key_values = decoder_outputs[1]
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)]) # The starting index of the remaining elements in `decoder_outputs`
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
if not decoder_inputs["return_dict"]:
if not isinstance(encoder_outputs, tuple): if not isinstance(encoder_outputs, tuple):
encoder_outputs = encoder_outputs.to_tuple() encoder_outputs = encoder_outputs.to_tuple()
output = (loss, logits, past_key_values) + decoder_outputs[start_index:] + encoder_outputs output = (loss, logits, past_key_values) + decoder_outputs[start_index:] + encoder_outputs
...@@ -646,7 +646,7 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -646,7 +646,7 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
return TFSeq2SeqLMOutput( return TFSeq2SeqLMOutput(
loss=loss, loss=loss,
logits=decoder_outputs.logits, logits=decoder_outputs.logits,
past_key_values=past_key_values, past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions, cross_attentions=decoder_outputs.cross_attentions,
......
...@@ -663,13 +663,13 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos ...@@ -663,13 +663,13 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
warnings.warn(DEPRECATION_WARNING, FutureWarning) warnings.warn(DEPRECATION_WARNING, FutureWarning)
loss = self.hf_compute_loss(labels, logits) loss = self.hf_compute_loss(labels, logits)
past_key_values = None if not return_dict:
if decoder_inputs["use_cache"]: past_key_values = None
past_key_values = decoder_outputs[1] if use_cache:
# The starting index of the remaining elements in `decoder_outputs` past_key_values = decoder_outputs[1]
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)]) # The starting index of the remaining elements in `decoder_outputs`
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
if not decoder_inputs["return_dict"]:
if not isinstance(encoder_outputs, tuple): if not isinstance(encoder_outputs, tuple):
encoder_outputs = encoder_outputs.to_tuple() encoder_outputs = encoder_outputs.to_tuple()
output = (loss, logits, past_key_values) + decoder_outputs[start_index:] + encoder_outputs output = (loss, logits, past_key_values) + decoder_outputs[start_index:] + encoder_outputs
...@@ -679,7 +679,7 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos ...@@ -679,7 +679,7 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
return TFSeq2SeqLMOutput( return TFSeq2SeqLMOutput(
loss=loss, loss=loss,
logits=decoder_outputs.logits, logits=decoder_outputs.logits,
past_key_values=past_key_values, past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions, cross_attentions=decoder_outputs.cross_attentions,
......
...@@ -351,6 +351,40 @@ class EncoderDecoderMixin: ...@@ -351,6 +351,40 @@ class EncoderDecoderMixin:
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,)) outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
) )
def _check_output_with_attentions(
self, outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
):
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
self.assertEqual(
encoder_attentions[0].shape[-3:], (config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1])
)
decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
num_decoder_layers = (
decoder_config.num_decoder_layers
if hasattr(decoder_config, "num_decoder_layers")
else decoder_config.num_hidden_layers
)
self.assertEqual(len(decoder_attentions), num_decoder_layers)
self.assertEqual(
decoder_attentions[0].shape[-3:],
(decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
)
cross_attentions = outputs_encoder_decoder["cross_attentions"]
self.assertEqual(len(cross_attentions), num_decoder_layers)
cross_attention_input_seq_len = decoder_input_ids.shape[-1] * (
1 + (decoder_config.ngram if hasattr(decoder_config, "ngram") else 0)
)
self.assertEqual(
cross_attentions[0].shape[-3:],
(decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]),
)
def check_encoder_decoder_model_output_attentions( def check_encoder_decoder_model_output_attentions(
self, self,
config, config,
...@@ -376,36 +410,58 @@ class EncoderDecoderMixin: ...@@ -376,36 +410,58 @@ class EncoderDecoderMixin:
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
output_attentions=True, output_attentions=True,
) )
self._check_output_with_attentions(
encoder_attentions = outputs_encoder_decoder["encoder_attentions"] outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
self.assertEqual(
encoder_attentions[0].shape[-3:], (config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1])
) )
decoder_attentions = outputs_encoder_decoder["decoder_attentions"] def check_encoder_decoder_model_output_attentions_from_config(
num_decoder_layers = ( self,
decoder_config.num_decoder_layers config,
if hasattr(decoder_config, "num_decoder_layers") input_ids,
else decoder_config.num_hidden_layers attention_mask,
) encoder_hidden_states,
self.assertEqual(len(decoder_attentions), num_decoder_layers) decoder_config,
decoder_input_ids,
decoder_attention_mask,
labels,
**kwargs
):
# Similar to `check_encoder_decoder_model_output_attentions`, but with `output_attentions` triggered from the
# config file. Contrarily to most models, changing the model's config won't work -- the defaults are loaded
# from the inner models' configurations.
self.assertEqual( decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attentions[0].shape[-3:], decoder_attention_mask = decoder_attention_mask[:, :-1]
(decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]), encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.config.output_attentions = True # model config -> won't work
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
self.assertTrue(
all(
key not in outputs_encoder_decoder
for key in ["encoder_attentions", "decoder_attentions", "cross_attentions"]
)
) )
cross_attentions = outputs_encoder_decoder["cross_attentions"] config.output_attentions = True # inner model config -> will work
self.assertEqual(len(cross_attentions), num_decoder_layers) decoder_config.output_attentions = True
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
cross_attention_input_seq_len = decoder_input_ids.shape[-1] * ( enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
1 + (decoder_config.ngram if hasattr(decoder_config, "ngram") else 0) enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
) )
self.assertEqual( self._check_output_with_attentions(
cross_attentions[0].shape[-3:], outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
(decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]),
) )
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs): def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
...@@ -543,6 +599,10 @@ class EncoderDecoderMixin: ...@@ -543,6 +599,10 @@ class EncoderDecoderMixin:
input_ids_dict = self.prepare_config_and_inputs() input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_output_attentions(**input_ids_dict) self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
def test_encoder_decoder_model_output_attentions_from_config(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_output_attentions_from_config(**input_ids_dict)
def test_encoder_decoder_model_generate(self): def test_encoder_decoder_model_generate(self):
input_ids_dict = self.prepare_config_and_inputs() input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_generate(**input_ids_dict) self.check_encoder_decoder_model_generate(**input_ids_dict)
......
...@@ -255,31 +255,9 @@ class TFEncoderDecoderMixin: ...@@ -255,31 +255,9 @@ class TFEncoderDecoderMixin:
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,)) outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
) )
def check_encoder_decoder_model_output_attentions( def _check_output_with_attentions(
self, self, outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
): ):
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
kwargs=kwargs,
)
encoder_attentions = outputs_encoder_decoder["encoder_attentions"] encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
self.assertEqual(len(encoder_attentions), config.num_hidden_layers) self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
...@@ -311,6 +289,83 @@ class TFEncoderDecoderMixin: ...@@ -311,6 +289,83 @@ class TFEncoderDecoderMixin:
(decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]), (decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]),
) )
def check_encoder_decoder_model_output_attentions(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
kwargs=kwargs,
)
self._check_output_with_attentions(
outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
)
def check_encoder_decoder_model_output_attentions_from_config(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
# Similar to `check_encoder_decoder_model_output_attentions`, but with `output_attentions` triggered from the
# config file. Contrarily to most models, changing the model's config won't work -- the defaults are loaded
# from the inner models' configurations.
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.config.output_attentions = True # model config -> won't work
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
self.assertTrue(
all(
key not in outputs_encoder_decoder
for key in ["encoder_attentions", "decoder_attentions", "cross_attentions"]
)
)
config.output_attentions = True # inner model config -> will work
decoder_config.output_attentions = True
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
self._check_output_with_attentions(
outputs_encoder_decoder, config, input_ids, decoder_config, decoder_input_ids
)
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs): def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
...@@ -570,6 +625,10 @@ class TFEncoderDecoderMixin: ...@@ -570,6 +625,10 @@ class TFEncoderDecoderMixin:
input_ids_dict = self.prepare_config_and_inputs() input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_output_attentions(**input_ids_dict) self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
def test_encoder_decoder_model_output_attentions_from_config(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_output_attentions_from_config(**input_ids_dict)
def test_encoder_decoder_model_generate(self): def test_encoder_decoder_model_generate(self):
input_ids_dict = self.prepare_config_and_inputs() input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_generate(**input_ids_dict) self.check_encoder_decoder_model_generate(**input_ids_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment