Commit 7633129b authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Merge internal changes (#283)

Summary:
Pull Request resolved: https://github.com/pytorch/translate/pull/283

Pull Request resolved: https://github.com/pytorch/fairseq/pull/428

Differential Revision: D13564190

Pulled By: myleott

fbshipit-source-id: 3b62282d7069c288f5bdd1dd2c120788cee4abb5
parent 0cb87130
...@@ -9,7 +9,6 @@ import numpy as np ...@@ -9,7 +9,6 @@ import numpy as np
import torch import torch
from . import data_utils, FairseqDataset from . import data_utils, FairseqDataset
from typing import List
def collate(samples, pad_idx, eos_idx): def collate(samples, pad_idx, eos_idx):
...@@ -53,8 +52,8 @@ class MonolingualDataset(FairseqDataset): ...@@ -53,8 +52,8 @@ class MonolingualDataset(FairseqDataset):
dataset (torch.utils.data.Dataset): dataset to wrap dataset (torch.utils.data.Dataset): dataset to wrap
sizes (List[int]): sentence lengths sizes (List[int]): sentence lengths
vocab (~fairseq.data.Dictionary): vocabulary vocab (~fairseq.data.Dictionary): vocabulary
shuffle (bool, optional): shuffle the elements before batching. shuffle (bool, optional): shuffle the elements before batching
Default: ``True`` (default: True).
""" """
def __init__(self, dataset, sizes, src_vocab, tgt_vocab, add_eos_for_other_targets, shuffle, def __init__(self, dataset, sizes, src_vocab, tgt_vocab, add_eos_for_other_targets, shuffle,
...@@ -66,8 +65,8 @@ class MonolingualDataset(FairseqDataset): ...@@ -66,8 +65,8 @@ class MonolingualDataset(FairseqDataset):
self.add_eos_for_other_targets = add_eos_for_other_targets self.add_eos_for_other_targets = add_eos_for_other_targets
self.shuffle = shuffle self.shuffle = shuffle
assert targets is None or all( assert targets is None or all(t in {'self', 'future', 'past'} for t in targets), \
t in {'self', 'future', 'past'} for t in targets), "targets must be none or one of 'self', 'future', 'past'" "targets must be none or one of 'self', 'future', 'past'"
if targets is not None and len(targets) == 0: if targets is not None and len(targets) == 0:
targets = None targets = None
self.targets = targets self.targets = targets
...@@ -185,7 +184,7 @@ class MonolingualDataset(FairseqDataset): ...@@ -185,7 +184,7 @@ class MonolingualDataset(FairseqDataset):
@property @property
def supports_prefetch(self): def supports_prefetch(self):
return self.dataset.supports_prefetch return getattr(self.dataset, 'supports_prefetch', False)
def prefetch(self, indices): def prefetch(self, indices):
self.dataset.prefetch(indices) self.dataset.prefetch(indices)
...@@ -245,11 +245,12 @@ class NoisingDataset(torch.utils.data.Dataset): ...@@ -245,11 +245,12 @@ class NoisingDataset(torch.utils.data.Dataset):
**kwargs **kwargs
): ):
""" """
Sets up a noising dataset which takes a src batch, generates Wrap a :class:`~torch.utils.data.Dataset` and apply noise to the
a noisy src using a noising config, and returns the samples based on the supplied noising configuration.
corresponding {noisy src, original src} batch
Args: Args:
src_dataset: dataset which will be used to build self.src_dataset -- src_dataset (~torch.utils.data.Dataset): dataset to wrap.
to build self.src_dataset --
a LanguagePairDataset with src dataset as the source dataset and a LanguagePairDataset with src dataset as the source dataset and
None as the target dataset. Should NOT have padding so that None as the target dataset. Should NOT have padding so that
src_lengths are accurately calculated by language_pair_dataset src_lengths are accurately calculated by language_pair_dataset
...@@ -257,26 +258,22 @@ class NoisingDataset(torch.utils.data.Dataset): ...@@ -257,26 +258,22 @@ class NoisingDataset(torch.utils.data.Dataset):
We use language_pair_dataset here to encapsulate the tgt_dataset We use language_pair_dataset here to encapsulate the tgt_dataset
so we can re-use the LanguagePairDataset collater to format the so we can re-use the LanguagePairDataset collater to format the
batches in the structure that SequenceGenerator expects. batches in the structure that SequenceGenerator expects.
src_dict: src dict src_dict (~fairseq.data.Dictionary): source dictionary
src_dict: src dictionary seed (int): seed to use when generating random noise
seed: seed to use when generating random noise noiser (WordNoising): a pre-initialized :class:`WordNoising`
noiser: a pre-initialized noiser. If this is None, a noiser will instance. If this is None, a new instance will be created using
be created using noising_class and kwargs. *noising_class* and *kwargs*.
noising_class: class to use when initializing noiser noising_class (class, optional): class to use to initialize a
kwargs: noising args for configuring noising to apply default :class:`WordNoising` instance.
Note that there is no equivalent argparse code for these args kwargs (dict, optional): arguments to initialize the default
anywhere in our top level train scripts yet. Integration is :class:`WordNoising` instance given by *noiser*.
still in progress. You can still, however, test out this dataset
functionality with the appropriate args as in the corresponding
unittest: test_noising_dataset.
""" """
self.src_dataset = src_dataset self.src_dataset = src_dataset
self.src_dict = src_dict self.src_dict = src_dict
self.seed = seed
self.noiser = noiser if noiser is not None else noising_class( self.noiser = noiser if noiser is not None else noising_class(
dictionary=src_dict, **kwargs, dictionary=src_dict, **kwargs,
) )
self.seed = seed
def __getitem__(self, index): def __getitem__(self, index):
""" """
......
...@@ -13,13 +13,16 @@ from . import FairseqDataset ...@@ -13,13 +13,16 @@ from . import FairseqDataset
class RoundRobinZipDatasets(FairseqDataset): class RoundRobinZipDatasets(FairseqDataset):
"""Zip multiple FairseqDatasets together, repeating shorter datasets in a """Zip multiple :class:`~fairseq.data.FairseqDataset` instances together.
round-robin fashion to match the length of the longest one.
Shorter datasets are repeated in a round-robin fashion to match the length
of the longest one.
Args: Args:
datasets: a dictionary of FairseqDatasets datasets (Dict[~fairseq.data.FairseqDataset]): a dictionary of
eval_key: an optional key used at evaluation time that causes this :class:`~fairseq.data.FairseqDataset` instances.
instance to pass-through batches from `datasets[eval_key]`. eval_key (str, optional): a key used at evaluation time that causes
this instance to pass-through batches from *datasets[eval_key]*.
""" """
def __init__(self, datasets, eval_key=None): def __init__(self, datasets, eval_key=None):
...@@ -107,3 +110,14 @@ class RoundRobinZipDatasets(FairseqDataset): ...@@ -107,3 +110,14 @@ class RoundRobinZipDatasets(FairseqDataset):
dataset.valid_size(self._map_index(key, index), max_positions[key]) dataset.valid_size(self._map_index(key, index), max_positions[key])
for key, dataset in self.datasets.items() for key, dataset in self.datasets.items()
) )
@property
def supports_prefetch(self):
return all(
getattr(dataset, 'supports_prefetch', False)
for dataset in self.datasets.values()
)
def prefetch(self, indices):
for key, dataset in self.datasets.items():
dataset.prefetch([self._map_index(key, index) for index in indices])
...@@ -14,32 +14,32 @@ from . import FairseqDataset ...@@ -14,32 +14,32 @@ from . import FairseqDataset
class TokenBlockDataset(FairseqDataset): class TokenBlockDataset(FairseqDataset):
"""Break a 1d tensor of tokens into blocks. """Break a Dataset of tokens into blocks.
The blocks are fetched from the original tensor so no additional memory is allocated.
Args: Args:
tokens: 1d tensor of tokens to break into blocks dataset (~torch.utils.data.Dataset): dataset to break into blocks
sizes: sentence lengths (required for 'complete' and 'eos') sizes (List[int]): sentence lengths (required for 'complete' and 'eos')
block_size: maximum block size (ignored in 'eos' break mode) block_size (int): maximum block size (ignored in 'eos' break mode)
break_mode: Mode used for breaking tokens. Values can be one of: break_mode (str, optional): Mode used for breaking tokens. Values can
be one of:
- 'none': break tokens into equally sized blocks (up to block_size) - 'none': break tokens into equally sized blocks (up to block_size)
- 'complete': break tokens into blocks (up to block_size) such that - 'complete': break tokens into blocks (up to block_size) such that
blocks contains complete sentences, although block_size may be blocks contains complete sentences, although block_size may be
exceeded if some sentences exceed block_size exceeded if some sentences exceed block_size
- 'eos': each block contains one sentence (block_size is ignored) - 'eos': each block contains one sentence (block_size is ignored)
include_targets: return next tokens as targets include_targets (bool, optional): return next tokens as targets
(default: False).
""" """
def __init__(self, ds, block_size, pad, eos, break_mode=None, include_targets=False): def __init__(self, dataset, sizes, block_size, pad, eos, break_mode=None, include_targets=False):
super().__init__() super().__init__()
self.dataset = ds self.dataset = dataset
self.pad = pad self.pad = pad
self.eos = eos self.eos = eos
self.include_targets = include_targets self.include_targets = include_targets
self.slice_indices = [] self.slice_indices = []
self.cache_index = {}
sizes = ds.sizes assert len(dataset) == len(sizes)
if break_mode is None or break_mode == 'none': if break_mode is None or break_mode == 'none':
total_size = sum(sizes) total_size = sum(sizes)
...@@ -77,44 +77,66 @@ class TokenBlockDataset(FairseqDataset): ...@@ -77,44 +77,66 @@ class TokenBlockDataset(FairseqDataset):
self.sizes = np.array([e - s for s, e in self.slice_indices]) self.sizes = np.array([e - s for s, e in self.slice_indices])
def __getitem__(self, index): # build index mapping block indices to the underlying dataset indices
s, e = self.cache_index[index] self.block_to_dataset_index = []
ds_idx, ds_remaining = -1, 0
for to_consume in self.sizes:
if ds_remaining == 0:
ds_idx += 1
ds_remaining = sizes[ds_idx]
start_ds_idx = ds_idx
start_offset = sizes[ds_idx] - ds_remaining
while to_consume > ds_remaining:
to_consume -= ds_remaining
ds_idx += 1
ds_remaining = sizes[ds_idx]
ds_remaining -= to_consume
self.block_to_dataset_index.append((
start_ds_idx, # starting index in dataset
start_offset, # starting offset within starting index
ds_idx, # ending index in dataset
))
assert ds_remaining == 0
assert ds_idx == len(self.dataset) - 1
item = torch.from_numpy(self.cache[s:e]).long() def __getitem__(self, index):
start_ds_idx, start_offset, end_ds_idx = self.block_to_dataset_index[index]
buffer = torch.cat([
self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)
])
slice_s, slice_e = self.slice_indices[index]
length = slice_e - slice_s
s, e = start_offset, start_offset + length
item = buffer[s:e]
if self.include_targets: if self.include_targets:
# target is the sentence, for source, rotate item one token to the left (would start with eos) # *target* is the original sentence (=item)
# past target is rotated to the left by 2 (padded if its first) # *source* is rotated left by 1 (maybe left-padded with eos)
# *past_target* is rotated left by 2 (left-padded as needed)
if s == 0: if s == 0:
source = np.concatenate([[self.eos], self.cache[0:e - 1]]) source = torch.cat([item.new([self.eos]), buffer[0:e - 1]])
past_target = np.concatenate([[self.pad, self.eos], self.cache[0:e - 2]]) past_target = torch.cat([item.new([self.pad, self.eos]), buffer[0:e - 2]])
else: else:
source = self.cache[s - 1: e - 1] source = buffer[s - 1:e - 1]
if s == 1: if s == 1:
past_target = np.concatenate([[self.eos], self.cache[0:e - 2]]) past_target = torch.cat([item.new([self.eos]), buffer[0:e - 2]])
else: else:
past_target = self.cache[s - 2:e - 2] past_target = buffer[s - 2:e - 2]
return torch.from_numpy(source).long(), item, torch.from_numpy(past_target).long() return source, item, past_target
return item return item
def __len__(self): def __len__(self):
return len(self.slice_indices) return len(self.slice_indices)
def prefetch(self, indices):
indices.sort()
total_size = 0
for idx in indices:
s, e = self.slice_indices[idx]
total_size += e - s
self.cache = np.empty(total_size, dtype=np.int32)
start = 0
for idx in indices:
s, e = self.slice_indices[idx]
self.dataset.read_into(s, self.cache[start:start + e - s])
self.cache_index[idx] = (start, start + e - s)
start += e - s
@property @property
def supports_prefetch(self): def supports_prefetch(self):
return True return getattr(self.dataset, 'supports_prefetch', False)
def prefetch(self, indices):
self.dataset.prefetch({
ds_idx
for index in indices
for start_ds_idx, _, end_ds_idx in [self.block_to_dataset_index[index]]
for ds_idx in range(start_ds_idx, end_ds_idx + 1)
})
...@@ -11,7 +11,7 @@ from . import FairseqDataset ...@@ -11,7 +11,7 @@ from . import FairseqDataset
class TransformEosDataset(FairseqDataset): class TransformEosDataset(FairseqDataset):
"""A dataset wrapper that appends/prepends/strips EOS. """A :class:`~fairseq.data.FairseqDataset` wrapper that appends/prepends/strips EOS.
Note that the transformation is applied in :func:`collater`. Note that the transformation is applied in :func:`collater`.
...@@ -111,7 +111,7 @@ class TransformEosDataset(FairseqDataset): ...@@ -111,7 +111,7 @@ class TransformEosDataset(FairseqDataset):
@property @property
def supports_prefetch(self): def supports_prefetch(self):
return self.dataset.supports_prefetch() return getattr(self.dataset, 'supports_prefetch', False)
def prefetch(self, indices): def prefetch(self, indices):
return self.dataset.prefetch(indices) return self.dataset.prefetch(indices)
...@@ -6,7 +6,9 @@ ...@@ -6,7 +6,9 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from collections import namedtuple from collections import namedtuple
import os
import pickle import pickle
import subprocess
import torch import torch
from torch import nn from torch import nn
...@@ -42,6 +44,38 @@ else: ...@@ -42,6 +44,38 @@ else:
import torch.distributed as dist_no_c10d import torch.distributed as dist_no_c10d
def infer_init_method(args):
if args.distributed_init_method is not None:
return
# support torch.distributed.launch
if all(key in os.environ for key in [
'MASTER_ADDR', 'MASTER_PORT', 'WORLD_SIZE', 'RANK'
]):
args.distributed_init_method = 'tcp://{addr}:{port}'.format(
addr=os.environ['MASTER_ADDR'],
port=os.environ['MASTER_PORT'],
)
args.distributed_world_size = int(os.environ['WORLD_SIZE'])
args.distributed_rank = int(os.environ['RANK'])
# we can determine the init method automatically for Slurm
elif args.distributed_port > 0:
node_list = os.environ.get('SLURM_JOB_NODELIST')
if node_list is not None:
try:
hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', node_list])
args.distributed_init_method = 'tcp://{host}:{port}'.format(
host=hostnames.split()[0].decode('utf-8'),
port=args.distributed_port)
args.distributed_rank = int(os.environ.get('SLURM_PROCID'))
args.device_id = int(os.environ.get('SLURM_LOCALID'))
except subprocess.CalledProcessError as e: # scontrol failed
raise e
except FileNotFoundError: # Slurm is not installed
pass
def distributed_init(args): def distributed_init(args):
if args.distributed_world_size == 1: if args.distributed_world_size == 1:
raise ValueError('Cannot initialize distributed with distributed_world_size=1') raise ValueError('Cannot initialize distributed with distributed_world_size=1')
...@@ -158,7 +192,7 @@ def all_gather_list(data, group=None, max_size=16384): ...@@ -158,7 +192,7 @@ def all_gather_list(data, group=None, max_size=16384):
pickle.loads(bytes(out_buffer[2:size+2].tolist())) pickle.loads(bytes(out_buffer[2:size+2].tolist()))
) )
return result return result
except pickle.UnpicklingError as e: except pickle.UnpicklingError:
raise Exception( raise Exception(
'Unable to unpickle data from other workers. all_gather_list requires all ' 'Unable to unpickle data from other workers. all_gather_list requires all '
'workers to enter the function together, so this error usually indicates ' 'workers to enter the function together, so this error usually indicates '
...@@ -167,4 +201,3 @@ def all_gather_list(data, group=None, max_size=16384): ...@@ -167,4 +201,3 @@ def all_gather_list(data, group=None, max_size=16384):
'in your training script that can cause one worker to finish an epoch ' 'in your training script that can cause one worker to finish an epoch '
'while other workers are still iterating over their portions of the data.' 'while other workers are still iterating over their portions of the data.'
) )
...@@ -12,7 +12,7 @@ computation (e.g., AdaptiveSoftmax) and which therefore do not work with the ...@@ -12,7 +12,7 @@ computation (e.g., AdaptiveSoftmax) and which therefore do not work with the
c10d version of DDP. c10d version of DDP.
This version also supports the *accumulate_grads* feature, which allows faster This version also supports the *accumulate_grads* feature, which allows faster
training with --update-freq. training with `--update-freq`.
""" """
import copy import copy
...@@ -27,18 +27,18 @@ from . import distributed_utils ...@@ -27,18 +27,18 @@ from . import distributed_utils
class LegacyDistributedDataParallel(nn.Module): class LegacyDistributedDataParallel(nn.Module):
"""Implements distributed data parallelism at the module level. """Implements distributed data parallelism at the module level.
A simplified version of torch.nn.parallel.DistributedDataParallel. A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`.
This version uses a c10d process group for communication and does This version uses a c10d process group for communication and does not
not broadcast buffers. broadcast buffers.
Args: Args:
module: module to be parallelized module (~torch.nn.Module): module to be parallelized
world_size: number of parallel workers world_size (int): number of parallel workers
process_group (optional): the c10d process group to be used for process_group (optional): the c10d process group to be used for
distributed data all-reduction. If None, the default process group distributed data all-reduction. If None, the default process group
will be used. will be used.
buffer_size: number of elements to buffer before performing all-reduce buffer_size (int, optional): number of elements to buffer before
(default: 256M). performing all-reduce (default: 256M).
""" """
def __init__(self, module, world_size, process_group=None, buffer_size=2**28): def __init__(self, module, world_size, process_group=None, buffer_size=2**28):
......
...@@ -179,10 +179,8 @@ class FConvEncoder(FairseqEncoder): ...@@ -179,10 +179,8 @@ class FConvEncoder(FairseqEncoder):
connections are added between layers when ``residual=1`` (which is connections are added between layers when ``residual=1`` (which is
the default behavior). the default behavior).
dropout (float, optional): dropout to be applied before each conv layer dropout (float, optional): dropout to be applied before each conv layer
normalization_constant (float, optional): multiplies the result of the left_pad (bool, optional): whether the input is left-padded
residual block by sqrt(value) (default: True).
left_pad (bool, optional): whether the input is left-padded. Default:
``True``
""" """
def __init__( def __init__(
...@@ -215,7 +213,7 @@ class FConvEncoder(FairseqEncoder): ...@@ -215,7 +213,7 @@ class FConvEncoder(FairseqEncoder):
self.residuals = [] self.residuals = []
layer_in_channels = [in_channels] layer_in_channels = [in_channels]
for i, (out_channels, kernel_size, residual) in enumerate(convolutions): for _, (out_channels, kernel_size, residual) in enumerate(convolutions):
if residual == 0: if residual == 0:
residual_dim = out_channels residual_dim = out_channels
else: else:
......
...@@ -524,6 +524,7 @@ def base_architecture(args): ...@@ -524,6 +524,7 @@ def base_architecture(args):
args.pretrained_checkpoint = getattr(args, 'pretrained_checkpoint', '') args.pretrained_checkpoint = getattr(args, 'pretrained_checkpoint', '')
args.pretrained = getattr(args, 'pretrained', 'False') args.pretrained = getattr(args, 'pretrained', 'False')
@register_model_architecture('fconv_self_att', 'fconv_self_att_wp') @register_model_architecture('fconv_self_att', 'fconv_self_att_wp')
def fconv_self_att_wp(args): def fconv_self_att_wp(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256) args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
......
...@@ -196,7 +196,6 @@ class LSTMEncoder(FairseqEncoder): ...@@ -196,7 +196,6 @@ class LSTMEncoder(FairseqEncoder):
if bidirectional: if bidirectional:
self.output_units *= 2 self.output_units *= 2
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths):
if self.left_pad: if self.left_pad:
# convert left-padding to right-padding # convert left-padding to right-padding
...@@ -235,7 +234,8 @@ class LSTMEncoder(FairseqEncoder): ...@@ -235,7 +234,8 @@ class LSTMEncoder(FairseqEncoder):
if self.bidirectional: if self.bidirectional:
def combine_bidir(outs): def combine_bidir(outs):
return outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous().view(self.num_layers, bsz, -1) out = outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous()
return out.view(self.num_layers, bsz, -1)
final_hiddens = combine_bidir(final_hiddens) final_hiddens = combine_bidir(final_hiddens)
final_cells = combine_bidir(final_cells) final_cells = combine_bidir(final_cells)
...@@ -340,7 +340,6 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -340,7 +340,6 @@ class LSTMDecoder(FairseqIncrementalDecoder):
elif not self.share_input_output_embed: elif not self.share_input_output_embed:
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None): def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None):
encoder_out = encoder_out_dict['encoder_out'] encoder_out = encoder_out_dict['encoder_out']
encoder_padding_mask = encoder_out_dict['encoder_padding_mask'] encoder_padding_mask = encoder_out_dict['encoder_padding_mask']
...@@ -504,6 +503,7 @@ def base_architecture(args): ...@@ -504,6 +503,7 @@ def base_architecture(args):
args.share_all_embeddings = getattr(args, 'share_all_embeddings', False) args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,50000,200000') args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,50000,200000')
@register_model_architecture('lstm', 'lstm_wiseman_iwslt_de_en') @register_model_architecture('lstm', 'lstm_wiseman_iwslt_de_en')
def lstm_wiseman_iwslt_de_en(args): def lstm_wiseman_iwslt_de_en(args):
args.dropout = getattr(args, 'dropout', 0.1) args.dropout = getattr(args, 'dropout', 0.1)
......
...@@ -219,7 +219,7 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -219,7 +219,7 @@ class TransformerLanguageModel(FairseqLanguageModel):
# make sure all arguments are present in older models # make sure all arguments are present in older models
base_lm_architecture(args) base_lm_architecture(args)
if hasattr(args, 'no_tie_adaptive_proj') and args.no_tie_adaptive_proj == False: if hasattr(args, 'no_tie_adaptive_proj') and args.no_tie_adaptive_proj is False:
# backward compatibility # backward compatibility
args.tie_adaptive_proj = True args.tie_adaptive_proj = True
...@@ -229,15 +229,17 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -229,15 +229,17 @@ class TransformerLanguageModel(FairseqLanguageModel):
args.max_target_positions = args.tokens_per_sample args.max_target_positions = args.tokens_per_sample
if args.character_embeddings: if args.character_embeddings:
embed_tokens = CharacterTokenEmbedder(task.dictionary, eval(args.character_filters), embed_tokens = CharacterTokenEmbedder(
args.character_embedding_dim, task.dictionary, eval(args.character_filters),
args.decoder_embed_dim, args.character_embedding_dim, args.decoder_embed_dim,
args.char_embedder_highway_layers, args.char_embedder_highway_layers,
) )
elif args.adaptive_input: elif args.adaptive_input:
embed_tokens = AdaptiveInput(len(task.dictionary), task.dictionary.pad(), args.decoder_input_dim, embed_tokens = AdaptiveInput(
args.adaptive_input_factor, args.decoder_embed_dim, len(task.dictionary), task.dictionary.pad(), args.decoder_input_dim,
options.eval_str_list(args.adaptive_input_cutoff, type=int)) args.adaptive_input_factor, args.decoder_embed_dim,
options.eval_str_list(args.adaptive_input_cutoff, type=int),
)
else: else:
embed_tokens = Embedding(len(task.dictionary), args.decoder_input_dim, task.dictionary.pad()) embed_tokens = Embedding(len(task.dictionary), args.decoder_input_dim, task.dictionary.pad())
...@@ -248,7 +250,9 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -248,7 +250,9 @@ class TransformerLanguageModel(FairseqLanguageModel):
args.adaptive_softmax_cutoff, args.adaptive_input_cutoff) args.adaptive_softmax_cutoff, args.adaptive_input_cutoff)
assert args.decoder_input_dim == args.decoder_output_dim assert args.decoder_input_dim == args.decoder_output_dim
decoder = TransformerDecoder(args, task.output_dictionary, embed_tokens, no_encoder_attn=True, final_norm=False) decoder = TransformerDecoder(
args, task.output_dictionary, embed_tokens, no_encoder_attn=True, final_norm=False,
)
return TransformerLanguageModel(decoder) return TransformerLanguageModel(decoder)
...@@ -261,8 +265,8 @@ class TransformerEncoder(FairseqEncoder): ...@@ -261,8 +265,8 @@ class TransformerEncoder(FairseqEncoder):
args (argparse.Namespace): parsed command-line arguments args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): encoding dictionary dictionary (~fairseq.data.Dictionary): encoding dictionary
embed_tokens (torch.nn.Embedding): input embedding embed_tokens (torch.nn.Embedding): input embedding
left_pad (bool, optional): whether the input is left-padded. Default: left_pad (bool, optional): whether the input is left-padded
``True`` (default: True).
""" """
def __init__(self, args, dictionary, embed_tokens, left_pad=True): def __init__(self, args, dictionary, embed_tokens, left_pad=True):
...@@ -382,10 +386,12 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -382,10 +386,12 @@ class TransformerDecoder(FairseqIncrementalDecoder):
args (argparse.Namespace): parsed command-line arguments args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs. no_encoder_attn (bool, optional): whether to attend to encoder outputs
Default: ``False`` (default: False).
left_pad (bool, optional): whether the input is left-padded. Default: left_pad (bool, optional): whether the input is left-padded
``False`` (default: False).
final_norm (bool, optional): apply layer norm to the output of the
final decoder layer (default: True).
""" """
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, left_pad=False, final_norm=True): def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, left_pad=False, final_norm=True):
...@@ -634,8 +640,8 @@ class TransformerDecoderLayer(nn.Module): ...@@ -634,8 +640,8 @@ class TransformerDecoderLayer(nn.Module):
Args: Args:
args (argparse.Namespace): parsed command-line arguments args (argparse.Namespace): parsed command-line arguments
no_encoder_attn (bool, optional): whether to attend to encoder outputs. no_encoder_attn (bool, optional): whether to attend to encoder outputs
Default: ``False`` (default: False).
""" """
def __init__(self, args, no_encoder_attn=False): def __init__(self, args, no_encoder_attn=False):
......
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn
from typing import List from typing import List
...@@ -16,13 +15,13 @@ from typing import List ...@@ -16,13 +15,13 @@ from typing import List
class AdaptiveInput(nn.Module): class AdaptiveInput(nn.Module):
def __init__( def __init__(
self, self,
vocab_size: int, vocab_size: int,
padding_idx: int, padding_idx: int,
initial_dim: int, initial_dim: int,
factor: float, factor: float,
output_dim: int, output_dim: int,
cutoff: List[int], cutoff: List[int],
): ):
super().__init__() super().__init__()
......
...@@ -113,8 +113,9 @@ class AdaptiveSoftmax(nn.Module): ...@@ -113,8 +113,9 @@ class AdaptiveSoftmax(nn.Module):
m = nn.Sequential( m = nn.Sequential(
proj, proj,
nn.Dropout(self.dropout), nn.Dropout(self.dropout),
nn.Linear(dim, self.cutoff[i + 1] - self.cutoff[i], bias=False) \ nn.Linear(
if tied_emb is None else TiedLinear(tied_emb, transpose=False) dim, self.cutoff[i + 1] - self.cutoff[i], bias=False,
) if tied_emb is None else TiedLinear(tied_emb, transpose=False),
) )
self.tail.append(m) self.tail.append(m)
......
...@@ -9,7 +9,7 @@ import importlib ...@@ -9,7 +9,7 @@ import importlib
import os import os
from .fairseq_optimizer import FairseqOptimizer from .fairseq_optimizer import FairseqOptimizer
from .fp16_optimizer import FP16Optimizer from .fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer
OPTIMIZER_REGISTRY = {} OPTIMIZER_REGISTRY = {}
......
...@@ -70,10 +70,11 @@ class FairseqOptimizer(object): ...@@ -70,10 +70,11 @@ class FairseqOptimizer(object):
group.update(optimizer_overrides) group.update(optimizer_overrides)
def backward(self, loss): def backward(self, loss):
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves."""
loss.backward() loss.backward()
def multiply_grads(self, c): def multiply_grads(self, c):
"""Multiplies grads by a constant ``c``.""" """Multiplies grads by a constant *c*."""
for p in self.params: for p in self.params:
if p.grad is not None: if p.grad is not None:
p.grad.data.mul_(c) p.grad.data.mul_(c)
......
...@@ -45,6 +45,164 @@ class DynamicLossScaler(object): ...@@ -45,6 +45,164 @@ class DynamicLossScaler(object):
return False return False
class FP16Optimizer(optim.FairseqOptimizer):
"""
Wrap an *optimizer* to support FP16 (mixed precision) training.
"""
def __init__(self, args, params, fp32_optimizer, fp32_params):
super().__init__(args, params)
self.fp32_optimizer = fp32_optimizer
self.fp32_params = fp32_params
if getattr(args, 'fp16_scale_window', None) is None:
if len(args.update_freq) > 1:
raise ValueError(
'--fp16-scale-window must be given explicitly when using a '
'custom --update-freq schedule'
)
scale_window = 2**14 / args.distributed_world_size / args.update_freq[0]
else:
scale_window = args.fp16_scale_window
self.scaler = DynamicLossScaler(
init_scale=args.fp16_init_scale,
scale_window=scale_window,
tolerance=args.fp16_scale_tolerance,
)
@classmethod
def build_optimizer(cls, args, params):
"""
Args:
args (argparse.Namespace): fairseq args
params (iterable): iterable of parameters to optimize
"""
# create FP32 copy of parameters and grads
total_param_size = sum(p.data.numel() for p in params)
fp32_params = params[0].new(0).float().new(total_param_size)
offset = 0
for p in params:
numel = p.data.numel()
fp32_params[offset:offset+numel].copy_(p.data.view(-1))
offset += numel
fp32_params = torch.nn.Parameter(fp32_params)
fp32_params.grad = fp32_params.data.new(total_param_size)
fp32_optimizer = optim.build_optimizer(args, [fp32_params])
return cls(args, params, fp32_optimizer, fp32_params)
@property
def optimizer(self):
return self.fp32_optimizer.optimizer
@property
def optimizer_config(self):
return self.fp32_optimizer.optimizer_config
def get_lr(self):
return self.fp32_optimizer.get_lr()
def set_lr(self, lr):
self.fp32_optimizer.set_lr(lr)
def state_dict(self):
"""Return the optimizer's state dict."""
state_dict = self.fp32_optimizer.state_dict()
state_dict['loss_scale'] = self.scaler.loss_scale
return state_dict
def load_state_dict(self, state_dict, optimizer_overrides=None):
"""Load an optimizer state dict.
In general we should prefer the configuration of the existing optimizer
instance (e.g., learning rate) over that found in the state_dict. This
allows us to resume training from a checkpoint using a new set of
optimizer args.
"""
if 'loss_scale' in state_dict:
self.scaler.loss_scale = state_dict['loss_scale']
self.fp32_optimizer.load_state_dict(state_dict, optimizer_overrides)
def backward(self, loss):
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves.
Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this
function additionally dynamically scales the loss to avoid gradient
underflow.
"""
loss = loss * self.scaler.loss_scale
loss.backward()
self._needs_sync = True
def _sync_fp16_grads_to_fp32(self, multiply_grads=1.):
if self._needs_sync:
# copy FP16 grads to FP32
offset = 0
for p in self.params:
if not p.requires_grad:
continue
grad_data = p.grad.data if p.grad is not None else p.data.new_zeros(p.data.shape)
numel = grad_data.numel()
self.fp32_params.grad.data[offset:offset+numel].copy_(grad_data.view(-1))
offset += numel
# correct for dynamic loss scaler
self.fp32_params.grad.data.mul_(multiply_grads / self.scaler.loss_scale)
self._needs_sync = False
def multiply_grads(self, c):
"""Multiplies grads by a constant ``c``."""
if self._needs_sync:
self._sync_fp16_grads_to_fp32(c)
else:
self.fp32_params.grad.data.mul_(c)
def clip_grad_norm(self, max_norm):
"""Clips gradient norm and updates dynamic loss scaler."""
self._sync_fp16_grads_to_fp32()
grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, max_norm)
# detect overflow and adjust loss scale
overflow = DynamicLossScaler.has_overflow(grad_norm)
self.scaler.update_scale(overflow)
if overflow:
if self.scaler.loss_scale <= self.args.min_loss_scale:
# Use FloatingPointError as an uncommon error that parent
# functions can safely catch to stop training.
raise FloatingPointError((
'Minimum loss scale reached ({}). Your loss is probably exploding. '
'Try lowering the learning rate, using gradient clipping or '
'increasing the batch size.'
).format(self.args.min_loss_scale))
raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale))
return grad_norm
def step(self, closure=None):
"""Performs a single optimization step."""
self._sync_fp16_grads_to_fp32()
self.fp32_optimizer.step(closure)
# copy FP32 params back into FP16 model
offset = 0
for p in self.params:
if not p.requires_grad:
continue
numel = p.data.numel()
p.data.copy_(self.fp32_params.data[offset:offset+numel].view_as(p.data))
offset += numel
def zero_grad(self):
"""Clears the gradients of all optimized parameters."""
self.fp32_optimizer.zero_grad()
for p in self.params:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
self._needs_sync = False
class ConvertToFP32(object): class ConvertToFP32(object):
""" """
A wrapper around a list of params that will convert them to FP32 on the A wrapper around a list of params that will convert them to FP32 on the
...@@ -94,14 +252,13 @@ class ConvertToFP32(object): ...@@ -94,14 +252,13 @@ class ConvertToFP32(object):
raise StopIteration raise StopIteration
class FP16Optimizer(optim.FairseqOptimizer): class MemoryEfficientFP16Optimizer(optim.FairseqOptimizer):
""" """
Wrap an *optimizer* to support FP16 (mixed precision) training. Wrap an *optimizer* to support FP16 (mixed precision) training.
Args: Compared to :class:`fairseq.optim.FP16Optimizer`, this version uses less
args (argparse.Namespace): fairseq args memory by copying between FP16 and FP32 parameters on-the-fly. The tradeoff
params (iterable): iterable of parameters to optimize is reduced optimization speed, which can be mitigated with `--update-freq`.
optimizer (~fairseq.optim.FairseqOptimizer): optimizer to wrap
""" """
def __init__(self, args, params, optimizer): def __init__(self, args, params, optimizer):
...@@ -124,10 +281,15 @@ class FP16Optimizer(optim.FairseqOptimizer): ...@@ -124,10 +281,15 @@ class FP16Optimizer(optim.FairseqOptimizer):
tolerance=args.fp16_scale_tolerance, tolerance=args.fp16_scale_tolerance,
) )
@staticmethod @classmethod
def build_optimizer(args, params): def build_optimizer(cls, args, params):
"""
Args:
args (argparse.Namespace): fairseq args
params (iterable): iterable of parameters to optimize
"""
fp16_optimizer = optim.build_optimizer(args, params) fp16_optimizer = optim.build_optimizer(args, params)
return FP16Optimizer(args, params, fp16_optimizer) return cls(args, params, fp16_optimizer)
@property @property
def optimizer(self): def optimizer(self):
...@@ -164,6 +326,12 @@ class FP16Optimizer(optim.FairseqOptimizer): ...@@ -164,6 +326,12 @@ class FP16Optimizer(optim.FairseqOptimizer):
ConvertToFP32.unwrap_optimizer_(self.wrapped_optimizer.optimizer) ConvertToFP32.unwrap_optimizer_(self.wrapped_optimizer.optimizer)
def backward(self, loss): def backward(self, loss):
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves.
Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this
function additionally dynamically scales the loss to avoid gradient
underflow.
"""
loss = loss * self.scaler.loss_scale loss = loss * self.scaler.loss_scale
loss.backward() loss.backward()
self._grads_are_scaled = True self._grads_are_scaled = True
...@@ -178,7 +346,7 @@ class FP16Optimizer(optim.FairseqOptimizer): ...@@ -178,7 +346,7 @@ class FP16Optimizer(optim.FairseqOptimizer):
assert multiply_grads == 1. assert multiply_grads == 1.
def multiply_grads(self, c): def multiply_grads(self, c):
"""Multiplies grads by a constant ``c``.""" """Multiplies grads by a constant *c*."""
if self._grads_are_scaled: if self._grads_are_scaled:
self._unscale_grads(c) self._unscale_grads(c)
else: else:
......
...@@ -13,18 +13,25 @@ from . import FairseqLRScheduler, register_lr_scheduler ...@@ -13,18 +13,25 @@ from . import FairseqLRScheduler, register_lr_scheduler
@register_lr_scheduler('cosine') @register_lr_scheduler('cosine')
class CosineSchedule(FairseqLRScheduler): class CosineSchedule(FairseqLRScheduler):
"""Assign LR based on a cyclical schedule that follows the cosine function. """Assign LR based on a cyclical schedule that follows the cosine function.
See https://arxiv.org/pdf/1608.03983.pdf for details
See https://arxiv.org/pdf/1608.03983.pdf for details.
We also support a warmup phase where we linearly increase the learning rate We also support a warmup phase where we linearly increase the learning rate
from some initial learning rate (`--warmup-init-lr`) until the configured from some initial learning rate (``--warmup-init-lr``) until the configured
learning rate (`--lr`). learning rate (``--lr``).
During warmup:
During warmup::
lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
lr = lrs[update_num] lr = lrs[update_num]
After warmup:
After warmup::
lr = lr_min + 0.5*(lr_max - lr_min)*(1 + cos(t_curr / t_i)) lr = lr_min + 0.5*(lr_max - lr_min)*(1 + cos(t_curr / t_i))
where
t_curr is current percentage of updates within the current period range where ``t_curr`` is current percentage of updates within the current period
t_i is the current period range, which is scaled by t_mul after every iteration range and ``t_i`` is the current period range, which is scaled by ``t_mul``
after every iteration.
""" """
def __init__(self, args, optimizer): def __init__(self, args, optimizer):
...@@ -39,7 +46,7 @@ class CosineSchedule(FairseqLRScheduler): ...@@ -39,7 +46,7 @@ class CosineSchedule(FairseqLRScheduler):
if args.warmup_init_lr < 0: if args.warmup_init_lr < 0:
args.warmup_init_lr = args.lr[0] args.warmup_init_lr = args.lr[0]
self.min_lr = args.lr[0] self.min_lr = args.lr[0]
self.max_lr = args.max_lr self.max_lr = args.max_lr
assert self.max_lr > self.min_lr, 'max_lr must be more than lr' assert self.max_lr > self.min_lr, 'max_lr must be more than lr'
...@@ -98,7 +105,7 @@ class CosineSchedule(FairseqLRScheduler): ...@@ -98,7 +105,7 @@ class CosineSchedule(FairseqLRScheduler):
t_curr = curr_updates - (self.period * i) t_curr = curr_updates - (self.period * i)
lr_shrink = self.lr_shrink ** i lr_shrink = self.lr_shrink ** i
min_lr = self.min_lr * lr_shrink min_lr = self.min_lr * lr_shrink
max_lr = self.max_lr * lr_shrink max_lr = self.max_lr * lr_shrink
self.lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * t_curr / t_i)) self.lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * t_curr / t_i))
......
...@@ -13,22 +13,19 @@ class InverseSquareRootSchedule(FairseqLRScheduler): ...@@ -13,22 +13,19 @@ class InverseSquareRootSchedule(FairseqLRScheduler):
"""Decay the LR based on the inverse square root of the update number. """Decay the LR based on the inverse square root of the update number.
We also support a warmup phase where we linearly increase the learning rate We also support a warmup phase where we linearly increase the learning rate
from some initial learning rate (`--warmup-init-lr`) until the configured from some initial learning rate (``--warmup-init-lr``) until the configured
learning rate (`--lr`). Thereafter we decay proportional to the number of learning rate (``--lr``). Thereafter we decay proportional to the number of
updates, with a decay factor set to align with the configured learning rate. updates, with a decay factor set to align with the configured learning rate.
During warmup: During warmup::
lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
lr = lrs[update_num] lr = lrs[update_num]
After warmup: After warmup::
lr = decay_factor / sqrt(update_num)
where
decay_factor = args.lr * sqrt(args.warmup_updates) decay_factor = args.lr * sqrt(args.warmup_updates)
lr = decay_factor / sqrt(update_num)
""" """
def __init__(self, args, optimizer): def __init__(self, args, optimizer):
......
...@@ -14,8 +14,7 @@ from . import FairseqLRScheduler, register_lr_scheduler ...@@ -14,8 +14,7 @@ from . import FairseqLRScheduler, register_lr_scheduler
class TriangularSchedule(FairseqLRScheduler): class TriangularSchedule(FairseqLRScheduler):
"""Assign LR based on a triangular cyclical schedule. """Assign LR based on a triangular cyclical schedule.
See https://arxiv.org/pdf/1506.01186.pdf for details See https://arxiv.org/pdf/1506.01186.pdf for details.
""" """
def __init__(self, args, optimizer): def __init__(self, args, optimizer):
......
...@@ -107,6 +107,8 @@ def parse_args_and_arch(parser, input_args=None, parse_known=False): ...@@ -107,6 +107,8 @@ def parse_args_and_arch(parser, input_args=None, parse_known=False):
args.update_freq = eval_str_list(args.update_freq, type=int) args.update_freq = eval_str_list(args.update_freq, type=int)
if hasattr(args, 'max_sentences_valid') and args.max_sentences_valid is None: if hasattr(args, 'max_sentences_valid') and args.max_sentences_valid is None:
args.max_sentences_valid = args.max_sentences args.max_sentences_valid = args.max_sentences
if getattr(args, 'memory_efficient_fp16', False):
args.fp16 = True
# Apply architecture configuration. # Apply architecture configuration.
if hasattr(args, 'arch'): if hasattr(args, 'arch'):
...@@ -128,7 +130,10 @@ def get_parser(desc, default_task='translation'): ...@@ -128,7 +130,10 @@ def get_parser(desc, default_task='translation'):
choices=['json', 'none', 'simple', 'tqdm']) choices=['json', 'none', 'simple', 'tqdm'])
parser.add_argument('--seed', default=1, type=int, metavar='N', parser.add_argument('--seed', default=1, type=int, metavar='N',
help='pseudo random number generator seed') help='pseudo random number generator seed')
parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA')
parser.add_argument('--fp16', action='store_true', help='use FP16') parser.add_argument('--fp16', action='store_true', help='use FP16')
parser.add_argument('--memory-efficient-fp16', action='store_true',
help='use a memory-efficient version of FP16 training; implies --fp16')
parser.add_argument('--fp16-init-scale', default=2**7, type=int, parser.add_argument('--fp16-init-scale', default=2**7, type=int,
help='default FP16 loss scale') help='default FP16 loss scale')
parser.add_argument('--fp16-scale-window', type=int, parser.add_argument('--fp16-scale-window', type=int,
...@@ -147,6 +152,8 @@ def get_parser(desc, default_task='translation'): ...@@ -147,6 +152,8 @@ def get_parser(desc, default_task='translation'):
def add_dataset_args(parser, train=False, gen=False): def add_dataset_args(parser, train=False, gen=False):
group = parser.add_argument_group('Dataset and data loading') group = parser.add_argument_group('Dataset and data loading')
# fmt: off # fmt: off
group.add_argument('--num-workers', default=0, type=int, metavar='N',
help='how many subprocesses to use for data loading')
group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true', group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true',
help='ignore too long or too short lines in valid and test set') help='ignore too long or too short lines in valid and test set')
group.add_argument('--max-tokens', type=int, metavar='N', group.add_argument('--max-tokens', type=int, metavar='N',
...@@ -178,7 +185,7 @@ def add_distributed_training_args(parser): ...@@ -178,7 +185,7 @@ def add_distributed_training_args(parser):
group = parser.add_argument_group('Distributed training') group = parser.add_argument_group('Distributed training')
# fmt: off # fmt: off
group.add_argument('--distributed-world-size', type=int, metavar='N', group.add_argument('--distributed-world-size', type=int, metavar='N',
default=torch.cuda.device_count(), default=max(1, torch.cuda.device_count()),
help='total number of GPUs across all nodes (default: all visible GPUs)') help='total number of GPUs across all nodes (default: all visible GPUs)')
group.add_argument('--distributed-rank', default=0, type=int, group.add_argument('--distributed-rank', default=0, type=int,
help='rank of the current worker') help='rank of the current worker')
...@@ -189,7 +196,7 @@ def add_distributed_training_args(parser): ...@@ -189,7 +196,7 @@ def add_distributed_training_args(parser):
'establish initial connetion') 'establish initial connetion')
group.add_argument('--distributed-port', default=-1, type=int, group.add_argument('--distributed-port', default=-1, type=int,
help='port number (not required if using --distributed-init-method)') help='port number (not required if using --distributed-init-method)')
group.add_argument('--device-id', default=0, type=int, group.add_argument('--device-id', '--local_rank', default=0, type=int,
help='which GPU to use (usually configured automatically)') help='which GPU to use (usually configured automatically)')
group.add_argument('--ddp-backend', default='c10d', type=str, group.add_argument('--ddp-backend', default='c10d', type=str,
choices=['c10d', 'no_c10d'], choices=['c10d', 'no_c10d'],
...@@ -197,8 +204,8 @@ def add_distributed_training_args(parser): ...@@ -197,8 +204,8 @@ def add_distributed_training_args(parser):
group.add_argument('--bucket-cap-mb', default=150, type=int, metavar='MB', group.add_argument('--bucket-cap-mb', default=150, type=int, metavar='MB',
help='bucket size for reduction') help='bucket size for reduction')
group.add_argument('--fix-batches-to-gpus', action='store_true', group.add_argument('--fix-batches-to-gpus', action='store_true',
help='Don\'t shuffle batches between GPUs, this reduces overall ' help='don\'t shuffle batches between GPUs; this reduces overall '
'randomness and may affect precision but avoids the cost of' 'randomness and may affect precision but avoids the cost of '
're-reading the data') 're-reading the data')
# fmt: on # fmt: on
return group return group
...@@ -263,7 +270,9 @@ def add_checkpoint_args(parser): ...@@ -263,7 +270,9 @@ def add_checkpoint_args(parser):
group.add_argument('--save-interval-updates', type=int, default=0, metavar='N', group.add_argument('--save-interval-updates', type=int, default=0, metavar='N',
help='save a checkpoint (and validate) every N updates') help='save a checkpoint (and validate) every N updates')
group.add_argument('--keep-interval-updates', type=int, default=-1, metavar='N', group.add_argument('--keep-interval-updates', type=int, default=-1, metavar='N',
help='keep last N checkpoints saved with --save-interval-updates') help='keep the last N checkpoints saved with --save-interval-updates')
group.add_argument('--keep-last-epochs', type=int, default=-1, metavar='N',
help='keep last N epoch checkpoints')
group.add_argument('--no-save', action='store_true', group.add_argument('--no-save', action='store_true',
help='don\'t save models or checkpoints') help='don\'t save models or checkpoints')
group.add_argument('--no-epoch-checkpoints', action='store_true', group.add_argument('--no-epoch-checkpoints', action='store_true',
...@@ -280,11 +289,11 @@ def add_common_eval_args(group): ...@@ -280,11 +289,11 @@ def add_common_eval_args(group):
help='path(s) to model file(s), colon separated') help='path(s) to model file(s), colon separated')
group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None, group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None,
help='remove BPE tokens before scoring') help='remove BPE tokens before scoring')
group.add_argument('--cpu', action='store_true', help='generate on CPU')
group.add_argument('--quiet', action='store_true', group.add_argument('--quiet', action='store_true',
help='only print final scores') help='only print final scores')
group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT', group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT',
help='a dictionary used to override model args at generation that were used during model training') help='a dictionary used to override model args at generation '
'that were used during model training')
# fmt: on # fmt: on
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment