Commit 32335404 authored by Naman Goyal's avatar Naman Goyal Committed by Facebook Github Bot
Browse files

added multilingual masked LM training (#849)

Summary:
The multilingual-RoBERTa training is working with aconneau XLM data.

Two pieces remaining:

1) `XLM` limits batch to be from same language, I am not 100% sure about the reason for that, but should be easy to implement, basically we can add `batch_by_size_and_language` instead of default `batch_by_size` function. If it's not critical, I would want to leave it out as it keeps the code very clean and simple.

2) `sample_ratio` in `ConcatDataset` works with `int` by tiling the datasets based on ratio. Currently I am handling it by sounding off the ratio to `first decimal` and then multiplying by `10`. We can see if some such simple heuristics are good enough, there are other options (we can talk about them offline).
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/849

Differential Revision: D17162460

fbshipit-source-id: d967f3d872f7a1f0aa4ea418bd362b68af9e432f
parent a8a85c26
......@@ -120,10 +120,10 @@ def load_checkpoint(args, trainer):
if extra_state is not None and not args.reset_dataloader:
# restore iterator from checkpoint
itr_state = extra_state['train_iterator']
epoch_itr = trainer.get_train_iterator(epoch=itr_state['epoch'])
epoch_itr = trainer.get_train_iterator(epoch=itr_state['epoch'], load_dataset=True)
epoch_itr.load_state_dict(itr_state)
else:
epoch_itr = trainer.get_train_iterator(epoch=0)
epoch_itr = trainer.get_train_iterator(epoch=0, load_dataset=True)
trainer.lr_step(epoch_itr.epoch)
......
......@@ -32,6 +32,7 @@ from .prepend_dataset import PrependDataset
from .prepend_token_dataset import PrependTokenDataset
from .raw_label_dataset import RawLabelDataset
from .replace_dataset import ReplaceDataset
from .resampling_dataset import ResamplingDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets
from .sharded_dataset import ShardedDataset
from .sort_dataset import SortDataset
......@@ -77,13 +78,14 @@ __all__ = [
'NoisingDataset',
'NumelDataset',
'NumSamplesDataset',
"OffsetTokensDataset",
'OffsetTokensDataset',
'PadDataset',
'PrependDataset',
'PrependTokenDataset',
'ReplaceDataset',
'FileAudioDataset',
"RawLabelDataset",
'RawLabelDataset',
'ResamplingDataset'
'RightPadDataset',
'RoundRobinZipDatasets',
'ShardedDataset',
......@@ -94,6 +96,6 @@ __all__ = [
'TokenBlockDataset',
'TransformEosDataset',
'TransformEosLangPairDataset',
"TruncateDataset",
'TruncateDataset',
'TruncatedDictionary',
]
......@@ -70,9 +70,15 @@ class ConcatDataset(FairseqDataset):
@property
def sizes(self):
return np.concatenate(
[np.tile(ds.sizes, sr) for ds, sr in zip(self.datasets, self.sample_ratios)]
)
_dataset_sizes = []
for ds, sr in zip(self.datasets, self.sample_ratios):
if isinstance(ds.sizes, np.ndarray):
_dataset_sizes.append(np.tile(ds.sizes, sr))
else:
# Only support underlying dataset with single size array.
assert isinstance(ds.sizes, list)
_dataset_sizes.append(np.tile(ds.sizes[0], sr))
return np.concatenate(_dataset_sizes)
@property
def supports_prefetch(self):
......
......@@ -79,6 +79,8 @@ class ResamplingDataset(BaseWrapperDataset):
@property
def sizes(self):
if isinstance(self.dataset.sizes, list):
return [s[self._cur_indices.array] for s in self.dataset.sizes]
return self.dataset.sizes[self._cur_indices.array]
def num_tokens(self, index):
......
......@@ -291,7 +291,8 @@ class RobertaEncoder(FairseqDecoder):
def extract_features(self, src_tokens, return_all_hiddens=False, **unused):
inner_states, _ = self.sentence_encoder(
src_tokens, last_state_only=not return_all_hiddens,
src_tokens,
last_state_only=not return_all_hiddens,
)
features = inner_states[-1]
return features, {'inner_states': inner_states if return_all_hiddens else None}
......@@ -332,3 +333,13 @@ def roberta_large_architecture(args):
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16)
base_architecture(args)
@register_model_architecture('roberta', 'xlm')
def xlm_architecture(args):
args.encoder_layers = getattr(args, 'encoder_layers', 16)
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1280)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1280*4)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16)
base_architecture(args)
......@@ -24,6 +24,7 @@ class FairseqTask(object):
def __init__(self, args):
self.args = args
self.datasets = {}
self.epoch_iter = None
@classmethod
def load_dictionary(cls, filename):
......@@ -124,6 +125,12 @@ class FairseqTask(object):
~fairseq.iterators.EpochBatchIterator: a batched iterator over the
given dataset split
"""
# For default fairseq task, return same iterator across epochs
# as datasets are not dynamic, can be overridden in task specific
# setting.
if self.epoch_iter is not None:
return self.epoch_iter
assert isinstance(dataset, FairseqDataset)
# initialize the dataset with the correct starting epoch
......@@ -146,7 +153,7 @@ class FairseqTask(object):
)
# return a reusable, sharded iterator
return iterators.EpochBatchIterator(
self.epoch_iter = iterators.EpochBatchIterator(
dataset=dataset,
collate_fn=dataset.collater,
batch_sampler=batch_sampler,
......@@ -156,6 +163,7 @@ class FairseqTask(object):
num_workers=num_workers,
epoch=epoch,
)
return self.epoch_iter
def build_model(self, args):
"""
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import numpy as np
import torch
from fairseq.data import (
data_utils,
Dictionary,
encoders,
ConcatDataset,
IdDataset,
MaskTokensDataset,
NestedDictionaryDataset,
NumelDataset,
NumSamplesDataset,
PadDataset,
PrependTokenDataset,
RawLabelDataset,
ResamplingDataset,
SortDataset,
TokenBlockDataset,
)
from fairseq.tasks import FairseqTask, register_task
@register_task('multilingual_masked_lm')
class MultiLingualMaskedLMTask(FairseqTask):
"""Task for training masked language models (e.g., BERT, RoBERTa)."""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('data', help='colon separated path to data directories list, \
will be iterated upon during epochs in round-robin manner')
parser.add_argument('--sample-break-mode', default='complete',
choices=['none', 'complete', 'complete_doc', 'eos'],
help='If omitted or "none", fills each sample with tokens-per-sample '
'tokens. If set to "complete", splits samples only at the end '
'of sentence, but may include multiple sentences per sample. '
'"complete_doc" is similar but respects doc boundaries. '
'If set to "eos", includes only one sentence per sample.')
parser.add_argument('--tokens-per-sample', default=512, type=int,
help='max number of total tokens over all segments '
'per sample for BERT dataset')
parser.add_argument('--mask-prob', default=0.15, type=float,
help='probability of replacing a token with mask')
parser.add_argument('--leave-unmasked-prob', default=0.1, type=float,
help='probability that a masked token is unmasked')
parser.add_argument('--random-token-prob', default=0.1, type=float,
help='probability of replacing a token with a random token')
parser.add_argument('--freq-weighted-replacement', action='store_true',
help='sample random replacement words based on word frequencies')
parser.add_argument('--mask-whole-words', default=False, action='store_true',
help='mask whole words; you may also want to set --bpe')
parser.add_argument('--multilang-sampling-alpha', type=float, default=1.0,
help='smoothing alpha for sample rations across multiple datasets')
def __init__(self, args, dictionary):
super().__init__(args)
self.dictionary = dictionary
self.seed = args.seed
# add mask token
self.mask_idx = dictionary.add_symbol('<mask>')
@classmethod
def setup_task(cls, args, **kwargs):
paths = args.data.split(':')
assert len(paths) > 0
dictionary = Dictionary.load(os.path.join(paths[0], 'dict.txt'))
print('| dictionary: {} types'.format(len(dictionary)))
return cls(args, dictionary)
def _get_whole_word_mask(self):
# create masked input and targets
if self.args.mask_whole_words:
bpe = encoders.build_bpe(self.args)
if bpe is not None:
def is_beginning_of_word(i):
if i < self.source_dictionary.nspecial:
# special elements are always considered beginnings
return True
tok = self.source_dictionary[i]
if tok.startswith('madeupword'):
return True
try:
return bpe.is_beginning_of_word(tok)
except ValueError:
return True
mask_whole_words = torch.ByteTensor(list(
map(is_beginning_of_word, range(len(self.source_dictionary)))
))
else:
mask_whole_words = None
return mask_whole_words
def _get_sample_prob(self, dataset_lens):
"""
Get smoothed sampling porbability by languages. This helps low resource
languages by upsampling them.
"""
prob = dataset_lens / dataset_lens.sum()
smoothed_prob = prob ** self.args.multilang_sampling_alpha
smoothed_prob = smoothed_prob / smoothed_prob.sum()
return smoothed_prob
def load_dataset(self, split, epoch=0, combine=False):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
paths = self.args.data.split(':')
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
languages = [
name for name in os.listdir(data_path)
if os.path.isdir(os.path.join(data_path, name))
]
print("| Training on {0} languages: {1}".format(len(languages), languages))
print("| Language to id mapping: ", {
lang: id for id, lang in enumerate(languages)
}
)
mask_whole_words = self._get_whole_word_mask()
lang_datasets = []
for lang_id, language in enumerate(languages):
split_path = os.path.join(data_path, language, split)
dataset = data_utils.load_indexed_dataset(
split_path,
self.source_dictionary,
self.args.dataset_impl,
combine=combine,
)
if dataset is None:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path))
# create continuous blocks of tokens
dataset = TokenBlockDataset(
dataset,
dataset.sizes,
self.args.tokens_per_sample - 1, # one less for <s>
pad=self.source_dictionary.pad(),
eos=self.source_dictionary.eos(),
break_mode=self.args.sample_break_mode,
)
print('| loaded {} blocks from: {}'.format(len(dataset), split_path))
# prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
dataset,
self.source_dictionary,
pad_idx=self.source_dictionary.pad(),
mask_idx=self.mask_idx,
seed=self.args.seed,
mask_prob=self.args.mask_prob,
leave_unmasked_prob=self.args.leave_unmasked_prob,
random_token_prob=self.args.random_token_prob,
freq_weighted_replacement=self.args.freq_weighted_replacement,
mask_whole_words=mask_whole_words,
)
lang_dataset = NestedDictionaryDataset(
{
'net_input': {
'src_tokens': PadDataset(
src_dataset,
pad_idx=self.source_dictionary.pad(),
left_pad=False,
),
'src_lengths': NumelDataset(src_dataset, reduce=False),
},
'target': PadDataset(
tgt_dataset,
pad_idx=self.source_dictionary.pad(),
left_pad=False,
),
'nsentences': NumSamplesDataset(),
'ntokens': NumelDataset(src_dataset, reduce=True),
'lang_id': RawLabelDataset([lang_id] * src_dataset.sizes.shape[0]),
},
sizes=[src_dataset.sizes],
)
lang_datasets.append(lang_dataset)
if split == self.args.train_subset:
# For train subset, additionally up or down sample languages.
dataset_lengths = np.array(
[len(d) for d in lang_datasets],
dtype=float,
)
sample_probs = self._get_sample_prob(dataset_lengths)
print("| Sample probability by language: ", {
lang: "{0:.4f}".format(sample_probs[id])
for id, lang in enumerate(languages)
}
)
size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths
print("| Up/Down Sampling ratio by language: ", {
lang: "{0:.2f}".format(size_ratio[id])
for id, lang in enumerate(languages)
}
)
resampled_lang_datasets = [
ResamplingDataset(
lang_datasets[i],
size_ratio=size_ratio[i],
seed=self.args.seed,
epoch=epoch,
replace=size_ratio[i] >= 1.0,
)
for i, d in enumerate(lang_datasets)
]
dataset = ConcatDataset(resampled_lang_datasets)
else:
dataset = ConcatDataset(lang_datasets)
lang_splits = [split]
for lang_id, lang_dataset in enumerate(lang_datasets):
split_name = split + '_' + languages[lang_id]
lang_splits.append(split_name)
self.datasets[split_name] = lang_dataset
# [TODO]: This is hacky for now to print validation ppl for each
# language individually. Maybe need task API changes to allow it
# in more generic ways.
if split in self.args.valid_subset:
self.args.valid_subset = self.args.valid_subset.replace(
split, ','.join(lang_splits)
)
with data_utils.numpy_seed(self.args.seed + epoch):
shuffle = np.random.permutation(len(dataset))
self.datasets[split] = SortDataset(
dataset,
sort_order=[
shuffle,
dataset.sizes,
],
)
def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True):
src_dataset = PadDataset(
TokenBlockDataset(
src_tokens,
src_lengths,
self.args.tokens_per_sample - 1, # one less for <s>
pad=self.source_dictionary.pad(),
eos=self.source_dictionary.eos(),
break_mode='eos',
),
pad_idx=self.source_dictionary.pad(),
left_pad=False,
)
src_dataset = PrependTokenDataset(src_dataset, self.source_dictionary.bos())
src_dataset = NestedDictionaryDataset(
{
'id': IdDataset(),
'net_input': {
'src_tokens': src_dataset,
'src_lengths': NumelDataset(src_dataset, reduce=False),
},
},
sizes=src_lengths,
)
if sort:
src_dataset = SortDataset(src_dataset, sort_order=[src_lengths])
return src_dataset
def get_batch_iterator(
self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
ignore_invalid_inputs=False, required_batch_size_multiple=1,
seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=0,
):
# Recreate epoch iterator every epoch cause the underlying
# datasets are dynamic due to sampling.
self.epoch_iter = None
return super().get_batch_iterator(
dataset, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs, required_batch_size_multiple,
seed, num_shards, shard_id, num_workers, epoch,
)
@property
def source_dictionary(self):
return self.dictionary
@property
def target_dictionary(self):
return self.dictionary
......@@ -225,8 +225,9 @@ class Trainer(object):
return extra_state
def get_train_iterator(self, epoch, combine=True):
def get_train_iterator(self, epoch, combine=True, load_dataset=True):
"""Return an EpochBatchIterator over the training set for a given epoch."""
if load_dataset:
print('| loading train data for epoch {}'.format(epoch))
self.task.load_dataset(self.args.train_subset, epoch=epoch, combine=combine)
return self.task.get_batch_iterator(
......
......@@ -92,9 +92,9 @@ def main(args, init_distributed=False):
if epoch_itr.epoch % args.save_interval == 0:
checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
if ':' in getattr(args, 'data', ''):
reload_dataset = ':' in getattr(args, 'data', '')
# sharded data: get train iterator for next epoch
epoch_itr = trainer.get_train_iterator(epoch_itr.epoch)
epoch_itr = trainer.get_train_iterator(epoch_itr.epoch, load_dataset=reload_dataset)
train_meter.stop()
print('| done training in {:.1f} seconds'.format(train_meter.sum))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment