Unverified Commit fe0b85e7 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[EncoderDecoder] Add functionality to tie encoder decoder weights (#6538)



* start adding tie encoder to decoder functionality

* finish model tying

* make style

* Apply suggestions from code review

* fix t5 list including cross attention

* apply sams suggestions

* Update src/transformers/modeling_encoder_decoder.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* add max depth break point
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent ab42d748
...@@ -87,7 +87,7 @@ class EncoderDecoderConfig(PretrainedConfig): ...@@ -87,7 +87,7 @@ class EncoderDecoderConfig(PretrainedConfig):
@classmethod @classmethod
def from_encoder_decoder_configs( def from_encoder_decoder_configs(
cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
) -> PretrainedConfig: ) -> PretrainedConfig:
r""" r"""
Instantiate a :class:`~transformers.EncoderDecoderConfig` (or a derived class) from a pre-trained encoder model configuration and decoder model configuration. Instantiate a :class:`~transformers.EncoderDecoderConfig` (or a derived class) from a pre-trained encoder model configuration and decoder model configuration.
...@@ -99,7 +99,7 @@ class EncoderDecoderConfig(PretrainedConfig): ...@@ -99,7 +99,7 @@ class EncoderDecoderConfig(PretrainedConfig):
decoder_config.is_decoder = True decoder_config.is_decoder = True
decoder_config.add_cross_attention = True decoder_config.add_cross_attention = True
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict()) return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
def to_dict(self): def to_dict(self):
""" """
......
...@@ -58,6 +58,8 @@ class PretrainedConfig(object): ...@@ -58,6 +58,8 @@ class PretrainedConfig(object):
Whether the model is used as decoder or not (in which case it's used as an encoder). Whether the model is used as decoder or not (in which case it's used as an encoder).
add_cross_attention (:obj:`bool`, `optional`, defaults to :obj:`False`): add_cross_attention (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether cross-attention layers should be added to the model. Note, this option is only relevant for models that can be used as decoder models within the `:class:~transformers.EncoderDecoderModel` class, which consists of all models in ``AUTO_MODELS_FOR_CAUSAL_LM``. Whether cross-attention layers should be added to the model. Note, this option is only relevant for models that can be used as decoder models within the `:class:~transformers.EncoderDecoderModel` class, which consists of all models in ``AUTO_MODELS_FOR_CAUSAL_LM``.
tie_encoder_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`)
Whether all encoder weights should be tied to their equivalent decoder weights. This requires the encoder and decoder model to have the exact same parameter names.
prune_heads (:obj:`Dict[int, List[int]]`, `optional`, defaults to :obj:`{}`): prune_heads (:obj:`Dict[int, List[int]]`, `optional`, defaults to :obj:`{}`):
Pruned heads of the model. The keys are the selected layer indices and the associated values, the list Pruned heads of the model. The keys are the selected layer indices and the associated values, the list
of heads to prune in said layer. of heads to prune in said layer.
...@@ -153,6 +155,7 @@ class PretrainedConfig(object): ...@@ -153,6 +155,7 @@ class PretrainedConfig(object):
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False) self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
self.is_decoder = kwargs.pop("is_decoder", False) self.is_decoder = kwargs.pop("is_decoder", False)
self.add_cross_attention = kwargs.pop("add_cross_attention", False) self.add_cross_attention = kwargs.pop("add_cross_attention", False)
self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False)
# Parameters for sequence generation # Parameters for sequence generation
self.max_length = kwargs.pop("max_length", 20) self.max_length = kwargs.pop("max_length", 20)
......
...@@ -71,9 +71,17 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -71,9 +71,17 @@ class EncoderDecoderModel(PreTrainedModel):
self.encoder.get_output_embeddings() is None self.encoder.get_output_embeddings() is None
), "The encoder {} should not have a LM Head. Please use a model without LM Head" ), "The encoder {} should not have a LM Head. Please use a model without LM Head"
# tie encoder, decoder weights if config set accordingly
self.tie_weights()
def tie_weights(self): def tie_weights(self):
# for now no weights tying in encoder-decoder # tie encoder & decoder if needed
pass if self.config.tie_encoder_decoder:
# tie encoder and decoder base model
decoder_base_model_prefix = self.decoder.base_model_prefix
self._tie_encoder_decoder_weights(
self.encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix
)
def get_encoder(self): def get_encoder(self):
return self.encoder return self.encoder
...@@ -122,7 +130,11 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -122,7 +130,11 @@ class EncoderDecoderModel(PreTrainedModel):
All remaning positional arguments will be passed to the underlying model's ``__init__`` method All remaning positional arguments will be passed to the underlying model's ``__init__`` method
kwargs: (`optional`) Remaining dictionary of keyword arguments. kwargs: (`optional`) Remaining dictionary of keyword arguments.
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``).
- To update the encoder configuration, use the prefix `encoder_` for each configuration parameter
- To update the decoder configuration, use the prefix `decoder_` for each configuration parameter
- To update the parent model configuration, do not use a prefix for each configuration parameter
Behave differently depending on whether a :obj:`config` is provided or automatically loaded.
Examples:: Examples::
...@@ -144,6 +156,12 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -144,6 +156,12 @@ class EncoderDecoderModel(PreTrainedModel):
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
} }
# remove encoder, decoder kwargs from kwargs
for key in kwargs_encoder.keys():
del kwargs["encoder_" + key]
for key in kwargs_decoder.keys():
del kwargs["decoder_" + key]
# Load and initialize the encoder and decoder # Load and initialize the encoder and decoder
# The distinction between encoder and decoder at the model level is made # The distinction between encoder and decoder at the model level is made
# by the value of the flag `is_decoder` that we need to set correctly. # by the value of the flag `is_decoder` that we need to set correctly.
...@@ -184,7 +202,9 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -184,7 +202,9 @@ class EncoderDecoderModel(PreTrainedModel):
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
return cls(encoder=encoder, decoder=decoder) # instantiate config with corresponding kwargs
config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
return cls(encoder=encoder, decoder=decoder, config=config)
def forward( def forward(
self, self,
......
...@@ -887,10 +887,12 @@ class T5Model(T5PreTrainedModel): ...@@ -887,10 +887,12 @@ class T5Model(T5PreTrainedModel):
encoder_config = copy.deepcopy(config) encoder_config = copy.deepcopy(config)
encoder_config.use_cache = False encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
self.encoder = T5Stack(encoder_config, self.shared) self.encoder = T5Stack(encoder_config, self.shared)
decoder_config = copy.deepcopy(config) decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True decoder_config.is_decoder = True
decoder_config.is_encoder_decoder = False
self.decoder = T5Stack(decoder_config, self.shared) self.decoder = T5Stack(decoder_config, self.shared)
self.init_weights() self.init_weights()
...@@ -1040,10 +1042,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1040,10 +1042,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
encoder_config = copy.deepcopy(config) encoder_config = copy.deepcopy(config)
encoder_config.use_cache = False encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
self.encoder = T5Stack(encoder_config, self.shared) self.encoder = T5Stack(encoder_config, self.shared)
decoder_config = copy.deepcopy(config) decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True decoder_config.is_decoder = True
decoder_config.is_encoder_decoder = False
self.decoder = T5Stack(decoder_config, self.shared) self.decoder = T5Stack(decoder_config, self.shared)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
......
...@@ -416,6 +416,77 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -416,6 +416,77 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
if output_embeddings is not None: if output_embeddings is not None:
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
if self.config.is_encoder_decoder and self.config.tie_encoder_decoder:
self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)
@staticmethod
def _tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str):
uninitialized_encoder_weights: List[str] = []
assert decoder.__class__ == encoder.__class__, f"{decoder.__class__} and {encoder.__class__} have to be equal."
def tie_encoder_to_decoder_recursively(
decoder_pointer: nn.Module,
encoder_pointer: nn.Module,
module_name: str,
uninitialized_encoder_weights: List[str],
depth=0,
):
assert isinstance(decoder_pointer, nn.Module) and isinstance(
encoder_pointer, nn.Module
), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module"
if hasattr(decoder_pointer, "weight"):
assert hasattr(encoder_pointer, "weight")
encoder_pointer.weight = decoder_pointer.weight
if hasattr(decoder_pointer, "bias"):
assert hasattr(encoder_pointer, "bias")
encoder_pointer.bias = decoder_pointer.bias
return
encoder_modules = encoder_pointer._modules
decoder_modules = decoder_pointer._modules
if len(decoder_modules) > 0:
assert (
len(encoder_modules) > 0
), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()])
encoder_layer_pos = 0
for name, module in decoder_modules.items():
if name.isdigit():
encoder_name = str(int(name) + encoder_layer_pos)
decoder_name = name
if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])):
# this can happen if the name corresponds to the position in a list module list of layers
# in this case the decoder has added a cross-attention that the encoder does not have
# thus skip this step and substract one layer pos from encoder
encoder_layer_pos -= 1
continue
elif name not in encoder_modules:
continue
elif depth > 500:
raise ValueError(
"Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model."
)
else:
decoder_name = encoder_name = name
tie_encoder_to_decoder_recursively(
decoder_modules[decoder_name],
encoder_modules[encoder_name],
module_name + "/" + name,
uninitialized_encoder_weights,
depth=depth + 1,
)
all_encoder_weights.remove(module_name + "/" + encoder_name)
uninitialized_encoder_weights += list(all_encoder_weights)
# tie weights recursively
tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights)
if len(uninitialized_encoder_weights) > 0:
logger.warning(
f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}"
)
def _tie_or_clone_weights(self, output_embeddings, input_embeddings): def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
""" Tie or clone module weights depending of whether we are using TorchScript or not """ Tie or clone module weights depending of whether we are using TorchScript or not
""" """
...@@ -894,7 +965,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -894,7 +965,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
model.__class__.__name__, "\n\t".join(error_msgs) model.__class__.__name__, "\n\t".join(error_msgs)
) )
) )
model.tie_weights() # make sure token embedding weights are still tied if needed # make sure token embedding weights are still tied if needed
model.tie_weights()
# Set model in evaluation mode to deactivate DropOut modules by default # Set model in evaluation mode to deactivate DropOut modules by default
model.eval() model.eval()
......
...@@ -268,6 +268,88 @@ class EncoderDecoderMixin: ...@@ -268,6 +268,88 @@ class EncoderDecoderMixin:
) )
self.assertEqual(generated_output.shape, (input_ids.shape[0],) + (decoder_config.max_length,)) self.assertEqual(generated_output.shape, (input_ids.shape[0],) + (decoder_config.max_length,))
def create_and_check_encoder_decoder_shared_weights(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
labels,
**kwargs
):
torch.manual_seed(0)
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
model.to(torch_device)
model.eval()
# load state dict copies weights but does not tie them
decoder_state_dict = model.decoder._modules[model.decoder.base_model_prefix].state_dict()
model.encoder.load_state_dict(decoder_state_dict, strict=False)
torch.manual_seed(0)
tied_encoder_model, tied_decoder_model = self.get_encoder_decoder_model(config, decoder_config)
config = EncoderDecoderConfig.from_encoder_decoder_configs(
tied_encoder_model.config, tied_decoder_model.config, tie_encoder_decoder=True
)
tied_model = EncoderDecoderModel(encoder=tied_encoder_model, decoder=tied_decoder_model, config=config)
tied_model.to(torch_device)
tied_model.eval()
model_result = model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
tied_model_result = tied_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
# check that models has less parameters
self.assertLess(sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()))
random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item()
# check that outputs are equal
self.assertTrue(
torch.allclose(
model_result[0][0, :, random_slice_idx], tied_model_result[0][0, :, random_slice_idx], atol=1e-4
)
)
# check that outputs after saving and loading are equal
with tempfile.TemporaryDirectory() as tmpdirname:
tied_model.save_pretrained(tmpdirname)
tied_model = EncoderDecoderModel.from_pretrained(tmpdirname)
tied_model.to(torch_device)
tied_model.eval()
# check that models has less parameters
self.assertLess(
sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters())
)
random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item()
tied_model_result = tied_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
# check that outputs are equal
self.assertTrue(
torch.allclose(
model_result[0][0, :, random_slice_idx], tied_model_result[0][0, :, random_slice_idx], atol=1e-4
)
)
def test_encoder_decoder_model(self): def test_encoder_decoder_model(self):
input_ids_dict = self.prepare_config_and_inputs() input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model(**input_ids_dict) self.check_encoder_decoder_model(**input_ids_dict)
...@@ -296,6 +378,10 @@ class EncoderDecoderMixin: ...@@ -296,6 +378,10 @@ class EncoderDecoderMixin:
input_ids_dict = self.prepare_config_and_inputs() input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_generate(**input_ids_dict) self.check_encoder_decoder_model_generate(**input_ids_dict)
def test_encoder_decoder_model_shared_weights(self):
input_ids_dict = self.prepare_config_and_inputs()
self.create_and_check_encoder_decoder_shared_weights(**input_ids_dict)
@slow @slow
def test_real_model_save_load_from_pretrained(self): def test_real_model_save_load_from_pretrained(self):
model_2 = self.get_pretrained_model() model_2 = self.get_pretrained_model()
...@@ -480,3 +566,6 @@ class GPT2EncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase): ...@@ -480,3 +566,6 @@ class GPT2EncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model(self): def get_pretrained_model(self):
return EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "gpt2") return EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "gpt2")
def test_encoder_decoder_model_shared_weights(self):
pass
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
import copy
import tempfile
import unittest import unittest
from transformers import is_torch_available from transformers import is_torch_available
...@@ -130,7 +132,7 @@ class T5ModelTester: ...@@ -130,7 +132,7 @@ class T5ModelTester:
# all items after square # all items after square
self.parent.assertListEqual(decoder_input_ids_slice[1:].tolist(), lm_labels_slice[:-1].tolist()) self.parent.assertListEqual(decoder_input_ids_slice[1:].tolist(), lm_labels_slice[:-1].tolist())
def create_and_check_t5_model( def create_and_check_model(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
): ):
model = T5Model(config=config) model = T5Model(config=config)
...@@ -156,7 +158,7 @@ class T5ModelTester: ...@@ -156,7 +158,7 @@ class T5ModelTester:
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple # There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple
self.parent.assertEqual(len(decoder_past[1][0]), 4) self.parent.assertEqual(len(decoder_past[1][0]), 4)
def create_and_check_t5_with_lm_head( def create_and_check_with_lm_head(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
): ):
model = T5ForConditionalGeneration(config=config).to(torch_device).eval() model = T5ForConditionalGeneration(config=config).to(torch_device).eval()
...@@ -170,7 +172,7 @@ class T5ModelTester: ...@@ -170,7 +172,7 @@ class T5ModelTester:
self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.decoder_seq_length, self.vocab_size)) self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.decoder_seq_length, self.vocab_size))
self.parent.assertEqual(outputs["loss"].size(), ()) self.parent.assertEqual(outputs["loss"].size(), ())
def create_and_check_t5_decoder_model_past( def create_and_check_decoder_model_past(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
): ):
model = T5Model(config=config).get_decoder().to(torch_device).eval() model = T5Model(config=config).get_decoder().to(torch_device).eval()
...@@ -201,7 +203,7 @@ class T5ModelTester: ...@@ -201,7 +203,7 @@ class T5ModelTester:
# test that outputs are equal for slice # test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_t5_decoder_model_attention_mask_past( def create_and_check_decoder_model_attention_mask_past(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
): ):
model = T5Model(config=config).get_decoder() model = T5Model(config=config).get_decoder()
...@@ -245,7 +247,7 @@ class T5ModelTester: ...@@ -245,7 +247,7 @@ class T5ModelTester:
# test that outputs are equal for slice # test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_t5_and_check_t5_generate_with_past_key_value_states( def create_and_check_generate_with_past_key_value_states(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
): ):
model = T5ForConditionalGeneration(config=config).to(torch_device).eval() model = T5ForConditionalGeneration(config=config).to(torch_device).eval()
...@@ -257,13 +259,83 @@ class T5ModelTester: ...@@ -257,13 +259,83 @@ class T5ModelTester:
output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True) output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True)
self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache)) self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache))
def create_and_check_t5_model_fp16_forward( def create_and_check_model_fp16_forward(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
): ):
model = T5Model(config=config).to(torch_device).half().eval() model = T5Model(config=config).to(torch_device).half().eval()
output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)["last_hidden_state"] output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)["last_hidden_state"]
self.parent.assertFalse(torch.isnan(output).any().item()) self.parent.assertFalse(torch.isnan(output).any().item())
def create_and_check_encoder_decoder_shared_weights(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
for model_class in [T5Model, T5ForConditionalGeneration]:
torch.manual_seed(0)
model = model_class(config=config).to(torch_device).eval()
# load state dict copies weights but does not tie them
model.encoder.load_state_dict(model.decoder.state_dict(), strict=False)
torch.manual_seed(0)
tied_config = copy.deepcopy(config)
tied_config.tie_encoder_decoder = True
tied_model = model_class(config=tied_config).to(torch_device).eval()
model_result = model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
tied_model_result = tied_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
# check that models has less parameters
self.parent.assertLess(
sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters())
)
random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item()
# check that outputs are equal
self.parent.assertTrue(
torch.allclose(
model_result[0][0, :, random_slice_idx], tied_model_result[0][0, :, random_slice_idx], atol=1e-4
)
)
# check that outputs after saving and loading are equal
with tempfile.TemporaryDirectory() as tmpdirname:
tied_model.save_pretrained(tmpdirname)
tied_model = model_class.from_pretrained(tmpdirname)
tied_model.to(torch_device)
tied_model.eval()
# check that models has less parameters
self.parent.assertLess(
sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters())
)
random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item()
tied_model_result = tied_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
# check that outputs are equal
self.parent.assertTrue(
torch.allclose(
model_result[0][0, :, random_slice_idx],
tied_model_result[0][0, :, random_slice_idx],
atol=1e-4,
)
)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,) = config_and_inputs (config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,) = config_and_inputs
...@@ -299,30 +371,34 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -299,30 +371,34 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_prepare_lm_labels_via_shift_left(*config_and_inputs) self.model_tester.check_prepare_lm_labels_via_shift_left(*config_and_inputs)
def test_t5_model(self): def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
def test_with_lm_head(self): def test_with_lm_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_with_lm_head(*config_and_inputs) self.model_tester.create_and_check_with_lm_head(*config_and_inputs)
def test_t5_decoder_model_past(self): def test_decoder_model_past(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_decoder_model_past(*config_and_inputs) self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
def test_t5_decoder_model_past_with_attn_mask(self): def test_decoder_model_past_with_attn_mask(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_decoder_model_attention_mask_past(*config_and_inputs) self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
def test_t5_generate_with_past_key_value_states(self): def test_generate_with_past_key_value_states(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_t5_and_check_t5_generate_with_past_key_value_states(*config_and_inputs) self.model_tester.create_and_check_generate_with_past_key_value_states(*config_and_inputs)
def test_encoder_decoder_shared_weights(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_encoder_decoder_shared_weights(*config_and_inputs)
@unittest.skipIf(torch_device == "cpu", "Cant do half precision") @unittest.skipIf(torch_device == "cpu", "Cant do half precision")
def test_t5_model_fp16_forward(self): def test_model_fp16_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_model_fp16_forward(*config_and_inputs) self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
...@@ -331,8 +407,6 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -331,8 +407,6 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
self.assertIsNotNone(model) self.assertIsNotNone(model)
def test_export_to_onnx(self): def test_export_to_onnx(self):
import tempfile
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
model = T5Model(config_and_inputs[0]).to(torch_device) model = T5Model(config_and_inputs[0]).to(torch_device)
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment