Commit 9e8a8c05 authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import contextlib
import itertools
import math
import os
import statistics
import time
import numpy as np
import torch
from . import FairseqDataset
import fairseq.data.batch_C_v0p5
import fairseq.data.batch_C_v0p5_better
import fairseq.data.batch_C_v0p6
import sys
def infer_language_pair(path):
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
src, dst = None, None
print('Infer language pair from filename...')
for filename in os.listdir(path):
print('filename:', filename)
parts = filename.split('.')
if len(parts) >= 3 and len(parts[1].split('-')) == 2:
return parts[1].split('-')
return src, dst
class ShardedIterator(object):
"""A sharded wrapper around an iterable (padded to length)."""
def __init__(self, iterable, num_shards, shard_id, fill_value=None):
if shard_id < 0 or shard_id >= num_shards:
raise ValueError('shard_id must be between 0 and num_shards')
self._sharded_len = len(iterable) // num_shards
if len(iterable) % num_shards > 0:
self._sharded_len += 1
self.itr = itertools.zip_longest(
range(self._sharded_len),
itertools.islice(iterable, shard_id, len(iterable), num_shards),
fillvalue=fill_value,
)
def __len__(self):
return self._sharded_len
def __iter__(self):
return self
def __next__(self):
return next(self.itr)[1]
class CountingIterator(object):
"""Wrapper around an iterable that maintains the iteration count."""
def __init__(self, iterable):
self.iterable = iterable
self.count = 0
self.itr = iter(self)
def __len__(self):
return len(self.iterable)
def __iter__(self):
for x in self.iterable:
self.count += 1
yield x
def __next__(self):
return next(self.itr)
def has_next(self):
return self.count < len(self)
def skip(self, num_to_skip):
next(itertools.islice(self.itr, num_to_skip, num_to_skip), None)
return self
def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=False, n_seq_per_batch_multiple=8, seq_len_multiple=1):
""" Convert a list of 1d tensors into a padded 2d tensor.
Args:
values: Python list where each element is a PyT 1d tensor
pad_idx: The index into the translation dictionary for the pad token (typically refer to 'dict.pad()')
eos_idx: The index into the translation dictionary for the eos token (typically refer to 'dict.eos()')
left_pad: Bool, left- or right-padding (true: left, false: right)
move_eos_to_beginning: Reverse order of sequence of tokens (true: reverse, false:leave in original order)
n_seq_per_batch_multiple: The number of sequences per batch to round down to
seq_len_multiple: The number of tokens per sequence to round up to
"""
size_of_seq_dim = max(v.size(0) for v in values) # Unpadded size
n_seq_in_batch = len(values)
if n_seq_per_batch_multiple % seq_len_multiple == 0:
n_seq_multiple = n_seq_per_batch_multiple / seq_len_multiple
else:
n_seq_multiple = n_seq_per_batch_multiple
if n_seq_in_batch < n_seq_multiple or n_seq_in_batch % n_seq_multiple > 0:
seq_len_multiple = n_seq_per_batch_multiple
size_of_seq_dim = (size_of_seq_dim + seq_len_multiple - 1) // seq_len_multiple * seq_len_multiple # Padded seq len, rounded up to next multiple
padded_2d_tensor = values[0].new(len(values), size_of_seq_dim).fill_(pad_idx)
def copy_tensor(src, dst):
assert dst.numel() == src.numel()
if move_eos_to_beginning:
assert src[-1] == eos_idx
dst[0] = eos_idx
dst[1:] = src[:-1]
else:
dst.copy_(src)
if left_pad:
for idx, val in enumerate(values):
copy_tensor(val, padded_2d_tensor[idx][size_of_seq_dim - len(val):])
else:
for idx, val in enumerate(values):
copy_tensor(val, padded_2d_tensor[idx][:len(val)])
return padded_2d_tensor
class EpochBatchIterator(object):
"""Iterate over a FairseqDataset and yield batches bucketed by size.
Batches may contain sequences of different lengths. This iterator can be
reused across multiple epochs with the next_epoch_itr() method.
Args:
dataset: a FairseqDataset
max_tokens: max number of tokens in each batch
max_sentences: max number of sentences in each batch
max_positions: max sentence length supported by the model
ignore_invalid_inputs: don't raise Exception for sentences that are too long
required_batch_size_multiple: require batch size to be a multiple of N
seeds: seeds for random number generator for reproducibility (1 seed for
each training epoch)
num_shards: shard the data iterator into N shards
shard_id: which shard of the data iterator to return
"""
def __init__(
self, dataset, dataloader_num_workers=1, dataloader_pin_memory=False, max_tokens=None, max_sentences=None, max_positions=None,
ignore_invalid_inputs=False, required_batch_size_multiple=1, seeds=[1],
num_shards=1, shard_id=0, epoch=0, bucket_growth_factor=1.1, seq_len_multiple=1,
batching_scheme='v0p5', batch_multiple_strategy='multiple_of_sequences',
):
assert isinstance(dataset, FairseqDataset)
self.dataset = dataset
self.max_tokens = max_tokens if max_tokens is not None else float('Inf')
self.max_sentences = max_sentences if max_sentences is not None else float('Inf')
self.dataloader_num_workers = dataloader_num_workers
self.dataloader_pin_memory = dataloader_pin_memory
assert len(max_positions) == 2, "Max positions contains source and target lengths!"
max_src_pos,max_tgt_pos = max_positions
self.max_positions = max_positions
self.max_positions_num = min(max_src_pos, max_tgt_pos)
self.ignore_invalid_inputs = ignore_invalid_inputs
self.bsz_mult = required_batch_size_multiple
self.seeds = seeds
self.num_shards = num_shards
self.shard_id = shard_id
self.seq_len_multiple = seq_len_multiple
self.batching_scheme = batching_scheme
self.batch_multiple_strategy = batch_multiple_strategy
self.epoch = epoch
self._cur_epoch_itr = None
self._next_epoch_itr = None
with numpy_seed(self.seeds[0]):
import time
start = time.time()
indices = self.dataset.ordered_indices(self.seeds[self.epoch])
#need integer, rather than float('Inf') values
max_sentences = max_sentences if max_sentences is not None else sys.maxsize
max_tokens = max_tokens if max_tokens is not None else sys.maxsize
if self.batching_scheme == 'v0p5' :
batches = fairseq.data.batch_C_v0p5.make_batches_v0p5(self.dataset.src_sizes, self.dataset.tgt_sizes, indices, max_tokens, max_sentences, self.bsz_mult, self.max_positions_num)
elif self.batching_scheme == 'v0p5_better' :
print('self.dataset.src_sizes', self.dataset.src_sizes.size)
print('self.dataset.tgt_sizes', self.dataset.tgt_sizes.size)
batches = fairseq.data.batch_C_v0p5_better.make_batches_v0p5_better(self.dataset.src_sizes, self.dataset.tgt_sizes, indices, max_tokens, max_sentences, self.max_positions_num, self.bsz_mult, self.seq_len_multiple)
elif self.batching_scheme == 'v0p6':
batch_strategy = 2
if self.batch_multiple_strategy == 'mult_of_sequences':
batch_strategy = 0
elif self.batch_multiple_strategy == 'pad_sequence_to_mult':
batch_strategy = 1
elif self.batch_multiple_strategy == 'dynamic':
batch_strategy = 2
else:
assert False, "Unknown batch multiple strategy!"
bucket_specify_min_boundary = 8
use_efficient_last_pack = False
#batch_strategy = 2
batches = fairseq.data.batch_C_v0p6.make_batches_v0p6(self.dataset.src_sizes,
self.dataset.tgt_sizes,
indices,
max_tokens,
max_sentences,
self.bsz_mult,
self.max_positions_num,
bucket_specify_min_boundary,
bucket_growth_factor,
batch_strategy,
use_efficient_last_pack)
else : # reference
def roundup(x, multiple):
return (x + multiple - 1) // multiple * multiple
def rounddown(x, multiple):
return x // multiple * multiple
def create_bucket_bounds_lists(max_allowable_seq_length, bucket_specify_min_boundary, bucket_specify_growth_scale):
bucket_boundaries = []
x = bucket_specify_min_boundary
while x < max_allowable_seq_length:
bucket_boundaries.append(x)
x = max(x + 1, int(x * bucket_specify_growth_scale))
if use_efficient_last_pack:
buckets_min_list = [0] + [i+1 for i in bucket_boundaries]
buckets_max_list = bucket_boundaries + [max_allowable_seq_length]
else:
buckets_min_list = [0] + bucket_boundaries
buckets_max_list = bucket_boundaries + [max_allowable_seq_length + 1]
return buckets_min_list, buckets_max_list
def create_seq_to_bucket_id_list_and_n_seq_per_batch(n_tok_per_seq, max_allowable_seq_length, max_sentences, pad_seq_per_batch_to_multiple_of, pad_tok_per_seq_to_multiple_of, bucket_specify_min_boundary, bucket_specify_growth_scale):
bucket_interval_min, bucket_interval_max = create_bucket_bounds_lists(max_allowable_seq_length, bucket_specify_min_boundary, bucket_specify_growth_scale)
if do_seq_len_padding_to_multiple:
n_seq_per_batch = [max_tokens // roundup(x, pad_tok_per_seq_to_multiple_of) for x in bucket_interval_max]
elif do_batch_size_rounding_down_to_multiple:
n_seq_per_batch = [rounddown(max_tokens // x, pad_seq_per_batch_to_multiple_of) for x in bucket_interval_max]
elif do_dynamic_batch_size_choice:
n_seq_per_batch_based_on_seq_len = [max_tokens // roundup(x, pad_tok_per_seq_to_multiple_of) for x in bucket_interval_max]
n_seq_per_batch_based_on_n_seq = [rounddown(max_tokens // x, pad_seq_per_batch_to_multiple_of) for x in bucket_interval_max]
n_seq_per_batch = [max(a,b) for a, b in zip(n_seq_per_batch_based_on_seq_len, n_seq_per_batch_based_on_n_seq)]
else:
n_seq_per_batch = [max_tokens // x for x in bucket_interval_max]
n_seq_per_batch = [min(max_sentences, i) if max_sentences is not None else i for i in n_seq_per_batch]
for a, b, c in zip(bucket_interval_min, bucket_interval_max, n_seq_per_batch):
print('bucket:', a, b, c)
token_length_2_bucket_id = {}
for x in range(max_allowable_seq_length+1):
for bucket_id, payload in enumerate(zip(bucket_interval_min, bucket_interval_max)):
bmin, bmax = payload
if (bmin <= x and x <= bmax and use_efficient_last_pack) or (bmin <= x and x < bmax):
token_length_2_bucket_id[x] = bucket_id
break
return ([token_length_2_bucket_id[x] if x <= max_allowable_seq_length else -1 for x in n_tok_per_seq], n_seq_per_batch, len(bucket_interval_min))
# Make adjustments to tuneable parameters here
pad_seq_per_batch_to_multiple_of = self.bsz_mult
pad_tok_per_seq_to_multiple_of = self.bsz_mult
max_allowable_seq_length = self.max_positions_num
bucket_specify_min_boundary = 8
bucket_specify_growth_scale = bucket_growth_factor ##1.035
do_seq_len_padding_to_multiple = False
do_batch_size_rounding_down_to_multiple = False
do_dynamic_batch_size_choice = True
use_efficient_last_pack = False
batches = []
src_token_counts = []
dst_token_counts = []
seq_counts = []
padded_token_counts = []
batch_max_padded_seq_len = 0
batch_seq_count = 0
batches.append([])
src_batch_token_count = 0
dst_batch_token_count = 0
curr_batch_padded_token_count = 0
batch_n_seq = 0
bucket_id = 0
longest_in_batch = []
print('### max_tokens:', max_tokens)
print('### max_sentences:', max_sentences)
pairwise_max_seq_len = [max(a,b) for a, b in zip(dataset.src_sizes, dataset.tgt_sizes)]
bucket_ids, n_seq_per_batch, n_buckets = create_seq_to_bucket_id_list_and_n_seq_per_batch(pairwise_max_seq_len, max_allowable_seq_length, max_sentences, pad_seq_per_batch_to_multiple_of, pad_tok_per_seq_to_multiple_of, bucket_specify_min_boundary, bucket_specify_growth_scale)
buckets = []
for i in range(n_buckets):
buckets.append([])
n_rejected_sequences = 0
for idx, bidx in enumerate(bucket_ids):
if bidx >= 0:
buckets[bidx].append(idx)
else:
n_rejected_sequences += 1
# Remove empty buckets (causes blow-up in eval code).
buckets = [i for i in buckets if len(i) > 0]
print(n_rejected_sequences, 'were omitted due to containing over 256 tokens.')
batch_seq_count = 0
#count = 0
seq_len_tracker = 0
for bucket, nspb in zip(buckets, n_seq_per_batch):
for item in bucket:
if batch_n_seq < nspb:
batches[-1].append(item)
src_batch_token_count += dataset.src_sizes[item]
dst_batch_token_count += dataset.tgt_sizes[item]
seq_len_tracker = max(seq_len_tracker, dst_batch_token_count)
batch_n_seq += 1
else:
batches.append([item])
src_token_counts.append(src_batch_token_count)
dst_token_counts.append(dst_batch_token_count)
src_batch_token_count = dataset.src_sizes[item]
dst_batch_token_count = dataset.tgt_sizes[item]
seq_counts.append(batch_n_seq)
batch_n_seq = 1
batches.append([])
batch_n_seq = 0
seq_counts.append(batch_n_seq)
src_batch_token_count = 0
dst_batch_token_count = 0
src_token_counts.append(src_batch_token_count)
dst_token_counts.append(dst_batch_token_count)
seq_cnt2 = []
for batch in batches:
seq_len_tracker = 0
nseqbucket = 0
for item in batch:
a = dataset.src_sizes[item]
b = dataset.tgt_sizes[item]
seq_len_tracker = max(seq_len_tracker, max(a, b))
nseqbucket += 1
longest_in_batch.append(seq_len_tracker)
seq_cnt2.append(nseqbucket)
# In the unlucky case, remove a newly created but empty last batch
if not batches[-1]:
del batches[-1]
del seq_counts[-1]
del src_token_counts[-1]
del dst_token_counts[-1]
tmp_batches = batches
batches = []
for b in tmp_batches:
if b:
batches.append(b)
#padded_token_counts = src_token_counts
#padded_token_counts = [x*0 for x in src_token_counts] # Setting to zero until this is actually implemented
#print('split dataset length:', len(dataset.src))
#print('mean src tokens per batch =', statistics.mean(src_token_counts), statistics.mean(padded_token_counts))
#print('median src tokens per batch =', statistics.median(src_token_counts), statistics.median(padded_token_counts))
#print('stdev src tokens per batch =', statistics.stdev(src_token_counts), statistics.stdev(padded_token_counts))
#print('min src tokens per batch =', min(src_token_counts), min(padded_token_counts))
#print('max src tokens per batch =', max(src_token_counts), max(padded_token_counts))
#print('mean tgt tokens per batch =', statistics.mean(dst_token_counts), statistics.mean(padded_token_counts))
#print('median tgt tokens per batch =', statistics.median(dst_token_counts), statistics.mean(padded_token_counts))
#print('stdev tgt tokens per batch =', statistics.stdev(dst_token_counts), statistics.stdev(padded_token_counts))
#print('min tgt tokens per batch =', min(dst_token_counts), min(padded_token_counts))
#print('max tgt tokens per batch =', max(dst_token_counts), max(padded_token_counts))
#print('mean seq per batch =', statistics.mean(seq_counts), statistics.mean(padded_token_counts))
#print('median seq per batch =', statistics.median(seq_counts), statistics.median(padded_token_counts))
#print('stdev seq per batch =', statistics.stdev(seq_counts), statistics.stdev(padded_token_counts))
#print('min seq per batch =', min(seq_counts), min(padded_token_counts))
#print('max seq per batch =', max(seq_counts), max(padded_token_counts))
#print('pad inc: mean tgt tokens per batch =', statistics.mean(np.array(seq_cnt2) * np.array(longest_in_batch)), longest_in_batch[:3], seq_cnt2[:3])
#print('pad inc: median tgt tokens per batch =', statistics.median(np.array(seq_cnt2) * np.array(longest_in_batch)), longest_in_batch[:3], seq_cnt2[:3])
self.frozen_batches = tuple(batches)
# self.frozen_batches = tuple(self._batch_generator())
print("generated %d batches in %fs" % (len(batches), time.time() - start))
def __len__(self):
return len(self.frozen_batches)
def next_epoch_itr(self, shuffle=True):
"""Shuffle batches and return a new iterator over the dataset."""
if self._next_epoch_itr is not None:
self._cur_epoch_itr = self._next_epoch_itr
self._next_epoch_itr = None
else:
self.epoch += 1
self._cur_epoch_itr = self._get_iterator_for_epoch(self.epoch, shuffle)
return self._cur_epoch_itr
def end_of_epoch(self):
return not self._cur_epoch_itr.has_next()
@property
def iterations_in_epoch(self):
if self._cur_epoch_itr is not None:
return self._cur_epoch_itr.count
elif self._next_epoch_itr is not None:
return self._next_epoch_itr.count
return 0
def state_dict(self):
return {
'epoch': self.epoch,
'iterations_in_epoch': self.iterations_in_epoch,
}
def load_state_dict(self, state_dict):
self.epoch = state_dict['epoch']
itr_pos = state_dict.get('iterations_in_epoch', 0)
if itr_pos > 0:
# fast-forward epoch iterator
itr = self._get_iterator_for_epoch(self.epoch, state_dict.get('shuffle', True))
if itr_pos < len(itr):
self._next_epoch_itr = itr.skip(itr_pos)
def _get_iterator_for_epoch(self, epoch, shuffle):
if shuffle:
# set seed based on the seed and epoch number so that we get
# reproducible results when resuming from checkpoints
with numpy_seed(self.seeds[epoch]):
batches = list(self.frozen_batches) # copy
np.random.shuffle(batches)
else:
batches = self.frozen_batches
return CountingIterator(torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.dataset.collater,
num_workers=self.dataloader_num_workers,
pin_memory=self.dataloader_pin_memory,
batch_sampler=ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]),
))
def _batch_generator(self):
batch = []
def is_batch_full(num_tokens):
if len(batch) == 0:
return False
if len(batch) == self.max_sentences:
return True
if num_tokens > self.max_tokens:
return True
return False
sample_len = 0
sample_lens = []
ignored = []
for idx in self.dataset.ordered_indices(self.seeds[self.epoch]):
if not self.dataset.valid_size(idx, self.max_positions):
if self.ignore_invalid_inputs:
ignored.append(idx)
continue
raise Exception((
'Size of sample #{} is invalid, max_positions={}, skip this example with --skip-invalid-size-inputs-valid-test'
).format(idx, self.max_positions))
sample_lens.append(self.dataset.num_tokens(idx))
sample_len = max(sample_len, sample_lens[-1])
num_tokens = (len(batch) + 1) * sample_len
if is_batch_full(num_tokens):
mod_len = max(self.bsz_mult * (len(batch) // self.bsz_mult), len(batch) % self.bsz_mult,)
yield batch[:mod_len]
batch = batch[mod_len:]
sample_lens = sample_lens[mod_len:]
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
batch.append(idx)
if len(batch) > 0:
yield batch
if len(ignored) > 0:
print((
'| WARNING: {} samples have invalid sizes and will be skipped, max_positions={}, first few sample ids={}'
).format(len(ignored), self.max_positions, ignored[:10]))
@contextlib.contextmanager
def numpy_seed(seed):
"""Context manager which seeds the NumPy PRNG with the specified seed and restores the state afterward"""
if seed is None:
yield
return
state = np.random.get_state()
np.random.seed(seed)
try:
yield
finally:
np.random.set_state(state)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from collections import Counter
import os
import torch
# MLPerf compliant dictionary
class Dictionary(object):
"""A mapping from symbols to consecutive integers"""
def __init__(self, pad='<pad>_', eos='<EOS>_'):
self.pad_word, self.eos_word = pad, eos
self.symbols = []
self.count = []
self.indices = {}
# dictionary indexing starts at 1 for consistency with Lua
# Commented out and hard-coded since pad and eos are in the dictionary files already
self.add_symbol('<lua_index_compat>')
self.pad_index = 1
self.eos_index = 2
#self.pad_index = self.add_symbol(pad)
#self.eos_index = self.add_symbol(eos)
#self.add_symbol('<bypass_unk>')
self.nspecial = 3
def __eq__(self, other):
return self.indices == other.indices
def __getitem__(self, idx):
if idx < len(self.symbols):
return self.symbols[idx]
else:
assert idx < len(self.symbols)
def __len__(self):
"""Returns the number of symbols in the dictionary"""
return len(self.symbols)
def index(self, sym):
"""Returns the index of the specified symbol"""
if sym in self.indices:
return self.indices[sym]
else:
assert sym in self.indices
def string(self, tensor, bpe_symbol=None):
"""Helper for converting a tensor of token indices to a string.
Can optionally remove BPE symbols or escape <unk> words.
"""
if torch.is_tensor(tensor) and tensor.dim() == 2:
return '\n'.join(self.string(t) for t in tensor)
def token_string(i):
return self[i]
sent = ' '.join(token_string(i) for i in tensor if i != self.eos())
if bpe_symbol is not None:
sent = (sent + ' ').replace(bpe_symbol, '').rstrip()
return sent
def add_symbol(self, word, n=1):
"""Adds a word to the dictionary"""
if word in self.indices:
idx = self.indices[word]
self.count[idx] = self.count[idx] + n
return idx
else:
idx = len(self.symbols)
self.indices[word] = idx
self.symbols.append(word)
self.count.append(n)
return idx
def update(self, new_dict):
"""Updates counts from new dictionary."""
for word in new_dict.symbols:
idx2 = new_dict.indices[word]
if word in self.indices:
idx = self.indices[word]
self.count[idx] = self.count[idx] + new_dict.count[idx2]
else:
idx = len(self.symbols)
self.indices[word] = idx
self.symbols.append(word)
self.count.append(new_dict.count[idx2])
def finalize(self, threshold=-1, nwords=-1, padding_factor=8):
"""Sort symbols by frequency in descending order, ignoring special ones.
Args:
- threshold defines the minimum word count
- nwords defines the total number of words in the final dictionary,
including special symbols
- padding_factor can be used to pad the dictionary size to be a
multiple of 8, which is important on some hardware (e.g., Nvidia
Tensor Cores).
"""
if nwords <= 0:
nwords = len(self)
new_indices = dict(zip(self.symbols[:self.nspecial], range(self.nspecial)))
new_symbols = self.symbols[:self.nspecial]
new_count = self.count[:self.nspecial]
c = Counter(dict(zip(self.symbols[self.nspecial:], self.count[self.nspecial:])))
for symbol, count in c.most_common(nwords - self.nspecial):
if count >= threshold:
new_indices[symbol] = len(new_symbols)
new_symbols.append(symbol)
new_count.append(count)
else:
break
threshold_nwords = len(new_symbols)
if padding_factor > 1:
i = 0
while threshold_nwords % padding_factor != 0:
symbol = 'madeupword{:04d}'.format(i)
new_indices[symbol] = len(new_symbols)
new_symbols.append(symbol)
new_count.append(0)
i += 1
threshold_nwords += 1
assert len(new_symbols) % padding_factor == 0
assert len(new_symbols) == len(new_indices)
self.count = list(new_count)
self.symbols = list(new_symbols)
self.indices = new_indices
def pad(self):
"""Helper to get index of pad symbol"""
return self.pad_index
def eos(self):
"""Helper to get index of end-of-sentence symbol"""
return self.eos_index
@classmethod
def load(cls, f, ignore_utf_errors=False):
"""Loads the dictionary from a text file with the format:
```
<symbol0>
<symbol1>
...
```
"""
if isinstance(f, str):
try:
if not ignore_utf_errors:
with open(f, 'r', encoding='utf-8') as fd:
return cls.load(fd)
else:
with open(f, 'r', encoding='utf-8', errors='ignore') as fd:
return cls.load(fd)
except FileNotFoundError as fnfe:
raise fnfe
except Exception:
raise Exception("Incorrect encoding detected in {}, please rebuild the dataset".format(f))
d = cls()
for line in f.readlines():
word = line.strip()[1:-1] ## Remove the single quotes
count = 1
d.indices[word] = len(d.symbols)
d.symbols.append(word)
d.count.append(count)
n_pad_tokens_on_end = 33712 - len(d.symbols)
#assert n_pad_tokens_on_end == 3 ## DEBUG: remove later, sanity check
for i in range(n_pad_tokens_on_end):
pad_str = '<pad000' + str(i) + '>'
d.indices[pad_str] = len(d.symbols)
d.symbols.append(pad_str)
d.count.append(1)
return d
def save(self, f):
"""Stores dictionary into a text file"""
if isinstance(f, str):
os.makedirs(os.path.dirname(f), exist_ok=True)
with open(f, 'w', encoding='utf-8') as fd:
return self.save(fd)
for symbol, count in zip(self.symbols[self.nspecial:], self.count[self.nspecial:]):
print('{} {}'.format(symbol, count), file=f)
def dummy_sentence(self, length):
t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long()
t[-1] = self.eos()
return t
class Dictionary_fairseq(object):
"""A mapping from symbols to consecutive integers"""
def __init__(self, pad='<pad>', eos='</s>', unk='<unk>'):
self.unk_word, self.pad_word, self.eos_word = unk, pad, eos
self.symbols = []
self.count = []
self.indices = {}
# dictionary indexing starts at 1 for consistency with Lua
self.add_symbol('<Lua heritage>')
self.pad_index = self.add_symbol(pad)
self.eos_index = self.add_symbol(eos)
self.unk_index = self.add_symbol(unk)
self.nspecial = len(self.symbols)
def __eq__(self, other):
return self.indices == other.indices
def __getitem__(self, idx):
if idx < len(self.symbols):
return self.symbols[idx]
return self.unk_word
def __len__(self):
"""Returns the number of symbols in the dictionary"""
return len(self.symbols)
def index(self, sym):
"""Returns the index of the specified symbol"""
if sym in self.indices:
return self.indices[sym]
return self.unk_index
def string(self, tensor, bpe_symbol=None, escape_unk=False):
"""Helper for converting a tensor of token indices to a string.
Can optionally remove BPE symbols or escape <unk> words.
"""
if torch.is_tensor(tensor) and tensor.dim() == 2:
return '\n'.join(self.string(t) for t in tensor)
def token_string(i):
if i == self.unk():
return self.unk_string(escape_unk)
else:
return self[i]
sent = ' '.join(token_string(i) for i in tensor if i != self.eos())
if bpe_symbol is not None:
sent = (sent + ' ').replace(bpe_symbol, '').rstrip()
return sent
def unk_string(self, escape=False):
"""Return unknown string, optionally escaped as: <<unk>>"""
if escape:
return '<{}>'.format(self.unk_word)
else:
return self.unk_word
def add_symbol(self, word, n=1):
"""Adds a word to the dictionary"""
if word in self.indices:
idx = self.indices[word]
self.count[idx] = self.count[idx] + n
return idx
else:
idx = len(self.symbols)
self.indices[word] = idx
self.symbols.append(word)
self.count.append(n)
return idx
def update(self, new_dict):
"""Updates counts from new dictionary."""
for word in new_dict.symbols:
idx2 = new_dict.indices[word]
if word in self.indices:
idx = self.indices[word]
self.count[idx] = self.count[idx] + new_dict.count[idx2]
else:
idx = len(self.symbols)
self.indices[word] = idx
self.symbols.append(word)
self.count.append(new_dict.count[idx2])
def finalize(self, threshold=-1, nwords=-1, padding_factor=8):
"""Sort symbols by frequency in descending order, ignoring special ones.
Args:
- threshold defines the minimum word count
- nwords defines the total number of words in the final dictionary,
including special symbols
- padding_factor can be used to pad the dictionary size to be a
multiple of 8, which is important on some hardware (e.g., Nvidia
Tensor Cores).
"""
if nwords <= 0:
nwords = len(self)
new_indices = dict(zip(self.symbols[:self.nspecial], range(self.nspecial)))
new_symbols = self.symbols[:self.nspecial]
new_count = self.count[:self.nspecial]
c = Counter(dict(zip(self.symbols[self.nspecial:], self.count[self.nspecial:])))
for symbol, count in c.most_common(nwords - self.nspecial):
if count >= threshold:
new_indices[symbol] = len(new_symbols)
new_symbols.append(symbol)
new_count.append(count)
else:
break
threshold_nwords = len(new_symbols)
if padding_factor > 1:
i = 0
while threshold_nwords % padding_factor != 0:
symbol = 'madeupword{:04d}'.format(i)
new_indices[symbol] = len(new_symbols)
new_symbols.append(symbol)
new_count.append(0)
i += 1
threshold_nwords += 1
assert len(new_symbols) % padding_factor == 0
assert len(new_symbols) == len(new_indices)
self.count = list(new_count)
self.symbols = list(new_symbols)
self.indices = new_indices
def pad(self):
"""Helper to get index of pad symbol"""
return self.pad_index
def eos(self):
"""Helper to get index of end-of-sentence symbol"""
return self.eos_index
def unk(self):
"""Helper to get index of unk symbol"""
return self.unk_index
@classmethod
def load(cls, f, ignore_utf_errors=False):
"""Loads the dictionary from a text file with the format:
```
<symbol0> <count0>
<symbol1> <count1>
...
```
"""
if isinstance(f, str):
try:
if not ignore_utf_errors:
with open(f, 'r', encoding='utf-8') as fd:
return cls.load(fd)
else:
with open(f, 'r', encoding='utf-8', errors='ignore') as fd:
return cls.load(fd)
except FileNotFoundError as fnfe:
raise fnfe
except Exception:
raise Exception("Incorrect encoding detected in {}, please "
"rebuild the dataset".format(f))
d = cls()
for line in f.readlines():
idx = line.rfind(' ')
word = line[:idx]
count = int(line[idx+1:])
d.indices[word] = len(d.symbols)
d.symbols.append(word)
d.count.append(count)
return d
def save(self, f):
"""Stores dictionary into a text file"""
if isinstance(f, str):
os.makedirs(os.path.dirname(f), exist_ok=True)
with open(f, 'w', encoding='utf-8') as fd:
return self.save(fd)
for symbol, count in zip(self.symbols[self.nspecial:], self.count[self.nspecial:]):
print('{} {}'.format(symbol, count), file=f)
def dummy_sentence(self, length):
t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long()
t[-1] = self.eos()
return t
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.utils.data
class FairseqDataset(torch.utils.data.Dataset):
"""A dataset that provides helpers for batching."""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def collater(self, samples):
"""Merge a list of samples to form a mini-batch."""
raise NotImplementedError
def get_dummy_batch(self, num_tokens, max_positions):
"""Return a dummy batch with a given number of tokens."""
raise NotImplementedError
def num_tokens(self, index):
"""Return an example's length (number of tokens), used for batching."""
raise NotImplementedError
def ordered_indices(self, seed=None):
"""Ordered indices for batching."""
raise NotImplementedError
def valid_size(self, index, max_positions):
"""Check if an example's size is valid according to max_positions."""
raise NotImplementedError
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
import struct
import numpy as np
import torch
from fairseq.tokenizer import Tokenizer
def read_longs(f, n):
a = np.empty(n, dtype=np.int64)
f.readinto(a)
return a
def write_longs(f, a):
f.write(np.array(a, dtype=np.int64))
dtypes = {
1: np.uint8,
2: np.int8,
3: np.int16,
4: np.int32,
5: np.int64,
6: np.float,
7: np.double,
}
def code(dtype):
for k in dtypes.keys():
if dtypes[k] == dtype:
return k
def index_file_path(prefix_path):
return prefix_path + '.idx'
def data_file_path(prefix_path):
return prefix_path + '.bin'
class IndexedDataset(torch.utils.data.Dataset):
"""Loader for TorchNet IndexedDataset"""
def __init__(self, path):
super().__init__()
with open(index_file_path(path), 'rb') as f:
magic = f.read(8)
assert magic == b'TNTIDX\x00\x00'
version = f.read(8)
assert struct.unpack('<Q', version) == (1,)
code, self.element_size = struct.unpack('<QQ', f.read(16))
self.dtype = dtypes[code]
self.size, self.s = struct.unpack('<QQ', f.read(16))
self.dim_offsets = read_longs(f, self.size + 1)
self.data_offsets = read_longs(f, self.size + 1)
self.sizes = read_longs(f, self.s)
self.read_data(path)
def read_data(self, path):
self.data_file = open(data_file_path(path), 'rb', buffering=0)
def check_index(self, i):
if i < 0 or i >= self.size:
raise IndexError('index out of range')
def __del__(self):
self.data_file.close()
def __getitem__(self, i):
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
#a += 1 ## DEBUG: lua_index_compat
item = torch.from_numpy(a).long()
return item
def __len__(self):
return self.size
@staticmethod
def exists(path):
return (
os.path.exists(index_file_path(path)) and
os.path.exists(data_file_path(path))
)
class MockedInMemoryDataset(IndexedDataset):
"""Loader for TorchNet IndexedDataset, keeps all the data in memory"""
def __init__(self, path, n_seq_pairs_in_mock_data, uniform_n_seq_per_batch, uniform_seq_len_per_batch):
self.dtype = np.int64
self.uniform_n_seq_per_batch = uniform_n_seq_per_batch
self.uniform_seq_len_per_batch = uniform_seq_len_per_batch
self.size = n_seq_pairs_in_mock_data
self.sizes = []
for i in range(n_seq_pairs_in_mock_data):
self.sizes.append(uniform_seq_len_per_batch)
def __del__(self):
pass
def __getitem__(self, i):
self.check_index(i)
arbitrary_token_id = 55 # Just not a reserved token
a = np.ones((self.uniform_seq_len_per_batch,), dtype=self.dtype) * arbitrary_token_id
a[-1] = 2 # Manually add an <EOS>
#a[self.uniform_seq_len_per_batch-1] = 2 # Manually add an <EOS>
return torch.from_numpy(a).long()
class IndexedInMemoryDataset(IndexedDataset):
"""Loader for TorchNet IndexedDataset, keeps all the data in memory"""
def read_data(self, path):
self.data_file = open(data_file_path(path), 'rb')
self.buffer = np.empty(self.data_offsets[-1], dtype=self.dtype)
self.data_file.readinto(self.buffer)
#print('buffer max:', np.max(self.buffer), np.min(self.buffer))
#self.buffer[self.buffer > 0] += 1 ## DEBUG
#self.buffer += 1 ## DEBUG
#print('buffer max after:', np.max(self.buffer), np.min(self.buffer))
self.data_file.close()
def __del__(self):
pass
def __getitem__(self, i):
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
np.copyto(a, self.buffer[self.data_offsets[i]:self.data_offsets[i + 1]])
return torch.from_numpy(a).long()
class IndexedRawTextDataset(IndexedDataset):
"""Takes a text file as input and binarizes it in memory at instantiation.
Original lines are also kept in memory"""
def __init__(self, path, dictionary, append_eos=True, reverse_order=False):
self.tokens_list = []
self.lines = []
self.sizes = []
self.append_eos = append_eos
self.reverse_order = reverse_order
self.read_data(path, dictionary)
self.size = len(self.tokens_list)
def read_data(self, path, dictionary):
with open(path, 'r') as f:
for line in f:
self.lines.append(line.strip('\n'))
tokens = Tokenizer.tokenize(
line, dictionary, add_if_not_exist=False,
append_eos=self.append_eos, reverse_order=self.reverse_order,
).long()
self.tokens_list.append(tokens)
self.sizes.append(len(tokens))
self.sizes = np.array(self.sizes)
def __getitem__(self, i):
self.check_index(i)
return self.tokens_list[i]
def get_original_text(self, i):
self.check_index(i)
return self.lines[i]
def __del__(self):
pass
def __len__(self):
return self.size
@staticmethod
def exists(path):
return os.path.exists(path)
class IndexedRawTokenIDDataset(IndexedDataset):
"""Takes a text file containing token IDs (integers written in UTF-8 format) as input and binarizes it in memory at instantiation.
Original lines are also kept in memory"""
def __init__(self, path, dictionary, append_eos=True, reverse_order=False):
self.tokens_list = []
self.lines = []
self.sizes = []
self.append_eos = append_eos
self.reverse_order = reverse_order
self.read_data(path, dictionary)
self.size = len(self.tokens_list)
def read_data(self, path, dictionary):
with open(path, 'r') as f:
for line in f:
if line != '\n':
self.lines.append(line.strip('\n'))
nwords = len(line.split(' '))
tokens = torch.IntTensor(nwords).long()
for idx, tok in enumerate(line.split(' ')):
tokens[idx] = int(tok)
#tokens = line.split(' ')
self.tokens_list.append(tokens)
self.sizes.append(len(tokens))
self.sizes = np.array(self.sizes)
def __getitem__(self, i):
self.check_index(i)
return self.tokens_list[i]
def get_original_text(self, i):
self.check_index(i)
return self.lines[i]
def __del__(self):
pass
def __len__(self):
return self.size
@staticmethod
def exists(path):
return os.path.exists(path)
class IndexedDatasetBuilder(object):
element_sizes = {
np.uint8: 1,
np.int8: 1,
np.int16: 2,
np.int32: 4,
np.int64: 8,
np.float: 4,
np.double: 8
}
def __init__(self, out_file, dtype=np.int32):
self.out_file = open(out_file, 'wb')
self.dtype = dtype
self.data_offsets = [0]
self.dim_offsets = [0]
self.sizes = []
self.element_size = self.element_sizes[self.dtype]
def add_item(self, tensor):
bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype))
self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size)
for s in tensor.size():
self.sizes.append(s)
self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
def finalize(self, index_file):
self.out_file.close()
index = open(index_file, 'wb')
index.write(b'TNTIDX\x00\x00')
index.write(struct.pack('<Q', 1))
index.write(struct.pack('<QQ', code(self.dtype), self.element_size))
index.write(struct.pack('<QQ', len(self.data_offsets) - 1, len(self.sizes)))
write_longs(index, self.dim_offsets)
write_longs(index, self.data_offsets)
write_longs(index, self.sizes)
index.close()
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import numpy as np
import torch
from . import data_utils, FairseqDataset
def collate(samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False, bsz_mult=8, seq_len_multiple=1):
if len(samples) == 0:
return {}
def merge(key, left_pad, move_eos_to_beginning=False):
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx,
eos_idx,
left_pad,
move_eos_to_beginning,
bsz_mult,
seq_len_multiple
)
id = torch.LongTensor([s['id'] for s in samples])
src_tokens = merge('source', left_pad=left_pad_source)
# sort by descending source length
src_lengths = torch.LongTensor([s['source'].numel() for s in samples])
prev_output_tokens = None
target = None
if samples[0].get('target', None) is not None:
target = merge('target', left_pad=left_pad_target)
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens = merge(
'target',
left_pad=left_pad_target,
move_eos_to_beginning=True,
)
ntokens = sum(len(s['target']) for s in samples)
else:
ntokens = sum(len(s['source']) for s in samples)
return {
'id': id,
'ntokens': ntokens,
'net_input': {
'src_tokens': src_tokens,
'src_lengths': src_lengths,
'prev_output_tokens': prev_output_tokens,
},
'target': target,
}
class LanguagePairDataset(FairseqDataset):
"""A pair of torch.utils.data.Datasets."""
def __init__(
self,
src,
src_sizes,
src_dict,
tgt=None,
tgt_sizes=None,
tgt_dict=None,
left_pad_source=True,
left_pad_target=False,
max_source_positions=256,
max_target_positions=256,
seq_len_multiple=1,
shuffle=True
):
if tgt_dict is not None:
assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos()
self.src = src
self.tgt = tgt
self.src_sizes = np.array(src_sizes)
self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None
self.src_dict = src_dict
self.tgt_dict = tgt_dict
self.left_pad_source = left_pad_source
self.left_pad_target = left_pad_target
self.max_source_positions = max_source_positions
self.max_target_positions = max_target_positions
self.seq_len_multiple = seq_len_multiple
self.shuffle = shuffle
print("| Sentences are being padded to multiples of: {}".format(self.seq_len_multiple))
def __getitem__(self, index):
return {
'id': index,
'source': self.src[index],
'target': self.tgt[index] if self.tgt is not None else None,
}
def __len__(self):
return len(self.src)
def collater(self, samples):
"""Merge a list of samples to form a mini-batch."""
return collate(
samples,
pad_idx=self.src_dict.pad(),
eos_idx=self.src_dict.eos(),
left_pad_source=self.left_pad_source,
left_pad_target=self.left_pad_target,
bsz_mult=8,
seq_len_multiple=self.seq_len_multiple,
)
def get_dummy_batch(self, max_tokens_per_batch, max_positions, src_len=256, tgt_len=256):
max_source_positions, max_target_positions = self._get_max_positions(max_positions)
src_len, tgt_len = min(src_len, max_source_positions), min(tgt_len, max_target_positions)
n_seq_per_batch_based_on_longest_seq = max_tokens_per_batch // max(src_len, tgt_len)
return self.collater([
{
'id': i,
'source': self.src_dict.dummy_sentence(src_len),
'target': self.tgt_dict.dummy_sentence(tgt_len) if self.tgt_dict is not None else None,
}
for i in range(n_seq_per_batch_based_on_longest_seq)
])
def num_tokens(self, index):
"""Return an example's length (number of tokens), used for batching.
Args:
index: points to the sequence pair
"""
n_tok_per_seq = max(self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0)
assert self.seq_len_multiple > 0, "Padding multiple has to be greater than 0"
n_tok_per_seq = (n_tok_per_seq + self.seq_len_multiple - 1) // self.seq_len_multiple * self.seq_len_multiple # Padded seq len, rounded up to next multiple
return n_tok_per_seq
def ordered_indices(self, seed=None):
"""Ordered indices for batching."""
if self.shuffle:
indices = np.random.RandomState(seed).permutation(len(self))
else:
indices = np.arange(len(self))
if self.tgt_sizes is not None:
indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')]
return indices[np.argsort(self.src_sizes[indices], kind='mergesort')]
def valid_size(self, index, max_positions):
"""Check if an example's size is valid according to max_positions."""
max_source_positions, max_target_positions = self._get_max_positions(max_positions)
return (
self.src_sizes[index] <= max_source_positions
and (self.tgt_sizes is None or self.tgt_sizes[index] <= max_target_positions)
)
def _get_max_positions(self, max_positions):
if max_positions is None:
return self.max_source_positions, self.max_target_positions
assert len(max_positions) == 2
max_src_pos, max_tgt_pos = max_positions
return min(self.max_source_positions, max_src_pos), min(self.max_target_positions, max_tgt_pos)
def collater_isolated(samples, seq_len_multiple, left_pad_source, left_pad_target):
"""Merge a list of samples to form a mini-batch."""
return collate(
samples,
pad_idx=1,
eos_idx=2,
left_pad_source=left_pad_source,
left_pad_target=left_pad_target,
bsz_mult=8,
seq_len_multiple=seq_len_multiple,
)
def get_dummy_batch_isolated(max_tokens_per_batch, max_positions, seq_len_multiple):
'''Creates a dummy batch'''
max_source_positions, max_target_positions = max_positions[0], max_positions[1]
src_len, tgt_len = max_source_positions, max_target_positions
n_seq_per_batch_based_on_longest_seq = max_tokens_per_batch // max(src_len, tgt_len)
nspecial = 3
ntok_alloc = 33712
eos_id = 2
dummy_seq_src = torch.Tensor(src_len).uniform_(nspecial + 1, ntok_alloc).long()
dummy_seq_src[-1] = eos_id
dummy_seq_tgt = torch.Tensor(tgt_len).uniform_(nspecial + 1, ntok_alloc).long()
dummy_seq_tgt[-1] = eos_id
return collater_isolated([
{
'id': i,
'source': dummy_seq_src,
'target': dummy_seq_tgt
}
for i in range(n_seq_per_batch_based_on_longest_seq)
],
seq_len_multiple,
left_pad_source=True,
left_pad_target=False,
)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import numpy as np
import torch
from . import data_utils, FairseqDataset
def collate(samples, pad_idx, eos_idx):
if len(samples) == 0:
return {}
def merge(key):
return data_utils.collate_tokens(
[s[key] for s in samples], pad_idx, eos_idx, left_pad=False,
)
return {
'id': torch.LongTensor([s['id'] for s in samples]),
'ntokens': sum(len(s['target']) for s in samples),
'net_input': {
'src_tokens': merge('source'),
},
'target': merge('target'),
}
class MonolingualDataset(FairseqDataset):
"""A wrapper around torch.utils.data.Dataset for monolingual data."""
def __init__(self, dataset, sizes, vocab, shuffle):
self.dataset = dataset
self.sizes = np.array(sizes)
self.vocab = vocab
self.shuffle = shuffle
def __getitem__(self, index):
source, target = self.dataset[index]
return {'id': index, 'source': source, 'target': target}
def __len__(self):
return len(self.dataset)
def collater(self, samples):
"""Merge a list of samples to form a mini-batch."""
return collate(samples, self.vocab.pad(), self.vocab.eos())
def get_dummy_batch(self, num_tokens, max_positions, tgt_len=128):
assert isinstance(max_positions, float) or isinstance(max_positions, int)
tgt_len = min(tgt_len, max_positions)
bsz = num_tokens // tgt_len
target = self.vocab.dummy_sentence(tgt_len + 1)
source, target = target[:-1], target[1:]
return self.collater([
{'id': i, 'source': source, 'target': target}
for i in range(bsz)
])
def num_tokens(self, index):
"""Return an example's length (number of tokens), used for batching."""
source, target = self.dataset[index]
return len(source)
def ordered_indices(self, seed=None):
"""Ordered indices for batching."""
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
order.append(np.flip(self.sizes, 0))
return np.lexsort(order)
def valid_size(self, index, max_positions):
"""Check if an example's size is valid according to max_positions."""
assert isinstance(max_positions, float) or isinstance(max_positions, int)
return self.sizes[index] <= max_positions
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import numpy as np
import torch
class TokenBlockDataset(torch.utils.data.Dataset):
"""Break a 1d tensor of tokens into blocks.
The blocks are fetched from the original tensor so no additional memory is allocated.
Args:
tokens: 1d tensor of tokens to break into blocks
sizes: sentence lengths (required for 'complete' and 'eos')
block_size: maximum block size (ignored in 'eos' break mode)
break_mode: Mode used for breaking tokens. Values can be one of:
- 'none': break tokens into equally sized blocks (up to block_size)
- 'complete': break tokens into blocks (up to block_size) such that
blocks contains complete sentences, although block_size may be
exceeded if some sentences exceed block_size
- 'eos': each block contains one sentence (block_size is ignored)
include_targets: return next tokens as targets
"""
def __init__(self, tokens, sizes, block_size, break_mode=None, include_targets=False):
super().__init__()
self.tokens = tokens
self.total_size = len(tokens)
self.include_targets = include_targets
self.slice_indices = []
if break_mode is None or break_mode == 'none':
length = math.ceil(len(tokens) / block_size)
def block_at(i):
start = i * block_size
end = min(start + block_size, len(tokens))
return (start, end)
self.slice_indices = [block_at(i) for i in range(length)]
elif break_mode == 'complete':
assert sizes is not None and sum(sizes) == len(tokens), '{} != {}'.format(sum(sizes), len(tokens))
tok_idx = 0
sz_idx = 0
curr_size = 0
while sz_idx < len(sizes):
if curr_size + sizes[sz_idx] <= block_size or curr_size == 0:
curr_size += sizes[sz_idx]
sz_idx += 1
else:
self.slice_indices.append((tok_idx, tok_idx + curr_size))
tok_idx += curr_size
curr_size = 0
if curr_size > 0:
self.slice_indices.append((tok_idx, tok_idx + curr_size))
elif break_mode == 'eos':
assert sizes is not None and sum(sizes) == len(tokens), '{} != {}'.format(sum(sizes), len(tokens))
curr = 0
for sz in sizes:
# skip samples with just 1 example (which would be just the eos token)
if sz > 1:
self.slice_indices.append((curr, curr + sz))
curr += sz
else:
raise ValueError('Invalid break_mode: ' + break_mode)
self.sizes = np.array([e - s for s, e in self.slice_indices])
def __getitem__(self, index):
s, e = self.slice_indices[index]
item = torch.LongTensor(self.tokens[s:e])
if self.include_targets:
# target is the sentence, for source, rotate item one token to the left (would start with eos)
if s == 0:
source = np.concatenate([self.tokens[-1:], self.tokens[0:e - 1]])
else:
source = self.tokens[s - 1:e - 1]
return torch.LongTensor(source), item
return item
def __len__(self):
return len(self.slice_indices)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import pickle
import torch.distributed
from fairseq import utils
def is_master(args):
return args.distributed_rank == 0
def distributed_init(args):
if args.distributed_world_size == 1:
raise ValueError('Cannot initialize distributed with distributed_world_size=1')
print('| distributed init (rank {}): {}'.format(
args.distributed_rank, args.distributed_init_method), flush=True)
if args.distributed_init_method.startswith('tcp://'):
torch.distributed.init_process_group(
backend=args.distributed_backend, init_method=args.distributed_init_method,
world_size=args.distributed_world_size, rank=args.distributed_rank)
elif args.distributed_init_method.startswith('env://'):
import os
print("| distributed env init. MASTER_ADDR: " + os.environ['MASTER_ADDR'] + ", MASTER_PORT: " + os.environ['MASTER_PORT'] +
", WORLD_SIZE: " + os.environ['WORLD_SIZE'] + ", RANK: " + os.environ['RANK'], flush=True)
torch.distributed.init_process_group(
backend=args.distributed_backend, init_method=args.distributed_init_method)
print("| distributed init done!", flush=True)
args.distributed_world_size = int(os.environ['WORLD_SIZE'])
else:
torch.distributed.init_process_group(
backend=args.distributed_backend, init_method=args.distributed_init_method,
world_size=args.distributed_world_size)
args.distributed_rank = torch.distributed.get_rank()
if not is_master(args):
suppress_output()
return args.distributed_rank
def suppress_output():
"""Suppress printing on the current device. Force printing with `force=True`."""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
if 'force' in kwargs:
force = kwargs.pop('force')
if force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def all_gather_list(data, max_size=16384):
"""Gathers arbitrary data from all nodes into a list."""
world_size = torch.distributed.get_world_size()
if not hasattr(all_gather_list, '_in_buffer') or \
max_size != len(all_gather_list._in_buffer):
all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size)
all_gather_list._out_buffers = [
torch.cuda.ByteTensor(max_size)
for i in range(world_size)
]
in_buffer = all_gather_list._in_buffer
out_buffers = all_gather_list._out_buffers
enc = pickle.dumps(data)
enc_size = len(enc)
if enc_size + 2 > max_size:
raise ValueError('encoded data exceeds max_size: {}'.format(enc_size + 2))
assert max_size < 255*256
in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k
in_buffer[1] = enc_size % 255
in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc))
torch.distributed.all_gather(out_buffers, in_buffer.cuda())
result = []
for i in range(world_size):
out_buffer = out_buffers[i]
size = (255 * utils.item(out_buffer[0])) + utils.item(out_buffer[1])
result.append(
pickle.loads(bytes(out_buffer[2:size+2].tolist()))
)
return result
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
"""
Train a network on multiple GPUs.
"""
import torch
import ctypes
from fairseq import optim, utils
from fairseq.meters import AverageMeter
from fairseq.optim import lr_scheduler
from fairseq.trainer import Trainer
def fused_norm(input):
return input.norm(dtype=torch.float32,p=2).item()
class DynamicLossScaler:
def __init__(self, init_scale=2.**15, scale_factor=2., scale_window=2000):
self.loss_scale = init_scale
self.scale_factor = scale_factor
self.scale_window = scale_window
self._iter = 0
self._last_overflow_iter = -1
def update_scale(self, overflow):
if overflow:
self.loss_scale /= self.scale_factor
self._last_overflow_iter = self._iter
elif (self._iter - self._last_overflow_iter) % self.scale_window == 0:
self.loss_scale *= self.scale_factor
self._iter += 1
@staticmethod
def has_overflow(grad_norm):
# detect inf and nan
if grad_norm == float('inf') or grad_norm != grad_norm:
return True
return False
class FP16Trainer(Trainer):
"""Modified trainer for FP16.
We maintain two copies of the model's parameters, both in FP16 and FP32.
We do forward/backward with FP16 and compute the loss + optimize with FP32.
"""
def __init__(self, args, task, model, criterion, allreduce_communicators=None):
super().__init__(args, task, model, criterion, allreduce_communicators)
# convert model to FP16 (but keep criterion FP32)
self.model.half()
# broadcast initial weights from rank=0
# this broadcast isn't required in DistributedFP16Trainer because the
# broadcast is done by DistributedFusedAdam
if torch.distributed.is_available() and torch.distributed.is_initialized():
for p in self.model.parameters():
torch.distributed.broadcast(p, 0)
# dynamically scale loss to reduce overflow
self.scaler = DynamicLossScaler(init_scale=2.**7)
self.meters['loss_scale'] = AverageMeter()
self.grad_denom = 1.0
if self.args.enable_parallel_backward_allred_opt:
import numpy as np
self._flat_grads_parallel = torch.tensor([], dtype=torch.float16).cuda()
self._grads_info = []
grads_size = 0
p_offset = 0
for p_i, p in enumerate([p for p in self.model.parameters() if p.requires_grad]):
p_grads_size = np.prod(list(p.size()))
grads_size += p_grads_size
# register hooks
def wrapper(param, param_i, param_grads_size, param_offset):
def allreduce_hook(grad):
self._do_allreduce(param_i, param_grads_size, param_offset, grad)
if param.requires_grad:
param.register_hook(allreduce_hook)
# print(p_i, p.size(), p_grads_size, p_offset)
self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset})
wrapper(p, p_i, p_grads_size, p_offset)
p_offset += p_grads_size
self._flat_grads_parallel.resize_(grads_size)
# print(grads_size, len(self._flat_grads_parallel), self._flat_grads_parallel.dtype, self._flat_grads_parallel.get_device())
self._allreduce_flush_min_threshold = self.args.parallel_backward_allred_opt_threshold
print("| parallel all-reduce ENABLED. all-reduce threshold: " + str(self._allreduce_flush_min_threshold))
self._grads_generated = [False]*len(self._grads_info)
self._allreduce_processed_idx = len(self._grads_info)-1
self._num_allreduce_sent = 0
print("| # of parallel all-reduce cuda streams: " + str(self.args.parallel_backward_allred_cuda_nstreams))
if allreduce_communicators:
self._allreduce_groups = allreduce_communicators[0]
self._allreduce_streams = allreduce_communicators[1]
else:
raise RuntimeError('Moved communicator init before RUN_START (invalid code path)')
self._allreduce_groups = [torch.distributed.new_group() for _ in range(self.args.parallel_backward_allred_cuda_nstreams)]
self._allreduce_streams = [torch.cuda.Stream() for _ in range(self.args.parallel_backward_allred_cuda_nstreams)]
if self.args.enable_parallel_backward_allred_opt_correctness_check:
self._num_grads_generated = 0
self._all_grads_generated = False
self._allreduce_schedule = []
def _get_flush_bucket(self):
# print([1 if x else 0 for x in self._grads_generated])
flush_bucket = []
size = 0
allreduce_processed_idx_list = []
allreduce_processed_end_idx = self._allreduce_processed_idx
remaining_grads_for_allreduce = self._grads_generated[allreduce_processed_end_idx-len(self._grads_generated)::-1]
# print([1 if x else 0 for x in remaining_grads_for_allreduce])
for s in remaining_grads_for_allreduce:
# print(s,allreduce_processed_end_idx,size)
if s:
allreduce_processed_idx_list.append(allreduce_processed_end_idx)
size += self._grads_info[allreduce_processed_end_idx]["param_grads_size"]
allreduce_processed_end_idx -= 1
else:
break
# print(size, allreduce_processed_idx_list)
ignore_threshold = all(self._grads_generated)
if size >= self._allreduce_flush_min_threshold or ignore_threshold:
# for i in allreduce_processed_idx_list:
# print(i, self._grads_info[i]["param_grads_size"], self._grads_info[i]["param_offset"],size)
if allreduce_processed_idx_list:
start = self._grads_info[(allreduce_processed_idx_list[-1])]["param_offset"]
end = start + size
# print("->", start, end)
flush_bucket = [start, end]
self._allreduce_processed_idx = allreduce_processed_end_idx
if self._allreduce_processed_idx < 0:
# reset
self._grads_generated = [False]*len(self._grads_info)
self._allreduce_processed_idx = len(self._grads_info)-1
return flush_bucket
def _do_allreduce(self, param_i, param_grads_size, param_offset, grad):
if self._last_step == False:
# # ----------------------
# # debugging: do all-reduce in the same stream
# print(self._last_step, self._grads_total, len(self._backward_grads_schedule), param_i, param_offset, param_grads_size, grad.size(), grad.numel(), grad.dtype)
# self._flat_grads_parallel[param_offset:param_offset+param_grads_size].copy_(grad.view(-1))
# self._flat_grads_parallel[param_offset:param_offset+param_grads_size].div_(self.args.distributed_world_size)
# torch.distributed.all_reduce(self._flat_grads_parallel[param_offset:param_offset+param_grads_size])
# # ----------------------
# # ----------------------
# # option #1: send per-layer gradients
# torch.div(grad.view(-1), self.args.distributed_world_size, out=self._flat_grads_parallel[param_offset:param_offset+param_grads_size])
# orig_stream = torch.cuda.current_stream()
# self._reduction_stream.wait_stream(orig_stream)
# with torch.cuda.stream(self._reduction_stream):
# torch.distributed.all_reduce(self._flat_grads_parallel[param_offset:param_offset+param_grads_size])
# # ----------------------
# ----------------------
# option #2: bucket all-reduce based on threshold
self._flat_grads_parallel.record_stream(torch.cuda.current_stream())
torch.div(grad.view(-1), self.args.distributed_world_size, out=self._flat_grads_parallel[param_offset:param_offset+param_grads_size])
self._grads_generated[param_i]=True
flush_bucket = self._get_flush_bucket()
if flush_bucket:
start = flush_bucket[0]
end = flush_bucket[1]
# print("->", start, end)
if self.args.enable_parallel_backward_allred_opt_correctness_check and not self._all_grads_generated:
self._allreduce_schedule.append(flush_bucket)
# orig_stream = torch.cuda.current_stream()
# self._reduction_stream.wait_stream(orig_stream)
# with torch.cuda.stream(self._reduction_stream):
# torch.distributed.all_reduce(self._flat_grads_parallel[start:end])
orig_stream = torch.cuda.current_stream()
allreduce_group = self._allreduce_groups[self._num_allreduce_sent%len(self._allreduce_groups)]
allreduce_stream = self._allreduce_streams[self._num_allreduce_sent%len(self._allreduce_streams)]
allreduce_stream.wait_stream(orig_stream)
with torch.cuda.stream(allreduce_stream):
self._flat_grads_parallel.record_stream(torch.cuda.current_stream())
torch.distributed.all_reduce(self._flat_grads_parallel[start:end], group=allreduce_group)
self._num_allreduce_sent += 1
if self.args.enable_parallel_backward_allred_opt_correctness_check:
self._num_grads_generated += 1
if self._num_grads_generated == len(self._grads_info):
self._all_grads_generated = True
# ----------------------
def _build_optimizer(self):
# create FP32 copy of parameters and grads
params = [p for p in self.model.parameters() if p.requires_grad]
total_param_size = sum(p.data.numel() for p in params)
self.fp32_params = params[0].new(0).float().new(total_param_size)
offset = 0
for p in params:
numel = p.data.numel()
self.fp32_params[offset:offset+numel].copy_(p.data.view(-1))
offset += numel
self.fp32_params = torch.nn.Parameter(self.fp32_params)
#self.fp32_params.grad = self.fp32_params.data.new(total_param_size)
# create optimizer using the copied FP32 params
self._optimizer = optim.build_optimizer(self.args, [self.fp32_params])
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file."""
extra_state['loss_scale'] = self.scaler.loss_scale
super().save_checkpoint(filename, extra_state)
def load_checkpoint(self, filename):
"""Load all training state from a checkpoint file."""
extra_state = super().load_checkpoint(filename)
if extra_state is not None and 'loss_scale' in extra_state:
self.scaler.loss_scale = extra_state['loss_scale']
return extra_state
def zero_grad(self):
# zero both the FP16 and FP32 grads
# self.model.zero_grad() # FP16
# self.optimizer.zero_grad() # FP32
# r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""
for p in self.model.parameters():
p.grad = None
def _backward(self, loss):
self.meters['loss_scale'].reset()
self.meters['loss_scale'].update(self.scaler.loss_scale)
if loss is not None:
# dynamically rescale loss to stay in FP16 range
loss = loss * self.scaler.loss_scale
return super()._backward(loss)
def _all_reduce_and_rescale(self, grad_denom, has_grad = True):
# undo effect of dynamic loss scaling on gradients
self.grad_denom = grad_denom * self.scaler.loss_scale
if self.args.distributed_world_size > 1:
self.grad_denom /= self.args.distributed_world_size
if not self.args.enable_parallel_backward_allred_opt or self._last_step:
# flatten grads into a single buffer
self._flat_grads = self._get_flat_grads(out=None, has_grad = has_grad)
# scale gradients to avoid overflow in all-reduce
self._flat_grads.div_(self.args.distributed_world_size)
# all-reduce flat grads
torch.distributed.all_reduce(self._flat_grads)
else:
# torch.cuda.current_stream().wait_stream(self._reduction_stream)
for allreduce_stream in self._allreduce_streams:
torch.cuda.current_stream().wait_stream(allreduce_stream)
self._flat_grads_parallel.record_stream(torch.cuda.current_stream())
self._flat_grads = self._flat_grads_parallel
if self.args.enable_parallel_backward_allred_opt_correctness_check:
# # ----------------------
# # option #1: send per-layer gradients
# grads = self._get_grads()
# offset = 0
# for g in grads:
# numel = g.numel()
# out = grads[0].new(numel).zero_()
# out.copy_(g.view(-1))
# out.div_(self.args.distributed_world_size)
# torch.distributed.all_reduce(out)
# is_parallel_grads_finite = torch.all(torch.isfinite(self._flat_grads_parallel[offset:offset+numel]))
# is_out_finite = torch.all(torch.isfinite(out))
# assert(is_out_finite == is_parallel_grads_finite)
# if not is_out_finite:
# print("| OVERLAP-CHECK: check inf/nan detected. this batch should be skipped")
# else:
# if not torch.all(torch.eq(out, self._flat_grads_parallel[offset:offset+numel])):
# print(out[0:10], self._flat_grads_parallel[offset:offset+10])
# # for i,_ in enumerate(out):
# # if out[i] != self._flat_grads_parallel[i]:
# # print(i,out[i],self._flat_grads_parallel[i])
# raise RuntimeError('w-gradients received in parallel vs. end differ')
# offset += numel
# # ----------------------
# ----------------------
# option #2: bucket all-reduce based on threshold
# print(self._allreduce_schedule)
out = self._get_flat_grads()
out.div_(self.args.distributed_world_size)
grads_size = 0
for s in self._allreduce_schedule:
start = s[0]
end = s[1]
assert(end > start)
grads_size += (end - start)
torch.distributed.all_reduce(out[start:end])
is_parallel_grads_finite = torch.all(torch.isfinite(self._flat_grads_parallel[start:end]))
is_out_finite = torch.all(torch.isfinite(out[start:end]))
assert(is_out_finite == is_parallel_grads_finite)
if not is_out_finite:
print("| OVERLAP-CHECK: check inf/nan detected. this batch should be skipped")
else:
if not torch.all(torch.eq(out[start:end], self._flat_grads_parallel[start:end])):
print(start, end, out[start:end], self._flat_grads_parallel[start:end])
raise RuntimeError('w-gradients received in parallel vs. end differ')
assert(grads_size == len(self._flat_grads_parallel))
# ----------------------
else:
# flatten grads into a single buffer
self._flat_grads = self._get_flat_grads(out=None, has_grad = has_grad)
grad_norm = fused_norm(self._flat_grads)
# detect overflow and adjust loss scale
overflow = DynamicLossScaler.has_overflow(grad_norm)
self.scaler.update_scale(overflow)
if overflow:
if self.scaler.loss_scale <= self.args.min_loss_scale:
raise Exception((
'Minimum loss scale reached ({}). Your loss is probably exploding. '
'Try lowering the learning rate, using gradient clipping or '
'increasing the batch size.'
).format(self.args.min_loss_scale))
raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale))
return grad_norm
def _opt(self):
# take an optimization step using the FP32 params and grads
#super()._opt()
new_params = self._flat_grads.new_empty(self._flat_grads.size())
self.optimizer.optimizer.step(closure=None, grads=[self._flat_grads], output_params=[new_params], scale=self.grad_denom)
self.zero_grad()
self._num_updates += 1
# update learning rate
self.lr_scheduler.step_update(self._num_updates)
# copy FP32 params back into FP16 model
offset = 0
with torch.no_grad():
for p in self.model.parameters():
if not p.requires_grad:
continue
numel = p.data.numel()
p.set_(new_params[offset:offset+numel].view_as(p.data))
offset += numel
class DistributedFP16Trainer(Trainer):
"""Modified trainer for FP16.
We maintain two copies of the model's parameters, both in FP16 and FP32.
We do forward/backward with FP16 and compute the loss + optimize with FP32.
"""
def __init__(self, args, task, model, criterion, allreduce_communicators=None):
super().__init__(args, task, model, criterion, allreduce_communicators)
# convert model to FP16 (but keep criterion FP32)
self.model.half()
# dynamically scale loss to reduce overflow
self.scaler = DynamicLossScaler(init_scale=2.**7)
self.meters['loss_scale'] = AverageMeter()
# FIXME: Add more meters
self.grad_denom = 1.0
assert (not self.args.enable_parallel_backward_allred_opt), "--distributed-weight-update cannot be combined with --enable-parallel-backward-allred-opt"
def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file."""
# To-Do: gather optimizer buffer chunks before saving state
extra_state['loss_scale'] = self.scaler.loss_scale
super().save_checkpoint(filename, extra_state)
def load_checkpoint(self, filename):
"""Load all training state from a checkpoint file."""
# To-Do: scatter optimizer buffer chunks after restoring state
extra_state = super().load_checkpoint(filename)
if extra_state is not None and 'loss_scale' in extra_state:
self.scaler.loss_scale = extra_state['loss_scale']
return extra_state
#def zero_grad(self):
# for p in self.model.parameters():
# p.grad = None
def _backward(self, loss):
self.meters['loss_scale'].reset()
self.meters['loss_scale'].update(self.scaler.loss_scale)
if loss is not None:
# dynamically rescale loss to stay in FP16 range
loss = loss * self.scaler.loss_scale
rval = super()._backward(loss)
self.optimizer.optimizer.complete_reductions()
return rval
def __process_overflow(self, overflow):
self.scaler.update_scale(overflow)
if overflow:
if self.scaler.loss_scale <= self.args.min_loss_scale:
raise Exception((
'Minimum loss scale reached ({}). Your loss is probably exploding. '
'Try lowering the learning rate, using gradient clipping or '
'increasing the batch size.'
).format(self.args.min_loss_scale))
raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale))
def _all_reduce_and_rescale(self, grad_denom, has_grad = True):
grad_norm = self.optimizer.optimizer.L2_grad_norm
if grad_norm is not None:
overflow = self.scaler.has_overflow(grad_norm)
self.__process_overflow(overflow)
return grad_norm
else:
return None
def _opt(self):
self.optimizer.optimizer.step(skip_overflow_check=self.args.dwu_compute_L2_grad_norm)
self.zero_grad()
self.__process_overflow(False if self.args.dwu_compute_L2_grad_norm or not self.optimizer.optimizer.has_overflow else True)
self._num_updates += 1
# update learning rate
self.lr_scheduler.step_update(self._num_updates)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import time
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class TimeMeter(object):
"""Computes the average occurrence of some event per second"""
def __init__(self, init=0):
self.reset(init)
def reset(self, init=0):
self.init = init
self.start = time.time()
self.n = 0
def update(self, val=1):
self.n += val
@property
def avg(self):
return self.n / self.elapsed_time
@property
def elapsed_time(self):
return self.init + (time.time() - self.start)
class StopwatchMeter(object):
"""Computes the sum/avg duration of some event in seconds"""
def __init__(self):
self.reset()
def start(self):
self.start_time = time.time()
def stop(self, n=1):
if self.start_time is not None:
delta = time.time() - self.start_time
self.sum += delta
self.n += n
self.start_time = None
def reset(self):
self.sum = 0
self.n = 0
self.start_time = None
@property
def avg(self):
return self.sum / self.n
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import importlib
import os
from .fairseq_decoder import FairseqDecoder # noqa: F401
from .fairseq_encoder import FairseqEncoder # noqa: F401
from .fairseq_incremental_decoder import FairseqIncrementalDecoder # noqa: F401
from .fairseq_model import BaseFairseqModel, FairseqModel, FairseqLanguageModel # noqa: F401
from .composite_encoder import CompositeEncoder # noqa: F401
MODEL_REGISTRY = {}
ARCH_MODEL_REGISTRY = {}
ARCH_CONFIG_REGISTRY = {}
def build_model(args, task):
return ARCH_MODEL_REGISTRY[args.arch].build_model(args, task)
def register_model(name):
"""Decorator to register a new model (e.g., LSTM)."""
def register_model_cls(cls):
if name in MODEL_REGISTRY:
raise ValueError('Cannot register duplicate model ({})'.format(name))
if not issubclass(cls, BaseFairseqModel):
raise ValueError('Model ({}: {}) must extend BaseFairseqModel'.format(name, cls.__name__))
MODEL_REGISTRY[name] = cls
return cls
return register_model_cls
def register_model_architecture(model_name, arch_name):
"""Decorator to register a new model architecture (e.g., lstm_luong_wmt_en_de)."""
def register_model_arch_fn(fn):
if model_name not in MODEL_REGISTRY:
raise ValueError('Cannot register model architecture for unknown model type ({})'.format(model_name))
if arch_name in ARCH_MODEL_REGISTRY:
raise ValueError('Cannot register duplicate model architecture ({})'.format(arch_name))
if not callable(fn):
raise ValueError('Model architecture must be callable ({})'.format(arch_name))
ARCH_MODEL_REGISTRY[arch_name] = MODEL_REGISTRY[model_name]
ARCH_CONFIG_REGISTRY[arch_name] = fn
return fn
return register_model_arch_fn
# automatically import any Python files in the models/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
module = file[:file.find('.py')]
importlib.import_module('fairseq.models.' + module)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment