Commit e8d609a8 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add new Masked LM task + criterion

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/761

Differential Revision: D16421335

Pulled By: myleott

fbshipit-source-id: 257d92c2b90361147642e2baa38486b4d18f6297
parent 654affc0
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import torch
import torch.nn.functional as F
from fairseq import utils
from . import FairseqCriterion, register_criterion
@register_criterion('masked_lm')
class MaskedLmLoss(FairseqCriterion):
"""
Implementation for the loss used in masked language model (MLM) training.
"""
def __init__(self, args, task):
super().__init__(args, task)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
# compute MLM loss
logits = model(**sample['net_input'], last_state_only=True)[0]
targets = model.get_targets(sample, [logits])
loss = F.nll_loss(
F.log_softmax(
logits.view(-1, logits.size(-1)),
dim=-1,
dtype=torch.float32,
),
targets.view(-1),
reduction='sum',
ignore_index=self.padding_idx,
)
sample_size = targets.ne(self.padding_idx).int().sum().item()
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample['nsentences'],
'sample_size': sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
loss = sum(log.get('loss', 0) for log in logging_outputs)
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
agg_output = {
'loss': loss / sample_size / math.log(2),
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size,
}
return agg_output
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import itertools
import os
import numpy as np
import torch
import torch.nn.functional as F
from fairseq.data import (
ConcatDataset,
data_utils,
Dictionary,
encoders,
IdDataset,
indexed_dataset,
MaskTokensDataset,
NestedDictionaryDataset,
NumelDataset,
NumSamplesDataset,
PadDataset,
PrependTokenDataset,
SortDataset,
TokenBlockDataset,
)
from fairseq.tasks import FairseqTask, register_task
@register_task('masked_lm')
class MaskedLMTask(FairseqTask):
"""Task for training masked language models (e.g., BERT, RoBERTa)."""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('data', help='colon separated path to data directories list, \
will be iterated upon during epochs in round-robin manner')
parser.add_argument('--sample-break-mode', default='complete',
choices=['none', 'complete', 'complete_doc', 'eos'],
help='If omitted or "none", fills each sample with tokens-per-sample '
'tokens. If set to "complete", splits samples only at the end '
'of sentence, but may include multiple sentences per sample. '
'"complete_doc" is similar but respects doc boundaries. '
'If set to "eos", includes only one sentence per sample.')
parser.add_argument('--tokens-per-sample', default=512, type=int,
help='max number of total tokens over all segments '
'per sample for BERT dataset')
parser.add_argument('--mask-prob', default=0.15, type=float,
help='probability of replacing a token with mask')
parser.add_argument('--leave-unmasked-prob', default=0.1, type=float,
help='probability that a masked token is unmasked')
parser.add_argument('--random-token-prob', default=0.1, type=float,
help='probability of replacing a token with a random token')
parser.add_argument('--freq-weighted-replacement', action='store_true',
help='sample random replacement words based on word frequencies')
parser.add_argument('--mask-whole-words', default=False, action='store_true',
help='mask whole words; you may also want to set --bpe')
def __init__(self, args, dictionary):
super().__init__(args)
self.dictionary = dictionary
self.seed = args.seed
# add mask token
self.mask_idx = dictionary.add_symbol('<mask>')
@classmethod
def setup_task(cls, args, **kwargs):
paths = args.data.split(':')
assert len(paths) > 0
dictionary = Dictionary.load(os.path.join(paths[0], 'dict.txt'))
print('| dictionary: {} types'.format(len(dictionary)))
return cls(args, dictionary)
def load_dataset(self, split, epoch=0, combine=False):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
paths = self.args.data.split(':')
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
split_path = os.path.join(data_path, split)
dataset = data_utils.load_indexed_dataset(
split_path,
self.source_dictionary,
self.args.dataset_impl,
combine=combine,
)
if dataset is None:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path))
# create continuous blocks of tokens
dataset = TokenBlockDataset(
dataset,
dataset.sizes,
self.args.tokens_per_sample - 1, # one less for <s>
pad=self.source_dictionary.pad(),
eos=self.source_dictionary.eos(),
break_mode=self.args.sample_break_mode,
)
# prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
# create masked input and targets
if self.args.mask_whole_words:
bpe = encoders.build_bpe(self.args)
if bpe is not None:
def is_beginning_of_word(i):
if i < self.source_dictionary.nspecial:
# special elements are always considered beginnings
return True
tok = self.source_dictionary[i]
if tok.startswith('madeupword'):
return True
try:
return bpe.is_beginning_of_word(tok)
except ValueError:
return True
mask_whole_words = torch.ByteTensor(list(
map(is_beginning_of_word, range(len(self.source_dictionary)))
))
else:
mask_whole_words = None
src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
dataset,
self.source_dictionary,
pad_idx=self.source_dictionary.pad(),
mask_idx=self.mask_idx,
seed=self.args.seed,
mask_prob=self.args.mask_prob,
leave_unmasked_prob=self.args.leave_unmasked_prob,
random_token_prob=self.args.random_token_prob,
freq_weighted_replacement=self.args.freq_weighted_replacement,
mask_whole_words=mask_whole_words,
)
with data_utils.numpy_seed(self.args.seed + epoch):
shuffle = np.random.permutation(len(src_dataset))
self.datasets[split] = SortDataset(
NestedDictionaryDataset(
{
'id': IdDataset(),
'net_input': {
'src_tokens': PadDataset(
src_dataset,
pad_idx=self.source_dictionary.pad(),
left_pad=False,
),
'src_lengths': NumelDataset(src_dataset, reduce=False),
},
'target': PadDataset(
tgt_dataset,
pad_idx=self.source_dictionary.pad(),
left_pad=False,
),
'nsentences': NumSamplesDataset(),
'ntokens': NumelDataset(src_dataset, reduce=True),
},
sizes=[src_dataset.sizes],
),
sort_order=[
shuffle,
src_dataset.sizes,
],
)
def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True):
if self.args.also_lowercase_words:
raise NotImplementedError
src_dataset = PadDataset(
TokenBlockDataset(
src_tokens,
src_lengths,
self.args.tokens_per_sample - 1, # one less for <s>
pad=self.source_dictionary.pad(),
eos=self.source_dictionary.eos(),
break_mode='eos',
),
pad_idx=self.source_dictionary.pad(),
left_pad=False,
)
src_dataset = PrependTokenDataset(src_dataset, self.source_dictionary.bos())
src_dataset = NestedDictionaryDataset(
{
'id': IdDataset(),
'net_input': {
'src_tokens': src_dataset,
'src_lengths': NumelDataset(src_dataset, reduce=False),
},
},
sizes=src_lengths,
)
if sort:
src_dataset = SortDataset(src_dataset, sort_order=[src_lengths])
return src_dataset
@property
def source_dictionary(self):
return self.dictionary
@property
def target_dictionary(self):
return self.dictionary
def get_average_masked_score(self, model, src_tokens, mask, **net_input):
"""Mask a set of tokens and return their average score."""
masked_tokens = src_tokens.clone()
masked_tokens[mask.byte()] = self.mask_idx
net_output = model(src_tokens=masked_tokens, **net_input, last_state_only=True)
lprobs = F.log_softmax(net_output[0], dim=-1, dtype=torch.float32)
lprobs = lprobs.gather(-1, src_tokens.unsqueeze(-1)).squeeze(-1)
mask = mask.type_as(lprobs)
score = (lprobs * mask).sum(dim=-1) / mask.sum(dim=-1)
return score
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment