Unverified Commit 34259366 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[EncoderDecoderModel] add a `add_cross_attention` boolean to config (#6377)

* correct encoder decoder model

* Apply suggestions from code review

* apply sylvains suggestions
parent 06bc347c
......@@ -58,6 +58,7 @@ class EncoderDecoderConfig(PretrainedConfig):
>>> config_decoder = model.config.decoder
>>> # set decoder config to causal lm
>>> config_decoder.is_decoder = True
>>> config_decoder.add_cross_attention = True
>>> # Saving the model, including its configuration
>>> model.save_pretrained('my-model')
......@@ -94,8 +95,9 @@ class EncoderDecoderConfig(PretrainedConfig):
Returns:
:class:`EncoderDecoderConfig`: An instance of a configuration object
"""
logger.info("Set `config.is_decoder=True` for decoder_config")
logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict())
......
......@@ -56,6 +56,8 @@ class PretrainedConfig(object):
Whether the model is used as an encoder/decoder or not.
is_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether the model is used as decoder or not (in which case it's used as an encoder).
add_cross_attention (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether cross-attention layers should be added to the model. Note, this option is only relevant for models that can be used as decoder models within the `:class:~transformers.EncoderDecoderModel` class, which consists of all models in ``AUTO_MODELS_FOR_CAUSAL_LM``.
prune_heads (:obj:`Dict[int, List[int]]`, `optional`, defaults to :obj:`{}`):
Pruned heads of the model. The keys are the selected layer indices and the associated values, the list
of heads to prune in said layer.
......@@ -145,6 +147,7 @@ class PretrainedConfig(object):
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
self.is_decoder = kwargs.pop("is_decoder", False)
self.add_cross_attention = kwargs.pop("add_cross_attention", False)
# Parameters for sequence generation
self.max_length = kwargs.pop("max_length", 20)
......
......@@ -378,7 +378,9 @@ class BertLayer(nn.Module):
super().__init__()
self.attention = BertAttention(config)
self.is_decoder = config.is_decoder
if self.is_decoder:
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
self.crossattention = BertAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
......@@ -399,6 +401,9 @@ class BertLayer(nn.Module):
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
if self.is_decoder and encoder_hidden_states is not None:
assert hasattr(
self, "crossattention"
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
......@@ -695,8 +700,10 @@ class BertModel(BertPreTrainedModel):
Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
To behave as an decoder the model needs to be initialized with the
:obj:`is_decoder` argument of the configuration set to :obj:`True`; an
:obj:`encoder_hidden_states` is expected as an input to the forward pass.
:obj:`is_decoder` argument of the configuration set to :obj:`True`.
To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder`
argument and :obj:`add_cross_attention` set to :obj:`True`; an
:obj:`encoder_hidden_states` is then expected as an input to the forward pass.
.. _`Attention is all you need`:
https://arxiv.org/abs/1706.03762
......
......@@ -168,17 +168,18 @@ class EncoderDecoderModel(PreTrainedModel):
from .configuration_auto import AutoConfig
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
if decoder_config.is_decoder is False:
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
kwargs_decoder["config"] = decoder_config
if kwargs_decoder["config"].is_decoder is False:
if kwargs_decoder["config"].is_decoder is False or decoder_config.add_cross_attention is False:
logger.warning(
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attribute `is_decoder` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` is set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
)
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
......
......@@ -176,6 +176,7 @@ class BertModelTester:
encoder_hidden_states,
encoder_attention_mask,
):
config.add_cross_attention = True
model = BertModel(config)
model.to(torch_device)
model.eval()
......@@ -235,6 +236,7 @@ class BertModelTester:
encoder_hidden_states,
encoder_attention_mask,
):
config.add_cross_attention = True
model = BertLMHeadModel(config=config)
model.to(torch_device)
model.eval()
......
......@@ -59,6 +59,9 @@ class EncoderDecoderModelTest(unittest.TestCase):
encoder_hidden_states,
encoder_attention_mask,
) = decoder_config_and_inputs
# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
return {
"config": config,
"input_ids": input_ids,
......@@ -119,6 +122,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
decoder_model = BertLMHeadModel(decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
self.assertTrue(enc_dec_model.config.decoder.is_decoder)
self.assertTrue(enc_dec_model.config.decoder.add_cross_attention)
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
......@@ -330,7 +334,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
self.assertIsNotNone(model)
@slow
def test_real_bert_model_from_pretrained_has_cross_attention(self):
def test_real_bert_model_from_pretrained_add_cross_attention(self):
model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
self.assertTrue(hasattr(model.decoder.bert.encoder.layer[0], "crossattention"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment