Commit 8776928c authored by Kartikay Khandelwal's avatar Kartikay Khandelwal Committed by Facebook Github Bot
Browse files

Open Source MLM Implementation in Fairseq (#635)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/635

Adding a task and relevant models, datasets and criteria needed for training Cross-lingual Language Models similar to Masked Language Model used in XLM (Lample and Conneau, 2019 - https://arxiv.org/abs/1901.07291).

Reviewed By: liezl200

Differential Revision: D14943776

fbshipit-source-id: 3e416a730303d1dd4f5b92550c78db989be27073
parent 303b95ce
# Cross-Lingual Language Model Pre-training
Below are some details for training Cross-Lingual Language Models (XLM) - similar to the ones presented in [Lample & Conneau, 2019](https://arxiv.org/pdf/1901.07291.pdf) - in Fairseq. The current implementation only supports the Masked Language Model (MLM) from the paper above.
## Downloading and Tokenizing Monolingual Data
Pointers to the monolingual data from wikipedia, used for training the XLM-style MLM model as well as details on processing (tokenization and BPE) it can be found in the [XLM Github Repository](https://github.com/facebookresearch/XLM#download--preprocess-monolingual-data).
Let's assume the following for the code snippets in later sections to work
- Processed data is in the folder: monolingual_data/processed
- Each language has 3 files for train, test and validation. For example we have the following files for English:
train.en, valid.en
- We are training a model for 5 languages: Arabic (ar), German (de), English (en), Hindi (hi) and French (fr)
- The vocabulary file is monolingual_data/processed/vocab_mlm
## Fairseq Pre-processing and Binarization
Pre-process and binarize the data with the MaskedLMDictionary and cross_lingual_lm task
```
# Ensure the output directory exists
mkdir -p monolingual_data/fairseq_processed
for lg in ar de en hi fr
do
fairseq-preprocess -- \
--task cross_lingual_lm \
--srcdict monolingual_data/processed/vocab_mlm \
--only-source \
--trainpref monolingual_data/processed/train \
--validpref monolingual_data/processed/valid \
--testpref monolingual_data/processed/test \
--destdir monolingual_data/fairseq_processed \
--workers 20 \
--source-lang $lg
# Since we only have a source language, the output file has a None for the
# target language. Remove this
for stage in train test valid
sudo mv $stage.$lg-None.$lg.bin $stage.$lg.bin
sudo mv $stage.$lg-None.$lg.idx $stage.$lg.idx
done
done
```
## Train a Cross-lingual Language Model similar to the XLM MLM model
Use the following command to train the model on 5 languages.
```
fairseq-train -- \
--task cross_lingual_lm monolingual_data/processed \
--save-dir checkpoints/mlm
--max-update 2400000 --save-interval 1 --no-epoch-checkpoints \
--arch xlm_base \
--optimizer adam --lr-scheduler reduce_lr_on_plateau \
--lr-shrink 0.5 --lr 0.0001 --min-lr 1e-09 \
--dropout 0.1 \
--criterion masked_lm_loss \
--max-tokens 2048 --tokens-per-sample 256 --no-bias-kv --attention-dropout 0.1 \
--lazy-load --seed 0 \
--masked-lm-only \
--monolingual-langs 'ar,de,en,hi,fr' --num-segment 5 \
--ddp-backend=no_c10d
```
Some Notes:
- Using tokens_per_sample greater than 256 can cause OOM (out-of-memory) issues. Usually since MLM packs in streams of text, this parameter doesn't need much tuning.
- The Evaluation workflow for computing MLM Perplexity on test data is in progress.
- Finetuning this model on a downstream task is something which is not currently available.
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import torch.nn.functional as F
from fairseq import utils
from . import FairseqCriterion, register_criterion
def compute_cross_entropy_loss(logits, targets, ignore_index=-100):
"""
Function to compute the cross entropy loss. The default value of
ignore_index is the same as the default value for F.cross_entropy in
pytorch.
"""
assert logits.size(0) == targets.size(-1), \
"Logits and Targets tensor shapes don't match up"
loss = F.cross_entropy(
logits,
targets,
reduction="sum",
ignore_index=ignore_index,
)
return loss
@register_criterion('masked_lm_loss')
class MaskedLmLoss(FairseqCriterion):
"""
Implementation for the loss used in masked language model (MLM) training.
This optionally also computes the next sentence prediction (NSP) loss and
adds it to the overall loss based on the specified args. There are three
cases to consider:
1) Generic MLM training without NSP loss. In this case sentence_targets
and sentence_logits are both None.
2) BERT training without NSP loss. In this case sentence_targets is
not None but sentence_logits is None and we should not be computing
a sentence level loss.
3) BERT training with NSP loss. In this case both sentence_targets and
sentence_logits are not None and we should be computing a sentence
level loss. The weight of the sentence level loss is specified as
an argument.
"""
def __init__(self, args, task):
super().__init__(args, task)
@staticmethod
def add_args(parser):
"""Args for MaskedLM Loss"""
# Default for masked_lm_only is False so as to not break BERT training
parser.add_argument('--masked-lm-only', default=False,
action='store_true', help='compute MLM loss only')
parser.add_argument('--nsp-loss-weight', default=1.0, type=float,
help='weight for next sentence prediction'
' loss (default 1)')
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
lm_logits, output_metadata = model(**sample["net_input"])
# reshape lm_logits from (N,T,C) to (N*T,C)
lm_logits = lm_logits.view(-1, lm_logits.size(-1))
lm_targets = sample['lm_target'].view(-1)
lm_loss = compute_cross_entropy_loss(
lm_logits, lm_targets, self.padding_idx)
# compute the number of tokens for which loss is computed. This is used
# to normalize the loss
ntokens = utils.strip_pad(lm_targets, self.padding_idx).numel()
loss = lm_loss / ntokens
nsentences = sample['nsentences']
# nsentences = 0
# Compute sentence loss if masked_lm_only is False
sentence_loss = None
if not self.args.masked_lm_only:
sentence_logits = output_metadata['sentence_logits']
sentence_targets = sample['sentence_target'].view(-1)
# This needs to be recomputed due to some differences between
# TokenBlock and BlockPair dataset. This can be resolved with a
# refactor of BERTModel which we will do in the future.
# TODO: Remove this after refactor of BERTModel
nsentences = sentence_targets.size(0)
# Check for logits being none which can happen when remove_heads
# is set to true in the BERT model. Ideally we should set
# masked_lm_only to true in this case, but that requires some
# refactor in the BERT model.
if sentence_logits is not None:
sentence_loss = compute_cross_entropy_loss(
sentence_logits, sentence_targets)
loss += self.args.nsp_loss_weight * (sentence_loss / nsentences)
# NOTE: as we are summing up per token mlm loss and per sentence nsp loss
# we don't need to use sample_size as denominator for the gradient
# here sample_size is just used for logging
sample_size = 1
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'lm_loss': utils.item(lm_loss.data) if reduce else lm_loss.data,
# sentence loss is not always computed
'sentence_loss': (
(
utils.item(sentence_loss.data) if reduce
else sentence_loss.data
) if sentence_loss is not None else 0.0
),
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
lm_loss_sum = sum(log.get('lm_loss', 0) for log in logging_outputs)
sentence_loss_sum = sum(
log.get('sentence_loss', 0) for log in logging_outputs)
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
agg_loss = sum(log.get('loss', 0) for log in logging_outputs)
agg_output = {
'loss': agg_loss / sample_size / math.log(2),
'lm_loss': lm_loss_sum / ntokens / math.log(2),
'sentence_loss': sentence_loss_sum / nsentences / math.log(2),
'nll_loss': lm_loss_sum / ntokens / math.log(2),
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size,
}
return agg_output
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import numpy as np
import torch
from typing import Dict, List, Tuple, Union
from . import FairseqDataset, data_utils
from fairseq.data import Dictionary
from fairseq.data.fb_block_pair_dataset import BlockPairDataset
from fairseq.data.token_block_dataset import TokenBlockDataset
class MaskedLMDataset(FairseqDataset):
"""
A wrapper Dataset for masked language modelling. The dataset
wraps around TokenBlockDataset or BlockedPairDataset and creates a batch
where the input blocks are masked according to the specified masking
probability. Additionally the batch can also contain sentence level targets
if this is specified.
Args:
dataset: Dataset which generates blocks of data. Only BlockPairDataset
and TokenBlockDataset are supported.
sizes: Sentence lengths
vocab: Dictionary with the vocabulary and special tokens.
pad_idx: Id of padding token in dictionary
mask_idx: Id of mask token in dictionary
classif_token_idx: Id of classification token in dictionary. This is the
token associated with the sentence embedding (Eg: CLS for BERT)
sep_token_idx: Id of separator token in dictionary
(Eg: SEP in BERT)
seed: Seed for random number generator for reproducibility.
shuffle: Shuffle the elements before batching.
has_pairs: Specifies whether the underlying dataset
generates a pair of blocks along with a sentence_target or not.
Setting it to True assumes that the underlying dataset generates a
label for the pair of sentences which is surfaced as
sentence_target. The default value assumes a single block with no
sentence target.
segment_id: An optional segment id for filling in the segment labels
when we are in the single block setting (Eg: XLM). Default is 0.
masking_ratio: specifies what percentage of the blocks should be masked.
masking_prob: specifies the probability of a given token being
replaced with the "MASK" token.
random_token_prob: specifies the probability of a given token being
replaced by a random token from the vocabulary.
unchanged_prob: specifies the probability of keeping a given
token unchanged.
"""
def __init__(
self,
dataset: FairseqDataset,
sizes: np.ndarray,
vocab: Dictionary,
pad_idx: int,
mask_idx: int,
classif_token_idx: int,
sep_token_idx: int,
seed: int = 1,
shuffle: bool = True,
has_pairs: bool = True,
segment_id: int = 0,
masking_ratio: float = 0.15,
masking_prob: float = 0.8,
random_token_prob: float = 0.1
):
# Make sure the input datasets are the ones supported
assert (
isinstance(dataset, TokenBlockDataset) or
isinstance(dataset, BlockPairDataset)
), "MaskedLMDataset only wraps TokenBlockDataset or BlockPairDataset"
self.dataset = dataset
self.sizes = np.array(sizes)
self.vocab = vocab
self.pad_idx = pad_idx
self.mask_idx = mask_idx
self.classif_token_idx = classif_token_idx
self.sep_token_idx = sep_token_idx
self.shuffle = shuffle
self.seed = seed
self.has_pairs = has_pairs
self.segment_id = segment_id
self.masking_ratio = masking_ratio
self.masking_prob = masking_prob
self.random_token_prob = random_token_prob
# If we have only one block then sizes needs to be updated to include
# the classification token
if not has_pairs:
self.sizes = self.sizes + 1
def __getitem__(
self,
index: int
):
# if has_pairs, then expect 2 blocks and a sentence target
if self.has_pairs:
(block_one, block_two, sentence_target) = self.dataset[index]
else:
block_one = self.dataset[index]
return {
"id": index,
"block_one": block_one,
"block_two": block_two if self.has_pairs else None,
"sentence_target": sentence_target if self.has_pairs else None,
}
def __len__(self):
return len(self.dataset)
def _mask_block(
self,
sentence: np.ndarray,
mask_idx: int,
pad_idx: int,
dictionary_token_range: Tuple,
masking_ratio: float = 0.15,
masking_prob: float = 0.8,
random_token_prob: float = 0.1
):
"""
Mask tokens for Masked Language Model training
Samples mask_ratio tokens that will be predicted by LM.
Note:This function may not be efficient enough since we had multiple
conversions between np and torch, we can replace them with torch
operators later.
Args:
sentence: 1d tensor to be masked
mask_idx: index to use for masking the sentence
pad_idx: index to use for masking the target for tokens we aren't
predicting
dictionary_token_range: range of indices in dictionary which can
be used for random word replacement
(e.g. without special characters)
masking_ratio: specifies what percentage of the blocks should be
masked.
masking_prob: specifies the probability of a given token being
replaced with the "MASK" token.
random_token_prob: specifies the probability of a given token being
replaced by a random token from the vocabulary
Return:
masked_sent: masked sentence
target: target with words which we are not predicting replaced
by pad_idx
"""
masked_sent = np.copy(sentence)
sent_length = len(sentence)
mask_num = math.ceil(sent_length * masking_ratio)
mask = np.random.choice(sent_length, mask_num)
target = np.copy(sentence)
for i in range(sent_length):
if i in mask:
rand = np.random.random()
# replace with mask if probability is less than masking_prob
# (Eg: 0.8)
if rand < masking_prob:
masked_sent[i] = mask_idx
# replace with random token if probability is less than
# masking_prob + random_token_prob (Eg: 0.9)
elif rand < (masking_prob + random_token_prob):
# sample random token from dictionary
masked_sent[i] = (
np.random.randint(
dictionary_token_range[0], dictionary_token_range[1]
)
)
else:
target[i] = pad_idx
return masked_sent, target
def _collate(
self,
samples: List[Dict],
pad_idx: int,
eos_idx: int
):
"""
Does the heavy lifting for creating a batch from the input list of
examples. The logic is as follows:
1. Mask the input blocks. In case has_pair is True then we have 2
blocks to mask.
2. Prepend the first masked block tensor with the special token
used as sentence embedding. Eg: CLS in BERT. This happens
irrespective of the value of has_pair.
3. If has_pair is True, then append the first masked block with the
special separator token (eg: SEP for BERT) and compute segment
label accordingly. In this case, also append the second masked
block with this special separator token and compute its segment
label.
4. For the targets tensor, prepend and append with padding index
accordingly.
5. Concatenate all tensors.
"""
if len(samples) == 0:
return {}
# To ensure determinism, we reset the state of the PRNG after every
# batch based on the seed and the first id of the batch. This ensures
# that across epochs we get the same mask for the same example. This
# is needed for reproducibility and is how BERT does masking
# TODO: Can we add deteminism without this constraint?
with data_utils.numpy_seed(self.seed + samples[0]["id"]):
for s in samples:
# token range is needed for replacing with random token during
# masking
token_range = (self.vocab.nspecial, len(self.vocab))
# mask according to specified probabilities.
masked_blk_one, masked_tgt_one = self._mask_block(
s["block_one"], self.mask_idx, self.pad_idx, token_range)
tokens = np.concatenate([
[self.classif_token_idx], masked_blk_one
])
targets = np.concatenate([[self.pad_idx], masked_tgt_one])
segments = np.ones(len(tokens)) * self.segment_id
# if has_pairs is True then we need to add the SEP token to both
# the blocks after masking and re-compute segments based on the new
# lengths.
if self.has_pairs:
tokens_one = np.concatenate([tokens, [self.sep_token_idx]])
targets_one = np.concatenate([targets, [self.pad_idx]])
masked_blk_two, masked_tgt_two = self._mask_block(
s["block_two"], self.mask_idx, self.pad_idx, token_range)
tokens_two = np.concatenate(
[masked_blk_two, [self.sep_token_idx]])
targets_two = np.concatenate([masked_tgt_two, [self.pad_idx]])
# block + 1 sep + 1 special (CLS)
segments_one = np.zeros(len(tokens_one))
# block + 1 sep
segments_two = np.ones(len(tokens_two))
tokens = np.concatenate([tokens_one, tokens_two])
targets = np.concatenate([targets_one, targets_two])
segments = np.concatenate([segments_one, segments_two])
s["source"] = torch.LongTensor(tokens)
s["segment_labels"] = torch.LongTensor(segments)
s["lm_target"] = torch.LongTensor(targets)
def merge(key):
return data_utils.collate_tokens(
[s[key] for s in samples], pad_idx, eos_idx, left_pad=False
)
return {
"id": torch.LongTensor([s["id"] for s in samples]),
"ntokens": sum(len(s["source"]) for s in samples),
"net_input": {
"tokens": merge("source"),
"segment_labels": merge("segment_labels"),
},
"lm_target": merge("lm_target"),
"sentence_target": torch.LongTensor(
[s["sentence_target"] for s in samples]
) if self.has_pairs else None,
"nsentences": len(samples),
}
def collater(
self,
samples: List[Dict]
):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch of data
"""
return self._collate(samples, self.vocab.pad(), self.vocab.eos())
def get_dummy_batch(
self,
num_tokens: int,
max_positions: Union[float, int],
tgt_len: int = 12
):
"""
Return a dummy batch with a given number of tokens.
"""
if isinstance(max_positions, float) or isinstance(max_positions, int):
tgt_len = min(tgt_len, max_positions)
source = self.vocab.dummy_sentence(tgt_len)
sentence_target = 0
bsz = num_tokens // tgt_len
return self.collater(
[
{
"id": i,
"block_one": source,
"block_two": source if self.has_pairs else None,
"sentence_target": sentence_target if self.has_pairs else None,
}
for i in range(bsz)
]
)
def num_tokens(
self,
index: int
):
"""
Return the number of tokens in a sample. This value is used to
enforce max-tokens during batching.
"""
return self.sizes[index]
def size(
self,
index: int
):
"""
Return an example's size as a float or tuple. This value is used when
filtering a dataset with max-positions.
"""
return self.sizes[index]
def ordered_indices(self):
"""
Return an ordered list of indices. Batches will be constructed based
on this order.
"""
if self.shuffle:
return np.random.permutation(len(self))
else:
order = [np.arange(len(self))]
order.append(self.sizes)
return np.lexsort(order)
@property
def supports_prefetch(self):
return getattr(self.dataset, "supports_prefetch", False)
def prefetch(self, indices):
self.dataset.prefetch(indices)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from fairseq.data import Dictionary
class MaskedLMDictionary(Dictionary):
"""
Dictionary for Masked Language Modelling tasks. This extends Dictionary by
adding the mask symbol.
"""
def __init__(
self,
pad='<pad>',
eos='</s>',
unk='<unk>',
mask='<mask>',
):
super().__init__(pad, eos, unk)
self.mask_word = mask
self.mask_index = self.add_symbol(mask)
self.nspecial = len(self.symbols)
def mask(self):
"""Helper to get index of mask symbol"""
return self.mask_index
class BertDictionary(MaskedLMDictionary):
"""
Dictionary for BERT task. This extends MaskedLMDictionary by adding support
for cls and sep symbols.
"""
def __init__(
self,
pad='<pad>',
eos='</s>',
unk='<unk>',
mask='<mask>',
cls='<cls>',
sep='<sep>'
):
super().__init__(pad, eos, unk, mask)
self.cls_word = cls
self.sep_word = sep
self.cls_index = self.add_symbol(cls)
self.sep_index = self.add_symbol(sep)
self.nspecial = len(self.symbols)
def cls(self):
"""Helper to get index of cls symbol"""
return self.cls_index
def sep(self):
"""Helper to get index of sep symbol"""
return self.sep_index
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from collections import OrderedDict
from typing import Dict, List
import numpy as np
from . import FairseqDataset
class MultiCorpusSampledDataset(FairseqDataset):
"""
Stores multiple instances of FairseqDataset together and in every iteration
creates a batch by first sampling a dataset occording to a specified
probability distribution and then getting instances from that dataset.
Args:
datasets: an OrderedDict of FairseqDataset instances.
sampling_dist: the sampling distribution used to select the dataset
from which the batch is created in a given iteration.
default_key: string which specifies the default key to be used for
generating dummy batches etc.
"""
def __init__(
self,
datasets: Dict[str, FairseqDataset],
sampling_dist: str = 'uniform',
default_key: str = ''
):
super().__init__()
assert isinstance(datasets, OrderedDict)
assert default_key in datasets
self.datasets = datasets
self.sampling_dist = sampling_dist
self.default_key = default_key
self.total_num_instances = 0
for _, dataset in datasets.items():
assert isinstance(dataset, FairseqDataset)
self.total_num_instances += dataset.__len__()
self._ordered_indices = None
def __len__(self):
"""
Length of this dataset is the sum of individual datasets
"""
return self.total_num_instances
def ordered_indices(self):
"""
Ordered indices for batching. Here we call the underlying
dataset's ordered_indices() so that we get the same random ordering
as we would have from using the underlying dataset directly.
"""
if self._ordered_indices is None:
self._ordered_indices = OrderedDict(
[
(
key, dataset.ordered_indices()
)
for key, dataset in self.datasets.items()
]
)
return np.arange(len(self))
def _map_index_to_dataset(
self,
key: int,
index: int
):
"""
Different underlying datasets have different lengths. In order to ensure
we are not accessing an index outside the range of the current dataset
size, we wrap around. This function should be called after we have
created an ordering for this and all underlying datasets.
"""
assert self._ordered_indices is not None, \
'Must call MultiCorpusSampledDataset.ordered_indices() first'
mapped_index = index % len(self.datasets[key])
return self._ordered_indices[key][mapped_index]
def __getitem__(
self,
index: int
):
"""
Get the item associated with index from each underlying dataset.
Since index is in the range of [0, TotalNumInstances], we need to
map the index to the dataset before retrieving the item.
"""
return OrderedDict(
[
(
key, dataset[self._map_index_to_dataset(key, index)]
)
for key, dataset in self.datasets.items()
]
)
def collater(
self,
samples: List[Dict]
):
"""
Generate a mini-batch for this dataset.
To convert this into a regular mini-batch we use the following
logic:
1. Select a dataset using the specified probability distribution.
2. Call the collater function of the selected dataset.
"""
if len(samples) == 0:
return None
if self.sampling_dist == 'uniform':
candidates = list(self.datasets.keys())
selected_key = np.random.choice(candidates, 1).item()
selected_samples = [
sample[selected_key]
for sample in samples
]
return self.datasets[selected_key].collater(selected_samples)
else:
raise NotImplementedError(
"Specified sampling is currently not Implemented."
)
def get_dummy_batch(
self,
num_tokens: int,
max_positions: int,
):
"""
Return a dummy batch with a given number of tokens. Assumes that the
max_positions specified is the same for all underlying datasets.
"""
return self.datasets[self.default_key].get_dummy_batch(
num_tokens, max_positions)
def num_tokens(
self,
index: int
):
"""
Return an example's length (number of tokens), used for batching. Here
we return the max across all examples at index across all underlying
datasets.
"""
return max(
dataset.num_tokens(self._map_index_to_dataset(key, index))
for key, dataset in self.datasets.items()
)
def size(
self,
index: int
):
"""
Return an example's size as a float or tuple. Here we return the max
across all underlying datasets. This value is used when filtering a
dataset with max-positions.
"""
return max(
dataset.num_tokens(self._map_index_to_dataset(key, index))
for key, dataset in self.datasets.items()
)
@property
def supports_prefetch(self):
return all(
getattr(dataset, 'supports_prefetch', False)
for dataset in self.datasets.values()
)
def prefetch(self, indices):
for key, dataset in self.datasets.items():
dataset.prefetch(
[
self._map_index_to_dataset(key, index) for index in indices
]
)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import (
BaseFairseqModel, FairseqEncoder, register_model, register_model_architecture,
)
from fairseq.modules import (
SinusoidalPositionalEmbedding,
TransformerSentenceEncoder
)
from fairseq.modules.transformer_sentence_encoder import init_bert_params
@register_model('masked_lm')
class MaskedLMModel(BaseFairseqModel):
"""
Class for training a Masked Language Model. It also supports an
additional sentence level prediction if the sent-loss argument is set.
"""
def __init__(self, args, encoder):
super().__init__()
self.args = args
self.encoder = encoder
# if specified then apply bert initialization on the model. We need
# to explictly call this to make sure that the output embeddings
# and projection layers are also correctly initialized
if getattr(args, 'apply_bert_init', False):
self.apply(init_bert_params)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# Arguments related to dropout
parser.add_argument('--dropout', default=0.1, type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', default=0.1, type=float,
metavar='D', help='dropout probability for'
' attention weights')
parser.add_argument('--act-dropout', default=0.1, type=float,
metavar='D', help='dropout probability after'
' activation in FFN')
# Arguments related to hidden states and self-attention
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
help='encoder embedding dimension for FFN')
parser.add_argument('--encoder-layers', type=int, metavar='N',
help='num encoder layers')
parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
help='num encoder attention heads')
parser.add_argument('--no-bias-kv', action='store_true',
help='if set, pads attn with zero instead of'
' adding a learnable bias kv')
# Arguments related to input and output embeddings
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension')
parser.add_argument('--share-encoder-input-output-embed',
action='store_true', help='share encoder input'
' and output embeddings')
parser.add_argument('--no-token-positional-embeddings',
action='store_true',
help='if set, disables positional embeddings'
' (outside self attention)')
parser.add_argument('--num-segment', type=int, metavar='N', default=2,
help='num segment in the input')
# Arguments related to sentence level prediction
parser.add_argument('--sentence-class-num', type=int, metavar='N',
default=2, help='number of classes for sentence'
' task')
parser.add_argument('--sent-loss', action='store_true', help='if set,'
' calculate sentence level predictions')
# Arguments related to parameter initialization
parser.add_argument('--apply-bert-init', action='store_true',
help='use custom param initialization for BERT')
# layer norm layers
parser.add_argument('--bert-layer-norm', action='store_true',
help='use custom Layer Norm module for BERT')
# misc params
parser.add_argument('--encoder-normalize-before', action='store_true',
help='apply layernorm before each encoder block')
parser.add_argument('--gelu', action='store_true',
help='Use gelu activation function in encoder'
' Layer')
def forward(self, tokens, segment_labels):
return self.encoder(tokens, segment_labels)
def max_positions(self):
return self.encoder.max_positions
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
if args.task == 'bert':
base_bert_architecture(args)
else:
xlm_architecture(args)
if not hasattr(args, 'max_positions'):
args.max_positions = args.tokens_per_sample
print("Model args: ", args)
encoder = MaskedLMEncoder(args, task.dictionary)
return cls(args, encoder)
class MaskedLMEncoder(FairseqEncoder):
"""
Encoder for Masked Language Modelling.
"""
def __init__(self, args, dictionary):
super().__init__(dictionary)
self.padding_idx = dictionary.pad()
self.vocab_size = dictionary.__len__()
self.max_positions = args.max_positions
use_position_embeddings = (
not getattr(args, 'no_token_positional_embeddings', False)
)
encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
use_bert_layer_norm = getattr(args, 'bert_layer_norm', False)
use_gelu = getattr(args, 'use_gelu', False)
apply_bert_init = getattr(args, 'apply_bert_init', False)
self.sentence_encoder = TransformerSentenceEncoder(
padding_idx=self.padding_idx,
vocab_size=self.vocab_size,
num_encoder_layers=args.encoder_layers,
embedding_dim=args.encoder_embed_dim,
ffn_embedding_dim=args.encoder_ffn_embed_dim,
num_attention_heads=args.encoder_attention_heads,
dropout=args.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.act_dropout,
max_seq_len=self.max_positions,
num_segments=args.num_segment,
use_position_embeddings=use_position_embeddings,
encoder_normalize_before=encoder_normalize_before,
use_bert_layer_norm=use_bert_layer_norm,
use_gelu=use_gelu,
apply_bert_init=apply_bert_init,
)
self.share_input_output_embed = getattr(
args, 'share_encoder_input_output_embed', False)
self.embed_out = None
self.sentence_projection_layer = None
self.sentence_out_dim = args.sentence_class_num
# Remove head is set to true during fine-tuning
self.load_softmax = not getattr(args, 'remove_head', False)
if self.load_softmax:
if not self.share_input_output_embed:
self.embed_out = nn.Linear(
args.encoder_embed_dim,
self.vocab_size,
bias=False
)
if args.sent_loss:
self.sentence_projection_layer = nn.Linear(
args.encoder_embed_dim,
self.sentence_out_dim,
bias=False
)
def forward(self, tokens, segment_labels, **unused):
"""
Forward pass for Masked LM encoder. This first computes the token
embedding using the token embedding matrix, position embeddings (if
specified) and segment embeddings (if specified).
Here we assume that the sentence representation corresponds to the
output of the classification_token (see bert_task or cross_lingual_lm
task for more details).
Args:
- tokens: B x T matrix representing sentences
- segment_labels: B x T matrix representing segment label for tokens
Returns:
- a tuple of the following:
- logits for predictions in format B x T x C to be used in
softmax afterwards
- a dictionary of additional data, where 'sentence_rep' contains
the representation for classification_token and 'inner_states'
is a list of internal model states used to compute the
predictions (similar in ELMO). 'sentence_logits'
is the prediction logit for NSP task and is only computed if
this is specified in the input arguments.
"""
inner_states, sentence_rep = self.sentence_encoder(tokens, segment_labels)
x = inner_states[-1].transpose(0, 1)
# project back to size of vocabulary
if self.share_input_output_embed \
and hasattr(self.sentence_encoder.embed_tokens, 'weight'):
x = F.linear(x, self.sentence_encoder.embed_tokens.weight)
elif self.embed_out is not None:
x = self.embed_out(x)
sentence_logits = None
if self.sentence_projection_layer:
sentence_logits = self.sentence_projection_layer(sentence_rep)
return x, {
'inner_states': inner_states,
'sentence_rep': sentence_rep,
'sentence_logits': sentence_logits
}
def max_positions(self):
"""Maximum output length supported by the encoder."""
return self.max_positions
def upgrade_state_dict_named(self, state_dict, name):
if isinstance(
self.sentence_encoder.position_embeddings,
SinusoidalPositionalEmbedding
):
state_dict[
name + '.sentence_encoder.position_embeddings._float_tensor'
] = torch.FloatTensor(1)
if not self.load_softmax:
for k in list(state_dict.keys()):
if "embed_out.weight" in k or "sentence_projection_layer.weight" in k:
del state_dict[k]
return state_dict
@register_model_architecture('masked_lm', 'bert_base')
def base_bert_architecture(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768)
args.share_encoder_input_output_embed = getattr(
args, 'share_encoder_input_output_embed', True)
args.no_token_positional_embeddings = getattr(
args, 'no_token_positional_embeddings', False)
args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', True)
args.num_segment = getattr(args, 'num_segment', 2)
args.encoder_layers = getattr(args, 'encoder_layers', 12)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 12)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 3072)
args.no_bias_kv = getattr(args, 'no_bias_kv', True)
args.sent_loss = getattr(args, 'sent_loss', True)
args.sentence_class_num = getattr(args, 'sentence-class-num', 2)
args.apply_bert_init = getattr(args, 'apply_bert_init', True)
# TODO: validate setups for layernorm
args.encoder_normalize_before = getattr(
args, 'encoder_normalize_before', True)
args.bert_layer_norm = getattr(args, 'bert_layer_norm', True)
args.gelu = getattr(args, 'gelu', True)
@register_model_architecture('masked_lm', 'xlm_base')
def xlm_architecture(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
args.share_encoder_input_output_embed = getattr(
args, 'share_encoder_input_output_embed', True)
args.no_token_positional_embeddings = getattr(
args, 'no_token_positional_embeddings', False)
args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', True)
args.num_segment = getattr(args, 'num_segment', 1)
args.encoder_layers = getattr(args, 'encoder_layers', 6)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
args.no_bias_kv = getattr(args, 'no_bias_kv', True)
args.sent_loss = getattr(args, 'sent_loss', False)
args.encoder_normalize_before = getattr(
args, 'encoder_normalize_before', False)
args.bert_layer_norm = getattr(args, 'bert_layer_norm', False)
args.gelu = getattr(args, 'gelu', True)
args.apply_bert_init = getattr(args, 'apply_bert_init', True)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
from collections import OrderedDict
from fairseq import tokenizer
from fairseq.data.masked_lm_dictionary import MaskedLMDictionary
from fairseq.data import (
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
TokenBlockDataset,
)
from fairseq.data import Dictionary
from fairseq.data.masked_lm_dataset import MaskedLMDataset
from fairseq.data.multi_corpus_sampled_dataset import MultiCorpusSampledDataset
from . import FairseqTask, register_task
@register_task('cross_lingual_lm')
class CrossLingualLMTask(FairseqTask):
"""
Task for training cross-lingual language models.
For more details look at: https://arxiv.org/pdf/1901.07291.pdf
Args:
dictionary (Dictionary): the dictionary for the input of the task
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('data', help='path to data directory')
parser.add_argument('--tokens-per-sample', default=512, type=int,
help='max number of total tokens over all segments'
' per sample')
parser.add_argument('--monolingual-langs', default='en', type=str,
help='comma separated list of languages for which we'
' want to train XLM on')
parser.add_argument('--raw-text', default=False, action='store_true',
help='load raw text dataset')
parser.add_argument('--lazy-load', action='store_true',
help='load the dataset lazily')
parser.add_argument('--shuffle', action='store_true',
help='shuffle each monolingual dataset while'
' training')
def __init__(self, args, dictionary):
super().__init__(args)
self.dictionary = dictionary
self.seed = args.seed
self.distributed_world_size = args.distributed_world_size
self.langs2id = self._lang_to_id(args.monolingual_langs)
self.default_key = None
def _lang_to_id(
self,
languages: str
):
"""
Build a map from languages to ids. These ids are used as segment labels
for cross-lingual LM training.
"""
lang2id = {}
langs = [l.strip() for l in languages.split(',')]
for id, lang in enumerate(langs):
lang2id[lang] = id
return lang2id
@classmethod
def load_dictionary(cls, filename):
return MaskedLMDictionary.load(filename)
@classmethod
def build_dictionary(cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8):
d = MaskedLMDictionary()
for filename in filenames:
Dictionary.add_file_to_dictionary(filename, d, tokenizer.tokenize_line, workers)
d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor)
return d
@property
def target_dictionary(self):
return self.dictionary
@classmethod
def setup_task(cls, args, **kwargs):
"""Setup the task.
"""
dictionary = MaskedLMDictionary.load(os.path.join(args.data, 'dict.txt'))
print('| dictionary: {} types'.format(len(dictionary)))
return cls(args, dictionary)
def load_dataset(self, split, combine=False):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
dataset_map = OrderedDict()
for lang in self.langs2id.keys():
if self.default_key is None:
self.default_key = lang
# Datasets are expected to be in "split.lang" format (Eg: train.en)
language_split = '{}.{}'.format(split, lang)
path = os.path.join(self.args.data, language_split)
if self.args.raw_text and IndexedRawTextDataset.exists(path):
ds = IndexedRawTextDataset(path, self.dictionary)
elif not self.args.raw_text and IndexedDataset.exists(path):
if self.args.lazy_load:
ds = IndexedDataset(path, fix_lua_indexing=True)
else:
ds = IndexedCachedDataset(path, fix_lua_indexing=True)
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(
language_split, self.args.data))
# Since we append each block with the classification_token,
# we need to effectively create blocks of length
# tokens_per_sample-1
block_dataset = TokenBlockDataset(
dataset=ds,
sizes=ds.sizes,
block_size=self.args.tokens_per_sample-1,
pad=self.dictionary.pad(),
eos=self.dictionary.eos()
)
dataset_map[lang] = MaskedLMDataset(
dataset=block_dataset,
sizes=block_dataset.sizes,
vocab=self.dictionary,
pad_idx=self.dictionary.pad(),
mask_idx=self.dictionary.mask(),
classif_token_idx=self.dictionary.eos(),
sep_token_idx=self.dictionary.eos(),
shuffle=getattr(self.args, 'shuffle', False),
has_pairs=False,
segment_id=self.langs2id[lang],
seed=self.seed,
)
self.datasets[split] = MultiCorpusSampledDataset(
dataset_map, default_key=self.default_key
)
print('| {} {} {} examples'.format(
self.args.data, split, len(self.datasets[split])
)
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment