Unverified Commit e8ce63ff authored by srush's avatar srush Committed by GitHub
Browse files

Change masking to direct labeling for TPU support. (#2982)

* change masking to direct labelings

* fix black

* switch to ignore index

* .

* fix black
parent 7a7ee28c
...@@ -1382,8 +1382,10 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1382,8 +1382,10 @@ class BertForTokenClassification(BertPreTrainedModel):
# Only keep active parts of the loss # Only keep active parts of the loss
if attention_mask is not None: if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1 active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)[active_loss] active_logits = logits.view(-1, self.num_labels)
active_labels = labels.view(-1)[active_loss] active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels) loss = loss_fct(active_logits, active_labels)
else: else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
......
...@@ -818,8 +818,10 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel): ...@@ -818,8 +818,10 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
# Only keep active parts of the loss # Only keep active parts of the loss
if attention_mask is not None: if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1 active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)[active_loss] active_logits = logits.view(-1, self.num_labels)
active_labels = labels.view(-1)[active_loss] active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels) loss = loss_fct(active_logits, active_labels)
else: else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
......
...@@ -542,13 +542,16 @@ class RobertaForTokenClassification(BertPreTrainedModel): ...@@ -542,13 +542,16 @@ class RobertaForTokenClassification(BertPreTrainedModel):
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss # Only keep active parts of the loss
if attention_mask is not None: if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1 active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)[active_loss] active_logits = logits.view(-1, self.num_labels)
active_labels = labels.view(-1)[active_loss] active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels) loss = loss_fct(active_logits, active_labels)
else: else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
......
...@@ -1264,8 +1264,10 @@ class XLNetForTokenClassification(XLNetPreTrainedModel): ...@@ -1264,8 +1264,10 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
# Only keep active parts of the loss # Only keep active parts of the loss
if attention_mask is not None: if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1 active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)[active_loss] active_logits = logits.view(-1, self.num_labels)
active_labels = labels.view(-1)[active_loss] active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels) loss = loss_fct(active_logits, active_labels)
else: else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment