Commit 2ed65b68 authored by Naman Goyal's avatar Naman Goyal Committed by Facebook Github Bot
Browse files

fixed corner case in mlm criterion when all tokens get masked

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/869

Reviewed By: myleott

Differential Revision: D17531776

Pulled By: myleott

fbshipit-source-id: 349c9449a0a7db5d3bb8449561302d4220cfa60c
parent 3f4fc501
......@@ -31,8 +31,17 @@ class MaskedLmLoss(FairseqCriterion):
"""
# compute MLM loss
masked_tokens = sample['target'].ne(self.padding_idx)
sample_size = masked_tokens.int().sum().item()
# (Rare case) When all tokens are masked, the model results in empty
# tensor and gives CUDA error.
if sample_size == 0:
masked_tokens = None
logits = model(**sample['net_input'], masked_tokens=masked_tokens)[0]
targets = model.get_targets(sample, [logits])
if sample_size != 0:
targets = targets[masked_tokens]
loss = F.nll_loss(
......@@ -45,9 +54,6 @@ class MaskedLmLoss(FairseqCriterion):
reduction='sum',
ignore_index=self.padding_idx,
)
sample_size = masked_tokens.int().sum().item()
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'nll_loss': utils.item(loss.data) if reduce else loss.data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment