Commit 100e3b6f authored by Lysandre's avatar Lysandre Committed by Lysandre Debut
Browse files

Bias should be resized with the weights

Created a link between the linear layer bias and the model attribute bias. This does not change anything for the user nor for the conversion scripts, but allows the `resize_token_embeddings` method to resize the bias as well as the weights of the decoder.

Added a test.
parent 6c32d8bb
...@@ -579,6 +579,9 @@ class AlbertMLMHead(nn.Module): ...@@ -579,6 +579,9 @@ class AlbertMLMHead(nn.Module):
self.decoder = nn.Linear(config.embedding_size, config.vocab_size) self.decoder = nn.Linear(config.embedding_size, config.vocab_size)
self.activation = ACT2FN[config.hidden_act] self.activation = ACT2FN[config.hidden_act]
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = self.activation(hidden_states) hidden_states = self.activation(hidden_states)
......
...@@ -481,6 +481,9 @@ class BertLMPredictionHead(nn.Module): ...@@ -481,6 +481,9 @@ class BertLMPredictionHead(nn.Module):
self.bias = nn.Parameter(torch.zeros(config.vocab_size)) self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.transform(hidden_states) hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states) + self.bias hidden_states = self.decoder(hidden_states) + self.bias
......
...@@ -306,6 +306,9 @@ class RobertaLMHead(nn.Module): ...@@ -306,6 +306,9 @@ class RobertaLMHead(nn.Module):
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size)) self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, features, **kwargs): def forward(self, features, **kwargs):
x = self.dense(features) x = self.dense(features)
x = gelu(x) x = gelu(x)
......
...@@ -487,6 +487,8 @@ class ModelTesterMixin: ...@@ -487,6 +487,8 @@ class ModelTesterMixin:
self.assertEqual(model.config.vocab_size, model_vocab_size + 10) self.assertEqual(model.config.vocab_size, model_vocab_size + 10)
# Check that it actually resizes the embeddings matrix # Check that it actually resizes the embeddings matrix
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10) self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**inputs_dict)
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
model_embed = model.resize_token_embeddings(model_vocab_size - 15) model_embed = model.resize_token_embeddings(model_vocab_size - 15)
...@@ -494,6 +496,11 @@ class ModelTesterMixin: ...@@ -494,6 +496,11 @@ class ModelTesterMixin:
# Check that it actually resizes the embeddings matrix # Check that it actually resizes the embeddings matrix
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15) self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
# Input ids should be clamped to the maximum size of the vocabulary
inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1)
model(**inputs_dict)
# Check that adding and removing tokens has not modified the first part of the embedding matrix. # Check that adding and removing tokens has not modified the first part of the embedding matrix.
models_equal = True models_equal = True
for p1, p2 in zip(cloned_embeddings, model_embed.weight): for p1, p2 in zip(cloned_embeddings, model_embed.weight):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment