Commit 870d7163 authored by thomwolf's avatar thomwolf
Browse files

fixing target size in crossentropy losses

parent 982339d8
......@@ -678,8 +678,8 @@ class BertForPreTraining(PreTrainedBertModel):
if masked_lm_labels is not None and next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(prediction_scores, masked_lm_labels)
next_sentence_loss = loss_fct(seq_relationship_score, next_sentence_label)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels(-1))
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
total_loss = masked_lm_loss + next_sentence_loss
return total_loss
else:
......@@ -741,7 +741,7 @@ class BertForMaskedLM(PreTrainedBertModel):
if masked_lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(prediction_scores, masked_lm_labels)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
return masked_lm_loss
else:
return prediction_scores
......@@ -803,7 +803,7 @@ class BertForNextSentencePrediction(PreTrainedBertModel):
if next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
next_sentence_loss = loss_fct(seq_relationship_score, next_sentence_label)
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
return next_sentence_loss
else:
return seq_relationship_score
......@@ -856,6 +856,7 @@ class BertForSequenceClassification(PreTrainedBertModel):
"""
def __init__(self, config, num_labels=2):
super(BertForSequenceClassification, self).__init__(config)
self.num_labels = num_labels
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels)
......@@ -868,7 +869,7 @@ class BertForSequenceClassification(PreTrainedBertModel):
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits, labels)
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss, logits
else:
return logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment