Unverified Commit 66415206 authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

fairseq-py goes distributed (#106)

This PR includes breaking API changes to modularize fairseq-py and adds support for distributed training across multiple nodes.

Changes:
- c7033ef: add support for distributed training! See updated README for usage.
- e016299: modularize fairseq-py, adding support for register_model, register_criterion, register_optimizer, etc.
- 154e440: update LSTM implementation to use PackedSequence objects in the encoder, better following best practices and improving perf
- 90c2973 and 1da6265: improve unit test coverage
parent 7e86e30c
......@@ -4,41 +4,106 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from . import FairseqEncoder, FairseqIncrementalDecoder, FairseqModel
from fairseq import utils
from fairseq.data import LanguagePairDataset
from . import FairseqEncoder, FairseqIncrementalDecoder, FairseqModel, register_model, register_model_architecture
@register_model('lstm')
class LSTMModel(FairseqModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument('--dropout', default=0.1, type=float, metavar='D',
help='dropout probability')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension')
parser.add_argument('--encoder-layers', type=int, metavar='N',
help='number of encoder layers')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-layers', type=int, metavar='N',
help='number of decoder layers')
parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N',
help='decoder output embedding dimension')
parser.add_argument('--decoder-attention', type=str, metavar='BOOL',
help='decoder attention')
# Granular dropout settings (if not specified these default to --dropout)
parser.add_argument('--encoder-dropout-in', type=float, metavar='D',
help='dropout probability for encoder input embedding')
parser.add_argument('--encoder-dropout-out', type=float, metavar='D',
help='dropout probability for encoder output')
parser.add_argument('--decoder-dropout-in', type=float, metavar='D',
help='dropout probability for decoder input embedding')
parser.add_argument('--decoder-dropout-out', type=float, metavar='D',
help='dropout probability for decoder output')
@classmethod
def build_model(cls, args, src_dict, dst_dict):
"""Build a new model instance."""
encoder = LSTMEncoder(
src_dict,
embed_dim=args.encoder_embed_dim,
num_layers=args.encoder_layers,
dropout_in=args.encoder_dropout_in,
dropout_out=args.encoder_dropout_out,
)
decoder = LSTMDecoder(
dst_dict,
encoder_embed_dim=args.encoder_embed_dim,
embed_dim=args.decoder_embed_dim,
out_embed_dim=args.decoder_out_embed_dim,
num_layers=args.decoder_layers,
attention=bool(args.decoder_attention),
dropout_in=args.decoder_dropout_in,
dropout_out=args.decoder_dropout_out,
)
return cls(encoder, decoder)
class LSTMEncoder(FairseqEncoder):
"""LSTM encoder."""
def __init__(self, dictionary, embed_dim=512, num_layers=1, dropout_in=0.1,
dropout_out=0.1):
super().__init__(dictionary)
self.num_layers = num_layers
self.dropout_in = dropout_in
self.dropout_out = dropout_out
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
self.padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx)
self.lstm = LSTM(
input_size=embed_dim,
hidden_size=embed_dim,
num_layers=num_layers,
dropout=self.dropout_out,
bidirectional=False,
)
def forward(self, src_tokens, src_lengths):
if LanguagePairDataset.LEFT_PAD_SOURCE:
# convert left-padding to right-padding
src_tokens.data = utils.convert_padding_direction(
src_tokens.data,
src_lengths.data,
self.padding_idx,
left_to_right=True,
)
self.layers = nn.ModuleList([
LSTMCell(embed_dim, embed_dim)
for layer in range(num_layers)
])
def forward(self, src_tokens):
bsz, seqlen = src_tokens.size()
num_layers = len(self.layers)
# embed tokens
x = self.embed_tokens(src_tokens)
......@@ -48,27 +113,21 @@ class LSTMEncoder(FairseqEncoder):
# B x T x C -> T x B x C
x = x.transpose(0, 1)
final_hiddens, final_cells = [], []
outs = [x[j] for j in range(seqlen)]
for i, rnn in enumerate(self.layers):
hidden = Variable(x.data.new(bsz, embed_dim).zero_())
cell = Variable(x.data.new(bsz, embed_dim).zero_())
for j in range(seqlen):
# recurrent cell
hidden, cell = rnn(outs[j], (hidden, cell))
# store the most recent hidden state in outs, either to be used
# as the input for the next layer, or as the final output
outs[j] = F.dropout(hidden, p=self.dropout_out, training=self.training)
# pack embedded source tokens into a PackedSequence
packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.data.tolist())
# save the final hidden and cell states for every layer
final_hiddens.append(hidden)
final_cells.append(cell)
# apply LSTM
h0 = Variable(x.data.new(self.num_layers, bsz, embed_dim).zero_())
c0 = Variable(x.data.new(self.num_layers, bsz, embed_dim).zero_())
packed_outs, (final_hiddens, final_cells) = self.lstm(
packed_x,
(h0, c0),
)
# collect outputs across time steps
x = torch.cat(outs, dim=0).view(seqlen, bsz, embed_dim)
final_hiddens = torch.cat(final_hiddens, dim=0).view(num_layers, bsz, embed_dim)
final_cells = torch.cat(final_cells, dim=0).view(num_layers, bsz, embed_dim)
# unpack outputs and apply dropout
x, _ = nn.utils.rnn.pad_packed_sequence(packed_outs, padding_value=0.)
x = F.dropout(x, p=self.dropout_out, training=self.training)
assert list(x.size()) == [seqlen, bsz, embed_dim]
return x, final_hiddens, final_cells
......@@ -124,20 +183,20 @@ class LSTMDecoder(FairseqIncrementalDecoder):
self.additional_fc = Linear(embed_dim, out_embed_dim)
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
def forward(self, input_tokens, encoder_out):
def forward(self, prev_output_tokens, encoder_out):
if self._is_incremental_eval:
input_tokens = input_tokens[:, -1:]
return self._forward(input_tokens, encoder_out)
prev_output_tokens = prev_output_tokens[:, -1:]
return self._forward(prev_output_tokens, encoder_out)
def _forward(self, input_tokens, encoder_out):
bsz, seqlen = input_tokens.size()
def _forward(self, prev_output_tokens, encoder_out):
bsz, seqlen = prev_output_tokens.size()
# get outputs from encoder
encoder_outs, _, _ = encoder_out
srclen = encoder_outs.size(0)
# embed tokens
x = self.embed_tokens(input_tokens)
x = self.embed_tokens(prev_output_tokens)
x = F.dropout(x, p=self.dropout_in, training=self.training)
embed_dim = x.size(2)
......@@ -148,7 +207,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
prev_hiddens = self.get_incremental_state('prev_hiddens')
if not prev_hiddens:
# first time step, initialize previous states
prev_hiddens, prev_cells = self._init_prev_states(input_tokens, encoder_out)
prev_hiddens, prev_cells = self._init_prev_states(encoder_out)
input_feed = Variable(x.data.new(bsz, embed_dim).zero_())
else:
# previous states are cached
......@@ -225,7 +284,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
"""Maximum output length supported by the decoder."""
return int(1e5) # an arbitrary large number
def _init_prev_states(self, input_tokens, encoder_out):
def _init_prev_states(self, encoder_out):
_, encoder_hiddens, encoder_cells = encoder_out
num_layers = len(self.layers)
prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)]
......@@ -239,8 +298,16 @@ def Embedding(num_embeddings, embedding_dim, padding_idx):
return m
def LSTMCell(input_dim, hidden_dim, **kwargs):
m = nn.LSTMCell(input_dim, hidden_dim, **kwargs)
def LSTM(input_size, hidden_size, **kwargs):
m = nn.LSTM(input_size, hidden_size, **kwargs)
for name, param in m.named_parameters():
if 'weight' in name or 'bias' in name:
param.data.uniform_(-0.1, 0.1)
return m
def LSTMCell(input_size, hidden_size, **kwargs):
m = nn.LSTMCell(input_size, hidden_size, **kwargs)
for name, param in m.named_parameters():
if 'weight' in name or 'bias' in name:
param.data.uniform_(-0.1, 0.1)
......@@ -256,50 +323,8 @@ def Linear(in_features, out_features, bias=True, dropout=0):
return m
def get_archs():
return [
'lstm', 'lstm_wiseman_iwslt_de_en', 'lstm_luong_wmt_en_de',
]
def _check_arch(args):
"""Check that the specified architecture is valid and not ambiguous."""
if args.arch not in get_archs():
raise ValueError('Unknown LSTM model architecture: {}'.format(args.arch))
if args.arch != 'lstm':
# check that architecture is not ambiguous
for a in ['encoder_embed_dim', 'encoder_layers', 'decoder_embed_dim', 'decoder_layers',
'decoder_out_embed_dim']:
if hasattr(args, a):
raise ValueError('--{} cannot be combined with --arch={}'.format(a, args.arch))
def parse_arch(args):
_check_arch(args)
if args.arch == 'lstm_wiseman_iwslt_de_en':
args.encoder_embed_dim = 256
args.encoder_layers = 1
args.encoder_dropout_in = 0
args.encoder_dropout_out = 0
args.decoder_embed_dim = 256
args.decoder_layers = 1
args.decoder_out_embed_dim = 256
args.decoder_attention = True
args.decoder_dropout_in = 0
elif args.arch == 'lstm_luong_wmt_en_de':
args.encoder_embed_dim = 1000
args.encoder_layers = 4
args.encoder_dropout_out = 0
args.decoder_embed_dim = 1000
args.decoder_layers = 4
args.decoder_out_embed_dim = 1000
args.decoder_attention = True
args.decoder_dropout_out = 0
else:
assert args.arch == 'lstm'
# default architecture
@register_model_architecture('lstm', 'lstm')
def base_architecture(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
args.encoder_layers = getattr(args, 'encoder_layers', 1)
args.encoder_dropout_in = getattr(args, 'encoder_dropout_in', args.dropout)
......@@ -310,25 +335,30 @@ def parse_arch(args):
args.decoder_attention = getattr(args, 'decoder_attention', True)
args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', args.dropout)
args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout)
return args
def build_model(args, src_dict, dst_dict):
encoder = LSTMEncoder(
src_dict,
embed_dim=args.encoder_embed_dim,
num_layers=int(args.encoder_layers),
dropout_in=args.encoder_dropout_in,
dropout_out=args.encoder_dropout_out,
)
decoder = LSTMDecoder(
dst_dict,
encoder_embed_dim=args.encoder_embed_dim,
embed_dim=args.decoder_embed_dim,
out_embed_dim=args.decoder_out_embed_dim,
num_layers=int(args.decoder_layers),
attention=bool(args.decoder_attention),
dropout_in=args.decoder_dropout_in,
dropout_out=args.decoder_dropout_out,
)
return LSTMModel(encoder, decoder)
@register_model_architecture('lstm', 'lstm_wiseman_iwslt_de_en')
def lstm_wiseman_iwslt_de_en(args):
base_architecture(args)
args.encoder_embed_dim = 256
args.encoder_layers = 1
args.encoder_dropout_in = 0
args.encoder_dropout_out = 0
args.decoder_embed_dim = 256
args.decoder_layers = 1
args.decoder_out_embed_dim = 256
args.decoder_attention = True
args.decoder_dropout_in = 0
@register_model_architecture('lstm', 'lstm_luong_wmt_en_de')
def lstm_luong_wmt_en_de(args):
base_architecture(args)
args.encoder_embed_dim = 1000
args.encoder_layers = 4
args.encoder_dropout_out = 0
args.decoder_embed_dim = 1000
args.decoder_layers = 4
args.decoder_out_embed_dim = 1000
args.decoder_attention = True
args.decoder_dropout_out = 0
......@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
from .beamable_mm import BeamableMM
from .conv_tbc import ConvTBC
......
......@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch
import torch.nn as nn
......
......@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch
from torch.autograd import Function
......
......@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch
......
......@@ -4,12 +4,10 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
class LearnedPositionalEmbedding(nn.Embedding):
......
......@@ -4,9 +4,7 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch
import torch.nn.functional as F
from fairseq import utils
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import os
import signal
import threading
from torch import multiprocessing
class MultiprocessingEventLoop(object):
"""Start a multiprocessing event loop."""
def __init__(self, device_ids=None, multiprocessing_method='spawn'):
super().__init__()
self.device_ids = tuple(device_ids)
self.num_replicas = len(device_ids)
self.rank = None
self._mp = multiprocessing.get_context(multiprocessing_method)
self._start_error_handler()
self._start_multiprocessing()
def call_async(self, rank, action, **kwargs):
"""Asynchronously call a function in each child process.
Call a function named `action` on the rank'th process and return
a Future with the result.
"""
def result_generator():
yield self.return_pipes[rank].recv()
assert not self.return_pipes[rank].poll(), \
'return pipe must be consumed before calling another function'
self.input_pipes[rank].send((action, kwargs))
return Future(result_generator())
def stop(self, interrupt_children=False):
"""Stop multiprocessing."""
for rank in range(self.num_replicas):
self.input_pipes[rank].close()
self.return_pipes[rank].close()
if interrupt_children:
# send KeyboardInterrupt to children
os.kill(self.procs[rank].pid, signal.SIGINT)
else:
self.procs[rank].join()
self.error_queue.put((None, None)) # poison pill
def _start_error_handler(self):
"""Error handler to catch exceptions in child processes."""
# create a thread to listen for errors in the child processes
self.error_queue = self._mp.SimpleQueue()
error_thread = threading.Thread(target=self._error_listener,
daemon=True)
error_thread.start()
# create signal handler that executes in the main process/thread and
# handles errors from child processes
signal.signal(signal.SIGUSR1, self._signal_handler)
def _error_listener(self):
"""A thread that listens for errors in the child processes.
Errors are handled in a signal handler in the main thread.
"""
(rank, original_trace) = self.error_queue.get()
if rank is None: # poison pill, return
return
# requeue error and switch to main thread for handling the error
self.error_queue.put((rank, original_trace))
os.kill(os.getpid(), signal.SIGUSR1)
def _signal_handler(self, signal, frame):
"""Signal handler that handles errors from child processes.
This signal handler executes in the main/process thread.
"""
self.stop(interrupt_children=True)
(rank, original_trace) = self.error_queue.get()
msg = "\n\n-- Tracebacks above this line can probably be ignored --\n\n"
msg += original_trace
raise Exception(msg)
def _start_multiprocessing(self):
"""Create child processes to run async event loop.
Each process reads input from a Pipe, performs some computation,
and returns its output to another Pipe.
"""
# create child processes
input_pipes = []
return_pipes = []
procs = []
for rank, id in enumerate(self.device_ids):
recv_input_pipe, send_input_pipe = self._mp.Pipe(duplex=False)
recv_return_pipe, send_return_pipe = self._mp.Pipe(duplex=False)
proc = self._mp.Process(
target=self._process_event_loop,
args=(rank, id, recv_input_pipe, send_return_pipe),
daemon=True)
proc.start()
input_pipes.append(send_input_pipe)
return_pipes.append(recv_return_pipe)
procs.append(proc)
self.input_pipes = input_pipes
self.return_pipes = return_pipes
self.procs = procs
def _process_event_loop(self, rank, device_id, input_pipe, return_pipe):
"""Event loop that runs in each child process.
Event loop:
- take an action from the input pipe
- call the corresponding function in this process
- put the return value in the return pipe
Any exceptions are put in the error queue.
"""
self.rank = rank
try:
# event loop
while True:
action, kwargs = input_pipe.recv()
action_fn = getattr(self, action)
return_pipe.send(action_fn(rank, device_id, **kwargs))
except EOFError:
# input pipe was closed, do nothing
pass
except KeyboardInterrupt:
# killed by parent, do nothing
pass
except Exception:
# propagate exception from child to parent process, keeping
# original traceback
import traceback
self.error_queue.put((rank, traceback.format_exc()))
finally:
# cleanup pipes
input_pipe.close()
return_pipe.close()
class Future(object):
"""A wrapper around a Python generator, with syntactic sugar."""
def __init__(self, generator):
self.generator = generator
def gen(self):
return next(self.generator)
@staticmethod
def gen_list(gens):
return [g.gen() for g in gens]
@staticmethod
def gen_tuple_list(gens):
list = [g.gen() for g in gens]
return zip(*list)
......@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import multiprocessing
import os
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
"""
Train a network on multiple GPUs using multiprocessing.
"""
from itertools import cycle, islice
import math
import torch
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
from fairseq import nccl, utils
from fairseq.multiprocessing_event_loop import MultiprocessingEventLoop, Future
from fairseq.optim.nag import NAG
from fairseq.optim.adam import Adam
class MultiprocessingTrainer(MultiprocessingEventLoop):
"""Main class for multi-GPU training.
Each GPU has a full copy of the model and is assigned to its own Python
process. Gradients are accumulated with all-reduce and all model replicas
are updated synchronously after each batch.
The methods in this class are divided into synchronous functions, which
prepare and dispatch the input to each process, and asynchronous functions
(prefixed with `_async_`), which run on each process in parallel.
"""
OPTIMIZERS = ['adagrad', 'adam', 'nag', 'sgd']
def __init__(self, args, model, criterion, device_ids=None,
multiprocessing_method='spawn'):
if device_ids is None:
device_ids = tuple(range(torch.cuda.device_count()))
super().__init__(device_ids, multiprocessing_method)
if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported')
model = model.share_memory()
nccl_uid = nccl.get_unique_id()
self.criterion = criterion
Future.gen_list([
self.call_async(rank, '_async_init', args=args, model=model,
criterion=criterion, nccl_uid=nccl_uid)
for rank in range(self.num_replicas)
])
self._grads_initialized = False
def _async_init(self, rank, device_id, args, model, criterion, nccl_uid):
"""Initialize child processes."""
self.args = args
# set CUDA device
torch.cuda.set_device(device_id)
# initialize NCCL
nccl.initialize(self.num_replicas, nccl_uid, device_id)
# copy model and criterion to current device
self.model = model.cuda()
self.criterion = criterion.cuda()
# initialize optimizer and LR scheduler
self.args.lr = list(map(float, self.args.lr.split(',')))
self.optimizer = self._build_optimizer()
self.lr_scheduler = self._build_lr_scheduler()
self.loss = None
self._max_bsz_seen = 0
def _build_optimizer(self):
# When resuming training from a checkpoint, we load the old optimizer
# state that includes things like learning rate, momentum factor, etc.
# We use this dictionary to override values stored in the checkpoint,
# e.g., we might prefer the values specified on the command line.
self._override_optim_state = {}
if self.args.optimizer == 'adagrad':
self._override_optim_state = {
'lr': self.args.lr[0],
'weight_decay': self.args.weight_decay,
}
return torch.optim.Adagrad(self.model.parameters(), **self._override_optim_state)
elif self.args.optimizer == 'adam':
self._override_optim_state = {
'lr': self.args.lr[0],
'betas': eval(self.args.adam_betas),
'weight_decay': self.args.weight_decay,
}
return Adam(self.model.parameters(), **self._override_optim_state)
elif self.args.optimizer == 'nag':
self._override_optim_state = {
'lr': self.args.lr[0],
'momentum': self.args.momentum,
'weight_decay': self.args.weight_decay,
}
return NAG(self.model.parameters(), **self._override_optim_state)
elif self.args.optimizer == 'sgd':
self._override_optim_state = {
'lr': self.args.lr[0],
'momentum': self.args.momentum,
'weight_decay': self.args.weight_decay,
}
return torch.optim.SGD(self.model.parameters(), **self._override_optim_state)
else:
raise ValueError('Unknown optimizer: {}'.format(self.args.optimizer))
def _build_lr_scheduler(self):
if len(self.args.lr) > 1 or self.args.force_anneal > 0:
lrs = self.args.lr
def anneal(e):
if e < self.args.force_anneal:
# use fixed LR schedule
next_lr = lrs[min(e, len(lrs) - 1)]
else:
next_lr = lrs[-1] * self.args.lrshrink ** (e + 1 - self.args.force_anneal)
return next_lr / lrs[0] # correct for scaling from LambdaLR
lr_scheduler = LambdaLR(self.optimizer, anneal)
lr_scheduler.best = None
else:
# decay the LR by a factor every time the validation loss plateaus
lr_scheduler = ReduceLROnPlateau(self.optimizer, patience=0,
factor=self.args.lrshrink)
return lr_scheduler
def get_model(self):
"""Get one of the model replicas."""
# just return the first model, since all replicas are the same
return self.call_async(0, '_async_get_model').gen()
def _async_get_model(self, rank, device_id):
return self.model
def save_checkpoint(self, filename, extra_state):
"""Save a checkpoint for the current model."""
self.call_async(0, '_async_save_checkpoint', filename=filename, extra_state=extra_state).gen()
def _async_save_checkpoint(self, rank, device_id, filename, extra_state):
utils.save_state(filename, self.args, self.model, self.criterion, self.optimizer,
self.lr_scheduler, self._optim_history, extra_state)
def load_checkpoint(self, filename):
"""Load a checkpoint into the model replicas in each process."""
results = Future.gen_list([
self.call_async(rank, '_async_load_checkpoint', filename=filename)
for rank in range(self.num_replicas)
])
extra_state = results[0]
return extra_state
def _async_load_checkpoint(self, rank, device_id, filename):
extra_state, self._optim_history, last_optim_state = utils.load_model_state(
filename, self.model, cuda_device=device_id)
if last_optim_state is not None:
# rebuild optimizer after loading model, since params may have changed
self.optimizer = self._build_optimizer()
self.lr_scheduler = self._build_lr_scheduler()
# only load optimizer and lr_scheduler if they match the checkpoint
last_optim = self._optim_history[-1]
if last_optim['criterion_name'] == self.criterion.__class__.__name__:
self.optimizer.load_state_dict(last_optim_state)
self.lr_scheduler.best = last_optim['best_loss']
# override learning rate, momentum, etc. with latest values
for group in self.optimizer.param_groups:
group.update(self._override_optim_state)
return extra_state
def set_seed(self, seed):
Future.gen_list([
self.call_async(rank, '_async_set_seed', seed=seed)
for rank in range(self.num_replicas)
])
def _async_set_seed(self, rank, device_id, seed):
torch.manual_seed(seed)
def train_step(self, samples):
"""Do forward, backward and gradient step in parallel."""
# PyTorch initializes gradient buffers lazily, so the first
# train step needs to send non-empty samples to all replicas
replace_empty_samples = False
if not self._grads_initialized:
replace_empty_samples = True
self._grads_initialized = True
# scatter sample across GPUs
self._scatter_samples(samples, replace_empty_samples=replace_empty_samples)
# forward pass
sample_sizes, logging_outputs, ooms_fwd = Future.gen_tuple_list([
self.call_async(rank, '_async_forward')
for rank in range(self.num_replicas)
])
# backward pass, all-reduce gradients and take an optimization step
grad_denom = self.criterion.__class__.grad_denom(sample_sizes)
grad_norms, ooms_bwd = Future.gen_tuple_list([
self.call_async(rank, '_async_backward_and_opt', grad_denom=grad_denom)
for rank in range(self.num_replicas)
])
# aggregate logging output
logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
logging_output['gnorm'] = grad_norms[0] # log the gradient norm
logging_output['oom'] = sum(ooms_fwd) + sum(ooms_bwd)
return logging_output
def _async_forward(self, rank, device_id, eval=False):
if eval:
self.model.eval()
else:
self.model.train()
self.optimizer.zero_grad()
with utils.maybe_no_grad(eval):
sample_size, logging_output, oom = 0, {}, False
if self._sample is not None:
try:
# calculate loss and sample size
self.loss, sample_size, logging_output = self.criterion(self.model, self._sample)
except RuntimeError as e:
if not eval and 'out of memory' in str(e):
print('| WARNING: ran out of memory on GPU #{}, skipping batch'.format(device_id))
oom = True
self.loss = None
if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
else:
raise e
return sample_size, logging_output, oom
def _async_backward_and_opt(self, rank, device_id, grad_denom):
oom = False
if self.loss is not None:
try:
# backward pass
self.loss.backward()
except RuntimeError as e:
if 'out of memory' in str(e):
print('| WARNING: ran out of memory on GPU #{}, skipping batch'.format(device_id))
oom = True
if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
self.optimizer.zero_grad()
else:
raise e
# all-reduce grads and rescale by grad_denom
self._all_reduce_and_rescale_grads(grad_denom)
# clip grads
if self.args.clip_norm > 0:
grad_norm = torch.nn.utils.clip_grad_norm(self.model.parameters(), self.args.clip_norm)
else:
grad_norm = math.sqrt(sum([p.grad.data.norm()**2 for p in self.model.parameters()]))
# take an optimization step
self.optimizer.step()
# reset loss
self.loss = None
return grad_norm, oom
def _all_reduce_and_rescale_grads(self, grad_denom, buffer_size=10485760):
"""All-reduce and rescale gradients in chunks of the specified size."""
grads = [p.grad.data for p in self.model.parameters() if p.requires_grad]
buffer_t = grads[0].new(math.ceil(buffer_size / grads[0].element_size())).zero_()
buffer = []
def all_reduce_buffer():
# copy grads into buffer_t
offset = 0
for g in buffer:
numel = g.numel()
buffer_t[offset:offset+numel].copy_(g.view(-1))
offset += numel
# all-reduce and rescale
nccl.all_reduce(buffer_t[:offset])
buffer_t.div_(grad_denom)
# copy all-reduced buffer back into grads
offset = 0
for g in buffer:
numel = g.numel()
g.view(-1).copy_(buffer_t[offset:offset+numel])
offset += numel
filled = 0
for g in grads:
sz = g.numel() * g.element_size()
if sz > buffer_size:
# grad is bigger than buffer, all-reduce and rescale directly
nccl.all_reduce(g)
g.div_(grad_denom)
elif filled + sz > buffer_size:
# buffer is full, all-reduce and replace buffer with grad
all_reduce_buffer()
buffer = [g]
filled = sz
else:
# add grad to buffer
buffer.append(g)
filled += sz
if len(buffer) > 0:
all_reduce_buffer()
def valid_step(self, samples):
"""Do forward pass in parallel."""
# scatter sample across GPUs
self._scatter_samples(samples, volatile=True)
# forward pass
_sample_sizes, logging_outputs, ooms_fwd = Future.gen_tuple_list([
self.call_async(rank, '_async_forward', eval=True)
for rank in range(self.num_replicas)
])
assert sum(ooms_fwd) == 0
# aggregate logging output
logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
return logging_output
def get_lr(self):
"""Get the current learning rate."""
return self.call_async(0, '_async_get_lr').gen()
def _async_get_lr(self, rank, device_id):
return self.optimizer.param_groups[0]['lr']
def lr_step(self, val_loss=None, epoch=None):
"""Adjust the learning rate depending on the validation loss."""
lr = Future.gen_list([
self.call_async(rank, '_async_lr_step', val_loss=val_loss, epoch=epoch)
for rank in range(self.num_replicas)
])
return lr[0]
def _async_lr_step(self, rank, device_id, epoch, val_loss):
# update the learning rate
if self.args.force_anneal > 0:
self.lr_scheduler.step(epoch)
else:
self.lr_scheduler.step(val_loss, epoch)
return self.optimizer.param_groups[0]['lr']
def _scatter_samples(self, samples, volatile=False, replace_empty_samples=False):
"""Split and distribute a sample across GPUs."""
if not replace_empty_samples:
# pad with None until its size is equal to the number of replicas
samples = samples + [None]*(self.num_replicas - len(samples))
else:
# pad by cycling through the given samples
samples = list(islice(cycle(samples), self.num_replicas))
Future.gen_list([
self.call_async(rank, '_async_prepare_sample', sample=samples[rank], volatile=volatile)
for rank in range(self.num_replicas)
])
def _async_prepare_sample(self, rank, device_id, sample, volatile):
if sample is None:
self._sample = None
else:
if hasattr(torch.cuda, 'empty_cache'):
# clear the caching allocator if this is the largest sample we've seen
if sample['target'].size(0) > self._max_bsz_seen:
self._max_bsz_seen = sample['target'].size(0)
torch.cuda.empty_cache()
self._sample = utils.make_variable(sample, volatile=volatile, cuda_device=device_id)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
"""
A modified version of torch.cuda.nccl.all_reduce for launching kernels on each
GPU separately.
"""
import ctypes
from ctypes.util import find_library
lib = None
nccl_2_0 = None
_uid = None
_rank = None
_num_devices = None
_comm = None
__all__ = ['all_reduce', 'initialize', 'get_unique_id']
# ncclDataType_t
nccl_types = {
'torch.cuda.ByteTensor': 0,
'torch.cuda.CharTensor': 0,
'torch.cuda.IntTensor': 1,
'torch.cuda.HalfTensor': 2,
'torch.cuda.FloatTensor': 3,
'torch.cuda.DoubleTensor': 4,
'torch.cuda.LongTensor': 5,
}
nccl_types_2_0 = {
'torch.cuda.ByteTensor': 0,
'torch.cuda.CharTensor': 0,
'torch.cuda.IntTensor': 2,
'torch.cuda.HalfTensor': 6,
'torch.cuda.FloatTensor': 7,
'torch.cuda.DoubleTensor': 8,
'torch.cuda.LongTensor': 4,
}
# ncclRedOp_t
SUM = 0
PROD = 1
MAX = 2
MIN = 3
status_codes_2_0 = {
0: "Success",
1: "Unhandled Cuda Error",
2: "System Error",
3: "Internal Error",
4: "Invalid Argument Error",
5: "Invalid Usage Error",
}
status_codes = {
0: "Success",
1: "Unhandled Cuda Error",
2: "System Error",
3: "Internal Error",
4: "Invalid Device Pointer",
5: "Invalid Rank",
6: "Unsupported Device Count",
7: "Device Not Found",
8: "Invalid Device Index",
9: "Lib Wrapper Not Set",
10: "Cuda Malloc Failed",
11: "Rank Mismatch",
12: "Invalid Argument",
13: "Invalid Type",
14: "Invalid Operation",
}
def _libnccl():
global nccl_2_0
global lib
global status_codes
global nccl_types
if lib is None:
lib = ctypes.pydll.LoadLibrary(find_library('nccl'))
if hasattr(lib, 'ncclCommDestroy'):
lib.ncclCommDestroy.restype = None
else:
lib = None
if hasattr(lib, 'ncclGroupStart'):
nccl_2_0 = True
status_codes = status_codes_2_0
nccl_types = nccl_types_2_0
return lib
class NcclError(RuntimeError):
def __init__(self, status):
self.status = status
msg = '{0} ({1})'.format(status_codes.get(status), status)
super(NcclError, self).__init__(msg)
class NcclComm(ctypes.c_void_p):
def __del__(self):
lib.ncclCommDestroy(self)
class NcclUniqueId(ctypes.Structure):
_fields_ = [
('internal', ctypes.c_uint8 * 128)
]
def check_error(status):
if status != 0:
raise NcclError(status)
_uids = []
def get_unique_id():
if _libnccl() is None:
raise RuntimeError('Unable to load NCCL library')
uid = NcclUniqueId()
check_error(lib.ncclGetUniqueId(ctypes.byref(uid)))
_uids.append(uid) # Don't allow UIDs to be collected
return uid
def initialize(num_devices, uid, rank):
global _num_devices, _uid, _rank
if _libnccl() is None:
raise RuntimeError('Unable to load NCCL library')
_num_devices = num_devices
if rank != 0:
_uid = NcclUniqueId.from_buffer_copy(uid)
else:
_uid = uid
_rank = rank
def communicator():
global _comm
if _libnccl() is None:
raise RuntimeError('Unable to load NCCL library')
if _uid is None:
raise RuntimeError('NCCL not initialized')
if _comm is None:
comm = NcclComm()
check_error(lib.ncclCommInitRank(
ctypes.byref(comm),
ctypes.c_int(_num_devices),
_uid,
ctypes.c_int(_rank)))
_comm = comm
return _comm
def all_reduce(input, output=None, op=SUM, stream=None):
comm = communicator()
if output is None:
output = input
if stream is not None:
stream = stream.cuda_stream
data_type = nccl_types[input.type()]
check_error(lib.ncclAllReduce(
ctypes.c_void_p(input.data_ptr()),
ctypes.c_void_p(output.data_ptr()),
ctypes.c_size_t(input.numel()),
data_type,
op,
comm,
ctypes.c_void_p(stream)))
return output
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import importlib
import os
from .fairseq_optimizer import FairseqOptimizer
OPTIMIZER_REGISTRY = {}
OPTIMIZER_CLASS_NAMES = set()
def build_optimizer(args, params):
return OPTIMIZER_REGISTRY[args.optimizer](args, params)
def register_optimizer(name):
"""Decorator to register a new optimizer."""
def register_optimizer_cls(cls):
if name in OPTIMIZER_REGISTRY:
raise ValueError('Cannot register duplicate optimizer ({})'.format(name))
if not issubclass(cls, FairseqOptimizer):
raise ValueError('Optimizer ({}: {}) must extend FairseqOptimizer'.format(name, cls.__name__))
if cls.__name__ in OPTIMIZER_CLASS_NAMES:
# We use the optimizer class name as a unique identifier in
# checkpoints, so all optimizer must have unique class names.
raise ValueError('Cannot register optimizer with duplicate class name ({})'.format(cls.__name__))
OPTIMIZER_REGISTRY[name] = cls
OPTIMIZER_CLASS_NAMES.add(cls.__name__)
return cls
return register_optimizer_cls
# automatically import any Python files in the optim/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
module = file[:file.find('.py')]
importlib.import_module('fairseq.optim.' + module)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.optim
from . import FairseqOptimizer, register_optimizer
@register_optimizer('adagrad')
class Adagrad(FairseqOptimizer):
def __init__(self, args, params):
super().__init__(args, params)
self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config)
@property
def optimizer_config(self):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
return {
'lr': self.args.lr[0],
'weight_decay': self.args.weight_decay,
}
......@@ -4,16 +4,48 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import math
import torch
from torch.optim.optimizer import Optimizer
import torch.optim
from . import FairseqOptimizer, register_optimizer
@register_optimizer('adam')
class FairseqAdam(FairseqOptimizer):
def __init__(self, args, params):
super().__init__(args, params)
self._optimizer = Adam(params, **self.optimizer_config)
class Adam(Optimizer):
@staticmethod
def add_args(parser):
"""Add optimizer-specific arguments to the parser."""
parser.add_argument('--adam-betas', default='(0.9, 0.999)', metavar='B',
help='betas for Adam optimizer')
@property
def optimizer_config(self):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
return {
'lr': self.args.lr[0],
'betas': eval(self.args.adam_betas),
'weight_decay': self.args.weight_decay,
}
class Adam(torch.optim.Optimizer):
"""Implements Adam algorithm.
This implementation is modified from torch.optim.Adam based on:
`Fixed Weight Decay Regularization in Adam`
(see https://arxiv.org/abs/1711.05101)
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.optim
class FairseqOptimizer(object):
def __init__(self, args, params):
super().__init__()
self.args = args
self.params = params
@staticmethod
def add_args(parser):
"""Add optimizer-specific arguments to the parser."""
pass
@property
def optimizer(self):
"""Return a torch.optim.optimizer.Optimizer instance."""
if not hasattr(self, '_optimizer'):
raise NotImplementedError
if not isinstance(self._optimizer, torch.optim.Optimizer):
raise ValueError('_optimizer must be an instance of torch.optim.Optimizer')
return self._optimizer
@property
def optimizer_config(self):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
raise NotImplementedError
def get_lr(self):
"""Return the current learning rate."""
return self.optimizer.param_groups[0]['lr']
def set_lr(self, lr):
"""Set the learning rate."""
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
def state_dict(self):
"""Return the optimizer's state dict."""
return self.optimizer.state_dict()
def load_state_dict(self, state_dict):
"""Load an optimizer state dict.
In general we should prefer the configuration of the existing optimizer
instance (e.g., learning rate) over that found in the state_dict. This
allows us to resume training from a checkpoint using a new set of
optimizer args.
"""
self.optimizer.load_state_dict(state_dict)
# override learning rate, momentum, etc. with latest values
for group in self.optimizer.param_groups:
group.update(self.optimizer_config)
def step(self, closure=None):
"""Performs a single optimization step."""
return self.optimizer.step(closure)
def zero_grad(self):
"""Clears the gradients of all optimized parameters."""
return self.optimizer.zero_grad()
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import importlib
import os
from .fairseq_lr_scheduler import FairseqLRScheduler
LR_SCHEDULER_REGISTRY = {}
def build_lr_scheduler(args, optimizer):
return LR_SCHEDULER_REGISTRY[args.lr_scheduler](args, optimizer)
def register_lr_scheduler(name):
"""Decorator to register a new LR scheduler."""
def register_lr_scheduler_cls(cls):
if name in LR_SCHEDULER_REGISTRY:
raise ValueError('Cannot register duplicate LR scheduler ({})'.format(name))
if not issubclass(cls, FairseqLRScheduler):
raise ValueError('LR Scheduler ({}: {}) must extend FairseqLRScheduler'.format(name, cls.__name__))
LR_SCHEDULER_REGISTRY[name] = cls
return cls
return register_lr_scheduler_cls
# automatically import any Python files in the optim/lr_scheduler/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
module = file[:file.find('.py')]
importlib.import_module('fairseq.optim.lr_scheduler.' + module)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from .. import FairseqOptimizer
class FairseqLRScheduler(object):
def __init__(self, args, optimizer):
super().__init__()
if not isinstance(optimizer, FairseqOptimizer):
raise ValueError('optimizer must be an instance of FairseqOptimizer')
self.args = args
self.optimizer = optimizer
self.best = None
@staticmethod
def add_args(parser):
"""Add arguments to the parser for this LR scheduler."""
pass
def state_dict(self):
"""Return the LR scheduler state dict."""
return {'best': self.best}
def load_state_dict(self, state_dict):
"""Load an LR scheduler state dict."""
self.best = state_dict['best']
def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch."""
if val_loss is not None:
if self.best is None:
self.best = val_loss
else:
self.best = min(self.best, val_loss)
def step_update(self, num_updates):
"""Update the learning rate after each update."""
return self.optimizer.get_lr()
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.optim.lr_scheduler
from . import FairseqLRScheduler, register_lr_scheduler
@register_lr_scheduler('fixed')
class FixedSchedule(FairseqLRScheduler):
"""Decay the LR on a fixed schedule."""
def __init__(self, args, optimizer):
super().__init__(args, optimizer)
self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
self.optimizer.optimizer, self.anneal)
@staticmethod
def add_args(parser):
"""Add arguments to the parser for this LR scheduler."""
parser.add_argument('--force-anneal', '--fa', type=int, metavar='N',
help='force annealing at specified epoch')
def anneal(self, epoch):
lrs = self.args.lr
if self.args.force_anneal is None or epoch < self.args.force_anneal:
# use fixed LR schedule
next_lr = lrs[min(epoch, len(lrs) - 1)]
else:
# annneal based on lr_shrink
next_lr = lrs[-1] * self.args.lr_shrink ** (epoch + 1 - self.args.force_anneal)
return next_lr / lrs[0] # correct for scaling from LambdaLR
def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch."""
super().step(epoch, val_loss)
self.lr_scheduler.step(epoch)
return self.optimizer.get_lr()
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from . import FairseqLRScheduler, register_lr_scheduler
@register_lr_scheduler('inverse_sqrt')
class InverseSquareRootSchedule(FairseqLRScheduler):
"""Decay the LR based on the inverse square root of the update number.
We also support a warmup phase where we linearly increase the learning rate
from some initial learning rate (`--warmup-init-lr`) until the configured
learning rate (`--lr`). Thereafter we decay proportional to the number of
updates, with a decay factor set to align with the configured learning rate.
During warmup:
lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
lr = lrs[update_num]
After warmup:
lr = decay_factor / sqrt(update_num)
where
decay_factor = args.lr * sqrt(args.warmup_updates)
"""
def __init__(self, args, optimizer):
super().__init__(args, optimizer)
if len(args.lr) > 1:
raise ValueError(
'Cannot use a fixed learning rate schedule with inverse_sqrt.'
' Consider --lr-scheduler=fixed instead.'
)
warmup_end_lr = args.lr[0]
if args.warmup_init_lr < 0:
args.warmup_init_lr = warmup_end_lr
# linearly warmup for the first args.warmup_updates
self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates
# then, decay prop. to the inverse square root of the update number
self.decay_factor = warmup_end_lr * args.warmup_updates**0.5
# initial learning rate
self.lr = args.warmup_init_lr
self.optimizer.set_lr(self.lr)
@staticmethod
def add_args(parser):
"""Add arguments to the parser for this LR scheduler."""
parser.add_argument('--warmup-updates', default=4000, type=int, metavar='N',
help='warmup the learning rate linearly for the first N updates')
parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR',
help='initial learning rate during warmup phase; default is args.lr')
def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch."""
super().step(epoch, val_loss)
# we don't change the learning rate at epoch boundaries
return self.optimizer.get_lr()
def step_update(self, num_updates):
"""Update the learning rate after each update."""
if num_updates < self.args.warmup_updates:
self.lr += self.lr_step
else:
self.lr = self.decay_factor * num_updates**-0.5
self.optimizer.set_lr(self.lr)
return self.lr
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.optim.lr_scheduler
from . import FairseqLRScheduler, register_lr_scheduler
@register_lr_scheduler('reduce_lr_on_plateau')
class ReduceLROnPlateau(FairseqLRScheduler):
"""Decay the LR by a factor every time the validation loss plateaus."""
def __init__(self, args, optimizer):
super().__init__(args, optimizer)
if len(args.lr) > 1:
raise ValueError(
'Cannot use a fixed learning rate schedule with reduce_lr_on_plateau.'
' Consider --lr-scheduler=fixed instead.'
)
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer.optimizer, patience=0, factor=args.lr_shrink)
def state_dict(self):
"""Return the LR scheduler state dict."""
return {
'best': self.lr_scheduler.best,
'last_epoch': self.lr_scheduler.last_epoch,
}
def load_state_dict(self, state_dict):
"""Load an LR scheduler state dict."""
self.lr_scheduler.best = state_dict['best']
if 'last_epoch' in state_dict:
self.lr_scheduler.last_epoch = state_dict['last_epoch']
def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch."""
if val_loss is not None:
self.lr_scheduler.step(val_loss, epoch)
else:
self.lr_scheduler.last_epoch = epoch
return self.optimizer.get_lr()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment