Commit 9e8a8c05 authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
class FairseqTask(object):
"""
A Task defines the data format, stores shared state (e.g., dictionaries) and
provides helpers for building the model/criterion and calculating the loss.
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
pass
def __init__(self, args):
self.args = args
self.datasets = {}
@classmethod
def setup_task(cls, args, **kwargs):
raise NotImplementedError
def load_dataset(self, split, combine=False):
raise NotImplementedError
def dataset(self, split):
"""Return a dataset split."""
from fairseq.data import FairseqDataset
if split not in self.datasets:
raise KeyError('Dataset not loaded: ' + split)
if not isinstance(self.datasets[split], FairseqDataset):
raise TypeError('Datasets are expected to be of type FairseqDataset')
return self.datasets[split]
def build_model(self, args):
from fairseq import models
return models.build_model(args, self)
def build_criterion(self, args):
from fairseq import criterions
return criterions.build_criterion(args, self)
def get_loss(self, model, criterion, sample):
return criterion(model, sample)
@property
def source_dictionary(self):
raise NotImplementedError
@property
def target_dictionary(self):
raise NotImplementedError
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import itertools
import numpy as np
import os
from torch.utils.data import ConcatDataset
from fairseq.data import (
Dictionary, IndexedInMemoryDataset, IndexedRawTextDataset,
MonolingualDataset, TokenBlockDataset,
)
from . import FairseqTask, register_task
@register_task('language_modeling')
class LanguageModelingTask(FairseqTask):
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('data', metavar='DIR', help='path to data directory')
parser.add_argument('--sample-break-mode', metavar='VAL',
choices=['none', 'complete', 'eos'],
help='If omitted or "none", fills each sample with tokens-per-sample '
'tokens. If set to "complete", splits samples only at the end '
'of sentence, but may include multiple sentences per sample. '
'If set to "eos", includes only one sentence per sample.')
parser.add_argument('--tokens-per-sample', default=1024, type=int, metavar='N',
help='max number of tokens per sample for LM dataset')
parser.add_argument('--raw-text', default=False, action='store_true',
help='load raw text dataset')
def __init__(self, args, dictionary):
super().__init__(args)
self.dictionary = dictionary
@classmethod
def setup_task(cls, args, **kwargs):
dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt'))
print('| dictionary: {} types'.format(len(dictionary)))
return cls(args, dictionary)
def load_dataset(self, split, combine=False):
"""Load a dataset split."""
loaded_datasets = []
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
path = os.path.join(self.args.data, split_k)
if self.args.raw_text and IndexedRawTextDataset.exists(path):
ds = IndexedRawTextDataset(path, self.dictionary)
tokens = [t for l in ds.tokens_list for t in l]
elif not self.args.raw_text and IndexedInMemoryDataset.exists(path):
ds = IndexedInMemoryDataset(path)
tokens = ds.buffer
else:
if k > 0:
break
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
loaded_datasets.append(
TokenBlockDataset(
tokens, ds.sizes, self.args.tokens_per_sample, self.args.sample_break_mode,
include_targets=True
))
print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1])))
if not combine:
break
if len(loaded_datasets) == 1:
dataset = loaded_datasets[0]
sizes = dataset.sizes
else:
dataset = ConcatDataset(loaded_datasets)
sizes = np.concatenate([ds.sizes for ds in loaded_datasets])
self.datasets[split] = MonolingualDataset(dataset, sizes, self.dictionary, shuffle=False)
@property
def target_dictionary(self):
return self.dictionary
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import itertools
import numpy as np
import os
from torch.utils.data import ConcatDataset
from fairseq import options
from fairseq.data import (
data_utils, Dictionary, LanguagePairDataset,
IndexedRawTextDataset,
)
from fairseq.data.indexed_dataset import IndexedRawTokenIDDataset
from fairseq.data.indexed_dataset import IndexedInMemoryDataset
from fairseq.data.indexed_dataset import MockedInMemoryDataset
from . import FairseqTask, register_task
@register_task('translation')
class TranslationTask(FairseqTask):
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('data', metavar='DIR', help='path to data directory')
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', help='source language')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', help='target language')
parser.add_argument('--raw-text', action='store_true', help='load raw text dataset')
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL', help='pad the source on the left (default: True)')
parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL', help='pad the target on the left (default: False)')
parser.add_argument('--max-source-positions', default=256, type=int, metavar='N', help='max number of tokens in the source sequence')
parser.add_argument('--max-target-positions', default=256, type=int, metavar='N', help='max number of tokens in the target sequence')
parser.add_argument('--seq-len-multiple', default=1, type=int, metavar='N', help='Pad sequences to a multiple of N')
parser.add_argument('--uniform-n-seq-per-batch', default=None, type=int, metavar='N', help='Make uniform batches with this many sequences')
parser.add_argument('--uniform-seq-len-per-batch', default=None, type=int, metavar='N', help='Make uniform batches with this seq len')
parser.add_argument('--uniform-n-seq-in-dataset', default=None, type=int, metavar='N', help='If creating uniform batches with mock data, this is the dataset size')
def __init__(self, args, src_dict, tgt_dict):
super().__init__(args)
self.src_dict = src_dict
self.tgt_dict = tgt_dict
@classmethod
def setup_task(cls, args, **kwargs):
args.left_pad_source = options.eval_bool(args.left_pad_source)
args.left_pad_target = options.eval_bool(args.left_pad_target)
# find language pair automatically
if args.source_lang is None or args.target_lang is None:
args.source_lang, args.target_lang = data_utils.infer_language_pair(args.data)
if args.source_lang is None or args.target_lang is None:
raise Exception('Could not infer language pair, please provide it explicitly')
# load dictionaries
src_dict = Dictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(args.source_lang)))
tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(args.target_lang)))
assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos()
print('| [{}] dictionary: {} types'.format(args.source_lang, len(src_dict)))
print('| [{}] dictionary: {} types'.format(args.target_lang, len(tgt_dict)))
return cls(args, src_dict, tgt_dict)
def load_dataset(self, split, combine=False):
"""Load a dataset split."""
def split_exists(split, src, tgt, lang):
filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang))
print('filename:', filename)
print('raw_text:', self.args.raw_text)
if self.args.raw_text and IndexedRawTokenIDDataset.exists(filename):
return True
elif not self.args.raw_text and IndexedInMemoryDataset.exists(filename):
return True
return False
def indexed_dataset(path, dictionary):
if self.args.raw_text and not self.args.uniform_n_seq_per_batch and not self.args.uniform_seq_len_per_batch:
return IndexedRawTokenIDDataset(path, dictionary)
elif IndexedInMemoryDataset.exists(path) and not self.args.uniform_n_seq_per_batch and not self.args.uniform_seq_len_per_batch:
return IndexedInMemoryDataset(path)
elif self.args.uniform_n_seq_per_batch and self.args.uniform_seq_len_per_batch:
if self.args.uniform_n_seq_in_dataset:
return MockedInMemoryDataset(path, self.args.uniform_n_seq_in_dataset, self.args.uniform_n_seq_per_batch, self.args.uniform_seq_len_per_batch)
return None
src_datasets = []
tgt_datasets = []
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
# infer langcode
src, tgt = self.args.source_lang, self.args.target_lang
if split_exists(split_k, src, tgt, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split_k, src, tgt))
elif split_exists(split_k, tgt, src, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split_k, tgt, src))
else:
if k > 0:
break
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
src_datasets.append(indexed_dataset(prefix + src, self.src_dict))
tgt_datasets.append(indexed_dataset(prefix + tgt, self.tgt_dict))
print('| {} {} {} examples'.format(self.args.data, split_k, len(src_datasets[-1])))
if not combine:
break
assert len(src_datasets) == len(tgt_datasets)
if len(src_datasets) == 1:
src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
src_sizes = src_dataset.sizes
tgt_sizes = tgt_dataset.sizes
else:
src_dataset = ConcatDataset(src_datasets)
tgt_dataset = ConcatDataset(tgt_datasets)
src_sizes = np.concatenate([ds.sizes for ds in src_datasets])
tgt_sizes = np.concatenate([ds.sizes for ds in tgt_datasets])
print('srcline:', src_dataset[0])
self.datasets[split] = LanguagePairDataset(
src_dataset, src_sizes, self.src_dict,
tgt_dataset, tgt_sizes, self.tgt_dict,
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
max_source_positions=self.args.max_source_positions,
max_target_positions=self.args.max_target_positions,
seq_len_multiple=self.args.seq_len_multiple,
)
@property
def source_dictionary(self):
return self.src_dict
@property
def target_dictionary(self):
return self.tgt_dict
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from collections import Counter
import re
import torch
SPACE_NORMALIZER = re.compile("\s+")
def tokenize_line(line):
line = SPACE_NORMALIZER.sub(" ", line)
line = line.strip()
return line.split()
# This is just to binarize UTF-8 files containing the already-tokenized integer IDs (Reading raw data is sloooowwwwww.)
class MockTokenizer:
@staticmethod
def add_file_to_dictionary(filename, dict, tokenize):
with open(filename, mode='r', newline='\n') as f:
for line in f:
for tok in line.split(' '):
dict.add_symbol(tok)
#dict.add_symbol(dict.eos_word) # Already added by MLPerf preprocessing
@staticmethod
def binarize(filename, dict, consumer, tokenize=tokenize_line,
append_eos=True, reverse_order=False):
nseq, ntok = 0, 0
replaced = Counter()
def replaced_consumer(word, idx):
pass
with open(filename, mode='r', newline='\n') as f:
for line in f:
nwords = len(line.split(' '))
ids = torch.IntTensor(nwords).long()
for idx, tok in enumerate(line.split(' ')):
ids[idx] = int(tok)
nseq += 1
consumer(ids)
ntok += len(ids)
return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': len(replaced)}
@staticmethod
def tokenize(line, dict, tokenize=tokenize_line, add_if_not_exist=True,
consumer=None, append_eos=True, reverse_order=False, lowercase=False):
return 0 # Do not call this....
class Tokenizer:
@staticmethod
def add_file_to_dictionary(filename, dict, tokenize):
with open(filename, 'r') as f:
for line in f:
for word in tokenize(line):
dict.add_symbol(word)
dict.add_symbol(dict.eos_word)
@staticmethod
def binarize(filename, dict, consumer, tokenize=tokenize_line,
append_eos=True, reverse_order=False):
nseq, ntok = 0, 0
replaced = Counter()
def replaced_consumer(word, idx):
pass
with open(filename, 'r') as f:
for line in f:
ids = Tokenizer.tokenize(
line=line,
dict=dict,
tokenize=tokenize,
add_if_not_exist=False,
consumer=replaced_consumer,
append_eos=append_eos,
reverse_order=reverse_order,
)
nseq += 1
consumer(ids)
ntok += len(ids)
return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': len(replaced)}
@staticmethod
def tokenize(line, dict, tokenize=tokenize_line, add_if_not_exist=True,
consumer=None, append_eos=True, reverse_order=False, lowercase=False):
words = tokenize(line)
if lowercase:
lc_words = []
for word in words:
lc_words.append(word.lower())
words = lc_words
if reverse_order:
words = list(reversed(words))
nwords = len(words)
ids = torch.IntTensor(nwords + 1 if append_eos else nwords)
for i, word in enumerate(words):
if add_if_not_exist:
idx = dict.add_symbol(word)
else:
idx = dict.index(word)
if consumer is not None:
consumer(word, idx)
ids[i] = idx
if append_eos:
ids[nwords] = dict.eos_index
return ids
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
"""
Train a network across multiple GPUs.
"""
from collections import defaultdict, OrderedDict
import contextlib
from itertools import chain
import torch
import apex_C
from fairseq import distributed_utils, optim, utils
from fairseq.meters import AverageMeter, TimeMeter
from fairseq.optim import lr_scheduler
class Trainer(object):
"""Main class for data parallel training.
This class supports data parallel training, where multiple workers each
have a full model replica and gradients are accumulated synchronously via
torch.distributed.all_reduce.
"""
def __init__(self, args, task, model, criterion, allreduce_communicators=None):
if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported')
self.args = args
# copy model and criterion to current device
self.task = task
self.model = model.cuda()
self.criterion = criterion.cuda()
# initialize meters
self.meters = OrderedDict()
self.meters['train_loss'] = AverageMeter()
self.meters['train_nll_loss'] = AverageMeter()
self.meters['valid_loss'] = AverageMeter()
self.meters['valid_nll_loss'] = AverageMeter()
self.meters['wps'] = TimeMeter() # words per second
self.meters['ups'] = TimeMeter() # updates per second
self.meters['wpb'] = AverageMeter() # words per batch
self.meters['bsz'] = AverageMeter() # sentences per batch
self.meters['gnorm'] = AverageMeter() # gradient norm
self.meters['clip'] = AverageMeter() # % of updates clipped
self.meters['oom'] = AverageMeter() # out of memory
self.meters['wall'] = TimeMeter() # wall time in seconds
self._buffered_stats = defaultdict(lambda: [])
self._flat_grads = None
self._num_updates = 0
self._optim_history = None
self._optimizer = None
self._stats_allreduce_stream = torch.cuda.Stream()
self._last_step = False
if self.args.enable_parallel_backward_allred_opt and not self.args.distributed_world_size > 1:
raise RuntimeError('--enable-parallel-backward-allred-opt is only meant for distributed training')
if self.args.enable_parallel_backward_allred_opt and not self.args.fp16:
raise RuntimeError('--enable-parallel-backward-allred-opt only works with FP16 training')
# rework all_gather_list implementation to mitigate memcpy overheads
# [sample_sizes,nsentences,loss,nll_loss,ooms_fwd,ooms_bwd]
self._all_reduce_list = [0.0] * 6
@property
def optimizer(self):
if self._optimizer is None:
self._build_optimizer()
return self._optimizer
def _build_optimizer(self):
self._optimizer = optim.build_optimizer(self.args, self.model.parameters())
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self._optimizer)
def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file."""
# FIXME: Gather optimizer state
if distributed_utils.is_master(self.args): # only save one checkpoint
extra_state['train_meters'] = self.meters
utils.save_state(
filename, self.args, self.model, self.criterion, self.optimizer,
self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
)
def load_checkpoint(self, filename, load_optim=True):
"""Load all training state from a checkpoint file."""
# FIXME: Scatter optimizer state
extra_state, optim_history, last_optim_state = \
utils.load_model_state(filename, self.model)
if last_optim_state is not None:
# rebuild optimizer after loading model, since params may have changed
self._build_optimizer()
if load_optim:
self._optim_history = optim_history
# only reload optimizer and lr_scheduler if they match
last_optim = self._optim_history[-1]
if last_optim['criterion_name'] == self.criterion.__class__.__name__:
self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])
if last_optim['optimizer_name'] == self.optimizer.__class__.__name__:
self.optimizer.load_state_dict(last_optim_state)
self._num_updates = last_optim['num_updates']
if extra_state is not None and 'train_meters' in extra_state:
self.meters = extra_state['train_meters']
del extra_state['train_meters']
return extra_state
def train_step(self, sample, update_params=True, last_step=False):
"""Do forward, backward and parameter update."""
# Set seed based on args.seed and the update number so that we get
# reproducible results when resuming from checkpoints
# INFO: Given we don't checkpoint, turning off setting the set.
#seed = self.args.seed + self.get_num_updates()
#torch.manual_seed(seed)
#torch.cuda.manual_seed(seed)
self._last_step = last_step
if self.args.distributed_weight_update >= 2:
self.optimizer.optimizer.set_last_step(self._last_step)
# forward and backward pass
sample = self._prepare_sample(sample)
if sample is not None:
my_ntokens = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
else:
my_ntokens = 0
with torch.cuda.stream(self._stats_allreduce_stream):
global_ntokens = torch.full((1,), my_ntokens, dtype=torch.float32, device='cuda')
if self.args.distributed_world_size > 1:
torch.distributed.all_reduce(global_ntokens)
loss, sample_size, logging_output, oom_fwd = self._forward(sample)
torch.cuda.current_stream().wait_stream(self._stats_allreduce_stream)
if self.args.distributed_weight_update >= 2:
self.optimizer.optimizer.set_global_scale(global_ntokens[0]*self.scaler.loss_scale/torch.distributed.get_world_size())
oom_bwd = self._backward(loss)
# buffer stats and logging outputs
self._buffered_stats['sample_sizes'].append(sample_size)
self._buffered_stats['logging_outputs'].append(logging_output)
self._buffered_stats['ooms_fwd'].append(oom_fwd)
self._buffered_stats['ooms_bwd'].append(oom_bwd)
# rework all_gather_list
assert(sample_size == logging_output.get('sample_size', 0.0))
assert(sample_size == logging_output.get('ntokens', 0.0))
self._all_reduce_list[0] += sample_size
self._all_reduce_list[1] += logging_output.get('nsentences', 0.0)
self._all_reduce_list[2] += logging_output.get('loss', 0.0)
self._all_reduce_list[3] += logging_output.get('nll_loss', 0.0)
self._all_reduce_list[4] += oom_fwd
self._all_reduce_list[5] += oom_bwd
# update parameters
if update_params:
check_against_old_code = False
# check_against_old_code = True
if check_against_old_code:
# gather logging outputs from all replicas
sample_sizes = self._buffered_stats['sample_sizes']
logging_outputs = self._buffered_stats['logging_outputs']
ooms_fwd = self._buffered_stats['ooms_fwd']
ooms_bwd = self._buffered_stats['ooms_bwd']
# print(sample_sizes,logging_outputs,ooms_fwd,ooms_bwd)
if self.args.distributed_world_size > 1:
sample_sizes, logging_outputs, ooms_fwd, ooms_bwd = map(
lambda l: list(chain.from_iterable(l)),
zip(*distributed_utils.all_gather_list(
(sample_sizes, logging_outputs, ooms_fwd, ooms_bwd)
))
)
# print("\n",sample_sizes, logging_outputs, ooms_fwd, ooms_bwd)
ooms_fwd = sum(ooms_fwd)
ooms_bwd = sum(ooms_bwd)
if ooms_fwd == self.args.distributed_world_size:
print('| WARNING: OOM in all workers, skipping batch')
self.zero_grad()
return None
# aggregate stats and logging outputs
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
grad_denom = self.criterion.__class__.grad_denom(sample_sizes)
assert( grad_denom == sum(sample_sizes) )
assert( grad_denom == ntokens )
assert( grad_denom == agg_logging_output['sample_size'] )
all_gather_list_tensor = torch.cuda.DoubleTensor([grad_denom, nsentences, agg_logging_output['loss'], agg_logging_output['nll_loss'], ooms_fwd, ooms_bwd])
print("\n",all_gather_list_tensor)
# rework all_gather_list
all_reduce_list_tensor = torch.cuda.DoubleTensor(self._all_reduce_list)
if self.args.distributed_world_size > 1 and self.args.enable_global_stats:
torch.distributed.all_reduce(all_reduce_list_tensor)
# Skip `div` if distributed and not reducing stats
if self.args.distributed_world_size == 1 or self.args.enable_global_stats:
all_reduce_list_tensor[2:4].div_((all_reduce_list_tensor[0:1]*torch.log(torch.cuda.DoubleTensor([2]))))
if check_against_old_code:
print(all_reduce_list_tensor)
assert(grad_denom == all_reduce_list_tensor[0].item())
assert(nsentences == all_reduce_list_tensor[1].item())
# compare loss values
# print(all_gather_list_tensor[2:4] - all_reduce_list_tensor[2:4])
assert(torch.all(torch.lt(torch.abs(torch.add(all_gather_list_tensor[2:4], -all_reduce_list_tensor[2:4])), 1e-12)))
assert(ooms_fwd == all_reduce_list_tensor[4].item())
assert(ooms_bwd == all_reduce_list_tensor[5].item())
agg_logging_output = {}
[grad_denom, nsentences, agg_logging_output['loss'], agg_logging_output['nll_loss'], ooms_fwd, ooms_bwd] = all_reduce_list_tensor.tolist()
# `grad_denom` should be based on pre-allreduce, in case we skipped stats allreduce
grad_denom = global_ntokens.item()
agg_logging_output['sample_size'] = grad_denom
ntokens = grad_denom
# print([grad_denom, ntokens, nsentences, agg_logging_output['loss'], agg_logging_output['nll_loss'], ooms_fwd, ooms_bwd])
if ooms_fwd == self.args.distributed_world_size:
print('| WARNING: OOM in all workers, skipping batch')
self.zero_grad()
return None
try:
# all-reduce and rescale gradients, then take an optimization step
grad_norm = self._all_reduce_and_rescale(grad_denom, sample is not None)
self._opt()
# update meters
self.meters['wps'].update(ntokens)
self.meters['ups'].update(1.)
self.meters['wpb'].update(ntokens)
self.meters['bsz'].update(nsentences)
if grad_norm is not None:
self.meters['gnorm'].update(grad_norm)
self.meters['clip'].update(1. if grad_norm > self.args.clip_norm else 0.)
self.meters['oom'].update(ooms_fwd + ooms_bwd)
# update loss meters for training
if 'loss' in agg_logging_output:
self.meters['train_loss'].update(agg_logging_output['loss'], grad_denom)
# criterions can optionally log the NLL loss too
if 'nll_loss' in agg_logging_output:
self.meters['train_nll_loss'].update(agg_logging_output['nll_loss'], ntokens)
except OverflowError as e:
self.zero_grad()
print('| WARNING: overflow detected, ' + str(e))
self.clear_buffered_stats()
return agg_logging_output
else:
return None # buffering updates
def _forward(self, sample, eval=False):
loss = None
sample_size = 0
logging_output = {
'ntokens': sample['ntokens'] if sample is not None else 0,
'nsentences': sample['target'].size(0) if sample is not None else 0,
}
oom = 0
try:
# prepare model and optimizer
if eval:
self.model.eval()
else:
self.model.train()
if sample is not None:
with torch.no_grad() if eval else contextlib.ExitStack():
# calculate loss and sample size
loss, sample_size, logging_output_ = self.task.get_loss(self.model, self.criterion, sample)
logging_output.update(logging_output_)
except RuntimeError as e:
if not eval and 'out of memory' in str(e):
print('| WARNING: ran out of memory, skipping batch')
oom = 1
loss = None
else:
raise e
return loss, sample_size, logging_output, oom
def _backward(self, loss):
oom = 0
if loss is not None:
try:
# backward pass
loss.backward()
except RuntimeError as e:
if 'out of memory' in str(e):
print('| WARNING: ran out of memory, skipping batch')
oom = 1
self.zero_grad()
else:
raise e
return oom
def _all_reduce_and_rescale(self, grad_denom, non_empty = True):
# flatten grads into a single buffer and all-reduce
flat_grads = self._flat_grads = self._get_flat_grads(out=self._flat_grads, has_grad = non_empty)
if self.args.distributed_world_size > 1:
torch.distributed.all_reduce(flat_grads)
# rescale and clip gradients
flat_grads.div_(grad_denom)
grad_norm = utils.clip_grad_norm_(flat_grads, self.args.clip_norm)
# copy grads back into model parameters
self._set_flat_grads(flat_grads)
return grad_norm
def _get_grads(self, has_grad = True):
grads = []
for name, p in self.model.named_parameters():
if not p.requires_grad:
continue
if p.grad is None:
if has_grad:
raise RuntimeError('Model parameter did not receive gradient: ' + name + '. '
'Use the param in the forward pass or set requires_grad=False')
else:
p.grad = torch.zeros_like(p)
grads.append(p.grad.data)
return grads
def _get_flat_grads(self, out=None, has_grad = True):
grads = self._get_grads(has_grad)
#if out is None:
# grads_size = sum(g.numel() for g in grads)
# out = grads[0].new(grads_size).zero_()
#offset = 0
#for g in grads:
# numel = g.numel()
# out[offset:offset+numel].copy_(g.view(-1))
# offset += numel
#return out[:offset]
return apex_C.flatten(grads)
def _set_flat_grads(self, new_grads):
grads = self._get_grads()
offset = 0
for g in grads:
numel = g.numel()
g.copy_(new_grads[offset:offset+numel].view_as(g))
offset += numel
def _opt(self):
# take an optimization step
self.optimizer.step()
self.zero_grad()
self._num_updates += 1
# update learning rate
self.lr_scheduler.step_update(self._num_updates)
def valid_step(self, sample):
"""Do forward pass in evaluation mode."""
# forward pass
sample = self._prepare_sample(sample)
_loss, sample_size, logging_output, oom_fwd = self._forward(sample, eval=True)
assert not oom_fwd, 'Ran out of memory during validation'
# gather logging outputs from all GPUs
if self.args.distributed_world_size > 1:
sample_sizes, logging_outputs = zip(*distributed_utils.all_gather_list(
(sample_size, logging_output)
))
else:
sample_sizes = [sample_size]
logging_outputs = [logging_output]
# aggregate stats and logging outputs
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
grad_denom = self.criterion.__class__.grad_denom(sample_sizes)
agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
# update loss meters for validation
if 'loss' in agg_logging_output:
self.meters['valid_loss'].update(agg_logging_output['loss'], grad_denom)
# criterions can optionally log the NLL loss too
if 'nll_loss' in agg_logging_output:
self.meters['valid_nll_loss'].update(agg_logging_output['nll_loss'], ntokens)
return agg_logging_output
def dummy_train_step(self, dummy_batch):
"""Dummy training step for warming caching allocator."""
self.train_step(dummy_batch, update_params=False)
self.zero_grad()
self.clear_buffered_stats()
def zero_grad(self):
self.optimizer.zero_grad()
def clear_buffered_stats(self):
self._buffered_stats.clear()
self._all_reduce_list = [0.0] * 6
def lr_step(self, epoch, val_loss=None):
"""Adjust the learning rate based on the validation loss."""
return self.lr_scheduler.step(epoch, val_loss)
def lr_step_update(self, num_updates):
"""Update the learning rate after each update."""
return self.lr_scheduler.step_update(num_updates)
def get_lr(self):
"""Get the current learning rate."""
return self.optimizer.get_lr()
def get_model(self):
"""Get the model replica."""
return self.model
def get_meter(self, name):
"""Get a specific meter by name."""
if name not in self.meters:
return None
return self.meters[name]
def get_num_updates(self):
"""Get the number of parameters updates."""
return self._num_updates
def _prepare_sample(self, sample):
if sample is None or len(sample) == 0:
return None
return utils.move_to_cuda(sample)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from collections import defaultdict, OrderedDict
import logging
import os
import re
import torch
import traceback
from torch.serialization import default_restore_location
def torch_persistent_save(*args, **kwargs):
for i in range(3):
try:
return torch.save(*args, **kwargs)
except Exception:
if i == 2:
logging.error(traceback.format_exc())
def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
if isinstance(state_dict, dict):
cpu_dict = OrderedDict()
for k, v in state_dict.items():
cpu_dict[k] = convert_state_dict_type(v)
return cpu_dict
elif isinstance(state_dict, list):
return [convert_state_dict_type(v) for v in state_dict]
elif torch.is_tensor(state_dict):
return state_dict.type(ttype)
else:
return state_dict
def save_state(filename, args, model, criterion, optimizer, lr_scheduler,
num_updates, optim_history=None, extra_state=None):
if optim_history is None:
optim_history = []
if extra_state is None:
extra_state = {}
state_dict = {
'args': args,
'model': convert_state_dict_type(model.state_dict()),
'optimizer_history': optim_history + [
{
'criterion_name': criterion.__class__.__name__,
'optimizer_name': optimizer.__class__.__name__,
'lr_scheduler_state': lr_scheduler.state_dict(),
'num_updates': num_updates,
}
],
'last_optimizer_state': convert_state_dict_type(optimizer.state_dict()),
'extra_state': extra_state,
}
torch_persistent_save(state_dict, filename)
def load_model_state(filename, model):
if not os.path.exists(filename):
return None, [], None
state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
state = _upgrade_state_dict(state)
model.upgrade_state_dict(state['model'])
# load model parameters
try:
model.load_state_dict(state['model'], strict=True)
except Exception:
raise Exception('Cannot load model parameters from checkpoint, '
'please ensure that the architectures match')
return state['extra_state'], state['optimizer_history'], state['last_optimizer_state']
def _upgrade_state_dict(state):
"""Helper for upgrading old model checkpoints."""
# add optimizer_history
if 'optimizer_history' not in state:
state['optimizer_history'] = [
{
'criterion_name': 'CrossEntropyCriterion',
'best_loss': state['best_loss'],
},
]
state['last_optimizer_state'] = state['optimizer']
del state['optimizer']
del state['best_loss']
# move extra_state into sub-dictionary
if 'epoch' in state and 'extra_state' not in state:
state['extra_state'] = {
'epoch': state['epoch'],
'batch_offset': state['batch_offset'],
'val_loss': state['val_loss'],
}
del state['epoch']
del state['batch_offset']
del state['val_loss']
# reduce optimizer history's memory usage (only keep the last state)
if 'optimizer' in state['optimizer_history'][-1]:
state['last_optimizer_state'] = state['optimizer_history'][-1]['optimizer']
for optim_hist in state['optimizer_history']:
del optim_hist['optimizer']
# record the optimizer class name
if 'optimizer_name' not in state['optimizer_history'][-1]:
state['optimizer_history'][-1]['optimizer_name'] = 'FairseqNAG'
# move best_loss into lr_scheduler_state
if 'lr_scheduler_state' not in state['optimizer_history'][-1]:
state['optimizer_history'][-1]['lr_scheduler_state'] = {
'best': state['optimizer_history'][-1]['best_loss'],
}
del state['optimizer_history'][-1]['best_loss']
# keep track of number of updates
if 'num_updates' not in state['optimizer_history'][-1]:
state['optimizer_history'][-1]['num_updates'] = 0
# old model checkpoints may not have separate source/target positions
if hasattr(state['args'], 'max_positions') and not hasattr(state['args'], 'max_source_positions'):
state['args'].max_source_positions = state['args'].max_positions
state['args'].max_target_positions = state['args'].max_positions
# use stateful training data iterator
if 'train_iterator' not in state['extra_state']:
state['extra_state']['train_iterator'] = {
'epoch': state['extra_state']['epoch'],
'iterations_in_epoch': state['extra_state'].get('batch_offset', 0),
}
return state
def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
"""Load an ensemble of models for inference.
model_arg_overrides allows you to pass a dictionary model_arg_overrides --
{'arg_name': arg} -- to override model args that were used during model
training
"""
# load model architectures and weights
states = []
for filename in filenames:
if not os.path.exists(filename):
raise IOError('Model file not found: {}'.format(filename))
state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
state = _upgrade_state_dict(state)
states.append(state)
ensemble = []
for state in states:
args = state['args']
if model_arg_overrides is not None:
args = _override_model_args(args, model_arg_overrides)
# build model for ensemble
model = task.build_model(args)
model.upgrade_state_dict(state['model'])
model.load_state_dict(state['model'], strict=True)
ensemble.append(model)
return ensemble, args
def _override_model_args(args, model_arg_overrides):
# Uses model_arg_overrides {'arg_name': arg} to override model args
for arg_name, arg_val in model_arg_overrides.items():
setattr(args, arg_name, arg_val)
return args
def move_to_cuda(sample):
if len(sample) == 0:
return {}
def _move_to_cuda(maybe_tensor):
if torch.is_tensor(maybe_tensor):
# Use non-blocking `.copy_()` rather than `.cuda()` to WAR bug 2895644
if maybe_tensor.is_cuda:
return maybe_tensor
else:
cuda_tensor = torch.empty_like(maybe_tensor, device='cuda')
return cuda_tensor.copy_(maybe_tensor, non_blocking=True)
elif isinstance(maybe_tensor, dict):
return {
key: _move_to_cuda(value)
for key, value in maybe_tensor.items()
}
elif isinstance(maybe_tensor, list):
return [_move_to_cuda(x) for x in maybe_tensor]
else:
return maybe_tensor
return _move_to_cuda(sample)
INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0)
def _get_full_incremental_state_key(module_instance, key):
module_name = module_instance.__class__.__name__
# assign a unique ID to each module instance, so that incremental state is
# not shared across module instances
if not hasattr(module_instance, '_fairseq_instance_id'):
INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1
module_instance._fairseq_instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name]
return '{}.{}.{}'.format(module_name, module_instance._fairseq_instance_id, key)
def get_incremental_state(module, incremental_state, key):
"""Helper for getting incremental state for an nn.Module."""
full_key = _get_full_incremental_state_key(module, key)
if incremental_state is None or full_key not in incremental_state:
return None
return incremental_state[full_key]
def set_incremental_state(module, incremental_state, key, value):
"""Helper for setting incremental state for an nn.Module."""
if incremental_state is not None:
full_key = _get_full_incremental_state_key(module, key)
incremental_state[full_key] = value
def load_align_dict(replace_unk):
if replace_unk is None:
align_dict = None
elif isinstance(replace_unk, str):
# Load alignment dictionary for unknown word replacement if it was passed as an argument.
align_dict = {}
with open(replace_unk, 'r') as f:
for line in f:
cols = line.split()
align_dict[cols[0]] = cols[1]
else:
# No alignment dictionary provided but we still want to perform unknown word replacement by copying the
# original source word.
align_dict = {}
return align_dict
def print_embed_overlap(embed_dict, vocab_dict):
embed_keys = set(embed_dict.keys())
vocab_keys = set(vocab_dict.symbols)
overlap = len(embed_keys & vocab_keys)
print("| Found {}/{} types in embedding file.".format(overlap, len(vocab_dict)))
def parse_embedding(embed_path):
"""Parse embedding text file into a dictionary of word and embedding tensors.
The first line can have vocabulary size and dimension. The following lines
should contain word and embedding separated by spaces.
Example:
2 5
the -0.0230 -0.0264 0.0287 0.0171 0.1403
at -0.0395 -0.1286 0.0275 0.0254 -0.0932
"""
embed_dict = {}
with open(embed_path) as f_embed:
next(f_embed) # skip header
for line in f_embed:
pieces = line.rstrip().split(" ")
embed_dict[pieces[0]] = torch.Tensor([float(weight) for weight in pieces[1:]])
return embed_dict
def load_embedding(embed_dict, vocab, embedding):
for idx in range(len(vocab)):
token = vocab[idx]
if token in embed_dict:
embedding.weight.data[idx] = embed_dict[token]
return embedding
def replace_unk(hypo_str, src_str, alignment, align_dict, unk):
from fairseq import tokenizer
# Tokens are strings here
hypo_tokens = tokenizer.tokenize_line(hypo_str)
# TODO: Very rare cases where the replacement is '<eos>' should be handled gracefully
src_tokens = tokenizer.tokenize_line(src_str) + ['<eos>']
for i, ht in enumerate(hypo_tokens):
if ht == unk:
src_token = src_tokens[alignment[i]]
# Either take the corresponding value in the aligned dictionary or just copy the original value.
hypo_tokens[i] = align_dict.get(src_token, src_token)
return ' '.join(hypo_tokens)
def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, tgt_dict, remove_bpe, lowercase=False):
from fairseq import tokenizer
hypo_str = tgt_dict.string(hypo_tokens, remove_bpe)
assert not align_dict
return hypo_tokens, hypo_str, alignment
#if align_dict is not None:
# hypo_str = replace_unk(hypo_str, src_str, alignment, align_dict, tgt_dict.unk_string())
#if align_dict is not None or remove_bpe is not None:
# Convert back to tokens for evaluating with unk replacement or without BPE
# Note that the dictionary can be modified inside the method.
# hypo_tokens = tokenizer.Tokenizer.tokenize(hypo_str, tgt_dict, add_if_not_exist=True, lowercase=lowercase)
def make_positions(tensor, padding_idx, left_pad):
"""Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1.
Padding symbols are ignored, but it is necessary to specify whether padding
is added on the left side (left_pad=True) or right side (left_pad=False).
"""
"""
max_pos = padding_idx + 1 + tensor.size(1)
if not hasattr(make_positions, 'range_buf'):
make_positions.range_buf = tensor.new()
make_positions.range_buf = make_positions.range_buf.type_as(tensor)
if make_positions.range_buf.numel() < max_pos:
torch.arange(padding_idx + 1, max_pos, out=make_positions.range_buf)
mask = tensor.ne(padding_idx)
positions = make_positions.range_buf[:tensor.size(1)].expand_as(tensor)
if left_pad:
positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
return tensor.clone().masked_scatter_(mask, positions[mask])
"""
mask = tensor.ne(padding_idx).long()
return torch.cumsum(mask, dim=1) * mask + padding_idx
def strip_pad(tensor, pad):
return tensor[tensor.ne(pad)]
def buffered_arange(max):
if not hasattr(buffered_arange, 'buf'):
buffered_arange.buf = torch.LongTensor()
if max > buffered_arange.buf.numel():
torch.arange(max, out=buffered_arange.buf)
return buffered_arange.buf[:max]
def convert_padding_direction(src_tokens, padding_idx, right_to_left=False, left_to_right=False):
assert right_to_left ^ left_to_right
pad_mask = src_tokens.eq(padding_idx)
if not pad_mask.any():
# no padding, return early
return src_tokens
if left_to_right and not pad_mask[:, 0].any():
# already right padded
return src_tokens
if right_to_left and not pad_mask[:, -1].any():
# already left padded
return src_tokens
max_len = src_tokens.size(1)
range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens)
num_pads = pad_mask.long().sum(dim=1, keepdim=True)
if right_to_left:
index = torch.remainder(range - num_pads, max_len)
else:
index = torch.remainder(range + num_pads, max_len)
return src_tokens.gather(1, index)
def item(tensor):
if hasattr(tensor, 'item'):
return tensor.item()
if hasattr(tensor, '__getitem__'):
return tensor[0]
return tensor
def clip_grad_norm_(tensor, max_norm):
grad_norm = item(torch.norm(tensor))
if grad_norm > max_norm > 0:
clip_coef = max_norm / (grad_norm + 1e-6)
tensor.mul_(clip_coef)
return grad_norm
def fill_with_neg_inf(t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(float('-inf')).type_as(t)
def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
"""Retrieves all checkpoints found in `path` directory.
Checkpoints are identified by matching filename to the specified pattern. If
the pattern contains groups, the result will be sorted by the first group in
descending order.
"""
pt_regexp = re.compile(pattern)
files = os.listdir(path)
entries = []
for i, f in enumerate(files):
m = pt_regexp.fullmatch(f)
if m is not None:
idx = int(m.group(1)) if len(m.groups()) > 0 else i
entries.append((idx, m.group(0)))
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
from fairseq import bleu, data, options, progress_bar, tasks, tokenizer, utils
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_generator import SequenceGenerator
from fairseq.sequence_scorer import SequenceScorer
def main(args):
assert args.path is not None, '--path required for generation!'
assert not args.sampling or args.nbest == args.beam, \
'--sampling requires --nbest to be equal to --beam'
assert args.replace_unk is None or args.raw_text, \
'--replace-unk requires a raw text dataset (--raw-text)'
if args.max_tokens is None and args.max_sentences is None:
args.max_tokens = 12000
print(args)
use_cuda = torch.cuda.is_available() and not args.cpu
# Load dataset splits
task = tasks.setup_task(args)
task.load_dataset(args.gen_subset)
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))
# Set dictionaries
src_dict = task.source_dictionary
tgt_dict = task.target_dictionary
# Load ensemble
print('| loading model(s) from {}'.format(args.path))
models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides))
# Optimize ensemble for generation
for model in models:
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
need_attn=args.print_alignment,
)
if args.fp16:
model.half()
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(args.replace_unk)
# Load dataset (possibly sharded)
itr = data.EpochBatchIterator(
dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=models[0].max_positions(),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=8,
num_shards=args.num_shards,
shard_id=args.shard_id,
).next_epoch_itr(shuffle=False)
# Initialize generator
gen_timer = StopwatchMeter()
if args.score_reference:
translator = SequenceScorer(models, task.target_dictionary)
else:
translator = SequenceGenerator(
models, task.target_dictionary, beam_size=args.beam,
stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
len_penalty=args.lenpen, unk_penalty=args.unkpen,
sampling=args.sampling, sampling_topk=args.sampling_topk, minlen=args.min_len,
)
if use_cuda:
translator.cuda()
# Generate and compute BLEU score
scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
num_sentences = 0
has_target = True
with progress_bar.build_progress_bar(args, itr) as t:
if args.score_reference:
translations = translator.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
else:
translations = translator.generate_batched_itr(
t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size,
)
wps_meter = TimeMeter()
for sample_id, src_tokens, target_tokens, hypos in translations:
# Process input and ground truth
has_target = target_tokens is not None
target_tokens = target_tokens.int().cpu() if has_target else None
# Either retrieve the original sentences or regenerate them from tokens.
if align_dict is not None:
src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
else:
src_str = src_dict.string(src_tokens, args.remove_bpe)
if has_target:
target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)
if not args.quiet:
print('S-{}\t{}'.format(sample_id, src_str))
if has_target:
print('T-{}\t{}'.format(sample_id, target_str))
# Process top predictions
for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str,
alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
align_dict=align_dict,
tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe,
)
if not args.quiet:
print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
print('P-{}\t{}'.format(
sample_id,
' '.join(map(
lambda x: '{:.4f}'.format(x),
hypo['positional_scores'].tolist(),
))
))
if args.print_alignment:
print('A-{}\t{}'.format(
sample_id,
' '.join(map(lambda x: str(utils.item(x)), alignment))
))
# Score only the top hypothesis
if has_target and i == 0:
if align_dict is not None or args.remove_bpe is not None:
# Convert back to tokens for evaluation with unk replacement and/or without BPE
target_tokens = tokenizer.Tokenizer.tokenize(
target_str, tgt_dict, add_if_not_exist=True)
scorer.add(target_tokens, hypo_tokens)
wps_meter.update(src_tokens.size(0))
t.log({'wps': round(wps_meter.avg)})
num_sentences += 1
print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
if has_target:
print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
if __name__ == '__main__':
parser = options.get_generation_parser()
args = options.parse_args_and_arch(parser)
main(args)
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from collections import namedtuple
import numpy as np
import sys
import torch
from fairseq import data, options, tasks, tokenizer, utils
from fairseq.sequence_generator import SequenceGenerator
Batch = namedtuple('Batch', 'srcs tokens lengths')
Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')
def buffered_read(buffer_size):
buffer = []
for src_str in sys.stdin:
buffer.append(src_str.strip())
if len(buffer) >= buffer_size:
yield buffer
buffer = []
if len(buffer) > 0:
yield buffer
def make_batches(lines, args, src_dict, max_positions):
tokens = [
tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long()
for src_str in lines
]
lengths = np.array([t.numel() for t in tokens])
itr = data.EpochBatchIterator(
dataset=data.LanguagePairDataset(tokens, lengths, src_dict),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=max_positions,
).next_epoch_itr(shuffle=False)
for batch in itr:
yield Batch(
srcs=[lines[i] for i in batch['id']],
tokens=batch['net_input']['src_tokens'],
lengths=batch['net_input']['src_lengths'],
), batch['id']
def main(args):
if args.buffer_size < 1:
args.buffer_size = 1
if args.max_tokens is None and args.max_sentences is None:
args.max_sentences = 1
assert not args.sampling or args.nbest == args.beam, \
'--sampling requires --nbest to be equal to --beam'
assert not args.max_sentences or args.max_sentences <= args.buffer_size, \
'--max-sentences/--batch-size cannot be larger than --buffer-size'
print(args)
use_cuda = torch.cuda.is_available() and not args.cpu
# Setup task, e.g., translation
task = tasks.setup_task(args)
# Load ensemble
print('| loading model(s) from {}'.format(args.path))
model_paths = args.path.split(':')
models, model_args = utils.load_ensemble_for_inference(model_paths, task, model_arg_overrides=eval(args.model_overrides))
# Set dictionaries
src_dict = task.source_dictionary
tgt_dict = task.target_dictionary
# Optimize ensemble for generation
for model in models:
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
need_attn=args.print_alignment,
)
if args.fp16:
model.half()
# Initialize generator
translator = SequenceGenerator(
models, tgt_dict, beam_size=args.beam, stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
unk_penalty=args.unkpen, sampling=args.sampling, sampling_topk=args.sampling_topk,
minlen=args.min_len, sampling_temperature=args.sampling_temperature
)
if use_cuda:
translator.cuda()
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(args.replace_unk)
def make_result(src_str, hypos):
result = Translation(
src_str='O\t{}'.format(src_str),
hypos=[],
pos_scores=[],
alignments=[],
)
# Process top predictions
for hypo in hypos[:min(len(hypos), args.nbest)]:
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str,
alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
align_dict=align_dict,
tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe,
)
result.hypos.append('H\t{}\t{}'.format(hypo['score'], hypo_str))
result.pos_scores.append('P\t{}'.format(
' '.join(map(
lambda x: '{:.4f}'.format(x),
hypo['positional_scores'].tolist(),
))
))
result.alignments.append(
'A\t{}'.format(' '.join(map(lambda x: str(utils.item(x)), alignment)))
if args.print_alignment else None
)
return result
def process_batch(batch):
tokens = batch.tokens
lengths = batch.lengths
if use_cuda:
tokens = tokens.cuda()
lengths = lengths.cuda()
translations = translator.generate(
tokens,
lengths,
maxlen=int(args.max_len_a * tokens.size(1) + args.max_len_b),
)
return [make_result(batch.srcs[i], t) for i, t in enumerate(translations)]
if args.buffer_size > 1:
print('| Sentence buffer size:', args.buffer_size)
print('| Type the input sentence and press return:')
for inputs in buffered_read(args.buffer_size):
indices = []
results = []
for batch, batch_indices in make_batches(inputs, args, src_dict, models[0].max_positions()):
indices.extend(batch_indices)
results += process_batch(batch)
for i in np.argsort(indices):
result = results[i]
print(result.src_str)
for hypo, pos_scores, align in zip(result.hypos, result.pos_scores, result.alignments):
print(hypo)
print(pos_scores)
if align is not None:
print(align)
if __name__ == '__main__':
parser = options.get_generation_parser(interactive=True)
args = options.parse_args_and_arch(parser)
main(args)
# Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import ctypes
import logging.config
import os
import random
import subprocess
import sys
import time
from contextlib import contextmanager
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.init as init
import torch.utils.collect_env
from mlperf_logging.mllog import constants
from mlperf_logging import mllog
mllogger = mllog.get_mllogger()
def log_start(*args, **kwargs):
_log_print(mllogger.start, *args, **kwargs)
def log_end(*args, **kwargs):
_log_print(mllogger.end, *args, **kwargs)
def log_event(*args, **kwargs):
_log_print(mllogger.event, *args, **kwargs)
def _log_print(logger, *args, **kwargs):
"""
Wrapper for MLPerf compliance logging calls.
All arguments but 'sync' and 'log_all_ranks' are passed to
mlperf_logging.mllog.
If 'sync' is set to True then the wrapper will synchronize all distributed
workers. 'sync' should be set to True for all compliance tags that require
accurate timing (RUN_START, RUN_STOP etc.)
If 'log_all_ranks' is set to True then all distributed workers will print
logging message, if set to False then only worker with rank=0 will print
the message.
"""
if kwargs.pop('sync', False):
barrier()
if 'stack_offset' not in kwargs:
kwargs['stack_offset'] = 3
if 'value' not in kwargs:
kwargs['value'] = None
if kwargs.pop('log_all_ranks', False):
log = True
else:
log = (get_rank() == 0)
if log:
logger(*args, **kwargs)
def mlperf_submission_log(benchmark):
required_dist_init = ['RANK', 'WORLD_SIZE', 'MASTER_ADDR', 'MASTER_PORT']
if all(var in os.environ for var in required_dist_init):
torch.distributed.init_process_group(backend='nccl', init_method='env://')
num_nodes = os.environ.get('SLURM_NNODES', 1)
mllog.config(filename=os.path.join(os.path.dirname(os.path.abspath(__file__)), 'transformer.log'))
mllogger = mllog.get_mllogger()
mllogger.logger.propagate = False
log_event(
key=constants.SUBMISSION_BENCHMARK,
value=benchmark,
)
log_event(
key=constants.SUBMISSION_ORG,
value='NVIDIA')
log_event(
key=constants.SUBMISSION_DIVISION,
value='closed')
log_event(
key=constants.SUBMISSION_STATUS,
value='onprem')
log_event(
key=constants.SUBMISSION_PLATFORM,
value=f'{num_nodes}xSUBMISSION_PLATFORM_PLACEHOLDER')
def barrier():
"""
Works as a temporary distributed barrier, currently pytorch
doesn't implement barrier for NCCL backend.
Calls all_reduce on dummy tensor and synchronizes with GPU.
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.all_reduce(torch.cuda.FloatTensor(1))
torch.cuda.synchronize()
def get_rank():
"""
Gets distributed rank or returns zero if distributed is not initialized.
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
else:
rank = 0
return rank
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
import random
import signal
import torch
from fairseq import distributed_utils, options
from train import main as single_process_main
def main(args):
# Set distributed training parameters for a single node.
args.distributed_world_size = torch.cuda.device_count()
args.distributed_init_method = 'tcp://localhost:{port}'.format(
port=random.randint(10000, 20000))
mp = torch.multiprocessing.get_context('spawn')
# Create a thread to listen for errors in the child processes.
error_queue = mp.SimpleQueue()
error_handler = ErrorHandler(error_queue)
# Train with multiprocessing.
procs = []
for i in range(args.distributed_world_size):
args.distributed_rank = i
args.device_id = i
procs.append(mp.Process(target=run, args=(args, error_queue, ), daemon=True))
procs[i].start()
error_handler.add_child(procs[i].pid)
for p in procs:
p.join()
def run(args, error_queue):
try:
args.distributed_rank = distributed_utils.distributed_init(args)
single_process_main(args)
except KeyboardInterrupt:
pass # killed by parent, do nothing
except Exception:
# propagate exception to parent process, keeping original traceback
import traceback
error_queue.put((args.distributed_rank, traceback.format_exc()))
class ErrorHandler(object):
"""A class that listens for exceptions in children processes and propagates
the tracebacks to the parent process."""
def __init__(self, error_queue):
import signal
import threading
self.error_queue = error_queue
self.children_pids = []
self.error_thread = threading.Thread(target=self.error_listener, daemon=True)
self.error_thread.start()
signal.signal(signal.SIGUSR1, self.signal_handler)
def add_child(self, pid):
self.children_pids.append(pid)
def error_listener(self):
(rank, original_trace) = self.error_queue.get()
self.error_queue.put((rank, original_trace))
os.kill(os.getpid(), signal.SIGUSR1)
def signal_handler(self, signalnum, stackframe):
for pid in self.children_pids:
os.kill(pid, signal.SIGINT) # kill children processes
(rank, original_trace) = self.error_queue.get()
msg = "\n\n-- Tracebacks above this line can probably be ignored --\n\n"
msg += original_trace
raise Exception(msg)
if __name__ == '__main__':
parser = options.get_training_parser()
args = options.parse_args_and_arch(parser)
main(args)
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
# Copyright 2018 MLBenchmark Group. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Download and preprocess WMT17 ende training and evaluation datasets."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import random
import sys
import tarfile
import urllib
import six
import urllib.request
from utils import tokenizer
# Data sources for training/evaluating the transformer translation model.
# If any of the training sources are changed, then either:
# 1) use the flag `--search` to find the best min count or
# 2) update the _TRAIN_DATA_MIN_COUNT constant.
# min_count is the minimum number of times a token must appear in the data
# before it is added to the vocabulary. "Best min count" refers to the value
# that generates a vocabulary set that is closest in size to _TARGET_VOCAB_SIZE.
_TRAIN_DATA_SOURCES = [
{
"url": "http://data.statmt.org/wmt17/translation-task/"
"training-parallel-nc-v12.tgz",
"input": "news-commentary-v12.de-en.en",
"target": "news-commentary-v12.de-en.de",
},
{
"url": "http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz",
"input": "commoncrawl.de-en.en",
"target": "commoncrawl.de-en.de",
},
{
"url": "http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz",
"input": "europarl-v7.de-en.en",
"target": "europarl-v7.de-en.de",
},
]
# Use pre-defined minimum count to generate subtoken vocabulary.
_TRAIN_DATA_MIN_COUNT = 6
_EVAL_DATA_SOURCES = [
{
"url": "http://data.statmt.org/wmt17/translation-task/dev.tgz",
"input": "newstest2013.en",
"target": "newstest2013.de",
}
]
# Vocabulary constants
_TARGET_VOCAB_SIZE = 32768 # Number of subtokens in the vocabulary list.
_TARGET_THRESHOLD = 327 # Accept vocabulary if size is within this threshold
_VOCAB_FILE = "dict.en.txt"
# Strings to inclue in the generated files.
_PREFIX = "wmt32k"
_COMPILE_TAG = "compiled"
_ENCODE_TAG = ""
_TRAIN_TAG = "train"
_EVAL_TAG = "dev" # Following WMT and Tensor2Tensor conventions, in which the
# evaluation datasets are tagged as "dev" for development.
_TEST_TAG = "test"
# Number of files to split train and evaluation data
_TRAIN_SHARDS = 1
_EVAL_SHARDS = 1
_TEST_SHARDS = 1
def find_file(path, filename, max_depth=5):
"""Returns full filepath if the file is in path or a subdirectory."""
for root, dirs, files in os.walk(path):
if filename in files:
return os.path.join(root, filename)
# Don't search past max_depth
depth = root[len(path) + 1:].count(os.sep)
if depth > max_depth:
del dirs[:] # Clear dirs
return None
###############################################################################
# Download and extraction functions
###############################################################################
def get_raw_files(raw_dir, data_source):
"""Return raw files from source. Downloads/extracts if needed.
Args:
raw_dir: string directory to store raw files
data_source: dictionary with
{"url": url of compressed dataset containing input and target files
"input": file with data in input language
"target": file with data in target language}
Returns:
dictionary with
{"inputs": list of files containing data in input language
"targets": list of files containing corresponding data in target language
}
"""
raw_files = {
"inputs": [],
"targets": [],
} # keys
for d in data_source:
input_file, target_file = download_and_extract(
raw_dir, d["url"], d["input"], d["target"])
raw_files["inputs"].append(input_file)
raw_files["targets"].append(target_file)
return raw_files
def download_report_hook(count, block_size, total_size):
"""Report hook for download progress.
Args:
count: current block number
block_size: block size
total_size: total size
"""
percent = int(count * block_size * 100 / total_size)
print("\r%d%%" % percent + " completed", end="\r")
def download_from_url(path, url):
"""Download content from a url.
Args:
path: string directory where file will be downloaded
url: string url
Returns:
Full path to downloaded file
"""
filename = url.split("/")[-1]
found_file = find_file(path, filename, max_depth=0)
if found_file is None:
filename = os.path.join(path, filename)
print("Downloading from %s to %s." % (url, filename))
inprogress_filepath = filename + ".incomplete"
inprogress_filepath, _ = urllib.request.urlretrieve(
url, inprogress_filepath, reporthook=download_report_hook)
# Print newline to clear the carriage return from the download progress.
print()
os.rename(inprogress_filepath, filename)
return filename
else:
print("Already downloaded: %s (at %s)." % (url, found_file))
return found_file
def download_and_extract(path, url, input_filename, target_filename):
"""Extract files from downloaded compressed archive file.
Args:
path: string directory where the files will be downloaded
url: url containing the compressed input and target files
input_filename: name of file containing data in source language
target_filename: name of file containing data in target language
Returns:
Full paths to extracted input and target files.
Raises:
OSError: if the the download/extraction fails.
"""
# Check if extracted files already exist in path
input_file = find_file(path, input_filename)
target_file = find_file(path, target_filename)
if input_file and target_file:
print("Already downloaded and extracted %s." % url)
return input_file, target_file
# Download archive file if it doesn't already exist.
compressed_file = download_from_url(path, url)
# Extract compressed files
print("Extracting %s." % compressed_file)
with tarfile.open(compressed_file, "r:gz") as corpus_tar:
corpus_tar.extractall(path)
# Return filepaths of the requested files.
input_file = find_file(path, input_filename)
target_file = find_file(path, target_filename)
if input_file and target_file:
return input_file, target_file
raise OSError("Download/extraction failed for url %s to path %s" %
(url, path))
def txt_line_iterator(path):
"""Iterate through lines of file."""
with open(path, mode='r', newline='\n') as f:
for line in f:
yield line.strip()
def compile_files(data_dir, raw_files, tag):
"""Compile raw files into a single file for each language.
Args:
raw_dir: Directory containing downloaded raw files.
raw_files: Dict containing filenames of input and target data.
{"inputs": list of files containing data in input language
"targets": list of files containing corresponding data in target language
}
tag: String to append to the compiled filename.
Returns:
Full path of compiled input and target files.
"""
print("Compiling files with tag %s." % tag)
filename = "%s-%s-%s" % (_PREFIX, _COMPILE_TAG, tag)
input_compiled_file = os.path.join(data_dir, filename + ".lang1")
target_compiled_file = os.path.join(data_dir, filename + ".lang2")
with open(input_compiled_file, mode="w", newline='\n') as input_writer:
with open(target_compiled_file, mode="w", newline='\n') as target_writer:
for i in range(len(raw_files["inputs"])):
input_file = raw_files["inputs"][i]
target_file = raw_files["targets"][i]
print("Reading files %s and %s." % (input_file, target_file))
write_file(input_writer, input_file)
write_file(target_writer, target_file)
return input_compiled_file, target_compiled_file
def write_file(writer, filename):
"""Write all of lines from file using the writer."""
for line in txt_line_iterator(filename):
writer.write(line)
writer.write("\n")
###############################################################################
# Data preprocessing
###############################################################################
def encode_and_save_files_utf8(subtokenizer, data_dir, raw_files, tag, total_shards):
"""Save data from files as encoded example pairs in UT8 format.
Args:
subtokenizer: Subtokenizer object that will be used to encode the strings.
data_dir: The directory in which to write the examples
raw_files: A tuple of (input, target) data files. Each line in the input and
the corresponding line in target file will be saved in encoded format (vocab integer ids).
tag: String that will be added onto the file names.
total_shards: Number of files to divide the data into.
Returns:
List of all files produced.
"""
# Create a file for each shard.
if total_shards > 1:
filepaths = [shard_filename(data_dir + '/utf8', tag, n + 1, total_shards) for n in range(total_shards)]
else:
filepaths = [shard_filename_single(data_dir + '/utf8', tag) for n in range(total_shards)]
if all_exist(filepaths):
print("Files with tag %s already exist." % tag)
return filepaths
print("Saving files with tag %s." % tag)
input_file = raw_files[0]
target_file = raw_files[1]
# Write examples to each shard in round robin order.
tmp_filepaths = [fname + ".incomplete" for fname in filepaths]
src_writers = [open(fname + '.en', mode='w', newline='\n') for fname in tmp_filepaths]
dst_writers = [open(fname + '.de', mode='w', newline='\n') for fname in tmp_filepaths]
counter, shard_idx = 0, 0
src_sum, tgt_sum = 0, 0
for counter, (input_line, target_line) in enumerate(zip(txt_line_iterator(input_file), txt_line_iterator(target_file))):
if counter > 0 and counter % 100000 == 0:
print("\tSaving case %d." % counter)
src_sum += len(input_line.split(' '))
tgt_sum += len(target_line.split(' '))
#print('src avg:', src_sum / (counter + 1))
#print('tgt avg:', tgt_sum / (counter + 1))
src_writers[shard_idx].write(' '.join([str(idx) for idx in subtokenizer.encode(input_line.strip(), add_eos=True)]))
dst_writers[shard_idx].write(' '.join([str(idx) for idx in subtokenizer.encode(target_line.strip(), add_eos=True)]))
src_writers[shard_idx].write('\n')
dst_writers[shard_idx].write('\n')
shard_idx = (shard_idx + 1) % total_shards
for writer in src_writers:
writer.close()
for writer in dst_writers:
writer.close()
for tmp_name, final_name in zip(tmp_filepaths, filepaths):
src_tmp_name = tmp_name + '.en'
dst_tmp_name = tmp_name + '.de'
src_fnl_name = final_name + '.en'
dst_fnl_name = final_name + '.de'
os.rename(src_tmp_name, src_fnl_name)
os.rename(dst_tmp_name, dst_fnl_name)
print("Saved %d Examples", counter)
joined_file_list = []
for file in filepaths:
joined_file_list.append(file + '.en')
joined_file_list.append(file + '.de')
return joined_file_list
def shard_filename_single(path, tag):
"""Create filename for data shard."""
return os.path.join(path, "%s.en-de" % (tag))
def shard_filename(path, tag, shard_num, total_shards):
"""Create filename for data shard."""
return os.path.join(path, "%s%d.en-de" % (tag, shard_num))
def shuffle_records(fname):
"""Shuffle records in a single file."""
print('Shuffling records.')
src = 'en-de.en'
dst = 'en-de.de'
file_sets = {
'train' : {src : [], dst : []},
'dev' : {src : [], dst : []},
'test' : {src : [], dst : []}
}
for f in fname:
tmp_fname = f + ".unshuffled"
os.rename(f, tmp_fname)
for split in file_sets:
if split in f:
print('fname:', tmp_fname)
if src in f:
file_sets[split][src].append(tmp_fname)
else:
file_sets[split][dst].append(tmp_fname)
for split in file_sets:
if file_sets[split][src]:
file_sets[split][src].sort
if file_sets[split][dst]:
file_sets[split][dst].sort
print('file_set:', file_sets)
for split in file_sets:
print('split:', split)
if file_sets[split][src] and file_sets[split][dst]:
for fs, fd in zip(file_sets[split][src], file_sets[split][dst]):
print('zipped fs, fd')
with open(fs, mode='r', newline='\n') as src_reader:
with open(fd, mode='r', newline='\n') as tgt_reader:
print('opened for write:', fs, fd)
records = []
for rs, rd in zip(src_reader, tgt_reader):
records.append((rs, rd))
if len(records) % 100000 == 0:
print("\t Read : %d", len(records))
print('performing shuffle')
random.shuffle(records)
with open(fs[:-11], mode='w', newline='\n') as src_writer:
with open(fd[:-11], mode='w', newline='\n') as tgt_writer:
print('opened for write:', fs, fd)
for i, record in enumerate(records):
src_writer.write(record[0])
tgt_writer.write(record[1])
if i > 0 and i % 100000 == 0:
print("\tWriting record: %d" % i)
os.remove(fs)
os.remove(fd)
def all_exist(filepaths):
"""Returns true if all files in the list exist."""
for fname in filepaths:
if not os.path.exists(fname):
return False
return True
def make_dir(path):
if not os.path.exists(path):
print("Creating directory %s" % path)
os.mkdir(path)
def main(unused_argv):
"""Obtain training and evaluation data for the Transformer model."""
make_dir(FLAGS.raw_dir)
make_dir(FLAGS.data_dir)
# Get paths of download/extracted training and evaluation files.
print("Step 1/4: Downloading data from source")
train_files = get_raw_files(FLAGS.raw_dir, _TRAIN_DATA_SOURCES)
eval_files = get_raw_files(FLAGS.raw_dir, _EVAL_DATA_SOURCES)
test_files = ['/workspace/translation/newstest2014.en', '/workspace/translation/newstest2014.de']
# Create subtokenizer based on the training files.
print("Step 2/4: Creating subtokenizer and building vocabulary")
train_files_flat = train_files["inputs"] + train_files["targets"]
vocab_file = os.path.join(FLAGS.data_dir, _VOCAB_FILE)
subtokenizer = tokenizer.Subtokenizer.init_from_existing_vocab_file(
vocab_file, train_files_flat, _TARGET_VOCAB_SIZE, _TARGET_THRESHOLD,
min_count=None if FLAGS.search else _TRAIN_DATA_MIN_COUNT, reserved_tokens='assumed_in_file')
print("Step 3/4: Compiling training and evaluation data")
compiled_train_files = compile_files(FLAGS.data_dir, train_files, _TRAIN_TAG)
compiled_eval_files = compile_files(FLAGS.data_dir, eval_files, _EVAL_TAG)
compiled_test_files = test_files
# Tokenize and save data as example pairs in UTF8 format.
print("Step 4/4: Preprocessing and saving data")
train_utf8_files = encode_and_save_files_utf8(subtokenizer, FLAGS.data_dir, compiled_train_files, _TRAIN_TAG, _TRAIN_SHARDS)
encode_and_save_files_utf8(subtokenizer, FLAGS.data_dir, compiled_eval_files, _EVAL_TAG, _EVAL_SHARDS)
encode_and_save_files_utf8(subtokenizer, FLAGS.data_dir, compiled_test_files, _TEST_TAG, _TEST_SHARDS)
shuffle_records(train_utf8_files)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_dir", "-dd", type=str, default="/tmp/translate_ende",
help="[default: %(default)s] Directory for where the "
"translate_ende_wmt32k dataset is saved.",
metavar="<DD>")
parser.add_argument(
"--raw_dir", "-rd", type=str, default="/tmp/translate_ende_raw",
help="[default: %(default)s] Path where the raw data will be downloaded "
"and extracted.",
metavar="<RD>")
parser.add_argument(
"--search", action="store_true",
help="If set, use binary search to find the vocabulary set with size"
"closest to the target size (%d)." % _TARGET_VOCAB_SIZE)
FLAGS, unparsed = parser.parse_known_args()
main(sys.argv)
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import argparse
from itertools import zip_longest
import os
import shutil
from fairseq.data import indexed_dataset, dictionary
from fairseq.tokenizer import Tokenizer, tokenize_line
from mlperf_log_utils import mlperf_print, mlperf_submission_log, set_seeds, get_rank
def get_parser():
parser = argparse.ArgumentParser(description='Data pre-processing: Create dictionary and store data in binary format')
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', help='source language')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', help='target language')
parser.add_argument('--trainpref', metavar='FP', default=None, help='train file prefix')
parser.add_argument('--validpref', metavar='FP', default=None, help='comma separated, valid file prefixes')
parser.add_argument('--testpref', metavar='FP', default=None, help='comma separated, test file prefixes')
parser.add_argument('--destdir', metavar='DIR', default='data-bin', help='destination dir')
parser.add_argument('--thresholdtgt', metavar='N', default=0, type=int, help='map words appearing less than threshold times to unknown')
parser.add_argument('--thresholdsrc', metavar='N', default=0, type=int, help='map words appearing less than threshold times to unknown')
parser.add_argument('--tgtdict', metavar='FP', help='reuse given target dictionary')
parser.add_argument('--srcdict', metavar='FP', help='reuse given source dictionary')
parser.add_argument('--nwordstgt', metavar='N', default=-1, type=int, help='number of target words to retain')
parser.add_argument('--nwordssrc', metavar='N', default=-1, type=int, help='number of source words to retain')
parser.add_argument('--alignfile', metavar='ALIGN', default=None, help='an alignment file (optional)')
parser.add_argument('--output-format', metavar='FORMAT', default='binary', choices=['binary', 'raw'], help='output format (optional)')
parser.add_argument('--joined-dictionary', action='store_true', help='Generate joined dictionary')
parser.add_argument('--only-source', action='store_true', help='Only process the source language')
parser.add_argument('--padding-factor', metavar='N', default=8, type=int, help='Pad dictionary size to be multiple of N')
return parser
def main(args):
def build_dictionary(filenames):
d = dictionary.Dictionary()
for filename in filenames:
Tokenizer.add_file_to_dictionary(filename, d, tokenize_line)
return d
def train_path(lang):
return '{}{}'.format(args.trainpref, ('.' + lang) if lang else '')
def file_name(prefix, lang):
fname = prefix
if lang is not None:
fname += f'.{lang}'
return fname
def dest_path(prefix, lang):
return os.path.join(args.destdir, file_name(prefix, lang))
def dict_path(lang):
return dest_path('dict', lang) + '.txt'
def dataset_dest_path(output_prefix, lang, extension):
base = f'{args.destdir}/{output_prefix}'
lang_part = f'.{args.source_lang}-{args.target_lang}.{lang}' if lang is not None else ''
return f'{base}{lang_part}.{extension}'
def make_binary_dataset(input_prefix, output_prefix, lang):
dict = dictionary.Dictionary.load(dict_path(lang))
print('| [{}] Dictionary: {} types'.format(lang, len(dict) - 1))
ds = indexed_dataset.IndexedDatasetBuilder(dataset_dest_path(output_prefix, lang, 'bin'))
def consumer(tensor):
ds.add_item(tensor)
input_file = '{}{}'.format(input_prefix, ('.' + lang) if lang is not None else '')
res = Tokenizer.binarize(input_file, dict, consumer)
print('| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'.format(
lang, input_file, res['nseq'], res['ntok'], 100 * res['nunk'] / res['ntok'], dict.unk_word))
ds.finalize(dataset_dest_path(output_prefix, lang, 'idx'))
def make_dataset(input_prefix, output_prefix, lang):
if args.output_format == 'binary':
make_binary_dataset(input_prefix, output_prefix, lang)
elif args.output_format == 'raw':
# Copy original text file to destination folder
output_text_file = dest_path(output_prefix + '.{}-{}'.format(args.source_lang, args.target_lang), lang,)
shutil.copyfile(file_name(input_prefix, lang), output_text_file)
def make_all(lang):
if args.trainpref:
make_dataset(args.trainpref, 'train', lang)
if args.validpref:
for k, validpref in enumerate(args.validpref.split(',')):
outprefix = 'valid{}'.format(k) if k > 0 else 'valid'
make_dataset(validpref, outprefix, lang)
if args.testpref:
for k, testpref in enumerate(args.testpref.split(',')):
outprefix = 'test{}'.format(k) if k > 0 else 'test'
make_dataset(testpref, outprefix, lang)
# main() code starts here
print(args)
os.makedirs(args.destdir, exist_ok=True)
target = not args.only_source
if args.joined_dictionary:
assert not args.srcdict, 'cannot combine --srcdict and --joined-dictionary'
assert not args.tgtdict, 'cannot combine --tgtdict and --joined-dictionary'
src_dict = build_dictionary(set([train_path(lang) for lang in [args.source_lang, args.target_lang]]))
tgt_dict = src_dict
else:
if args.srcdict:
src_dict = dictionary.Dictionary.load(args.srcdict)
else:
assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
src_dict = build_dictionary([train_path(args.source_lang)])
if target:
if args.tgtdict:
tgt_dict = dictionary.Dictionary.load(args.tgtdict)
else:
assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
tgt_dict = build_dictionary([train_path(args.target_lang)])
src_dict.finalize(threshold=args.thresholdsrc, nwords=args.nwordssrc, padding_factor=args.padding_factor, )
src_dict.save(dict_path(args.source_lang))
if target:
if not args.joined_dictionary:
tgt_dict.finalize(threshold=args.thresholdtgt, nwords=args.nwordstgt, padding_factor=args.padding_factor, )
tgt_dict.save(dict_path(args.target_lang))
make_all(args.source_lang)
if target:
make_all(args.target_lang)
print('| Wrote preprocessed data to {}'.format(args.destdir))
if args.alignfile:
assert args.trainpref, "--trainpref must be set if --alignfile is specified"
src_file_name = train_path(args.source_lang)
tgt_file_name = train_path(args.target_lang)
src_dict = dictionary.Dictionary.load(dict_path(args.source_lang))
tgt_dict = dictionary.Dictionary.load(dict_path(args.target_lang))
freq_map = {}
with open(args.alignfile, 'r') as align_file:
with open(src_file_name, 'r') as src_file:
with open(tgt_file_name, 'r') as tgt_file:
for a, s, t in zip_longest(align_file, src_file, tgt_file):
si = Tokenizer.tokenize(s, src_dict, add_if_not_exist=False)
ti = Tokenizer.tokenize(t, tgt_dict, add_if_not_exist=False)
ai = list(map(lambda x: tuple(x.split('-')), a.split()))
for sai, tai in ai:
srcidx = si[int(sai)]
tgtidx = ti[int(tai)]
if srcidx != src_dict.unk() and tgtidx != tgt_dict.unk():
assert srcidx != src_dict.pad()
assert srcidx != src_dict.eos()
assert tgtidx != tgt_dict.pad()
assert tgtidx != tgt_dict.eos()
if srcidx not in freq_map:
freq_map[srcidx] = {}
if tgtidx not in freq_map[srcidx]:
freq_map[srcidx][tgtidx] = 1
else:
freq_map[srcidx][tgtidx] += 1
align_dict = {}
for srcidx in freq_map.keys():
align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get)
with open(os.path.join(args.destdir, 'alignment.{}-{}.txt'.format(args.source_lang, args.target_lang)), 'w') as f:
for k, v in align_dict.items():
print('{} {}'.format(src_dict[k], tgt_dict[v]), file=f)
if __name__ == '__main__':
parser = get_parser()
args = parser.parse_args()
main(args)
This source diff could not be displayed because it is too large. You can view the blob instead.
cffi==1.14.0
#numpy==1.18.1
sacrebleu==1.2.10
torch
tqdm==4.31.1
#!/bin/bash
#SBATCH --job-name translation
set -euxo pipefail
# Vars without defaults
: "${DGXSYSTEM:?DGXSYSTEM not set}"
: "${CONT:?CONT not set}"
# Vars with defaults
: "${NEXP:=5}"
: "${DATESTAMP:=$(date +'%y%m%d%H%M%S%N')}"
: "${CLEAR_CACHES:=1}"
: "${DATADIR:=/raid/datasets/xformer_v0p6/utf8}"
: "${LOGDIR:=./results}"
# Other vars
readonly _seed_override=${SEED:-}
readonly _logfile_base="${LOGDIR}/${DATESTAMP}"
readonly _cont_name=translation
_cont_mounts="${DATADIR}:/data,${LOGDIR}:/results"
# Setup directories
mkdir -p "${LOGDIR}"
srun --ntasks="${SLURM_JOB_NUM_NODES}" mkdir -p "${LOGDIR}"
# Setup container
srun --ntasks="${SLURM_JOB_NUM_NODES}" --container-image="${CONT}" --container-name="${_cont_name}" true
# Run experiments
for _experiment_index in $(seq 1 "${NEXP}"); do
(
echo "Beginning trial ${_experiment_index} of ${NEXP}"
# Print system info
srun --ntasks="${SLURM_JOB_NUM_NODES}" --container-name="${_cont_name}" python -c "
import mlperf_log_utils
from mlperf_logging.mllog import constants
mlperf_log_utils.mlperf_submission_log(constants.TRANSFORMER)"
# Clear caches
if [ "${CLEAR_CACHES}" -eq 1 ]; then
srun --ntasks="${SLURM_JOB_NUM_NODES}" bash -c "echo -n 'Clearing cache on ' && hostname && sync && sudo /sbin/sysctl vm.drop_caches=3"
srun --ntasks="${SLURM_JOB_NUM_NODES}" --container-name="${_cont_name}" python -c "
from mlperf_logging.mllog import constants
from mlperf_log_utils import log_event
log_event(key=constants.CACHE_CLEAR, value=True)"
fi
# Run experiment
export SEED=${_seed_override:-$RANDOM}
srun --mpi=none --ntasks="$(( SLURM_JOB_NUM_NODES * DGXNGPU ))" --ntasks-per-node="${DGXNGPU}" \
--container-name="${_cont_name}" --container-mounts="${_cont_mounts}" \
./run_and_time.sh
) |& tee "${_logfile_base}_${_experiment_index}.log"
done
#!/bin/bash
SLURM_NTASKS_PER_NODE=${SLURM_NTASKS_PER_NODE:-$DGXNGPU}
SLURM_JOB_ID=${SLURM_JOB_ID:-$RANDOM}
MULTI_NODE=${MULTI_NODE:-''}
echo "Run vars: id $SLURM_JOB_ID gpus $SLURM_NTASKS_PER_NODE mparams $MULTI_NODE"
# Options
set -x
SEED=${SEED:-$RANDOM}
MAX_TOKENS=${MAX_TOKENS:-5120}
# DATASET_DIR="/data"
DATASET_DIR=/mnt/fs/user/llama/custom_model/libo_test/mlperf_test/transformer/implementations/dataset/translation/examples/translation/wmt14_en_de/utf8
MODE=${MODE:-TRAIN}
NUMEPOCHS=${NUMEPOCHS:-30}
case "$MODE" in
PREPROCESS) source run_preprocessing.sh;;
TRAIN) source run_training.sh;;
esac
set +x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment