"tests/vscode:/vscode.git/clone" did not exist on "7b87ecb04712eed50793e65a2b39376f4570fcf2"
Unverified Commit 4cdb67ca authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Use cross_attention_hidden_size in Encoder-Decoder models (#14378)



* add cross_attention_hidden_size to text-2-text encoder-decoder models (PT/Flax)

* for TFEncoderDecoderModel

* add equivalence test for TFEncoderDecoderModel

* fix

* fix failed equivalence tests

* remove unused import

* add detailed comment

* Fix check_equivalence_tf_to_pt by using encoder/decoder

* cleaning

* Use cross_attention_hidden_size in speech-to-text

* clean fast init logging msg in encoder decoder models

* increase tol from 1e-5 to 1e-3 for tf test

* style

* style

* make sure projection layer can run

* remove type conversion + add check

* fix conflict (config.output_hidden_size)

* Remove TF -> PT in check_pt_tf_equivalence for TFEncoderDecoderModel
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 381b05a3
......@@ -18,6 +18,7 @@ import warnings
from typing import Optional
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from ...configuration_utils import PretrainedConfig
......@@ -25,6 +26,8 @@ from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_fo
from ...modeling_outputs import Seq2SeqLMOutput
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from ..auto.configuration_auto import AutoConfig
from ..auto.modeling_auto import AutoModel, AutoModelForCausalLM
from .configuration_encoder_decoder import EncoderDecoderConfig
......@@ -181,13 +184,23 @@ class EncoderDecoderModel(PreTrainedModel):
encoder: Optional[PreTrainedModel] = None,
decoder: Optional[PreTrainedModel] = None,
):
assert config is not None or (
encoder is not None and decoder is not None
), "Either a configuration or an Encoder and a decoder has to be provided"
if config is None and (encoder is None or decoder is None):
raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
if config is None:
config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
else:
assert isinstance(config, self.config_class), f"config: {config} has to be of type {self.config_class}"
if not isinstance(config, self.config_class):
raise ValueError(f"Config: {config} has to be of type {self.config_class}")
if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError(
"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
"it has to be equal to the encoder's `hidden_size`. "
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
)
# initialize with config
super().__init__(config)
......@@ -218,9 +231,17 @@ class EncoderDecoderModel(PreTrainedModel):
self.encoder.config = self.config.encoder
self.decoder.config = self.config.decoder
assert (
self.encoder.get_output_embeddings() is None
), "The encoder {} should not have a LM Head. Please use a model without LM Head"
# encoder outputs might need to be projected to different dimension for decoder
if (
self.encoder.config.hidden_size != self.decoder.config.hidden_size
and self.decoder.config.cross_attention_hidden_size is None
):
self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)
if self.encoder.get_output_embeddings() is not None:
raise ValueError(
f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
)
# tie encoder, decoder weights if config set accordingly
self.tie_weights()
......@@ -251,8 +272,12 @@ class EncoderDecoderModel(PreTrainedModel):
@classmethod
def from_pretrained(cls, *args, **kwargs):
# At the moment fast initialization is not supported
# for composite models
# At the moment fast initialization is not supported for composite models
if kwargs.get("_fast_init", False):
logger.warning(
"Fast initialization is currently not supported for EncoderDecoderModel. "
"Falling back to slow initialization..."
)
kwargs["_fast_init"] = False
return super().from_pretrained(*args, **kwargs)
......@@ -343,19 +368,18 @@ class EncoderDecoderModel(PreTrainedModel):
# by the value of the flag `is_decoder` that we need to set correctly.
encoder = kwargs_encoder.pop("model", None)
if encoder is None:
assert (
encoder_pretrained_model_name_or_path is not None
), "If `model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has to be defined"
from ..auto.modeling_auto import AutoModel
if encoder_pretrained_model_name_or_path is None:
raise ValueError(
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
"to be defined."
)
if "config" not in kwargs_encoder:
from ..auto.configuration_auto import AutoConfig
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
logger.info(
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model from a decoder model. Cross-attention and casual mask are disabled."
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
"from a decoder model. Cross-attention and casual mask are disabled."
)
encoder_config.is_decoder = False
encoder_config.add_cross_attention = False
......@@ -366,18 +390,20 @@ class EncoderDecoderModel(PreTrainedModel):
decoder = kwargs_decoder.pop("model", None)
if decoder is None:
assert (
decoder_pretrained_model_name_or_path is not None
), "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined"
from ..auto.modeling_auto import AutoModelForCausalLM
if decoder_pretrained_model_name_or_path is None:
raise ValueError(
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
"to be defined."
)
if "config" not in kwargs_decoder:
from ..auto.configuration_auto import AutoConfig
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
"cross attention layers."
)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
......@@ -386,7 +412,11 @@ class EncoderDecoderModel(PreTrainedModel):
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
logger.warning(
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
)
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
......@@ -464,6 +494,13 @@ class EncoderDecoderModel(PreTrainedModel):
encoder_hidden_states = encoder_outputs[0]
# optionally project encoder_hidden_states
if (
self.encoder.config.hidden_size != self.decoder.config.hidden_size
and self.decoder.config.cross_attention_hidden_size is None
):
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
......
......@@ -29,6 +29,8 @@ from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_fo
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput
from ...modeling_flax_utils import FlaxPreTrainedModel
from ...utils import logging
from ..auto.configuration_auto import AutoConfig
from ..auto.modeling_flax_auto import FlaxAutoModel, FlaxAutoModelForCausalLM
from .configuration_encoder_decoder import EncoderDecoderConfig
......@@ -227,9 +229,25 @@ class FlaxEncoderDecoderModule(nn.Module):
self.encoder = encoder_module(encoder_config, dtype=self.dtype)
self.decoder = decoder_module(decoder_config, dtype=self.dtype)
# encoder outputs might need to be projected to different dimension for decoder
if (
self.encoder.config.hidden_size != self.decoder.config.hidden_size
and self.decoder.config.cross_attention_hidden_size is None
):
self.enc_to_dec_proj = nn.Dense(
self.decoder.config.hidden_size,
kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range),
dtype=self.dtype,
)
else:
self.enc_to_dec_proj = None
def _get_encoder_module(self):
return self.encoder
def _get_projection_module(self):
return self.enc_to_dec_proj
def _get_decoder_module(self):
return self.decoder
......@@ -256,11 +274,17 @@ class FlaxEncoderDecoderModule(nn.Module):
deterministic=deterministic,
)
encoder_hidden_states = encoder_outputs[0]
# optionally project encoder_hidden_states
if self.enc_to_dec_proj is not None:
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
......@@ -305,6 +329,15 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
if input_shape is None:
input_shape = ((1, 1), (1, 1))
if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError(
"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
"it has to be equal to the encoder's `hidden_size`. "
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
)
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
......@@ -537,12 +570,22 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
else:
mutable = False
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
def _decoder_forward(
module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs
):
projection_module = module._get_projection_module()
decoder_module = module._get_decoder_module()
# optionally project encoder_hidden_states
if projection_module is not None:
encoder_hidden_states = projection_module(encoder_hidden_states)
return decoder_module(
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
encoder_hidden_states,
**kwargs,
)
......@@ -772,19 +815,18 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
# by the value of the flag `is_decoder` that we need to set correctly.
encoder = kwargs_encoder.pop("model", None)
if encoder is None:
assert (
encoder_pretrained_model_name_or_path is not None
), "If `model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has to be defined"
from ..auto.modeling_flax_auto import FlaxAutoModel
if encoder_pretrained_model_name_or_path is None:
raise ValueError(
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
"to be defined."
)
if "config" not in kwargs_encoder:
from ..auto.configuration_auto import AutoConfig
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
logger.info(
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model from a decoder model. Cross-attention and casual mask are disabled."
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
"from a decoder model. Cross-attention and casual mask are disabled."
)
encoder_config.is_decoder = False
encoder_config.add_cross_attention = False
......@@ -797,18 +839,20 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
decoder = kwargs_decoder.pop("model", None)
if decoder is None:
assert (
decoder_pretrained_model_name_or_path is not None
), "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined"
from ..auto.modeling_flax_auto import FlaxAutoModelForCausalLM
if decoder_pretrained_model_name_or_path is None:
raise ValueError(
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
"to be defined."
)
if "config" not in kwargs_decoder:
from ..auto.configuration_auto import AutoConfig
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
"cross attention layers."
)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
......@@ -817,7 +861,11 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
logger.warning(
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
)
decoder = FlaxAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
......
......@@ -23,12 +23,13 @@ import tensorflow as tf
from ...configuration_utils import PretrainedConfig
from ...file_utils import (
DUMMY_INPUTS,
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput
from ...modeling_tf_utils import TFPreTrainedModel, input_processing
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, input_processing
from ...utils import logging
from ..auto.configuration_auto import AutoConfig
from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM
......@@ -168,12 +169,22 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
decoder: Optional[TFPreTrainedModel] = None,
):
if config is None and (encoder is None or decoder is None):
raise ValueError("Either a configuration or an encoder and a decoder has to be provided")
raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
if config is None:
config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
else:
if not isinstance(config, self.config_class):
raise ValueError(f"config: {config} has to be of type {self.config_class}")
if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError(
"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
"it has to be equal to the encoder's `hidden_size`. "
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
)
# initialize with config
super().__init__(config)
......@@ -200,8 +211,21 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
self.encoder.config = self.config.encoder
self.decoder.config = self.config.decoder
# encoder outputs might need to be projected to different dimension for decoder
if (
self.encoder.config.hidden_size != self.decoder.config.hidden_size
and self.decoder.config.cross_attention_hidden_size is None
):
self.enc_to_dec_proj = tf.keras.layers.Dense(
units=self.decoder.config.hidden_size,
kernel_initializer=get_initializer(config.encoder.initializer_range),
name="enc_to_dec_proj",
)
if self.encoder.get_output_embeddings() is not None:
raise ValueError("The encoder {} should not have a LM Head. Please use a model without LM Head")
raise ValueError(
f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
)
@property
def dummy_inputs(self):
......@@ -355,16 +379,16 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
if encoder is None:
if encoder_pretrained_model_name_or_path is None:
raise ValueError(
"If `model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has to be defined"
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
"to be defined."
)
if "config" not in kwargs_encoder:
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
logger.info(
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model from a decoder model. Cross-attention and casual mask are disabled."
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
"from a decoder model. Cross-attention and casual mask are disabled."
)
encoder_config.is_decoder = False
encoder_config.add_cross_attention = False
......@@ -387,15 +411,18 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
if decoder is None:
if decoder_pretrained_model_name_or_path is None:
raise ValueError(
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined"
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
"to be defined."
)
if "config" not in kwargs_decoder:
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
"cross attention layers."
)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
......@@ -404,7 +431,11 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
logger.warning(
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
)
kwargs_decoder["name"] = "decoder"
......@@ -485,6 +516,14 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
}
# Let the user be responsible for the expected format.
if encoder_outputs is not None:
if return_dict and not isinstance(encoder_outputs, ModelOutput):
raise ValueError(
"If `return_dict=True` and `encoder_outputs` is provided, it should be an instance of "
f"`ModelOutput`. Got an instance {type(encoder_outputs)} for `encoder_outputs`."
)
if encoder_outputs is None:
encoder_processing_inputs = {
......@@ -518,6 +557,13 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
encoder_hidden_states = encoder_outputs[0]
# optionally project encoder_hidden_states
if (
self.encoder.config.hidden_size != self.decoder.config.hidden_size
and self.decoder.config.cross_attention_hidden_size is None
):
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
decoder_processing_inputs = {
"func": self.decoder.call,
"config": self.decoder.config,
......@@ -562,14 +608,6 @@ class TFEncoderDecoderModel(TFPreTrainedModel):
output = tuple([x for x in output if x is not None])
return output
# If the user passed a tuple for encoder_outputs, we wrap it in a TFBaseModelOutput when return_dict=True
if not isinstance(encoder_outputs, TFBaseModelOutput):
encoder_outputs = TFBaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
return TFSeq2SeqLMOutput(
loss=decoder_outputs.loss,
logits=decoder_outputs.logits,
......
......@@ -195,6 +195,15 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
if not isinstance(config, self.config_class):
raise ValueError(f"Config: {config} has to be of type {self.config_class}")
if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError(
"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
"it has to be equal to the encoder's `hidden_size`. "
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
)
# initialize with config
# make sure input & output embeddings is not tied
config.tie_word_embeddings = False
......@@ -225,7 +234,10 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
# get encoder output hidden size
self.encoder_output_dim = getattr(config.encoder, "output_hidden_size", config.encoder.hidden_size)
if self.encoder_output_dim != self.decoder.config.hidden_size:
if (
self.encoder_output_dim != self.decoder.config.hidden_size
and self.decoder.config.cross_attention_hidden_size is None
):
# encoder outputs might need to be projected to different dimension for decoder
self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)
......@@ -248,11 +260,11 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
@classmethod
def from_pretrained(cls, *args, **kwargs):
# At the moment fast initialization is not supported
# for composite models
# At the moment fast initialization is not supported for composite models
if kwargs.get("_fast_init", False):
logger.warning(
"Fast initialization is currently not supported for SpeechEncoderDecoderModel. Falling back to slow intialization..."
"Fast initialization is currently not supported for SpeechEncoderDecoderModel. "
"Falling back to slow initialization..."
)
kwargs["_fast_init"] = False
return super().from_pretrained(*args, **kwargs)
......@@ -346,13 +358,13 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
if encoder is None:
if encoder_pretrained_model_name_or_path is None:
raise ValueError(
f"No `encoder_model` is passed to kwargs: {kwargs_encoder}. In this case make sure that `encoder_pretrained_model_name_or_path` defined"
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
"to be defined."
)
if "config" not in kwargs_encoder:
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
logger.info(
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
"from a decoder model. Cross-attention and casual mask are disabled."
......@@ -368,7 +380,8 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
if decoder is None:
if decoder_pretrained_model_name_or_path is None:
raise ValueError(
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined"
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
"to be defined."
)
if "config" not in kwargs_decoder:
......@@ -376,8 +389,9 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
"cross attention layers."
)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
......@@ -389,7 +403,8 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
)
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
......@@ -472,8 +487,11 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
encoder_hidden_states = encoder_outputs[0]
# project encoder_hidden_states
if self.encoder_output_dim != self.decoder.config.hidden_size:
# optionally project encoder_hidden_states
if (
self.encoder_output_dim != self.decoder.config.hidden_size
and self.decoder.config.cross_attention_hidden_size is None
):
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
# compute correct encoder attention mask
......
......@@ -29,6 +29,8 @@ from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_fo
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput
from ...modeling_flax_utils import FlaxPreTrainedModel
from ...utils import logging
from ..auto.configuration_auto import AutoConfig
from ..auto.modeling_flax_auto import FlaxAutoModel, FlaxAutoModelForCausalLM
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig
......@@ -301,8 +303,8 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError(
f"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
f"it has to be equal to the encoder's `hidden_size`."
"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
"it has to be equal to the encoder's `hidden_size`. "
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
)
......@@ -781,19 +783,15 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
if encoder_pretrained_model_name_or_path is None:
raise ValueError(
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
"to be defined"
"to be defined."
)
from ..auto.modeling_flax_auto import FlaxAutoModel
if "config" not in kwargs_encoder:
from ..auto.configuration_auto import AutoConfig
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
logger.info(
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model from a decoder "
"model. Cross-attention and casual mask are disabled."
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
"from a decoder model. Cross-attention and casual mask are disabled."
)
encoder_config.is_decoder = False
encoder_config.add_cross_attention = False
......@@ -811,17 +809,15 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
"to be defined."
)
from ..auto.modeling_flax_auto import FlaxAutoModelForCausalLM
if "config" not in kwargs_decoder:
from ..auto.configuration_auto import AutoConfig
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention "
f"layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if "
f"{decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
"cross attention layers."
)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
......@@ -830,11 +826,11 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
logger.warning(
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order "
f"to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the "
"attributes `is_decoder` and `add_cross_attention` of `decoder_config` passed to "
"`.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` to "
"`.from_encoder_decoder_pretrained(...)`"
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
)
decoder = FlaxAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
......
......@@ -178,8 +178,8 @@ class VisionEncoderDecoderModel(PreTrainedModel):
if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError(
f"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
f"it has to be equal to the encoder's `hidden_size`."
"If `cross_attention_hidden_size` is specified in the decoder's configuration, "
"it has to be equal to the encoder's `hidden_size`. "
f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` "
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
)
......@@ -241,7 +241,8 @@ class VisionEncoderDecoderModel(PreTrainedModel):
# At the moment fast initialization is not supported for composite models
if kwargs.get("_fast_init", False):
logger.warning(
"Fast initialization is currently not supported for VisionEncoderDecoderModel. Falling back to slow intialization..."
"Fast initialization is currently not supported for VisionEncoderDecoderModel. "
"Falling back to slow initialization..."
)
kwargs["_fast_init"] = False
return super().from_pretrained(*args, **kwargs)
......@@ -334,14 +335,13 @@ class VisionEncoderDecoderModel(PreTrainedModel):
if encoder is None:
if encoder_pretrained_model_name_or_path is None:
raise ValueError(
f"No `encoder_model` is passed to kwargs: {kwargs_encoder}. "
f"In this case make sure that `encoder_pretrained_model_name_or_path` defined"
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
"to be defined."
)
if "config" not in kwargs_encoder:
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
logger.info(
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
"from a decoder model. Cross-attention and casual mask are disabled."
......@@ -357,16 +357,18 @@ class VisionEncoderDecoderModel(PreTrainedModel):
if decoder is None:
if decoder_pretrained_model_name_or_path is None:
raise ValueError(
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined"
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
"to be defined."
)
if "config" not in kwargs_decoder:
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model."
"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} "
f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for "
"cross attention layers."
)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
......@@ -375,11 +377,11 @@ class VisionEncoderDecoderModel(PreTrainedModel):
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
logger.warning(
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder."
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config`"
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` "
f"to `.from_encoder_decoder_pretrained(...)`"
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
)
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
......
......@@ -19,8 +19,8 @@ import unittest
import numpy as np
from transformers import is_flax_available
from transformers.testing_utils import require_flax, slow
from transformers import is_flax_available, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow, torch_device
from .test_modeling_flax_bert import FlaxBertModelTester
from .test_modeling_flax_common import ids_tensor
......@@ -35,6 +35,15 @@ if is_flax_available():
FlaxEncoderDecoderModel,
FlaxGPT2LMHeadModel,
)
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
if is_torch_available():
import torch
from transformers import EncoderDecoderModel
@require_flax
......@@ -234,6 +243,71 @@ class FlaxEncoderDecoderMixin:
generated_sequences = generated_output.sequences
self.assertEqual(generated_sequences.shape, (input_ids.shape[0],) + (decoder_config.max_length,))
def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict):
pt_model.to(torch_device)
pt_model.eval()
# prepare inputs
flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5)
# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = FlaxEncoderDecoderModel.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-5)
# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = EncoderDecoderModel.from_pretrained(tmpdirname, from_flax=True)
pt_model_loaded.to(torch_device)
pt_model_loaded.eval()
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5)
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
pt_model = EncoderDecoderModel(encoder_decoder_config)
fx_model = FlaxEncoderDecoderModel(encoder_decoder_config)
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
def check_equivalence_flax_to_pt(self, config, decoder_config, inputs_dict):
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
pt_model = EncoderDecoderModel(encoder_decoder_config)
fx_model = FlaxEncoderDecoderModel(encoder_decoder_config)
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
self.check_pt_flax_equivalence(pt_model, fx_model, inputs_dict)
def test_encoder_decoder_model_from_pretrained_configs(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_from_pretrained_configs(**input_ids_dict)
......@@ -258,6 +332,44 @@ class FlaxEncoderDecoderMixin:
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_generate(**input_ids_dict)
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
diff = np.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
@is_pt_flax_cross_test
def test_pt_flax_equivalence(self):
config_inputs_dict = self.prepare_config_and_inputs()
config = config_inputs_dict.pop("config")
decoder_config = config_inputs_dict.pop("decoder_config")
inputs_dict = config_inputs_dict
# `encoder_hidden_states` is not used in model call/forward
del inputs_dict["encoder_hidden_states"]
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
batch_size = inputs_dict["decoder_attention_mask"].shape[0]
inputs_dict["decoder_attention_mask"] = np.concatenate(
[np.ones(shape=(batch_size, 1)), inputs_dict["decoder_attention_mask"][:, 1:]], axis=1
)
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
decoder_config.use_cache = False
self.assertTrue(decoder_config.cross_attention_hidden_size is None)
# check without `enc_to_dec_proj` projection
self.assertTrue(config.hidden_size == decoder_config.hidden_size)
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
# check `enc_to_dec_proj` work as expected
decoder_config.hidden_size = decoder_config.hidden_size * 2
self.assertTrue(config.hidden_size != decoder_config.hidden_size)
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
@slow
def test_real_model_save_load_from_pretrained(self):
model_2 = self.get_pretrained_model()
......
......@@ -31,6 +31,8 @@ from .test_modeling_tf_roberta import TFRobertaModelTester
if is_tf_available():
import tensorflow as tf
from transformers import (
AutoConfig,
AutoTokenizer,
......@@ -309,6 +311,90 @@ class TFEncoderDecoderMixin:
)
self.assertEqual(tuple(generated_output.shape.as_list()), (input_ids.shape[0],) + (decoder_config.max_length,))
def check_pt_tf_equivalence(self, pt_model, tf_model, inputs_dict):
pt_model.to(torch_device)
pt_model.eval()
# prepare inputs
tf_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.numpy()) for k, v in tf_inputs.items()}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
tf_outputs = tf_model(**inputs_dict).to_tuple()
self.assertEqual(len(tf_outputs), len(pt_outputs), "Output lengths differ between TF and PyTorch")
for tf_output, pt_output in zip(tf_outputs, pt_outputs):
self.assert_almost_equals(tf_output.numpy(), pt_output.numpy(), 1e-3)
# PT -> TF
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
pt_model.encoder.save_pretrained(encoder_tmp_dirname)
pt_model.decoder.save_pretrained(decoder_tmp_dirname)
tf_model_loaded = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_tmp_dirname, decoder_tmp_dirname, encoder_from_pt=True, decoder_from_pt=True
)
# This is only for copying some specific attributes of this particular model.
tf_model_loaded.config = pt_model.config
tf_outputs_loaded = tf_model_loaded(**inputs_dict).to_tuple()
self.assertEqual(len(tf_outputs_loaded), len(pt_outputs), "Output lengths differ between TF and PyTorch")
for tf_output_loaded, pt_output in zip(tf_outputs_loaded, pt_outputs):
self.assert_almost_equals(tf_output_loaded.numpy(), pt_output.numpy(), 1e-3)
def check_equivalence_pt_to_tf(self, config, decoder_config, inputs_dict):
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
pt_model = EncoderDecoderModel(encoder_decoder_config)
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
pt_model.encoder.save_pretrained(encoder_tmp_dirname)
pt_model.decoder.save_pretrained(decoder_tmp_dirname)
tf_model = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_tmp_dirname, decoder_tmp_dirname, encoder_from_pt=True, decoder_from_pt=True
)
# This is only for copying some specific attributes of this particular model.
tf_model.config = pt_model.config
self.check_pt_tf_equivalence(pt_model, tf_model, inputs_dict)
def check_equivalence_tf_to_pt(self, config, decoder_config, inputs_dict):
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
# Using `_tf_model`, the test will fail, because the weights of `_tf_model` get extended before saving
# the encoder/decoder models.
# There was a (very) ugly potential fix, which wasn't integrated to `transformers`: see
# https://github.com/huggingface/transformers/pull/13222/commits/dbb3c9de76eee235791d2064094654637c99f36d#r697304245
# (the change in `src/transformers/modeling_tf_utils.py`)
_tf_model = TFEncoderDecoderModel(encoder_decoder_config)
# Make sure model is built
_tf_model(**inputs_dict)
# Using `tf_model` to pass the test.
encoder = _tf_model.encoder.__class__(encoder_decoder_config.encoder)
decoder = _tf_model.decoder.__class__(encoder_decoder_config.decoder)
# Make sure models are built
encoder(encoder.dummy_inputs)
decoder(decoder.dummy_inputs)
tf_model = TFEncoderDecoderModel(encoder=encoder, decoder=decoder)
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
tf_model.encoder.save_pretrained(encoder_tmp_dirname)
tf_model.decoder.save_pretrained(decoder_tmp_dirname)
pt_model = EncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_tmp_dirname, decoder_tmp_dirname, encoder_from_tf=True, decoder_from_tf=True
)
# This is only for copying some specific attributes of this particular model.
pt_model.config = tf_model.config
self.check_pt_tf_equivalence(pt_model, tf_model, inputs_dict)
def test_encoder_decoder_model(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model(**input_ids_dict)
......@@ -341,6 +427,65 @@ class TFEncoderDecoderMixin:
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_generate(**input_ids_dict)
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
diff = np.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between torch and tf is {diff} (>= {tol}).")
@is_pt_tf_cross_test
def test_pt_tf_equivalence(self):
config_inputs_dict = self.prepare_config_and_inputs()
# Keep only common arguments
arg_names = [
"config",
"input_ids",
"attention_mask",
"decoder_config",
"decoder_input_ids",
"decoder_attention_mask",
"encoder_hidden_states",
]
config_inputs_dict = {k: v for k, v in config_inputs_dict.items() if k in arg_names}
config = config_inputs_dict.pop("config")
decoder_config = config_inputs_dict.pop("decoder_config")
inputs_dict = config_inputs_dict
# `encoder_hidden_states` is not used in model call/forward
del inputs_dict["encoder_hidden_states"]
# Avoid the case where a sequence has no place to attend (after combined with the causal attention mask)
batch_size = inputs_dict["decoder_attention_mask"].shape[0]
inputs_dict["decoder_attention_mask"] = tf.constant(
np.concatenate([np.ones(shape=(batch_size, 1)), inputs_dict["decoder_attention_mask"][:, 1:]], axis=1)
)
# TF models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
decoder_config.use_cache = False
self.assertTrue(decoder_config.cross_attention_hidden_size is None)
# check without `enc_to_dec_proj` projection
self.assertTrue(config.hidden_size == decoder_config.hidden_size)
self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)
# This is not working, because pt/tf equivalence test for encoder-decoder use `from_encoder_decoder_pretrained`,
# which randomly initialize `enc_to_dec_proj`.
# # check `enc_to_dec_proj` work as expected
# decoder_config.hidden_size = decoder_config.hidden_size * 2
# self.assertTrue(config.hidden_size != decoder_config.hidden_size)
# self.check_equivalence_pt_to_tf(config, decoder_config, inputs_dict)
# self.check_equivalence_tf_to_pt(config, decoder_config, inputs_dict)
# Let's just check `enc_to_dec_proj` can run for now
decoder_config.hidden_size = decoder_config.hidden_size * 2
self.assertTrue(config.hidden_size != decoder_config.hidden_size)
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
model = TFEncoderDecoderModel(encoder_decoder_config)
model(**inputs_dict)
@slow
def test_real_model_save_load_from_pretrained(self):
model_2 = self.get_pretrained_model()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment