"tests/test_tokenization_fast.py" did not exist on "ba8c4d0ac04acfcdbdeaed954f698d6d5ec3e532"
Unverified Commit eae7a96b authored by Ibraheem Moosa's avatar Ibraheem Moosa Committed by GitHub
Browse files

Optimize Token Classification models for TPU (#13096)

* Optimize Token Classification models for TPU

As per the XLA document XLA cannot handle masked indexing well. So token classification
models for BERT and others use an implementation based on `torch.where`. This implementation
works well on TPU. 

ALBERT token classification model uses the masked indexing which causes performance issues
on TPU. This PR fixes this issue by following the BERT implementation.

* Same fix for ELECTRA

* Same fix for LayoutLM
parent e02ed0ee
...@@ -1150,8 +1150,10 @@ class AlbertForTokenClassification(AlbertPreTrainedModel): ...@@ -1150,8 +1150,10 @@ class AlbertForTokenClassification(AlbertPreTrainedModel):
# Only keep active parts of the loss # Only keep active parts of the loss
if attention_mask is not None: if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1 active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)[active_loss] active_logits = logits.view(-1, self.num_labels)
active_labels = labels.view(-1)[active_loss] active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels) loss = loss_fct(active_logits, active_labels)
else: else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
......
...@@ -1259,8 +1259,10 @@ class ElectraForTokenClassification(ElectraPreTrainedModel): ...@@ -1259,8 +1259,10 @@ class ElectraForTokenClassification(ElectraPreTrainedModel):
# Only keep active parts of the loss # Only keep active parts of the loss
if attention_mask is not None: if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1 active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.config.num_labels)[active_loss] active_logits = logits.view(-1, self.config.num_labels)
active_labels = labels.view(-1)[active_loss] active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels) loss = loss_fct(active_logits, active_labels)
else: else:
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
......
...@@ -1173,8 +1173,10 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel): ...@@ -1173,8 +1173,10 @@ class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
if attention_mask is not None: if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1 active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)[active_loss] active_logits = logits.view(-1, self.num_labels)
active_labels = labels.view(-1)[active_loss] active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels) loss = loss_fct(active_logits, active_labels)
else: else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment