Commit ba5f829f authored by Jeff Cai's avatar Jeff Cai Committed by Facebook Github Bot
Browse files

Parameterized criterions (#808)

Summary:
Support criterion with parameters, such as AutoSegmentationCriterion (ASG) used in wav2letter which has a transition matrix parameter. This is needed to integrate wav2letter's ASG into PySpeech.

With this diff, parameters in criterions will be:
(1) updated by optimizers, with a configurable learning rate
(2) saved and loaded from checkpoints, preserving backward compatibility for criterions without parameters
(3) synchronized across nodes in distributed training.
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/808

Reviewed By: jcai1

Differential Revision: D16934097

Pulled By: okhonko

fbshipit-source-id: 121ec9382459385c6f9cbef3a8274bec1a434038
parent a2f5361d
...@@ -222,6 +222,7 @@ def save_state( ...@@ -222,6 +222,7 @@ def save_state(
filename, args, model_state_dict, criterion, optimizer, lr_scheduler, filename, args, model_state_dict, criterion, optimizer, lr_scheduler,
num_updates, optim_history=None, extra_state=None, num_updates, optim_history=None, extra_state=None,
): ):
from fairseq import utils
if optim_history is None: if optim_history is None:
optim_history = [] optim_history = []
if extra_state is None: if extra_state is None:
...@@ -239,6 +240,8 @@ def save_state( ...@@ -239,6 +240,8 @@ def save_state(
], ],
'extra_state': extra_state, 'extra_state': extra_state,
} }
if utils.has_parameters(criterion):
state_dict['criterion'] = criterion.state_dict()
if not args.no_save_optimizer_state: if not args.no_save_optimizer_state:
state_dict['last_optimizer_state'] = convert_state_dict_type(optimizer.state_dict()) state_dict['last_optimizer_state'] = convert_state_dict_type(optimizer.state_dict())
torch_persistent_save(state_dict, filename) torch_persistent_save(state_dict, filename)
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import inspect import inspect
from torch.nn import parallel import torch.nn as nn
from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel
from fairseq.models import BaseFairseqModel from fairseq.models import BaseFairseqModel
...@@ -25,9 +25,9 @@ def DistributedFairseqModel(args, model): ...@@ -25,9 +25,9 @@ def DistributedFairseqModel(args, model):
model (BaseFairseqModel): model to wrap model (BaseFairseqModel): model to wrap
""" """
# determine which DDP class to extend # determine which DDP class to extend
assert isinstance(model, BaseFairseqModel) assert isinstance(model, nn.Module)
if args.ddp_backend == 'c10d': if args.ddp_backend == 'c10d':
ddp_class = parallel.DistributedDataParallel ddp_class = nn.parallel.DistributedDataParallel
init_kwargs = dict( init_kwargs = dict(
module=model, module=model,
device_ids=[args.device_id], device_ids=[args.device_id],
......
...@@ -19,18 +19,13 @@ __all__ = [ ...@@ -19,18 +19,13 @@ __all__ = [
] ]
_build_optimizer, register_optimizer, OPTIMIZER_REGISTRY = registry.setup_registry( build_optimizer, register_optimizer, OPTIMIZER_REGISTRY = registry.setup_registry(
'--optimizer', '--optimizer',
base_class=FairseqOptimizer, base_class=FairseqOptimizer,
default='nag', default='nag',
) )
def build_optimizer(args, params, *extra_args, **extra_kwargs):
params = list(filter(lambda p: p.requires_grad, params))
return _build_optimizer(args, params, *extra_args, **extra_kwargs)
# automatically import any Python files in the optim/ directory # automatically import any Python files in the optim/ directory
for file in os.listdir(os.path.dirname(__file__)): for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'): if file.endswith('.py') and not file.startswith('_'):
......
...@@ -11,7 +11,7 @@ from . import FairseqOptimizer, register_optimizer ...@@ -11,7 +11,7 @@ from . import FairseqOptimizer, register_optimizer
@register_optimizer('adadelta') @register_optimizer('adadelta')
class Adadelta(FairseqOptimizer): class Adadelta(FairseqOptimizer):
def __init__(self, args, params): def __init__(self, args, params):
super().__init__(args, params) super().__init__(args)
self._optimizer = torch.optim.Adadelta(params, **self.optimizer_config) self._optimizer = torch.optim.Adadelta(params, **self.optimizer_config)
@staticmethod @staticmethod
......
...@@ -13,7 +13,7 @@ from . import FairseqOptimizer, register_optimizer ...@@ -13,7 +13,7 @@ from . import FairseqOptimizer, register_optimizer
@register_optimizer('adafactor') @register_optimizer('adafactor')
class FairseqAdafactor(FairseqOptimizer): class FairseqAdafactor(FairseqOptimizer):
def __init__(self, args, params): def __init__(self, args, params):
super().__init__(args, params) super().__init__(args)
self._optimizer = Adafactor(params, **self.optimizer_config) self._optimizer = Adafactor(params, **self.optimizer_config)
@staticmethod @staticmethod
......
...@@ -11,7 +11,7 @@ from . import FairseqOptimizer, register_optimizer ...@@ -11,7 +11,7 @@ from . import FairseqOptimizer, register_optimizer
@register_optimizer('adagrad') @register_optimizer('adagrad')
class Adagrad(FairseqOptimizer): class Adagrad(FairseqOptimizer):
def __init__(self, args, params): def __init__(self, args, params):
super().__init__(args, params) super().__init__(args)
self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config) self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config)
@staticmethod @staticmethod
......
...@@ -16,7 +16,7 @@ from . import FairseqOptimizer, register_optimizer ...@@ -16,7 +16,7 @@ from . import FairseqOptimizer, register_optimizer
class FairseqAdam(FairseqOptimizer): class FairseqAdam(FairseqOptimizer):
def __init__(self, args, params): def __init__(self, args, params):
super().__init__(args, params) super().__init__(args)
if torch.cuda.is_available(): if torch.cuda.is_available():
try: try:
from apex.optimizers import FusedAdam as _FusedAdam # noqa from apex.optimizers import FusedAdam as _FusedAdam # noqa
......
...@@ -12,7 +12,7 @@ from . import FairseqOptimizer, register_optimizer ...@@ -12,7 +12,7 @@ from . import FairseqOptimizer, register_optimizer
@register_optimizer('adamax') @register_optimizer('adamax')
class FairseqAdamax(FairseqOptimizer): class FairseqAdamax(FairseqOptimizer):
def __init__(self, args, params): def __init__(self, args, params):
super().__init__(args, params) super().__init__(args)
self._optimizer = Adamax(params, **self.optimizer_config) self._optimizer = Adamax(params, **self.optimizer_config)
@staticmethod @staticmethod
......
...@@ -19,11 +19,10 @@ class FairseqBMUF(FairseqOptimizer): ...@@ -19,11 +19,10 @@ class FairseqBMUF(FairseqOptimizer):
model-update filtering model-update filtering
""" """
def __init__(self, args, params, optimizer): def __init__(self, args, optimizer):
super().__init__(args, params) super().__init__(args)
self._optimizer = optimizer self._optimizer = optimizer
self.params = params
self._num_updates = 0 self._num_updates = 0
self.sync_iter = self.args.global_sync_iter self.sync_iter = self.args.global_sync_iter
self.block_momentum = self.args.block_momentum self.block_momentum = self.args.block_momentum
......
...@@ -10,10 +10,9 @@ import torch ...@@ -10,10 +10,9 @@ import torch
class FairseqOptimizer(object): class FairseqOptimizer(object):
def __init__(self, args, params): def __init__(self, args):
super().__init__() super().__init__()
self.args = args self.args = args
self.params = list(params)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
...@@ -39,6 +38,13 @@ class FairseqOptimizer(object): ...@@ -39,6 +38,13 @@ class FairseqOptimizer(object):
""" """
raise NotImplementedError raise NotImplementedError
@property
def params(self):
"""Return an iterable of the parameters held by the optimizer."""
for param_group in self.optimizer.param_groups:
for p in param_group['params']:
yield p
def __getstate__(self): def __getstate__(self):
return self._optimizer.__getstate__() return self._optimizer.__getstate__()
...@@ -93,9 +99,8 @@ class FairseqOptimizer(object): ...@@ -93,9 +99,8 @@ class FairseqOptimizer(object):
def zero_grad(self): def zero_grad(self):
"""Clears the gradients of all optimized parameters.""" """Clears the gradients of all optimized parameters."""
for group in self.optimizer.param_groups: for p in self.params:
for p in group['params']: p.grad = None
p.grad = None
self.optimizer.zero_grad() self.optimizer.zero_grad()
@property @property
......
...@@ -60,7 +60,8 @@ class FP16Optimizer(optim.FairseqOptimizer): ...@@ -60,7 +60,8 @@ class FP16Optimizer(optim.FairseqOptimizer):
""" """
def __init__(self, args, params, fp32_optimizer, fp32_params): def __init__(self, args, params, fp32_optimizer, fp32_params):
super().__init__(args, params) super().__init__(args)
self.fp16_params = params
self.fp32_optimizer = fp32_optimizer self.fp32_optimizer = fp32_optimizer
self.fp32_params = fp32_params self.fp32_params = fp32_params
...@@ -149,7 +150,7 @@ class FP16Optimizer(optim.FairseqOptimizer): ...@@ -149,7 +150,7 @@ class FP16Optimizer(optim.FairseqOptimizer):
if self._needs_sync: if self._needs_sync:
# copy FP16 grads to FP32 # copy FP16 grads to FP32
offset = 0 offset = 0
for p in self.params: for p in self.fp16_params:
if not p.requires_grad: if not p.requires_grad:
continue continue
grad_data = p.grad.data if p.grad is not None else p.data.new_zeros(p.data.shape) grad_data = p.grad.data if p.grad is not None else p.data.new_zeros(p.data.shape)
...@@ -196,7 +197,7 @@ class FP16Optimizer(optim.FairseqOptimizer): ...@@ -196,7 +197,7 @@ class FP16Optimizer(optim.FairseqOptimizer):
# copy FP32 params back into FP16 model # copy FP32 params back into FP16 model
offset = 0 offset = 0
for p in self.params: for p in self.fp16_params:
if not p.requires_grad: if not p.requires_grad:
continue continue
numel = p.data.numel() numel = p.data.numel()
...@@ -205,7 +206,7 @@ class FP16Optimizer(optim.FairseqOptimizer): ...@@ -205,7 +206,7 @@ class FP16Optimizer(optim.FairseqOptimizer):
def zero_grad(self): def zero_grad(self):
"""Clears the gradients of all optimized parameters.""" """Clears the gradients of all optimized parameters."""
for p in self.params: for p in self.fp16_params:
p.grad = None p.grad = None
self._needs_sync = False self._needs_sync = False
...@@ -232,7 +233,7 @@ class MemoryEfficientFP16Optimizer(optim.FairseqOptimizer): ...@@ -232,7 +233,7 @@ class MemoryEfficientFP16Optimizer(optim.FairseqOptimizer):
'Unsupported optimizer: {}'.format(optimizer.__class__.__name__) 'Unsupported optimizer: {}'.format(optimizer.__class__.__name__)
) )
super().__init__(args, params) super().__init__(args)
self.wrapped_optimizer = optimizer self.wrapped_optimizer = optimizer
if getattr(args, 'fp16_scale_window', None) is None: if getattr(args, 'fp16_scale_window', None) is None:
......
...@@ -12,7 +12,7 @@ from . import FairseqOptimizer, register_optimizer ...@@ -12,7 +12,7 @@ from . import FairseqOptimizer, register_optimizer
@register_optimizer('nag') @register_optimizer('nag')
class FairseqNAG(FairseqOptimizer): class FairseqNAG(FairseqOptimizer):
def __init__(self, args, params): def __init__(self, args, params):
super().__init__(args, params) super().__init__(args)
self._optimizer = NAG(params, **self.optimizer_config) self._optimizer = NAG(params, **self.optimizer_config)
@staticmethod @staticmethod
......
...@@ -11,7 +11,7 @@ from . import FairseqOptimizer, register_optimizer ...@@ -11,7 +11,7 @@ from . import FairseqOptimizer, register_optimizer
@register_optimizer('sgd') @register_optimizer('sgd')
class SGD(FairseqOptimizer): class SGD(FairseqOptimizer):
def __init__(self, args, params): def __init__(self, args, params):
super().__init__(args, params) super().__init__(args)
self._optimizer = torch.optim.SGD(params, **self.optimizer_config) self._optimizer = torch.optim.SGD(params, **self.optimizer_config)
@staticmethod @staticmethod
......
...@@ -36,13 +36,14 @@ class Trainer(object): ...@@ -36,13 +36,14 @@ class Trainer(object):
self.task = task self.task = task
# copy model and criterion to current device # copy model and criterion to current device
self.criterion = criterion self._criterion = criterion
self._model = model self._model = model
self.cuda = torch.cuda.is_available() and not args.cpu self.cuda = torch.cuda.is_available() and not args.cpu
if args.fp16: if args.fp16:
self._criterion = self._criterion.half()
self._model = self._model.half() self._model = self._model.half()
if self.cuda: if self.cuda:
self.criterion = self.criterion.cuda() self._criterion = self._criterion.cuda()
self._model = self._model.cuda() self._model = self._model.cuda()
self._dummy_batch = dummy_batch self._dummy_batch = dummy_batch
...@@ -53,6 +54,7 @@ class Trainer(object): ...@@ -53,6 +54,7 @@ class Trainer(object):
self._optim_history = None self._optim_history = None
self._optimizer = None self._optimizer = None
self._prev_grad_norm = None self._prev_grad_norm = None
self._wrapped_criterion = None
self._wrapped_model = None self._wrapped_model = None
self.init_meters(args) self.init_meters(args)
...@@ -75,6 +77,21 @@ class Trainer(object): ...@@ -75,6 +77,21 @@ class Trainer(object):
self.meters['wall'] = TimeMeter() # wall time in seconds self.meters['wall'] = TimeMeter() # wall time in seconds
self.meters['train_wall'] = StopwatchMeter() # train wall time in seconds self.meters['train_wall'] = StopwatchMeter() # train wall time in seconds
@property
def criterion(self):
if self._wrapped_criterion is None:
if (
utils.has_parameters(self._criterion)
and self.args.distributed_world_size > 1
and not self.args.use_bmuf
):
self._wrapped_criterion = models.DistributedFairseqModel(
self.args, self._criterion
)
else:
self._wrapped_criterion = self._criterion
return self._wrapped_criterion
@property @property
def model(self): def model(self):
if self._wrapped_model is None: if self._wrapped_model is None:
...@@ -99,7 +116,13 @@ class Trainer(object): ...@@ -99,7 +116,13 @@ class Trainer(object):
return self._lr_scheduler return self._lr_scheduler
def _build_optimizer(self): def _build_optimizer(self):
params = list(filter(lambda p: p.requires_grad, self.model.parameters())) params = list(
filter(
lambda p: p.requires_grad,
chain(self.model.parameters(), self.criterion.parameters()),
)
)
if self.args.fp16: if self.args.fp16:
if self.cuda and torch.cuda.get_device_capability(0)[0] < 7: if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
print('| WARNING: your device does NOT support faster training with --fp16, ' print('| WARNING: your device does NOT support faster training with --fp16, '
...@@ -114,7 +137,7 @@ class Trainer(object): ...@@ -114,7 +137,7 @@ class Trainer(object):
self._optimizer = optim.build_optimizer(self.args, params) self._optimizer = optim.build_optimizer(self.args, params)
if self.args.use_bmuf: if self.args.use_bmuf:
self._optimizer = optim.FairseqBMUF(self.args, params, self._optimizer) self._optimizer = optim.FairseqBMUF(self.args, self._optimizer)
# We should initialize the learning rate scheduler immediately after # We should initialize the learning rate scheduler immediately after
# building the optimizer, so that the initial learning rate is set. # building the optimizer, so that the initial learning rate is set.
...@@ -126,7 +149,7 @@ class Trainer(object): ...@@ -126,7 +149,7 @@ class Trainer(object):
if distributed_utils.is_master(self.args): # only save one checkpoint if distributed_utils.is_master(self.args): # only save one checkpoint
extra_state['train_meters'] = self.meters extra_state['train_meters'] = self.meters
checkpoint_utils.save_state( checkpoint_utils.save_state(
filename, self.args, self.get_model().state_dict(), self.criterion, filename, self.args, self.get_model().state_dict(), self.get_criterion(),
self.optimizer, self.lr_scheduler, self.get_num_updates(), self.optimizer, self.lr_scheduler, self.get_num_updates(),
self._optim_history, extra_state, self._optim_history, extra_state,
) )
...@@ -148,6 +171,8 @@ class Trainer(object): ...@@ -148,6 +171,8 @@ class Trainer(object):
# load model parameters # load model parameters
try: try:
self.get_model().load_state_dict(state['model'], strict=True) self.get_model().load_state_dict(state['model'], strict=True)
if utils.has_parameters(self.get_criterion()):
self.get_criterion().load_state_dict(state['criterion'], strict=True)
except Exception: except Exception:
raise Exception( raise Exception(
'Cannot load model parameters from checkpoint {}; ' 'Cannot load model parameters from checkpoint {}; '
...@@ -164,7 +189,7 @@ class Trainer(object): ...@@ -164,7 +189,7 @@ class Trainer(object):
# only reload optimizer and lr_scheduler if they match # only reload optimizer and lr_scheduler if they match
last_optim = self._optim_history[-1] last_optim = self._optim_history[-1]
assert last_optim['criterion_name'] == self.criterion.__class__.__name__, \ assert last_optim['criterion_name'] == self.get_criterion().__class__.__name__, \
'Criterion does not match; please reset the optimizer (--reset-optimizer).' 'Criterion does not match; please reset the optimizer (--reset-optimizer).'
assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \ assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \
'Optimizer does not match; please reset the optimizer (--reset-optimizer).' 'Optimizer does not match; please reset the optimizer (--reset-optimizer).'
...@@ -322,9 +347,9 @@ class Trainer(object): ...@@ -322,9 +347,9 @@ class Trainer(object):
# aggregate logging outputs and sample sizes # aggregate logging outputs and sample sizes
logging_output = self.task.aggregate_logging_outputs( logging_output = self.task.aggregate_logging_outputs(
logging_outputs, self.criterion logging_outputs, self.get_criterion()
) )
sample_size = self.task.grad_denom(sample_sizes, self.criterion) sample_size = self.task.grad_denom(sample_sizes, self.get_criterion())
if not all(k in logging_output for k in ['ntokens', 'nsentences']): if not all(k in logging_output for k in ['ntokens', 'nsentences']):
raise Exception(( raise Exception((
...@@ -424,10 +449,10 @@ class Trainer(object): ...@@ -424,10 +449,10 @@ class Trainer(object):
# aggregate logging outputs and sample sizes # aggregate logging outputs and sample sizes
logging_output = self.task.aggregate_logging_outputs( logging_output = self.task.aggregate_logging_outputs(
logging_output, self.criterion logging_output, self.get_criterion()
) )
sample_size = self.task.grad_denom( sample_size = self.task.grad_denom(
sample_size, self.criterion sample_size, self.get_criterion()
) )
# update meters for validation # update meters for validation
...@@ -477,6 +502,10 @@ class Trainer(object): ...@@ -477,6 +502,10 @@ class Trainer(object):
"""Get the (non-wrapped) model instance.""" """Get the (non-wrapped) model instance."""
return self._model return self._model
def get_criterion(self):
"""Get the (non-wrapped) criterion instance."""
return self._criterion
def get_meter(self, name): def get_meter(self, name):
"""Get a specific meter by name.""" """Get a specific meter by name."""
if name not in self.meters: if name not in self.meters:
......
...@@ -351,3 +351,11 @@ def eval(model): ...@@ -351,3 +351,11 @@ def eval(model):
model.eval() model.eval()
yield yield
model.train(is_training) model.train(is_training)
def has_parameters(module):
try:
next(module.parameters())
return True
except StopIteration:
return False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment