Commit 7633129b authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Merge internal changes (#283)

Summary:
Pull Request resolved: https://github.com/pytorch/translate/pull/283

Pull Request resolved: https://github.com/pytorch/fairseq/pull/428

Differential Revision: D13564190

Pulled By: myleott

fbshipit-source-id: 3b62282d7069c288f5bdd1dd2c120788cee4abb5
parent 0cb87130
......@@ -9,7 +9,6 @@ import numpy as np
import torch
from . import data_utils, FairseqDataset
from typing import List
def collate(samples, pad_idx, eos_idx):
......@@ -53,8 +52,8 @@ class MonolingualDataset(FairseqDataset):
dataset (torch.utils.data.Dataset): dataset to wrap
sizes (List[int]): sentence lengths
vocab (~fairseq.data.Dictionary): vocabulary
shuffle (bool, optional): shuffle the elements before batching.
Default: ``True``
shuffle (bool, optional): shuffle the elements before batching
(default: True).
"""
def __init__(self, dataset, sizes, src_vocab, tgt_vocab, add_eos_for_other_targets, shuffle,
......@@ -66,8 +65,8 @@ class MonolingualDataset(FairseqDataset):
self.add_eos_for_other_targets = add_eos_for_other_targets
self.shuffle = shuffle
assert targets is None or all(
t in {'self', 'future', 'past'} for t in targets), "targets must be none or one of 'self', 'future', 'past'"
assert targets is None or all(t in {'self', 'future', 'past'} for t in targets), \
"targets must be none or one of 'self', 'future', 'past'"
if targets is not None and len(targets) == 0:
targets = None
self.targets = targets
......@@ -185,7 +184,7 @@ class MonolingualDataset(FairseqDataset):
@property
def supports_prefetch(self):
return self.dataset.supports_prefetch
return getattr(self.dataset, 'supports_prefetch', False)
def prefetch(self, indices):
self.dataset.prefetch(indices)
......@@ -245,11 +245,12 @@ class NoisingDataset(torch.utils.data.Dataset):
**kwargs
):
"""
Sets up a noising dataset which takes a src batch, generates
a noisy src using a noising config, and returns the
corresponding {noisy src, original src} batch
Wrap a :class:`~torch.utils.data.Dataset` and apply noise to the
samples based on the supplied noising configuration.
Args:
src_dataset: dataset which will be used to build self.src_dataset --
src_dataset (~torch.utils.data.Dataset): dataset to wrap.
to build self.src_dataset --
a LanguagePairDataset with src dataset as the source dataset and
None as the target dataset. Should NOT have padding so that
src_lengths are accurately calculated by language_pair_dataset
......@@ -257,26 +258,22 @@ class NoisingDataset(torch.utils.data.Dataset):
We use language_pair_dataset here to encapsulate the tgt_dataset
so we can re-use the LanguagePairDataset collater to format the
batches in the structure that SequenceGenerator expects.
src_dict: src dict
src_dict: src dictionary
seed: seed to use when generating random noise
noiser: a pre-initialized noiser. If this is None, a noiser will
be created using noising_class and kwargs.
noising_class: class to use when initializing noiser
kwargs: noising args for configuring noising to apply
Note that there is no equivalent argparse code for these args
anywhere in our top level train scripts yet. Integration is
still in progress. You can still, however, test out this dataset
functionality with the appropriate args as in the corresponding
unittest: test_noising_dataset.
src_dict (~fairseq.data.Dictionary): source dictionary
seed (int): seed to use when generating random noise
noiser (WordNoising): a pre-initialized :class:`WordNoising`
instance. If this is None, a new instance will be created using
*noising_class* and *kwargs*.
noising_class (class, optional): class to use to initialize a
default :class:`WordNoising` instance.
kwargs (dict, optional): arguments to initialize the default
:class:`WordNoising` instance given by *noiser*.
"""
self.src_dataset = src_dataset
self.src_dict = src_dict
self.seed = seed
self.noiser = noiser if noiser is not None else noising_class(
dictionary=src_dict, **kwargs,
)
self.seed = seed
def __getitem__(self, index):
"""
......
......@@ -13,13 +13,16 @@ from . import FairseqDataset
class RoundRobinZipDatasets(FairseqDataset):
"""Zip multiple FairseqDatasets together, repeating shorter datasets in a
round-robin fashion to match the length of the longest one.
"""Zip multiple :class:`~fairseq.data.FairseqDataset` instances together.
Shorter datasets are repeated in a round-robin fashion to match the length
of the longest one.
Args:
datasets: a dictionary of FairseqDatasets
eval_key: an optional key used at evaluation time that causes this
instance to pass-through batches from `datasets[eval_key]`.
datasets (Dict[~fairseq.data.FairseqDataset]): a dictionary of
:class:`~fairseq.data.FairseqDataset` instances.
eval_key (str, optional): a key used at evaluation time that causes
this instance to pass-through batches from *datasets[eval_key]*.
"""
def __init__(self, datasets, eval_key=None):
......@@ -107,3 +110,14 @@ class RoundRobinZipDatasets(FairseqDataset):
dataset.valid_size(self._map_index(key, index), max_positions[key])
for key, dataset in self.datasets.items()
)
@property
def supports_prefetch(self):
return all(
getattr(dataset, 'supports_prefetch', False)
for dataset in self.datasets.values()
)
def prefetch(self, indices):
for key, dataset in self.datasets.items():
dataset.prefetch([self._map_index(key, index) for index in indices])
......@@ -14,32 +14,32 @@ from . import FairseqDataset
class TokenBlockDataset(FairseqDataset):
"""Break a 1d tensor of tokens into blocks.
The blocks are fetched from the original tensor so no additional memory is allocated.
"""Break a Dataset of tokens into blocks.
Args:
tokens: 1d tensor of tokens to break into blocks
sizes: sentence lengths (required for 'complete' and 'eos')
block_size: maximum block size (ignored in 'eos' break mode)
break_mode: Mode used for breaking tokens. Values can be one of:
dataset (~torch.utils.data.Dataset): dataset to break into blocks
sizes (List[int]): sentence lengths (required for 'complete' and 'eos')
block_size (int): maximum block size (ignored in 'eos' break mode)
break_mode (str, optional): Mode used for breaking tokens. Values can
be one of:
- 'none': break tokens into equally sized blocks (up to block_size)
- 'complete': break tokens into blocks (up to block_size) such that
blocks contains complete sentences, although block_size may be
exceeded if some sentences exceed block_size
- 'eos': each block contains one sentence (block_size is ignored)
include_targets: return next tokens as targets
include_targets (bool, optional): return next tokens as targets
(default: False).
"""
def __init__(self, ds, block_size, pad, eos, break_mode=None, include_targets=False):
def __init__(self, dataset, sizes, block_size, pad, eos, break_mode=None, include_targets=False):
super().__init__()
self.dataset = ds
self.dataset = dataset
self.pad = pad
self.eos = eos
self.include_targets = include_targets
self.slice_indices = []
self.cache_index = {}
sizes = ds.sizes
assert len(dataset) == len(sizes)
if break_mode is None or break_mode == 'none':
total_size = sum(sizes)
......@@ -77,44 +77,66 @@ class TokenBlockDataset(FairseqDataset):
self.sizes = np.array([e - s for s, e in self.slice_indices])
def __getitem__(self, index):
s, e = self.cache_index[index]
# build index mapping block indices to the underlying dataset indices
self.block_to_dataset_index = []
ds_idx, ds_remaining = -1, 0
for to_consume in self.sizes:
if ds_remaining == 0:
ds_idx += 1
ds_remaining = sizes[ds_idx]
start_ds_idx = ds_idx
start_offset = sizes[ds_idx] - ds_remaining
while to_consume > ds_remaining:
to_consume -= ds_remaining
ds_idx += 1
ds_remaining = sizes[ds_idx]
ds_remaining -= to_consume
self.block_to_dataset_index.append((
start_ds_idx, # starting index in dataset
start_offset, # starting offset within starting index
ds_idx, # ending index in dataset
))
assert ds_remaining == 0
assert ds_idx == len(self.dataset) - 1
item = torch.from_numpy(self.cache[s:e]).long()
def __getitem__(self, index):
start_ds_idx, start_offset, end_ds_idx = self.block_to_dataset_index[index]
buffer = torch.cat([
self.dataset[idx] for idx in range(start_ds_idx, end_ds_idx + 1)
])
slice_s, slice_e = self.slice_indices[index]
length = slice_e - slice_s
s, e = start_offset, start_offset + length
item = buffer[s:e]
if self.include_targets:
# target is the sentence, for source, rotate item one token to the left (would start with eos)
# past target is rotated to the left by 2 (padded if its first)
# *target* is the original sentence (=item)
# *source* is rotated left by 1 (maybe left-padded with eos)
# *past_target* is rotated left by 2 (left-padded as needed)
if s == 0:
source = np.concatenate([[self.eos], self.cache[0:e - 1]])
past_target = np.concatenate([[self.pad, self.eos], self.cache[0:e - 2]])
source = torch.cat([item.new([self.eos]), buffer[0:e - 1]])
past_target = torch.cat([item.new([self.pad, self.eos]), buffer[0:e - 2]])
else:
source = self.cache[s - 1: e - 1]
source = buffer[s - 1:e - 1]
if s == 1:
past_target = np.concatenate([[self.eos], self.cache[0:e - 2]])
past_target = torch.cat([item.new([self.eos]), buffer[0:e - 2]])
else:
past_target = self.cache[s - 2:e - 2]
past_target = buffer[s - 2:e - 2]
return torch.from_numpy(source).long(), item, torch.from_numpy(past_target).long()
return source, item, past_target
return item
def __len__(self):
return len(self.slice_indices)
def prefetch(self, indices):
indices.sort()
total_size = 0
for idx in indices:
s, e = self.slice_indices[idx]
total_size += e - s
self.cache = np.empty(total_size, dtype=np.int32)
start = 0
for idx in indices:
s, e = self.slice_indices[idx]
self.dataset.read_into(s, self.cache[start:start + e - s])
self.cache_index[idx] = (start, start + e - s)
start += e - s
@property
def supports_prefetch(self):
return True
return getattr(self.dataset, 'supports_prefetch', False)
def prefetch(self, indices):
self.dataset.prefetch({
ds_idx
for index in indices
for start_ds_idx, _, end_ds_idx in [self.block_to_dataset_index[index]]
for ds_idx in range(start_ds_idx, end_ds_idx + 1)
})
......@@ -11,7 +11,7 @@ from . import FairseqDataset
class TransformEosDataset(FairseqDataset):
"""A dataset wrapper that appends/prepends/strips EOS.
"""A :class:`~fairseq.data.FairseqDataset` wrapper that appends/prepends/strips EOS.
Note that the transformation is applied in :func:`collater`.
......@@ -111,7 +111,7 @@ class TransformEosDataset(FairseqDataset):
@property
def supports_prefetch(self):
return self.dataset.supports_prefetch()
return getattr(self.dataset, 'supports_prefetch', False)
def prefetch(self, indices):
return self.dataset.prefetch(indices)
......@@ -6,7 +6,9 @@
# can be found in the PATENTS file in the same directory.
from collections import namedtuple
import os
import pickle
import subprocess
import torch
from torch import nn
......@@ -42,6 +44,38 @@ else:
import torch.distributed as dist_no_c10d
def infer_init_method(args):
if args.distributed_init_method is not None:
return
# support torch.distributed.launch
if all(key in os.environ for key in [
'MASTER_ADDR', 'MASTER_PORT', 'WORLD_SIZE', 'RANK'
]):
args.distributed_init_method = 'tcp://{addr}:{port}'.format(
addr=os.environ['MASTER_ADDR'],
port=os.environ['MASTER_PORT'],
)
args.distributed_world_size = int(os.environ['WORLD_SIZE'])
args.distributed_rank = int(os.environ['RANK'])
# we can determine the init method automatically for Slurm
elif args.distributed_port > 0:
node_list = os.environ.get('SLURM_JOB_NODELIST')
if node_list is not None:
try:
hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', node_list])
args.distributed_init_method = 'tcp://{host}:{port}'.format(
host=hostnames.split()[0].decode('utf-8'),
port=args.distributed_port)
args.distributed_rank = int(os.environ.get('SLURM_PROCID'))
args.device_id = int(os.environ.get('SLURM_LOCALID'))
except subprocess.CalledProcessError as e: # scontrol failed
raise e
except FileNotFoundError: # Slurm is not installed
pass
def distributed_init(args):
if args.distributed_world_size == 1:
raise ValueError('Cannot initialize distributed with distributed_world_size=1')
......@@ -158,7 +192,7 @@ def all_gather_list(data, group=None, max_size=16384):
pickle.loads(bytes(out_buffer[2:size+2].tolist()))
)
return result
except pickle.UnpicklingError as e:
except pickle.UnpicklingError:
raise Exception(
'Unable to unpickle data from other workers. all_gather_list requires all '
'workers to enter the function together, so this error usually indicates '
......@@ -167,4 +201,3 @@ def all_gather_list(data, group=None, max_size=16384):
'in your training script that can cause one worker to finish an epoch '
'while other workers are still iterating over their portions of the data.'
)
......@@ -12,7 +12,7 @@ computation (e.g., AdaptiveSoftmax) and which therefore do not work with the
c10d version of DDP.
This version also supports the *accumulate_grads* feature, which allows faster
training with --update-freq.
training with `--update-freq`.
"""
import copy
......@@ -27,18 +27,18 @@ from . import distributed_utils
class LegacyDistributedDataParallel(nn.Module):
"""Implements distributed data parallelism at the module level.
A simplified version of torch.nn.parallel.DistributedDataParallel.
This version uses a c10d process group for communication and does
not broadcast buffers.
A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`.
This version uses a c10d process group for communication and does not
broadcast buffers.
Args:
module: module to be parallelized
world_size: number of parallel workers
module (~torch.nn.Module): module to be parallelized
world_size (int): number of parallel workers
process_group (optional): the c10d process group to be used for
distributed data all-reduction. If None, the default process group
will be used.
buffer_size: number of elements to buffer before performing all-reduce
(default: 256M).
buffer_size (int, optional): number of elements to buffer before
performing all-reduce (default: 256M).
"""
def __init__(self, module, world_size, process_group=None, buffer_size=2**28):
......
......@@ -179,10 +179,8 @@ class FConvEncoder(FairseqEncoder):
connections are added between layers when ``residual=1`` (which is
the default behavior).
dropout (float, optional): dropout to be applied before each conv layer
normalization_constant (float, optional): multiplies the result of the
residual block by sqrt(value)
left_pad (bool, optional): whether the input is left-padded. Default:
``True``
left_pad (bool, optional): whether the input is left-padded
(default: True).
"""
def __init__(
......@@ -215,7 +213,7 @@ class FConvEncoder(FairseqEncoder):
self.residuals = []
layer_in_channels = [in_channels]
for i, (out_channels, kernel_size, residual) in enumerate(convolutions):
for _, (out_channels, kernel_size, residual) in enumerate(convolutions):
if residual == 0:
residual_dim = out_channels
else:
......
......@@ -524,6 +524,7 @@ def base_architecture(args):
args.pretrained_checkpoint = getattr(args, 'pretrained_checkpoint', '')
args.pretrained = getattr(args, 'pretrained', 'False')
@register_model_architecture('fconv_self_att', 'fconv_self_att_wp')
def fconv_self_att_wp(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
......
......@@ -196,7 +196,6 @@ class LSTMEncoder(FairseqEncoder):
if bidirectional:
self.output_units *= 2
def forward(self, src_tokens, src_lengths):
if self.left_pad:
# convert left-padding to right-padding
......@@ -235,7 +234,8 @@ class LSTMEncoder(FairseqEncoder):
if self.bidirectional:
def combine_bidir(outs):
return outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous().view(self.num_layers, bsz, -1)
out = outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous()
return out.view(self.num_layers, bsz, -1)
final_hiddens = combine_bidir(final_hiddens)
final_cells = combine_bidir(final_cells)
......@@ -340,7 +340,6 @@ class LSTMDecoder(FairseqIncrementalDecoder):
elif not self.share_input_output_embed:
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None):
encoder_out = encoder_out_dict['encoder_out']
encoder_padding_mask = encoder_out_dict['encoder_padding_mask']
......@@ -504,6 +503,7 @@ def base_architecture(args):
args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,50000,200000')
@register_model_architecture('lstm', 'lstm_wiseman_iwslt_de_en')
def lstm_wiseman_iwslt_de_en(args):
args.dropout = getattr(args, 'dropout', 0.1)
......
......@@ -219,7 +219,7 @@ class TransformerLanguageModel(FairseqLanguageModel):
# make sure all arguments are present in older models
base_lm_architecture(args)
if hasattr(args, 'no_tie_adaptive_proj') and args.no_tie_adaptive_proj == False:
if hasattr(args, 'no_tie_adaptive_proj') and args.no_tie_adaptive_proj is False:
# backward compatibility
args.tie_adaptive_proj = True
......@@ -229,15 +229,17 @@ class TransformerLanguageModel(FairseqLanguageModel):
args.max_target_positions = args.tokens_per_sample
if args.character_embeddings:
embed_tokens = CharacterTokenEmbedder(task.dictionary, eval(args.character_filters),
args.character_embedding_dim,
args.decoder_embed_dim,
args.char_embedder_highway_layers,
)
embed_tokens = CharacterTokenEmbedder(
task.dictionary, eval(args.character_filters),
args.character_embedding_dim, args.decoder_embed_dim,
args.char_embedder_highway_layers,
)
elif args.adaptive_input:
embed_tokens = AdaptiveInput(len(task.dictionary), task.dictionary.pad(), args.decoder_input_dim,
args.adaptive_input_factor, args.decoder_embed_dim,
options.eval_str_list(args.adaptive_input_cutoff, type=int))
embed_tokens = AdaptiveInput(
len(task.dictionary), task.dictionary.pad(), args.decoder_input_dim,
args.adaptive_input_factor, args.decoder_embed_dim,
options.eval_str_list(args.adaptive_input_cutoff, type=int),
)
else:
embed_tokens = Embedding(len(task.dictionary), args.decoder_input_dim, task.dictionary.pad())
......@@ -248,7 +250,9 @@ class TransformerLanguageModel(FairseqLanguageModel):
args.adaptive_softmax_cutoff, args.adaptive_input_cutoff)
assert args.decoder_input_dim == args.decoder_output_dim
decoder = TransformerDecoder(args, task.output_dictionary, embed_tokens, no_encoder_attn=True, final_norm=False)
decoder = TransformerDecoder(
args, task.output_dictionary, embed_tokens, no_encoder_attn=True, final_norm=False,
)
return TransformerLanguageModel(decoder)
......@@ -261,8 +265,8 @@ class TransformerEncoder(FairseqEncoder):
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): encoding dictionary
embed_tokens (torch.nn.Embedding): input embedding
left_pad (bool, optional): whether the input is left-padded. Default:
``True``
left_pad (bool, optional): whether the input is left-padded
(default: True).
"""
def __init__(self, args, dictionary, embed_tokens, left_pad=True):
......@@ -382,10 +386,12 @@ class TransformerDecoder(FairseqIncrementalDecoder):
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs.
Default: ``False``
left_pad (bool, optional): whether the input is left-padded. Default:
``False``
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
left_pad (bool, optional): whether the input is left-padded
(default: False).
final_norm (bool, optional): apply layer norm to the output of the
final decoder layer (default: True).
"""
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, left_pad=False, final_norm=True):
......@@ -634,8 +640,8 @@ class TransformerDecoderLayer(nn.Module):
Args:
args (argparse.Namespace): parsed command-line arguments
no_encoder_attn (bool, optional): whether to attend to encoder outputs.
Default: ``False``
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(self, args, no_encoder_attn=False):
......
......@@ -7,7 +7,6 @@
import torch
import torch.nn.functional as F
from torch import nn
from typing import List
......@@ -16,13 +15,13 @@ from typing import List
class AdaptiveInput(nn.Module):
def __init__(
self,
vocab_size: int,
padding_idx: int,
initial_dim: int,
factor: float,
output_dim: int,
cutoff: List[int],
self,
vocab_size: int,
padding_idx: int,
initial_dim: int,
factor: float,
output_dim: int,
cutoff: List[int],
):
super().__init__()
......
......@@ -113,8 +113,9 @@ class AdaptiveSoftmax(nn.Module):
m = nn.Sequential(
proj,
nn.Dropout(self.dropout),
nn.Linear(dim, self.cutoff[i + 1] - self.cutoff[i], bias=False) \
if tied_emb is None else TiedLinear(tied_emb, transpose=False)
nn.Linear(
dim, self.cutoff[i + 1] - self.cutoff[i], bias=False,
) if tied_emb is None else TiedLinear(tied_emb, transpose=False),
)
self.tail.append(m)
......
......@@ -9,7 +9,7 @@ import importlib
import os
from .fairseq_optimizer import FairseqOptimizer
from .fp16_optimizer import FP16Optimizer
from .fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer
OPTIMIZER_REGISTRY = {}
......
......@@ -70,10 +70,11 @@ class FairseqOptimizer(object):
group.update(optimizer_overrides)
def backward(self, loss):
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves."""
loss.backward()
def multiply_grads(self, c):
"""Multiplies grads by a constant ``c``."""
"""Multiplies grads by a constant *c*."""
for p in self.params:
if p.grad is not None:
p.grad.data.mul_(c)
......
......@@ -45,6 +45,164 @@ class DynamicLossScaler(object):
return False
class FP16Optimizer(optim.FairseqOptimizer):
"""
Wrap an *optimizer* to support FP16 (mixed precision) training.
"""
def __init__(self, args, params, fp32_optimizer, fp32_params):
super().__init__(args, params)
self.fp32_optimizer = fp32_optimizer
self.fp32_params = fp32_params
if getattr(args, 'fp16_scale_window', None) is None:
if len(args.update_freq) > 1:
raise ValueError(
'--fp16-scale-window must be given explicitly when using a '
'custom --update-freq schedule'
)
scale_window = 2**14 / args.distributed_world_size / args.update_freq[0]
else:
scale_window = args.fp16_scale_window
self.scaler = DynamicLossScaler(
init_scale=args.fp16_init_scale,
scale_window=scale_window,
tolerance=args.fp16_scale_tolerance,
)
@classmethod
def build_optimizer(cls, args, params):
"""
Args:
args (argparse.Namespace): fairseq args
params (iterable): iterable of parameters to optimize
"""
# create FP32 copy of parameters and grads
total_param_size = sum(p.data.numel() for p in params)
fp32_params = params[0].new(0).float().new(total_param_size)
offset = 0
for p in params:
numel = p.data.numel()
fp32_params[offset:offset+numel].copy_(p.data.view(-1))
offset += numel
fp32_params = torch.nn.Parameter(fp32_params)
fp32_params.grad = fp32_params.data.new(total_param_size)
fp32_optimizer = optim.build_optimizer(args, [fp32_params])
return cls(args, params, fp32_optimizer, fp32_params)
@property
def optimizer(self):
return self.fp32_optimizer.optimizer
@property
def optimizer_config(self):
return self.fp32_optimizer.optimizer_config
def get_lr(self):
return self.fp32_optimizer.get_lr()
def set_lr(self, lr):
self.fp32_optimizer.set_lr(lr)
def state_dict(self):
"""Return the optimizer's state dict."""
state_dict = self.fp32_optimizer.state_dict()
state_dict['loss_scale'] = self.scaler.loss_scale
return state_dict
def load_state_dict(self, state_dict, optimizer_overrides=None):
"""Load an optimizer state dict.
In general we should prefer the configuration of the existing optimizer
instance (e.g., learning rate) over that found in the state_dict. This
allows us to resume training from a checkpoint using a new set of
optimizer args.
"""
if 'loss_scale' in state_dict:
self.scaler.loss_scale = state_dict['loss_scale']
self.fp32_optimizer.load_state_dict(state_dict, optimizer_overrides)
def backward(self, loss):
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves.
Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this
function additionally dynamically scales the loss to avoid gradient
underflow.
"""
loss = loss * self.scaler.loss_scale
loss.backward()
self._needs_sync = True
def _sync_fp16_grads_to_fp32(self, multiply_grads=1.):
if self._needs_sync:
# copy FP16 grads to FP32
offset = 0
for p in self.params:
if not p.requires_grad:
continue
grad_data = p.grad.data if p.grad is not None else p.data.new_zeros(p.data.shape)
numel = grad_data.numel()
self.fp32_params.grad.data[offset:offset+numel].copy_(grad_data.view(-1))
offset += numel
# correct for dynamic loss scaler
self.fp32_params.grad.data.mul_(multiply_grads / self.scaler.loss_scale)
self._needs_sync = False
def multiply_grads(self, c):
"""Multiplies grads by a constant ``c``."""
if self._needs_sync:
self._sync_fp16_grads_to_fp32(c)
else:
self.fp32_params.grad.data.mul_(c)
def clip_grad_norm(self, max_norm):
"""Clips gradient norm and updates dynamic loss scaler."""
self._sync_fp16_grads_to_fp32()
grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, max_norm)
# detect overflow and adjust loss scale
overflow = DynamicLossScaler.has_overflow(grad_norm)
self.scaler.update_scale(overflow)
if overflow:
if self.scaler.loss_scale <= self.args.min_loss_scale:
# Use FloatingPointError as an uncommon error that parent
# functions can safely catch to stop training.
raise FloatingPointError((
'Minimum loss scale reached ({}). Your loss is probably exploding. '
'Try lowering the learning rate, using gradient clipping or '
'increasing the batch size.'
).format(self.args.min_loss_scale))
raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale))
return grad_norm
def step(self, closure=None):
"""Performs a single optimization step."""
self._sync_fp16_grads_to_fp32()
self.fp32_optimizer.step(closure)
# copy FP32 params back into FP16 model
offset = 0
for p in self.params:
if not p.requires_grad:
continue
numel = p.data.numel()
p.data.copy_(self.fp32_params.data[offset:offset+numel].view_as(p.data))
offset += numel
def zero_grad(self):
"""Clears the gradients of all optimized parameters."""
self.fp32_optimizer.zero_grad()
for p in self.params:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
self._needs_sync = False
class ConvertToFP32(object):
"""
A wrapper around a list of params that will convert them to FP32 on the
......@@ -94,14 +252,13 @@ class ConvertToFP32(object):
raise StopIteration
class FP16Optimizer(optim.FairseqOptimizer):
class MemoryEfficientFP16Optimizer(optim.FairseqOptimizer):
"""
Wrap an *optimizer* to support FP16 (mixed precision) training.
Args:
args (argparse.Namespace): fairseq args
params (iterable): iterable of parameters to optimize
optimizer (~fairseq.optim.FairseqOptimizer): optimizer to wrap
Compared to :class:`fairseq.optim.FP16Optimizer`, this version uses less
memory by copying between FP16 and FP32 parameters on-the-fly. The tradeoff
is reduced optimization speed, which can be mitigated with `--update-freq`.
"""
def __init__(self, args, params, optimizer):
......@@ -124,10 +281,15 @@ class FP16Optimizer(optim.FairseqOptimizer):
tolerance=args.fp16_scale_tolerance,
)
@staticmethod
def build_optimizer(args, params):
@classmethod
def build_optimizer(cls, args, params):
"""
Args:
args (argparse.Namespace): fairseq args
params (iterable): iterable of parameters to optimize
"""
fp16_optimizer = optim.build_optimizer(args, params)
return FP16Optimizer(args, params, fp16_optimizer)
return cls(args, params, fp16_optimizer)
@property
def optimizer(self):
......@@ -164,6 +326,12 @@ class FP16Optimizer(optim.FairseqOptimizer):
ConvertToFP32.unwrap_optimizer_(self.wrapped_optimizer.optimizer)
def backward(self, loss):
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves.
Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this
function additionally dynamically scales the loss to avoid gradient
underflow.
"""
loss = loss * self.scaler.loss_scale
loss.backward()
self._grads_are_scaled = True
......@@ -178,7 +346,7 @@ class FP16Optimizer(optim.FairseqOptimizer):
assert multiply_grads == 1.
def multiply_grads(self, c):
"""Multiplies grads by a constant ``c``."""
"""Multiplies grads by a constant *c*."""
if self._grads_are_scaled:
self._unscale_grads(c)
else:
......
......@@ -13,18 +13,25 @@ from . import FairseqLRScheduler, register_lr_scheduler
@register_lr_scheduler('cosine')
class CosineSchedule(FairseqLRScheduler):
"""Assign LR based on a cyclical schedule that follows the cosine function.
See https://arxiv.org/pdf/1608.03983.pdf for details
See https://arxiv.org/pdf/1608.03983.pdf for details.
We also support a warmup phase where we linearly increase the learning rate
from some initial learning rate (`--warmup-init-lr`) until the configured
learning rate (`--lr`).
During warmup:
from some initial learning rate (``--warmup-init-lr``) until the configured
learning rate (``--lr``).
During warmup::
lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
lr = lrs[update_num]
After warmup:
After warmup::
lr = lr_min + 0.5*(lr_max - lr_min)*(1 + cos(t_curr / t_i))
where
t_curr is current percentage of updates within the current period range
t_i is the current period range, which is scaled by t_mul after every iteration
where ``t_curr`` is current percentage of updates within the current period
range and ``t_i`` is the current period range, which is scaled by ``t_mul``
after every iteration.
"""
def __init__(self, args, optimizer):
......@@ -39,7 +46,7 @@ class CosineSchedule(FairseqLRScheduler):
if args.warmup_init_lr < 0:
args.warmup_init_lr = args.lr[0]
self.min_lr = args.lr[0]
self.min_lr = args.lr[0]
self.max_lr = args.max_lr
assert self.max_lr > self.min_lr, 'max_lr must be more than lr'
......@@ -98,7 +105,7 @@ class CosineSchedule(FairseqLRScheduler):
t_curr = curr_updates - (self.period * i)
lr_shrink = self.lr_shrink ** i
min_lr = self.min_lr * lr_shrink
min_lr = self.min_lr * lr_shrink
max_lr = self.max_lr * lr_shrink
self.lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * t_curr / t_i))
......
......@@ -13,22 +13,19 @@ class InverseSquareRootSchedule(FairseqLRScheduler):
"""Decay the LR based on the inverse square root of the update number.
We also support a warmup phase where we linearly increase the learning rate
from some initial learning rate (`--warmup-init-lr`) until the configured
learning rate (`--lr`). Thereafter we decay proportional to the number of
from some initial learning rate (``--warmup-init-lr``) until the configured
learning rate (``--lr``). Thereafter we decay proportional to the number of
updates, with a decay factor set to align with the configured learning rate.
During warmup:
During warmup::
lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
lr = lrs[update_num]
After warmup:
lr = decay_factor / sqrt(update_num)
where
After warmup::
decay_factor = args.lr * sqrt(args.warmup_updates)
lr = decay_factor / sqrt(update_num)
"""
def __init__(self, args, optimizer):
......
......@@ -14,8 +14,7 @@ from . import FairseqLRScheduler, register_lr_scheduler
class TriangularSchedule(FairseqLRScheduler):
"""Assign LR based on a triangular cyclical schedule.
See https://arxiv.org/pdf/1506.01186.pdf for details
See https://arxiv.org/pdf/1506.01186.pdf for details.
"""
def __init__(self, args, optimizer):
......
......@@ -107,6 +107,8 @@ def parse_args_and_arch(parser, input_args=None, parse_known=False):
args.update_freq = eval_str_list(args.update_freq, type=int)
if hasattr(args, 'max_sentences_valid') and args.max_sentences_valid is None:
args.max_sentences_valid = args.max_sentences
if getattr(args, 'memory_efficient_fp16', False):
args.fp16 = True
# Apply architecture configuration.
if hasattr(args, 'arch'):
......@@ -128,7 +130,10 @@ def get_parser(desc, default_task='translation'):
choices=['json', 'none', 'simple', 'tqdm'])
parser.add_argument('--seed', default=1, type=int, metavar='N',
help='pseudo random number generator seed')
parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA')
parser.add_argument('--fp16', action='store_true', help='use FP16')
parser.add_argument('--memory-efficient-fp16', action='store_true',
help='use a memory-efficient version of FP16 training; implies --fp16')
parser.add_argument('--fp16-init-scale', default=2**7, type=int,
help='default FP16 loss scale')
parser.add_argument('--fp16-scale-window', type=int,
......@@ -147,6 +152,8 @@ def get_parser(desc, default_task='translation'):
def add_dataset_args(parser, train=False, gen=False):
group = parser.add_argument_group('Dataset and data loading')
# fmt: off
group.add_argument('--num-workers', default=0, type=int, metavar='N',
help='how many subprocesses to use for data loading')
group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true',
help='ignore too long or too short lines in valid and test set')
group.add_argument('--max-tokens', type=int, metavar='N',
......@@ -178,7 +185,7 @@ def add_distributed_training_args(parser):
group = parser.add_argument_group('Distributed training')
# fmt: off
group.add_argument('--distributed-world-size', type=int, metavar='N',
default=torch.cuda.device_count(),
default=max(1, torch.cuda.device_count()),
help='total number of GPUs across all nodes (default: all visible GPUs)')
group.add_argument('--distributed-rank', default=0, type=int,
help='rank of the current worker')
......@@ -189,7 +196,7 @@ def add_distributed_training_args(parser):
'establish initial connetion')
group.add_argument('--distributed-port', default=-1, type=int,
help='port number (not required if using --distributed-init-method)')
group.add_argument('--device-id', default=0, type=int,
group.add_argument('--device-id', '--local_rank', default=0, type=int,
help='which GPU to use (usually configured automatically)')
group.add_argument('--ddp-backend', default='c10d', type=str,
choices=['c10d', 'no_c10d'],
......@@ -197,8 +204,8 @@ def add_distributed_training_args(parser):
group.add_argument('--bucket-cap-mb', default=150, type=int, metavar='MB',
help='bucket size for reduction')
group.add_argument('--fix-batches-to-gpus', action='store_true',
help='Don\'t shuffle batches between GPUs, this reduces overall '
'randomness and may affect precision but avoids the cost of'
help='don\'t shuffle batches between GPUs; this reduces overall '
'randomness and may affect precision but avoids the cost of '
're-reading the data')
# fmt: on
return group
......@@ -263,7 +270,9 @@ def add_checkpoint_args(parser):
group.add_argument('--save-interval-updates', type=int, default=0, metavar='N',
help='save a checkpoint (and validate) every N updates')
group.add_argument('--keep-interval-updates', type=int, default=-1, metavar='N',
help='keep last N checkpoints saved with --save-interval-updates')
help='keep the last N checkpoints saved with --save-interval-updates')
group.add_argument('--keep-last-epochs', type=int, default=-1, metavar='N',
help='keep last N epoch checkpoints')
group.add_argument('--no-save', action='store_true',
help='don\'t save models or checkpoints')
group.add_argument('--no-epoch-checkpoints', action='store_true',
......@@ -280,11 +289,11 @@ def add_common_eval_args(group):
help='path(s) to model file(s), colon separated')
group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None,
help='remove BPE tokens before scoring')
group.add_argument('--cpu', action='store_true', help='generate on CPU')
group.add_argument('--quiet', action='store_true',
help='only print final scores')
group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT',
help='a dictionary used to override model args at generation that were used during model training')
help='a dictionary used to override model args at generation '
'that were used during model training')
# fmt: on
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment