Unverified Commit 34259366 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[EncoderDecoderModel] add a `add_cross_attention` boolean to config (#6377)

* correct encoder decoder model

* Apply suggestions from code review

* apply sylvains suggestions
parent 06bc347c
...@@ -58,6 +58,7 @@ class EncoderDecoderConfig(PretrainedConfig): ...@@ -58,6 +58,7 @@ class EncoderDecoderConfig(PretrainedConfig):
>>> config_decoder = model.config.decoder >>> config_decoder = model.config.decoder
>>> # set decoder config to causal lm >>> # set decoder config to causal lm
>>> config_decoder.is_decoder = True >>> config_decoder.is_decoder = True
>>> config_decoder.add_cross_attention = True
>>> # Saving the model, including its configuration >>> # Saving the model, including its configuration
>>> model.save_pretrained('my-model') >>> model.save_pretrained('my-model')
...@@ -94,8 +95,9 @@ class EncoderDecoderConfig(PretrainedConfig): ...@@ -94,8 +95,9 @@ class EncoderDecoderConfig(PretrainedConfig):
Returns: Returns:
:class:`EncoderDecoderConfig`: An instance of a configuration object :class:`EncoderDecoderConfig`: An instance of a configuration object
""" """
logger.info("Set `config.is_decoder=True` for decoder_config") logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
decoder_config.is_decoder = True decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict()) return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict())
......
...@@ -56,6 +56,8 @@ class PretrainedConfig(object): ...@@ -56,6 +56,8 @@ class PretrainedConfig(object):
Whether the model is used as an encoder/decoder or not. Whether the model is used as an encoder/decoder or not.
is_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`): is_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether the model is used as decoder or not (in which case it's used as an encoder). Whether the model is used as decoder or not (in which case it's used as an encoder).
add_cross_attention (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether cross-attention layers should be added to the model. Note, this option is only relevant for models that can be used as decoder models within the `:class:~transformers.EncoderDecoderModel` class, which consists of all models in ``AUTO_MODELS_FOR_CAUSAL_LM``.
prune_heads (:obj:`Dict[int, List[int]]`, `optional`, defaults to :obj:`{}`): prune_heads (:obj:`Dict[int, List[int]]`, `optional`, defaults to :obj:`{}`):
Pruned heads of the model. The keys are the selected layer indices and the associated values, the list Pruned heads of the model. The keys are the selected layer indices and the associated values, the list
of heads to prune in said layer. of heads to prune in said layer.
...@@ -145,6 +147,7 @@ class PretrainedConfig(object): ...@@ -145,6 +147,7 @@ class PretrainedConfig(object):
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder # Is decoder is used in encoder-decoder models to differentiate encoder from decoder
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False) self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
self.is_decoder = kwargs.pop("is_decoder", False) self.is_decoder = kwargs.pop("is_decoder", False)
self.add_cross_attention = kwargs.pop("add_cross_attention", False)
# Parameters for sequence generation # Parameters for sequence generation
self.max_length = kwargs.pop("max_length", 20) self.max_length = kwargs.pop("max_length", 20)
......
...@@ -378,7 +378,9 @@ class BertLayer(nn.Module): ...@@ -378,7 +378,9 @@ class BertLayer(nn.Module):
super().__init__() super().__init__()
self.attention = BertAttention(config) self.attention = BertAttention(config)
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
if self.is_decoder: self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
self.crossattention = BertAttention(config) self.crossattention = BertAttention(config)
self.intermediate = BertIntermediate(config) self.intermediate = BertIntermediate(config)
self.output = BertOutput(config) self.output = BertOutput(config)
...@@ -399,6 +401,9 @@ class BertLayer(nn.Module): ...@@ -399,6 +401,9 @@ class BertLayer(nn.Module):
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
if self.is_decoder and encoder_hidden_states is not None: if self.is_decoder and encoder_hidden_states is not None:
assert hasattr(
self, "crossattention"
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
cross_attention_outputs = self.crossattention( cross_attention_outputs = self.crossattention(
attention_output, attention_output,
attention_mask, attention_mask,
...@@ -695,8 +700,10 @@ class BertModel(BertPreTrainedModel): ...@@ -695,8 +700,10 @@ class BertModel(BertPreTrainedModel):
Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
To behave as an decoder the model needs to be initialized with the To behave as an decoder the model needs to be initialized with the
:obj:`is_decoder` argument of the configuration set to :obj:`True`; an :obj:`is_decoder` argument of the configuration set to :obj:`True`.
:obj:`encoder_hidden_states` is expected as an input to the forward pass. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder`
argument and :obj:`add_cross_attention` set to :obj:`True`; an
:obj:`encoder_hidden_states` is then expected as an input to the forward pass.
.. _`Attention is all you need`: .. _`Attention is all you need`:
https://arxiv.org/abs/1706.03762 https://arxiv.org/abs/1706.03762
......
...@@ -168,17 +168,18 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -168,17 +168,18 @@ class EncoderDecoderModel(PreTrainedModel):
from .configuration_auto import AutoConfig from .configuration_auto import AutoConfig
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path) decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
if decoder_config.is_decoder is False: if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info( logger.info(
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
) )
decoder_config.is_decoder = True decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
kwargs_decoder["config"] = decoder_config kwargs_decoder["config"] = decoder_config
if kwargs_decoder["config"].is_decoder is False: if kwargs_decoder["config"].is_decoder is False or decoder_config.add_cross_attention is False:
logger.warning( logger.warning(
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attribute `is_decoder` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` is set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`" f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
) )
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
......
...@@ -176,6 +176,7 @@ class BertModelTester: ...@@ -176,6 +176,7 @@ class BertModelTester:
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
): ):
config.add_cross_attention = True
model = BertModel(config) model = BertModel(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -235,6 +236,7 @@ class BertModelTester: ...@@ -235,6 +236,7 @@ class BertModelTester:
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
): ):
config.add_cross_attention = True
model = BertLMHeadModel(config=config) model = BertLMHeadModel(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
......
...@@ -59,6 +59,9 @@ class EncoderDecoderModelTest(unittest.TestCase): ...@@ -59,6 +59,9 @@ class EncoderDecoderModelTest(unittest.TestCase):
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
) = decoder_config_and_inputs ) = decoder_config_and_inputs
# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
return { return {
"config": config, "config": config,
"input_ids": input_ids, "input_ids": input_ids,
...@@ -119,6 +122,7 @@ class EncoderDecoderModelTest(unittest.TestCase): ...@@ -119,6 +122,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
decoder_model = BertLMHeadModel(decoder_config) decoder_model = BertLMHeadModel(decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
self.assertTrue(enc_dec_model.config.decoder.is_decoder) self.assertTrue(enc_dec_model.config.decoder.is_decoder)
self.assertTrue(enc_dec_model.config.decoder.add_cross_attention)
self.assertTrue(enc_dec_model.config.is_encoder_decoder) self.assertTrue(enc_dec_model.config.is_encoder_decoder)
enc_dec_model.to(torch_device) enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model( outputs_encoder_decoder = enc_dec_model(
...@@ -330,7 +334,7 @@ class EncoderDecoderModelTest(unittest.TestCase): ...@@ -330,7 +334,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
self.assertIsNotNone(model) self.assertIsNotNone(model)
@slow @slow
def test_real_bert_model_from_pretrained_has_cross_attention(self): def test_real_bert_model_from_pretrained_add_cross_attention(self):
model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased") model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
self.assertTrue(hasattr(model.decoder.bert.encoder.layer[0], "crossattention")) self.assertTrue(hasattr(model.decoder.bert.encoder.layer[0], "crossattention"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment