"vscode:/vscode.git/clone" did not exist on "25b0463d0ba3fcbcf7fff8aa4027a2d8e959364b"
Commit 6c2ee16c authored by LysandreJik's avatar LysandreJik
Browse files

Test suite testing the tie_weights function as well as the resize_token_embeddings function.

Patched an issue relating to the tied weights I had introduced with the TorchScript addition.
Byte order mark management in TSV glue reading.
parent bd404735
......@@ -78,7 +78,7 @@ class DataProcessor(object):
@classmethod
def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file."""
with open(input_file, "r", encoding="utf-8") as f:
with open(input_file, "r", encoding="utf-8-sig") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = []
for line in reader:
......
......@@ -762,7 +762,7 @@ class BertForPreTraining(BertPreTrainedModel):
if self.config.torchscript:
self.cls.predictions.decoder.weight = nn.Parameter(input_embeddings.clone())
else:
self.cls.predictions.decoder.weight = input_embeddings # Tied weights
self.cls.predictions.decoder = self.bert.embeddings.word_embeddings # Tied weights
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
next_sentence_label=None, head_mask=None):
......@@ -868,7 +868,7 @@ class BertForMaskedLM(BertPreTrainedModel):
if self.config.torchscript:
self.cls.predictions.decoder.weight = nn.Parameter(input_embeddings.clone())
else:
self.cls.predictions.decoder.weight = input_embeddings # Tied weights
self.cls.predictions.decoder = self.bert.embeddings.word_embeddings # Tied weights
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None):
"""
......
......@@ -566,7 +566,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
if self.config.torchscript:
self.lm_head.weight = nn.Parameter(input_embeddings.clone())
else:
self.lm_head.weight = input_embeddings # Tied weights
self.lm_head = self.transformer.wte # Tied weights
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None, head_mask=None):
"""
......@@ -662,7 +662,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
if self.config.torchscript:
self.lm_head.weight = nn.Parameter(input_embeddings.clone())
else:
self.lm_head.weight = input_embeddings # Tied weights
self.lm_head = self.transformer.wte # Tied weights
def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None,
position_ids=None, past=None, head_mask=None):
......
......@@ -587,7 +587,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
if self.config.torchscript:
self.lm_head.weight = nn.Parameter(input_embeddings.clone())
else:
self.lm_head.weight = input_embeddings # Tied weights
self.lm_head = self.transformer.tokens_embed # Tied weights
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, head_mask=None):
"""
......@@ -700,7 +700,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
if self.config.torchscript:
self.lm_head.weight = nn.Parameter(input_embeddings.clone())
else:
self.lm_head.weight = input_embeddings # Tied weights
self.lm_head = self.transformer.tokens_embed # Tied weights
def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None,
position_ids=None, head_mask=None):
......
......@@ -29,6 +29,7 @@ import torch
from pytorch_transformers import PretrainedConfig, PreTrainedModel
from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from pytorch_transformers.modeling_gpt2 import GPT2LMHeadModel, GPT2Config, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
def _config_zero_init(config):
......@@ -470,6 +471,79 @@ class ModelUtilsTest(unittest.TestCase):
self.assertEqual(model.config.output_hidden_states, True)
self.assertEqual(model.config, config)
def test_resize_tokens_embeddings(self):
logging.basicConfig(level=logging.INFO)
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
config = BertConfig.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)
model_vocab_size = config.vocab_size
# Retrieve the embeddings and clone theme
cloned_embeddings = model.embeddings.word_embeddings.weight.clone()
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
model.resize_token_embeddings(model_vocab_size + 10)
self.assertEqual(model.config.vocab_size, model_vocab_size + 10)
# Check that it actually resizes the embeddings matrix
self.assertEqual(model.embeddings.word_embeddings.weight.shape[0], cloned_embeddings.shape[0] + 10)
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
model.resize_token_embeddings(model_vocab_size)
self.assertEqual(model.config.vocab_size, model_vocab_size)
# Check that it actually resizes the embeddings matrix
self.assertEqual(model.embeddings.word_embeddings.weight.shape[0], cloned_embeddings.shape[0])
# Check that adding and removing tokens has not modified the first part of the embedding matrix.
models_equal = True
for p1, p2 in zip(cloned_embeddings, model.embeddings.word_embeddings.weight):
if p1.data.ne(p2.data).sum() > 0:
models_equal = False
self.assertTrue(models_equal)
def test_tie_model_weights(self):
logging.basicConfig(level=logging.INFO)
def check_same_values(layer_1, layer_2):
equal = True
for p1, p2 in zip(layer_1.weight, layer_2.weight):
if p1.data.ne(p2.data).sum() > 0:
equal = False
return equal
for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
config = GPT2Config.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
# Get the embeddings and decoding layer
embeddings = model.transformer.wte
decoding = model.lm_head
# Check that the embedding layer and decoding layer are the same in size and in value
self.assertTrue(embeddings.weight.shape, decoding.weight.shape)
self.assertTrue(check_same_values(embeddings, decoding))
# Check that after modification, they remain the same.
embeddings.weight.data.div_(2)
# Check that the embedding layer and decoding layer are the same in size and in value
self.assertTrue(embeddings.weight.shape, decoding.weight.shape)
self.assertTrue(check_same_values(embeddings, decoding))
# Check that after modification, they remain the same.
decoding.weight.data.div_(4)
# Check that the embedding layer and decoding layer are the same in size and in value
self.assertTrue(embeddings.weight.shape, decoding.weight.shape)
self.assertTrue(check_same_values(embeddings, decoding))
# Check that after resize they remain tied.
model.resize_token_embeddings(config.vocab_size + 10)
decoding.weight.data.mul_(20)
# Check that the embedding layer and decoding layer are the same in size and in value
self.assertTrue(embeddings.weight.shape, decoding.weight.shape)
self.assertTrue(check_same_values(embeddings, decoding))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment