Commit e1ffea87 authored by Naman Goyal's avatar Naman Goyal Committed by Facebook Github Bot
Browse files

added masked_lm task (#697)



Summary:
Co-authored-by: default avatarjingfeidu <jingfeidu@fb.com>

1) Adding `masked_lm` task for BERT like training. Code mostly taken from jingfeidu 's implementation.

2) Added `has_eos` option to `block_pair_dataset` for working with dataset that has been preprocessed with having `eos`.

Depends on: https://github.com/pytorch/fairseq/pull/696
Pull Request resolved: https://github.com/pytorch/fairseq/pull/697

Differential Revision: D15214050

fbshipit-source-id: c179ce2d70e59d2ddc941b13ceda99d929878931
parent 817fccf5
......@@ -32,6 +32,8 @@ class BlockPairDataset(FairseqDataset):
doc: respect document boundaries and each part of the pair should belong to on document
none: don't respect any boundary and cut tokens evenly
short_seq_prob: probability for generating shorter block pairs
doc_break_size: Size for empty line separating documents. Typically 1 if
the sentences have eos, 0 otherwise.
"""
def __init__(
......@@ -42,6 +44,7 @@ class BlockPairDataset(FairseqDataset):
block_size,
break_mode="doc",
short_seq_prob=0.1,
doc_break_size=1,
):
super().__init__()
self.dataset = dataset
......@@ -60,8 +63,12 @@ class BlockPairDataset(FairseqDataset):
if break_mode == "doc":
cur_doc = []
for sent_id, sz in enumerate(sizes):
assert doc_break_size == 0 or sz != 0, (
"when doc_break_size is non-zero, we expect documents to be"
"separated by a blank line with a single eos."
)
# empty line as document separator
if sz == 0:
if sz == doc_break_size:
if len(cur_doc) == 0:
continue
self.block_indices.append(cur_doc)
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import itertools
import numpy as np
import os
from fairseq import tokenizer
from fairseq.data import (
ConcatDataset,
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
data_utils,
)
from fairseq.data import Dictionary
from fairseq.data.block_pair_dataset import BlockPairDataset
from fairseq.data.masked_lm_dataset import MaskedLMDataset
from fairseq.data.masked_lm_dictionary import BertDictionary
from . import FairseqTask, register_task
@register_task('masked_lm')
class MaskedLMTask(FairseqTask):
"""
Task for training Masked LM (BERT) model.
Args:
dictionary (Dictionary): the dictionary for the input of the task
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('data', help='colon separated path to data directories list, \
will be iterated upon during epochs in round-robin manner')
parser.add_argument('--tokens-per-sample', default=512, type=int,
help='max number of total tokens over all segments'
' per sample for BERT dataset')
parser.add_argument('--raw-text', default=False, action='store_true',
help='load raw text dataset')
parser.add_argument('--break-mode', default="doc", type=str, help='mode for breaking sentence')
parser.add_argument('--lazy-load', action='store_true', help='load the dataset lazily')
def __init__(self, args, dictionary):
super().__init__(args)
self.dictionary = dictionary
self.seed = args.seed
@classmethod
def load_dictionary(cls, filename):
return BertDictionary.load(filename)
@classmethod
def build_dictionary(cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8):
d = BertDictionary()
for filename in filenames:
Dictionary.add_file_to_dictionary(filename, d, tokenizer.tokenize_line, workers)
d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor)
return d
@property
def target_dictionary(self):
return self.dictionary
@classmethod
def setup_task(cls, args, **kwargs):
"""Setup the task.
"""
paths = args.data.split(':')
assert len(paths) > 0
dictionary = BertDictionary.load(os.path.join(paths[0], 'dict.txt'))
print('| dictionary: {} types'.format(len(dictionary)))
return cls(args, dictionary)
def load_dataset(self, split, epoch=0, combine=False):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
loaded_datasets = []
paths = self.args.data.split(':')
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
print("| data_path", data_path)
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
path = os.path.join(data_path, split_k)
if self.args.raw_text and IndexedRawTextDataset.exists(path):
ds = IndexedRawTextDataset(path, self.dictionary)
elif not self.args.raw_text and IndexedDataset.exists(path):
if self.args.lazy_load:
ds = IndexedDataset(path, fix_lua_indexing=True)
else:
ds = IndexedCachedDataset(path, fix_lua_indexing=True)
else:
if k > 0:
break
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
with data_utils.numpy_seed(self.seed + k):
loaded_datasets.append(
BlockPairDataset(
ds,
self.dictionary,
ds.sizes,
self.args.tokens_per_sample,
break_mode=self.args.break_mode,
doc_break_size=1,
))
print('| {} {} {} examples'.format(data_path, split_k, len(loaded_datasets[-1])))
if not combine:
break
if len(loaded_datasets) == 1:
dataset = loaded_datasets[0]
sizes = dataset.sizes
else:
dataset = ConcatDataset(loaded_datasets)
sizes = np.concatenate([ds.sizes for ds in loaded_datasets])
self.datasets[split] = MaskedLMDataset(
dataset=dataset,
sizes=sizes,
vocab=self.dictionary,
pad_idx=self.dictionary.pad(),
mask_idx=self.dictionary.mask(),
classif_token_idx=self.dictionary.cls(),
sep_token_idx=self.dictionary.sep(),
shuffle=True,
seed=self.seed,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment