Unverified Commit a81cf9ee authored by Hovnatan Karapetyan's avatar Hovnatan Karapetyan Committed by GitHub
Browse files

Fix 29807, sinusoidal positional encodings overwritten by post_init() (#29813)

* Check for requires_grad when initing weights

* Add unit test

* Move sinusoidal positional encoding generation after post_init()

* Add modules to skip init list

* Move create_sinusoidal_embeddings to _init_weights
parent cefb819f
......@@ -106,10 +106,6 @@ class Embeddings(nn.Module):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim)
if config.sinusoidal_pos_embds:
create_sinusoidal_embeddings(
n_pos=config.max_position_embeddings, dim=config.dim, out=self.position_embeddings.weight
)
self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
self.dropout = nn.Dropout(config.dropout)
......@@ -634,6 +630,10 @@ class DistilBertPreTrainedModel(PreTrainedModel):
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, Embeddings) and self.config.sinusoidal_pos_embds:
create_sinusoidal_embeddings(
self.config.max_position_embeddings, self.config.dim, module.position_embeddings.weight
)
DISTILBERT_START_DOCSTRING = r"""
......
......@@ -37,6 +37,7 @@ if is_torch_available():
DistilBertForTokenClassification,
DistilBertModel,
)
from transformers.models.distilbert.modeling_distilbert import _create_sinusoidal_embeddings
class DistilBertModelTester(object):
......@@ -238,6 +239,15 @@ class DistilBertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_distilbert_model(*config_and_inputs)
def test_distilbert_model_with_sinusoidal_encodings(self):
config = DistilBertConfig(sinusoidal_pos_embds=True)
model = DistilBertModel(config=config)
sinusoidal_pos_embds = torch.empty((config.max_position_embeddings, config.dim), dtype=torch.float32)
_create_sinusoidal_embeddings(config.max_position_embeddings, config.dim, sinusoidal_pos_embds)
self.model_tester.parent.assertTrue(
torch.equal(model.embeddings.position_embeddings.weight, sinusoidal_pos_embds)
)
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_distilbert_for_masked_lm(*config_and_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment