Commit a725db4f authored by thomwolf's avatar thomwolf
Browse files

fixing BertForQuestionAnswering loss computation

parent bb5ce67a
...@@ -384,16 +384,16 @@ class BertForSequenceClassification(nn.Module): ...@@ -384,16 +384,16 @@ class BertForSequenceClassification(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels) self.classifier = nn.Linear(config.hidden_size, num_labels)
def init_weights(m): def init_weights(module):
if isinstance(m, (nn.Linear, nn.Embedding)): if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
m.weight.data.normal_(config.initializer_range) module.weight.data.normal_(config.initializer_range)
elif isinstance(m, BERTLayerNorm): elif isinstance(module, BERTLayerNorm):
m.beta.data.normal_(config.initializer_range) module.beta.data.normal_(config.initializer_range)
m.gamma.data.normal_(config.initializer_range) module.gamma.data.normal_(config.initializer_range)
if isinstance(m, nn.Linear): if isinstance(module, nn.Linear):
m.bias.data.zero_() module.bias.data.zero_()
self.apply(init_weights) self.apply(init_weights)
def forward(self, input_ids, token_type_ids, attention_mask, labels=None): def forward(self, input_ids, token_type_ids, attention_mask, labels=None):
...@@ -434,16 +434,16 @@ class BertForQuestionAnswering(nn.Module): ...@@ -434,16 +434,16 @@ class BertForQuestionAnswering(nn.Module):
# self.dropout = nn.Dropout(config.hidden_dropout_prob) # self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.qa_outputs = nn.Linear(config.hidden_size, 2) self.qa_outputs = nn.Linear(config.hidden_size, 2)
def init_weights(m): def init_weights(module):
if isinstance(m, (nn.Linear, nn.Embedding)): if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
m.weight.data.normal_(config.initializer_range) module.weight.data.normal_(config.initializer_range)
elif isinstance(m, BERTLayerNorm): elif isinstance(module, BERTLayerNorm):
m.beta.data.normal_(config.initializer_range) module.beta.data.normal_(config.initializer_range)
m.gamma.data.normal_(config.initializer_range) module.gamma.data.normal_(config.initializer_range)
if isinstance(m, nn.Linear): if isinstance(module, nn.Linear):
m.bias.data.zero_() module.bias.data.zero_()
self.apply(init_weights) self.apply(init_weights)
def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None): def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None):
...@@ -451,21 +451,13 @@ class BertForQuestionAnswering(nn.Module): ...@@ -451,21 +451,13 @@ class BertForQuestionAnswering(nn.Module):
sequence_output = all_encoder_layers[-1] sequence_output = all_encoder_layers[-1]
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
batch_size, seq_length = input_ids.size() loss_fct = CrossEntropyLoss()
start_loss = loss_fct(start_logits, start_positions)
def compute_loss(logits, positions): end_loss = loss_fct(end_logits, end_positions)
max_position = positions.max().item()
one_hot = torch.FloatTensor(batch_size, max(max_position, seq_length) +1).zero_()
one_hot = one_hot.scatter_(1, positions.cpu(), 1) # Do this on CPU
one_hot = one_hot[:, :seq_length].to(input_ids.device)
log_probs = nn.functional.log_softmax(logits, dim = -1).view(batch_size, seq_length)
loss = -torch.mean(torch.sum(one_hot*log_probs), dim = -1)
return loss
start_loss = compute_loss(start_logits, start_positions)
end_loss = compute_loss(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2 total_loss = (start_loss + end_loss) / 2
return total_loss, (start_logits, end_logits) return total_loss, (start_logits, end_logits)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment