Commit 02b74c58 authored by Jiatao Gu's avatar Jiatao Gu Committed by Facebook Github Bot
Browse files

fix the random mask function for CMLM model

Summary: The original implementation of the random mask is different from what the paper was stated.

Reviewed By: kahne

Differential Revision: D17652564

fbshipit-source-id: 238a9158041b3ff2482ee50ce6151c3f77f0b2c1
parent cce92bdd
......@@ -502,7 +502,7 @@ def add_generation_args(parser):
group.add_argument('--print-step', action='store_true')
# arguments for iterative refinement generator
group.add_argument('---iter-decode-eos-penalty', default=0.0, type=float, metavar='N',
group.add_argument('--iter-decode-eos-penalty', default=0.0, type=float, metavar='N',
help='if > 0.0, it penalized early-stopping in decoding.')
group.add_argument('--iter-decode-max-iter', default=10, type=int, metavar='N',
help='maximum iterations for iterative refinement.')
......
......@@ -5,6 +5,7 @@
import torch
from fairseq.utils import new_arange
from fairseq.tasks import register_task
from fairseq.tasks.translation import TranslationTask, load_langpair_dataset
......@@ -87,14 +88,19 @@ class TranslationLevenshteinTask(TranslationTask):
eos = self.tgt_dict.eos()
unk = self.tgt_dict.unk()
target_mask = target_tokens.eq(bos) | target_tokens.eq(
eos) | target_tokens.eq(pad)
target_masks = target_tokens.ne(pad) & \
target_tokens.ne(bos) & \
target_tokens.ne(eos)
target_score = target_tokens.clone().float().uniform_()
target_score.masked_fill_(target_mask, 1.0)
target_score.masked_fill_(~target_masks, 2.0)
target_length = target_masks.sum(1).float()
target_length = target_length * target_length.clone().uniform_()
target_length = target_length + 1 # make sure to mask at least one token.
_, target_rank = target_score.sort(1)
target_cutoff = new_arange(target_rank) < target_length[:, None].long()
prev_target_tokens = target_tokens.masked_fill(
target_score < target_score.new_zeros(target_score.size(0),
1).uniform_(), unk)
target_cutoff.scatter(1, target_rank, target_cutoff), unk)
return prev_target_tokens
def _full_mask(target_tokens):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment