"examples/vscode:/vscode.git/clone" did not exist on "f6c0680d36236bd149e68ed2ee640acbcd2f09ef"
Unverified Commit bd746326 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #251 from Iwontbecreative/active_loss_tok_classif

Only keep the active part mof the loss for token classification
parents fd223374 f3bda235
......@@ -1025,7 +1025,14 @@ class BertForTokenClassification(PreTrainedBertModel):
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)[active_loss]
active_labels = labels.view(-1)[active_loss]
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss
else:
return logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment