Unverified Commit 554d333e authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix loss calculation in TFXXXForTokenClassification models (#15294)



* Fix loss calculation in TFFunnelForTokenClassification

* revert the change in TFFunnelForTokenClassification

* fix FunnelForTokenClassification loss

* fix other TokenClassification loss

* fix more

* fix more

* add num_labels to ElectraForTokenClassification

* revert the change to research projects
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 44c7857b
...@@ -1413,16 +1413,7 @@ class RemBertForTokenClassification(RemBertPreTrainedModel): ...@@ -1413,16 +1413,7 @@ class RemBertForTokenClassification(RemBertPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)
active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
......
...@@ -1414,16 +1414,7 @@ class RobertaForTokenClassification(RobertaPreTrainedModel): ...@@ -1414,16 +1414,7 @@ class RobertaForTokenClassification(RobertaPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)
active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
......
...@@ -1472,16 +1472,7 @@ class RoFormerForTokenClassification(RoFormerPreTrainedModel): ...@@ -1472,16 +1472,7 @@ class RoFormerForTokenClassification(RoFormerPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)
active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
......
...@@ -984,16 +984,7 @@ class SqueezeBertForTokenClassification(SqueezeBertPreTrainedModel): ...@@ -984,16 +984,7 @@ class SqueezeBertForTokenClassification(SqueezeBertPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)
active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
......
...@@ -1169,16 +1169,7 @@ class XLMForTokenClassification(XLMPreTrainedModel): ...@@ -1169,16 +1169,7 @@ class XLMForTokenClassification(XLMPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)
active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
......
...@@ -1680,16 +1680,7 @@ class XLNetForTokenClassification(XLNetPreTrainedModel): ...@@ -1680,16 +1680,7 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)
active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
......
...@@ -1455,16 +1455,7 @@ class {{cookiecutter.camelcase_modelname}}ForTokenClassification({{cookiecutter. ...@@ -1455,16 +1455,7 @@ class {{cookiecutter.camelcase_modelname}}ForTokenClassification({{cookiecutter.
loss = None loss = None
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)
active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment