Unverified Commit fe0b85e7 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[EncoderDecoder] Add functionality to tie encoder decoder weights (#6538)



* start adding tie encoder to decoder functionality

* finish model tying

* make style

* Apply suggestions from code review

* fix t5 list including cross attention

* apply sams suggestions

* Update src/transformers/modeling_encoder_decoder.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* add max depth break point
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent ab42d748
......@@ -87,7 +87,7 @@ class EncoderDecoderConfig(PretrainedConfig):
@classmethod
def from_encoder_decoder_configs(
cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig
cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
) -> PretrainedConfig:
r"""
Instantiate a :class:`~transformers.EncoderDecoderConfig` (or a derived class) from a pre-trained encoder model configuration and decoder model configuration.
......@@ -99,7 +99,7 @@ class EncoderDecoderConfig(PretrainedConfig):
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict())
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
def to_dict(self):
"""
......
......@@ -58,6 +58,8 @@ class PretrainedConfig(object):
Whether the model is used as decoder or not (in which case it's used as an encoder).
add_cross_attention (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether cross-attention layers should be added to the model. Note, this option is only relevant for models that can be used as decoder models within the `:class:~transformers.EncoderDecoderModel` class, which consists of all models in ``AUTO_MODELS_FOR_CAUSAL_LM``.
tie_encoder_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`)
Whether all encoder weights should be tied to their equivalent decoder weights. This requires the encoder and decoder model to have the exact same parameter names.
prune_heads (:obj:`Dict[int, List[int]]`, `optional`, defaults to :obj:`{}`):
Pruned heads of the model. The keys are the selected layer indices and the associated values, the list
of heads to prune in said layer.
......@@ -153,6 +155,7 @@ class PretrainedConfig(object):
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
self.is_decoder = kwargs.pop("is_decoder", False)
self.add_cross_attention = kwargs.pop("add_cross_attention", False)
self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False)
# Parameters for sequence generation
self.max_length = kwargs.pop("max_length", 20)
......
......@@ -71,9 +71,17 @@ class EncoderDecoderModel(PreTrainedModel):
self.encoder.get_output_embeddings() is None
), "The encoder {} should not have a LM Head. Please use a model without LM Head"
# tie encoder, decoder weights if config set accordingly
self.tie_weights()
def tie_weights(self):
# for now no weights tying in encoder-decoder
pass
# tie encoder & decoder if needed
if self.config.tie_encoder_decoder:
# tie encoder and decoder base model
decoder_base_model_prefix = self.decoder.base_model_prefix
self._tie_encoder_decoder_weights(
self.encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix
)
def get_encoder(self):
return self.encoder
......@@ -122,7 +130,11 @@ class EncoderDecoderModel(PreTrainedModel):
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
kwargs: (`optional`) Remaining dictionary of keyword arguments.
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``).
- To update the encoder configuration, use the prefix `encoder_` for each configuration parameter
- To update the decoder configuration, use the prefix `decoder_` for each configuration parameter
- To update the parent model configuration, do not use a prefix for each configuration parameter
Behave differently depending on whether a :obj:`config` is provided or automatically loaded.
Examples::
......@@ -144,6 +156,12 @@ class EncoderDecoderModel(PreTrainedModel):
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
}
# remove encoder, decoder kwargs from kwargs
for key in kwargs_encoder.keys():
del kwargs["encoder_" + key]
for key in kwargs_decoder.keys():
del kwargs["decoder_" + key]
# Load and initialize the encoder and decoder
# The distinction between encoder and decoder at the model level is made
# by the value of the flag `is_decoder` that we need to set correctly.
......@@ -184,7 +202,9 @@ class EncoderDecoderModel(PreTrainedModel):
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
return cls(encoder=encoder, decoder=decoder)
# instantiate config with corresponding kwargs
config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
return cls(encoder=encoder, decoder=decoder, config=config)
def forward(
self,
......
......@@ -887,10 +887,12 @@ class T5Model(T5PreTrainedModel):
encoder_config = copy.deepcopy(config)
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
self.encoder = T5Stack(encoder_config, self.shared)
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
decoder_config.is_encoder_decoder = False
self.decoder = T5Stack(decoder_config, self.shared)
self.init_weights()
......@@ -1040,10 +1042,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
encoder_config = copy.deepcopy(config)
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
self.encoder = T5Stack(encoder_config, self.shared)
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
decoder_config.is_encoder_decoder = False
self.decoder = T5Stack(decoder_config, self.shared)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
......
......@@ -416,6 +416,77 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
if output_embeddings is not None:
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
if self.config.is_encoder_decoder and self.config.tie_encoder_decoder:
self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)
@staticmethod
def _tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str):
uninitialized_encoder_weights: List[str] = []
assert decoder.__class__ == encoder.__class__, f"{decoder.__class__} and {encoder.__class__} have to be equal."
def tie_encoder_to_decoder_recursively(
decoder_pointer: nn.Module,
encoder_pointer: nn.Module,
module_name: str,
uninitialized_encoder_weights: List[str],
depth=0,
):
assert isinstance(decoder_pointer, nn.Module) and isinstance(
encoder_pointer, nn.Module
), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module"
if hasattr(decoder_pointer, "weight"):
assert hasattr(encoder_pointer, "weight")
encoder_pointer.weight = decoder_pointer.weight
if hasattr(decoder_pointer, "bias"):
assert hasattr(encoder_pointer, "bias")
encoder_pointer.bias = decoder_pointer.bias
return
encoder_modules = encoder_pointer._modules
decoder_modules = decoder_pointer._modules
if len(decoder_modules) > 0:
assert (
len(encoder_modules) > 0
), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()])
encoder_layer_pos = 0
for name, module in decoder_modules.items():
if name.isdigit():
encoder_name = str(int(name) + encoder_layer_pos)
decoder_name = name
if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])):
# this can happen if the name corresponds to the position in a list module list of layers
# in this case the decoder has added a cross-attention that the encoder does not have
# thus skip this step and substract one layer pos from encoder
encoder_layer_pos -= 1
continue
elif name not in encoder_modules:
continue
elif depth > 500:
raise ValueError(
"Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model."
)
else:
decoder_name = encoder_name = name
tie_encoder_to_decoder_recursively(
decoder_modules[decoder_name],
encoder_modules[encoder_name],
module_name + "/" + name,
uninitialized_encoder_weights,
depth=depth + 1,
)
all_encoder_weights.remove(module_name + "/" + encoder_name)
uninitialized_encoder_weights += list(all_encoder_weights)
# tie weights recursively
tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights)
if len(uninitialized_encoder_weights) > 0:
logger.warning(
f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}"
)
def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
""" Tie or clone module weights depending of whether we are using TorchScript or not
"""
......@@ -894,7 +965,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
model.__class__.__name__, "\n\t".join(error_msgs)
)
)
model.tie_weights() # make sure token embedding weights are still tied if needed
# make sure token embedding weights are still tied if needed
model.tie_weights()
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
......
......@@ -268,6 +268,88 @@ class EncoderDecoderMixin:
)
self.assertEqual(generated_output.shape, (input_ids.shape[0],) + (decoder_config.max_length,))
def create_and_check_encoder_decoder_shared_weights(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
labels,
**kwargs
):
torch.manual_seed(0)
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
model.to(torch_device)
model.eval()
# load state dict copies weights but does not tie them
decoder_state_dict = model.decoder._modules[model.decoder.base_model_prefix].state_dict()
model.encoder.load_state_dict(decoder_state_dict, strict=False)
torch.manual_seed(0)
tied_encoder_model, tied_decoder_model = self.get_encoder_decoder_model(config, decoder_config)
config = EncoderDecoderConfig.from_encoder_decoder_configs(
tied_encoder_model.config, tied_decoder_model.config, tie_encoder_decoder=True
)
tied_model = EncoderDecoderModel(encoder=tied_encoder_model, decoder=tied_decoder_model, config=config)
tied_model.to(torch_device)
tied_model.eval()
model_result = model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
tied_model_result = tied_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
# check that models has less parameters
self.assertLess(sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()))
random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item()
# check that outputs are equal
self.assertTrue(
torch.allclose(
model_result[0][0, :, random_slice_idx], tied_model_result[0][0, :, random_slice_idx], atol=1e-4
)
)
# check that outputs after saving and loading are equal
with tempfile.TemporaryDirectory() as tmpdirname:
tied_model.save_pretrained(tmpdirname)
tied_model = EncoderDecoderModel.from_pretrained(tmpdirname)
tied_model.to(torch_device)
tied_model.eval()
# check that models has less parameters
self.assertLess(
sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters())
)
random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item()
tied_model_result = tied_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
# check that outputs are equal
self.assertTrue(
torch.allclose(
model_result[0][0, :, random_slice_idx], tied_model_result[0][0, :, random_slice_idx], atol=1e-4
)
)
def test_encoder_decoder_model(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model(**input_ids_dict)
......@@ -296,6 +378,10 @@ class EncoderDecoderMixin:
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_generate(**input_ids_dict)
def test_encoder_decoder_model_shared_weights(self):
input_ids_dict = self.prepare_config_and_inputs()
self.create_and_check_encoder_decoder_shared_weights(**input_ids_dict)
@slow
def test_real_model_save_load_from_pretrained(self):
model_2 = self.get_pretrained_model()
......@@ -480,3 +566,6 @@ class GPT2EncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model(self):
return EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "gpt2")
def test_encoder_decoder_model_shared_weights(self):
pass
......@@ -14,6 +14,8 @@
# limitations under the License.
import copy
import tempfile
import unittest
from transformers import is_torch_available
......@@ -130,7 +132,7 @@ class T5ModelTester:
# all items after square
self.parent.assertListEqual(decoder_input_ids_slice[1:].tolist(), lm_labels_slice[:-1].tolist())
def create_and_check_t5_model(
def create_and_check_model(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
model = T5Model(config=config)
......@@ -156,7 +158,7 @@ class T5ModelTester:
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple
self.parent.assertEqual(len(decoder_past[1][0]), 4)
def create_and_check_t5_with_lm_head(
def create_and_check_with_lm_head(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
model = T5ForConditionalGeneration(config=config).to(torch_device).eval()
......@@ -170,7 +172,7 @@ class T5ModelTester:
self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.decoder_seq_length, self.vocab_size))
self.parent.assertEqual(outputs["loss"].size(), ())
def create_and_check_t5_decoder_model_past(
def create_and_check_decoder_model_past(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
model = T5Model(config=config).get_decoder().to(torch_device).eval()
......@@ -201,7 +203,7 @@ class T5ModelTester:
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_t5_decoder_model_attention_mask_past(
def create_and_check_decoder_model_attention_mask_past(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
model = T5Model(config=config).get_decoder()
......@@ -245,7 +247,7 @@ class T5ModelTester:
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_t5_and_check_t5_generate_with_past_key_value_states(
def create_and_check_generate_with_past_key_value_states(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
model = T5ForConditionalGeneration(config=config).to(torch_device).eval()
......@@ -257,13 +259,83 @@ class T5ModelTester:
output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True)
self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache))
def create_and_check_t5_model_fp16_forward(
def create_and_check_model_fp16_forward(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
model = T5Model(config=config).to(torch_device).half().eval()
output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)["last_hidden_state"]
self.parent.assertFalse(torch.isnan(output).any().item())
def create_and_check_encoder_decoder_shared_weights(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
for model_class in [T5Model, T5ForConditionalGeneration]:
torch.manual_seed(0)
model = model_class(config=config).to(torch_device).eval()
# load state dict copies weights but does not tie them
model.encoder.load_state_dict(model.decoder.state_dict(), strict=False)
torch.manual_seed(0)
tied_config = copy.deepcopy(config)
tied_config.tie_encoder_decoder = True
tied_model = model_class(config=tied_config).to(torch_device).eval()
model_result = model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
tied_model_result = tied_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
# check that models has less parameters
self.parent.assertLess(
sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters())
)
random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item()
# check that outputs are equal
self.parent.assertTrue(
torch.allclose(
model_result[0][0, :, random_slice_idx], tied_model_result[0][0, :, random_slice_idx], atol=1e-4
)
)
# check that outputs after saving and loading are equal
with tempfile.TemporaryDirectory() as tmpdirname:
tied_model.save_pretrained(tmpdirname)
tied_model = model_class.from_pretrained(tmpdirname)
tied_model.to(torch_device)
tied_model.eval()
# check that models has less parameters
self.parent.assertLess(
sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters())
)
random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item()
tied_model_result = tied_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
# check that outputs are equal
self.parent.assertTrue(
torch.allclose(
model_result[0][0, :, random_slice_idx],
tied_model_result[0][0, :, random_slice_idx],
atol=1e-4,
)
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,) = config_and_inputs
......@@ -299,30 +371,34 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_prepare_lm_labels_via_shift_left(*config_and_inputs)
def test_t5_model(self):
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_model(*config_and_inputs)
self.model_tester.create_and_check_model(*config_and_inputs)
def test_with_lm_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_with_lm_head(*config_and_inputs)
self.model_tester.create_and_check_with_lm_head(*config_and_inputs)
def test_t5_decoder_model_past(self):
def test_decoder_model_past(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_decoder_model_past(*config_and_inputs)
self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
def test_t5_decoder_model_past_with_attn_mask(self):
def test_decoder_model_past_with_attn_mask(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_decoder_model_attention_mask_past(*config_and_inputs)
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
def test_t5_generate_with_past_key_value_states(self):
def test_generate_with_past_key_value_states(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_t5_and_check_t5_generate_with_past_key_value_states(*config_and_inputs)
self.model_tester.create_and_check_generate_with_past_key_value_states(*config_and_inputs)
def test_encoder_decoder_shared_weights(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_encoder_decoder_shared_weights(*config_and_inputs)
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
def test_t5_model_fp16_forward(self):
def test_model_fp16_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_model_fp16_forward(*config_and_inputs)
self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
......@@ -331,8 +407,6 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
self.assertIsNotNone(model)
def test_export_to_onnx(self):
import tempfile
config_and_inputs = self.model_tester.prepare_config_and_inputs()
model = T5Model(config_and_inputs[0]).to(torch_device)
with tempfile.TemporaryDirectory() as tmpdirname:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment