Commit a8ad8304 authored by VictorSanh's avatar VictorSanh
Browse files

fix bugs

parent 60c984da
...@@ -60,7 +60,7 @@ class DilBertConfig(PretrainedConfig): ...@@ -60,7 +60,7 @@ class DilBertConfig(PretrainedConfig):
attention_dropout=0.1, attention_dropout=0.1,
activation='gelu', activation='gelu',
initializer_range=0.02, initializer_range=0.02,
tie_weights=True, tie_weights_=True,
**kwargs): **kwargs):
super(DilBertConfig, self).__init__(**kwargs) super(DilBertConfig, self).__init__(**kwargs)
...@@ -82,7 +82,7 @@ class DilBertConfig(PretrainedConfig): ...@@ -82,7 +82,7 @@ class DilBertConfig(PretrainedConfig):
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.activation = activation self.activation = activation
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.tie_weights = tie_weights self.tie_weights_ = tie_weights_
else: else:
raise ValueError("First argument must be either a vocabulary size (int)" raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)") "or the path to a pretrained model config file (str)")
...@@ -274,13 +274,15 @@ class TransformerBlock(nn.Module): ...@@ -274,13 +274,15 @@ class TransformerBlock(nn.Module):
sa_output = self.attention(query=x, key=x, value=x, mask=attn_mask) sa_output = self.attention(query=x, key=x, value=x, mask=attn_mask)
if self.output_attentions: if self.output_attentions:
sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length) sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
else:
sa_output = sa_output[0]
sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim) sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim)
# Feed Forward Network # Feed Forward Network
ffn_output = self.ffn(sa_output) # (bs, seq_length, dim) ffn_output = self.ffn(sa_output) # (bs, seq_length, dim)
ffn_output = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim) ffn_output = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim)
output = (ffn_output) output = (ffn_output,)
if self.output_attentions: if self.output_attentions:
output = (sa_weights,) + output output = (sa_weights,) + output
return output return output
...@@ -468,36 +470,36 @@ class DilBertForMaskedLM(DilBertPreTrainedModel): ...@@ -468,36 +470,36 @@ class DilBertForMaskedLM(DilBertPreTrainedModel):
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.encoder = DilBertModel(config) self.dilbert = DilBertModel(config)
self.vocab_transform = nn.Linear(config.dim, config.dim) self.vocab_transform = nn.Linear(config.dim, config.dim)
self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12) self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
self.vocab_projector = nn.Linear(config.dim, config.vocab_size) self.vocab_projector = nn.Linear(config.dim, config.vocab_size)
self.apply(self.init_weights) self.apply(self.init_weights)
self.tie_weights_() self.tie_weights()
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1) self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
def tie_weights_(self): def tie_weights(self):
""" """
Tying the weights of the vocabulary projection to the base token embeddings. Tying the weights of the vocabulary projection to the base token embeddings.
""" """
if self.config.tie_weights: if self.config.tie_weights_:
self.vocab_projector.weight = self.encoder.embeddings.word_embeddings.weight self.vocab_projector.weight = self.dilbert.embeddings.word_embeddings.weight
def forward(self, def forward(self,
input_ids: torch.tensor, input_ids: torch.tensor,
attention_mask: torch.tensor = None, attention_mask: torch.tensor = None,
masked_lm_labels: torch.tensor = None): masked_lm_labels: torch.tensor = None):
tfmr_output = self.encoder(input_ids=input_ids, dlbrt_output = self.dilbert(input_ids=input_ids,
attention_mask=attention_mask) attention_mask=attention_mask)
hidden_states = tfmr_output[0] # (bs, seq_length, dim) hidden_states = dlbrt_output[0] # (bs, seq_length, dim)
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim) prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
prediction_logits = gelu(prediction_logits) # (bs, seq_length, dim) prediction_logits = gelu(prediction_logits) # (bs, seq_length, dim)
prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim) prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim)
prediction_logits = self.vocab_projector(prediction_logits) # (bs, seq_length, vocab_size) prediction_logits = self.vocab_projector(prediction_logits) # (bs, seq_length, vocab_size)
outputs = (prediction_logits, ) + tfmr_output[2:] outputs = (prediction_logits, ) + dlbrt_output[2:]
if masked_lm_labels is not None: if masked_lm_labels is not None:
mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)), mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)),
masked_lm_labels.view(-1)) masked_lm_labels.view(-1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment