Commit 10c2e8c1 authored by Wei Ho's avatar Wei Ho Committed by Facebook Github Bot
Browse files

Apply Black auto-formatting

Reviewed By: sujitoc

Differential Revision: D18738392

fbshipit-source-id: b7b7b75ef97946786c463c1887ef9a8676f030e6
parent cfc4b303
......@@ -3,25 +3,24 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
from typing import Union
import collections
import logging
import os
import re
import traceback
import shutil
import traceback
from collections import OrderedDict
from typing import Union
import torch
from fairseq.models import FairseqDecoder, FairseqEncoder
from torch.serialization import default_restore_location
from fairseq.models import FairseqEncoder, FairseqDecoder
def save_checkpoint(args, trainer, epoch_itr, val_loss):
from fairseq import distributed_utils, meters
prev_best = getattr(save_checkpoint, 'best', val_loss)
prev_best = getattr(save_checkpoint, "best", val_loss)
if val_loss is not None:
best_function = max if args.maximize_best_checkpoint_metric else min
save_checkpoint.best = best_function(val_loss, prev_best)
......@@ -40,56 +39,59 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
updates = trainer.get_num_updates()
checkpoint_conds = collections.OrderedDict()
checkpoint_conds['checkpoint{}.pt'.format(epoch)] = (
end_of_epoch and not args.no_epoch_checkpoints and
epoch % args.save_interval == 0
checkpoint_conds["checkpoint{}.pt".format(epoch)] = (
end_of_epoch
and not args.no_epoch_checkpoints
and epoch % args.save_interval == 0
)
checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = (
not end_of_epoch and args.save_interval_updates > 0 and
updates % args.save_interval_updates == 0
checkpoint_conds["checkpoint_{}_{}.pt".format(epoch, updates)] = (
not end_of_epoch
and args.save_interval_updates > 0
and updates % args.save_interval_updates == 0
)
checkpoint_conds['checkpoint_best.pt'] = (
val_loss is not None and
(not hasattr(save_checkpoint, 'best') or is_better(val_loss, save_checkpoint.best))
checkpoint_conds["checkpoint_best.pt"] = val_loss is not None and (
not hasattr(save_checkpoint, "best")
or is_better(val_loss, save_checkpoint.best)
)
checkpoint_conds['checkpoint_last.pt'] = not args.no_last_checkpoints
checkpoint_conds["checkpoint_last.pt"] = not args.no_last_checkpoints
extra_state = {
'train_iterator': epoch_itr.state_dict(),
'val_loss': val_loss,
}
if hasattr(save_checkpoint, 'best'):
extra_state.update({'best': save_checkpoint.best})
extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
if hasattr(save_checkpoint, "best"):
extra_state.update({"best": save_checkpoint.best})
checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]
checkpoints = [
os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
]
if len(checkpoints) > 0:
trainer.save_checkpoint(checkpoints[0], extra_state)
for cp in checkpoints[1:]:
try:
from fairseq.fb_pathmgr import fb_pathmgr
fb_pathmgr.copy(checkpoints[0], cp, True)
except (ModuleNotFoundError, ImportError):
shutil.copyfile(checkpoints[0], cp)
write_timer.stop()
print('| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)'.format(
checkpoints[0], epoch, updates, write_timer.sum))
print(
"| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)".format(
checkpoints[0], epoch, updates, write_timer.sum
)
)
if not end_of_epoch and args.keep_interval_updates > 0:
# remove old checkpoints; checkpoints are sorted in descending order
checkpoints = checkpoint_paths(
args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt',
args.save_dir, pattern=r"checkpoint_\d+_(\d+)\.pt"
)
for old_chk in checkpoints[args.keep_interval_updates:]:
for old_chk in checkpoints[args.keep_interval_updates :]:
if os.path.lexists(old_chk):
os.remove(old_chk)
if args.keep_last_epochs > 0:
# remove old epoch checkpoints; checkpoints are sorted in descending order
checkpoints = checkpoint_paths(
args.save_dir, pattern=r'checkpoint(\d+)\.pt',
)
for old_chk in checkpoints[args.keep_last_epochs:]:
checkpoints = checkpoint_paths(args.save_dir, pattern=r"checkpoint(\d+)\.pt")
for old_chk in checkpoints[args.keep_last_epochs :]:
if os.path.lexists(old_chk):
os.remove(old_chk)
......@@ -105,8 +107,8 @@ def load_checkpoint(args, trainer, **passthrough_args):
if args.distributed_rank == 0:
os.makedirs(args.save_dir, exist_ok=True)
if args.restore_file == 'checkpoint_last.pt':
checkpoint_path = os.path.join(args.save_dir, 'checkpoint_last.pt')
if args.restore_file == "checkpoint_last.pt":
checkpoint_path = os.path.join(args.save_dir, "checkpoint_last.pt")
else:
checkpoint_path = args.restore_file
......@@ -120,26 +122,22 @@ def load_checkpoint(args, trainer, **passthrough_args):
if (
extra_state is not None
and 'best' in extra_state
and "best" in extra_state
and not args.reset_optimizer
and not args.reset_meters
):
save_checkpoint.best = extra_state['best']
save_checkpoint.best = extra_state["best"]
if extra_state is not None and not args.reset_dataloader:
# restore iterator from checkpoint
itr_state = extra_state['train_iterator']
itr_state = extra_state["train_iterator"]
epoch_itr = trainer.get_train_iterator(
epoch=itr_state['epoch'],
load_dataset=True,
**passthrough_args
epoch=itr_state["epoch"], load_dataset=True, **passthrough_args
)
epoch_itr.load_state_dict(itr_state)
else:
epoch_itr = trainer.get_train_iterator(
epoch=0,
load_dataset=True,
**passthrough_args
epoch=0, load_dataset=True, **passthrough_args
)
trainer.lr_step(epoch_itr.epoch)
......@@ -151,16 +149,17 @@ def load_checkpoint_to_cpu(path, arg_overrides=None):
"""Loads a checkpoint to CPU (with upgrading for backward compatibility)."""
try:
from fairseq.fb_pathmgr import fb_pathmgr
with fb_pathmgr.open(path, "rb") as f:
state = torch.load(
f, map_location=lambda s, l: default_restore_location(s, 'cpu'),
f, map_location=lambda s, l: default_restore_location(s, "cpu")
)
except (ModuleNotFoundError, ImportError):
# if path manager not found, continue with local file.
state = torch.load(
path, map_location=lambda s, l: default_restore_location(s, 'cpu'),
path, map_location=lambda s, l: default_restore_location(s, "cpu")
)
args = state['args']
args = state["args"]
if arg_overrides is not None:
for arg_name, arg_val in arg_overrides.items():
setattr(args, arg_name, arg_val)
......@@ -187,21 +186,21 @@ def load_model_ensemble_and_task(filenames, arg_overrides=None, task=None):
ensemble = []
for filename in filenames:
if not os.path.exists(filename):
raise IOError('Model file not found: {}'.format(filename))
raise IOError("Model file not found: {}".format(filename))
state = load_checkpoint_to_cpu(filename, arg_overrides)
args = state['args']
args = state["args"]
if task is None:
task = tasks.setup_task(args)
# build model for ensemble
model = task.build_model(args)
model.load_state_dict(state['model'], strict=True, args=args)
model.load_state_dict(state["model"], strict=True, args=args)
ensemble.append(model)
return ensemble, args, task
def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"):
"""Retrieves all checkpoints found in `path` directory.
Checkpoints are identified by matching filename to the specified pattern. If
......@@ -244,34 +243,46 @@ def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
def save_state(
filename, args, model_state_dict, criterion, optimizer, lr_scheduler,
num_updates, optim_history=None, extra_state=None,
filename,
args,
model_state_dict,
criterion,
optimizer,
lr_scheduler,
num_updates,
optim_history=None,
extra_state=None,
):
from fairseq import utils
if optim_history is None:
optim_history = []
if extra_state is None:
extra_state = {}
state_dict = {
'args': args,
'model': model_state_dict if model_state_dict else {},
'optimizer_history': optim_history + [
"args": args,
"model": model_state_dict if model_state_dict else {},
"optimizer_history": optim_history
+ [
{
'criterion_name': criterion.__class__.__name__,
'optimizer_name': optimizer.__class__.__name__,
'lr_scheduler_state': lr_scheduler.state_dict(),
'num_updates': num_updates,
"criterion_name": criterion.__class__.__name__,
"optimizer_name": optimizer.__class__.__name__,
"lr_scheduler_state": lr_scheduler.state_dict(),
"num_updates": num_updates,
}
],
'extra_state': extra_state,
"extra_state": extra_state,
}
if utils.has_parameters(criterion):
state_dict['criterion'] = criterion.state_dict()
state_dict["criterion"] = criterion.state_dict()
if not args.no_save_optimizer_state:
state_dict['last_optimizer_state'] = convert_state_dict_type(optimizer.state_dict())
state_dict["last_optimizer_state"] = convert_state_dict_type(
optimizer.state_dict()
)
try:
from fairseq.fb_pathmgr import fb_pathmgr
with fb_pathmgr.open(filename, "wb") as f:
torch_persistent_save(state_dict, f)
except (ModuleNotFoundError, ImportError):
......@@ -284,65 +295,64 @@ def _upgrade_state_dict(state):
from fairseq import models, registry, tasks
# add optimizer_history
if 'optimizer_history' not in state:
state['optimizer_history'] = [
{
'criterion_name': 'CrossEntropyCriterion',
'best_loss': state['best_loss'],
},
if "optimizer_history" not in state:
state["optimizer_history"] = [
{"criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"]}
]
state['last_optimizer_state'] = state['optimizer']
del state['optimizer']
del state['best_loss']
state["last_optimizer_state"] = state["optimizer"]
del state["optimizer"]
del state["best_loss"]
# move extra_state into sub-dictionary
if 'epoch' in state and 'extra_state' not in state:
state['extra_state'] = {
'epoch': state['epoch'],
'batch_offset': state['batch_offset'],
'val_loss': state['val_loss'],
if "epoch" in state and "extra_state" not in state:
state["extra_state"] = {
"epoch": state["epoch"],
"batch_offset": state["batch_offset"],
"val_loss": state["val_loss"],
}
del state['epoch']
del state['batch_offset']
del state['val_loss']
del state["epoch"]
del state["batch_offset"]
del state["val_loss"]
# reduce optimizer history's memory usage (only keep the last state)
if 'optimizer' in state['optimizer_history'][-1]:
state['last_optimizer_state'] = state['optimizer_history'][-1]['optimizer']
for optim_hist in state['optimizer_history']:
del optim_hist['optimizer']
if "optimizer" in state["optimizer_history"][-1]:
state["last_optimizer_state"] = state["optimizer_history"][-1]["optimizer"]
for optim_hist in state["optimizer_history"]:
del optim_hist["optimizer"]
# record the optimizer class name
if 'optimizer_name' not in state['optimizer_history'][-1]:
state['optimizer_history'][-1]['optimizer_name'] = 'FairseqNAG'
if "optimizer_name" not in state["optimizer_history"][-1]:
state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG"
# move best_loss into lr_scheduler_state
if 'lr_scheduler_state' not in state['optimizer_history'][-1]:
state['optimizer_history'][-1]['lr_scheduler_state'] = {
'best': state['optimizer_history'][-1]['best_loss'],
if "lr_scheduler_state" not in state["optimizer_history"][-1]:
state["optimizer_history"][-1]["lr_scheduler_state"] = {
"best": state["optimizer_history"][-1]["best_loss"]
}
del state['optimizer_history'][-1]['best_loss']
del state["optimizer_history"][-1]["best_loss"]
# keep track of number of updates
if 'num_updates' not in state['optimizer_history'][-1]:
state['optimizer_history'][-1]['num_updates'] = 0
if "num_updates" not in state["optimizer_history"][-1]:
state["optimizer_history"][-1]["num_updates"] = 0
# old model checkpoints may not have separate source/target positions
if hasattr(state['args'], 'max_positions') and not hasattr(state['args'], 'max_source_positions'):
state['args'].max_source_positions = state['args'].max_positions
state['args'].max_target_positions = state['args'].max_positions
if hasattr(state["args"], "max_positions") and not hasattr(
state["args"], "max_source_positions"
):
state["args"].max_source_positions = state["args"].max_positions
state["args"].max_target_positions = state["args"].max_positions
# use stateful training data iterator
if 'train_iterator' not in state['extra_state']:
state['extra_state']['train_iterator'] = {
'epoch': state['extra_state']['epoch'],
'iterations_in_epoch': state['extra_state'].get('batch_offset', 0),
if "train_iterator" not in state["extra_state"]:
state["extra_state"]["train_iterator"] = {
"epoch": state["extra_state"]["epoch"],
"iterations_in_epoch": state["extra_state"].get("batch_offset", 0),
}
# default to translation task
if not hasattr(state['args'], 'task'):
state['args'].task = 'translation'
if not hasattr(state["args"], "task"):
state["args"].task = "translation"
# set any missing default values in the task, model or other registries
registry.set_defaults(state['args'], tasks.TASK_REGISTRY[state['args'].task])
registry.set_defaults(state['args'], models.ARCH_MODEL_REGISTRY[state['args'].arch])
registry.set_defaults(state["args"], tasks.TASK_REGISTRY[state["args"].task])
registry.set_defaults(state["args"], models.ARCH_MODEL_REGISTRY[state["args"].arch])
for registry_name, REGISTRY in registry.REGISTRIES.items():
choice = getattr(state['args'], registry_name, None)
choice = getattr(state["args"], registry_name, None)
if choice is not None:
cls = REGISTRY['registry'][choice]
registry.set_defaults(state['args'], cls)
cls = REGISTRY["registry"][choice]
registry.set_defaults(state["args"], cls)
return state
......@@ -362,26 +372,31 @@ def prune_state_dict(state_dict, args):
# args should not be none, but don't crash if it is.
return state_dict
encoder_layers_to_keep = args.encoder_layers_to_keep if "encoder_layers_to_keep" in vars(args) else None
decoder_layers_to_keep = args.decoder_layers_to_keep if "decoder_layers_to_keep" in vars(args) else None
encoder_layers_to_keep = (
args.encoder_layers_to_keep if "encoder_layers_to_keep" in vars(args) else None
)
decoder_layers_to_keep = (
args.decoder_layers_to_keep if "decoder_layers_to_keep" in vars(args) else None
)
if not encoder_layers_to_keep and not decoder_layers_to_keep:
return state_dict
# apply pruning
print("| Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop")
print(
"| Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop"
)
def create_pruning_pass(layers_to_keep, layer_name):
keep_layers = sorted([int(layer_string) for layer_string in layers_to_keep.split(",")])
keep_layers = sorted(
[int(layer_string) for layer_string in layers_to_keep.split(",")]
)
mapping_dict = {}
for i in range(len(keep_layers)):
mapping_dict[str(keep_layers[i])] = str(i)
regex = re.compile("^{layer}.*\.layers\.(\d+)".format(layer=layer_name))
return {
"substitution_regex": regex,
"mapping_dict": mapping_dict
}
return {"substitution_regex": regex, "mapping_dict": mapping_dict}
pruning_passes = []
if encoder_layers_to_keep:
......@@ -402,10 +417,18 @@ def prune_state_dict(state_dict, args):
original_layer_number = match.group(1)
# figure out which mapping dict to replace from
for pruning_pass in pruning_passes:
if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass["substitution_regex"].search(layer_name):
if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass[
"substitution_regex"
].search(layer_name):
new_layer_number = pruning_pass["mapping_dict"][original_layer_number]
substitution_match = pruning_pass["substitution_regex"].search(layer_name)
new_state_key = layer_name[:substitution_match.start(1)] + new_layer_number + layer_name[substitution_match.end(1):]
substitution_match = pruning_pass["substitution_regex"].search(
layer_name
)
new_state_key = (
layer_name[: substitution_match.start(1)]
+ new_layer_number
+ layer_name[substitution_match.end(1) :]
)
new_state_dict[new_state_key] = state_dict[layer_name]
# Since layers are now pruned, *_layers_to_keep are no longer needed.
......@@ -428,7 +451,7 @@ def load_pretrained_component_from_model(
`checkpoint` file.
"""
if not os.path.exists(checkpoint):
raise IOError('Model file not found: {}'.format(checkpoint))
raise IOError("Model file not found: {}".format(checkpoint))
state = load_checkpoint_to_cpu(checkpoint)
if isinstance(component, FairseqEncoder):
component_type = "encoder"
......@@ -443,7 +466,7 @@ def load_pretrained_component_from_model(
for key in state["model"].keys():
if key.startswith(component_type):
# encoder.input_layers.0.0.weight --> input_layers.0.0.weight
component_subkey = key[len(component_type) + 1:]
component_subkey = key[len(component_type) + 1 :]
component_state_dict[component_subkey] = state["model"][key]
component.load_state_dict(component_state_dict, strict=True)
return component
......@@ -452,12 +475,12 @@ def load_pretrained_component_from_model(
def verify_checkpoint_directory(save_dir: str) -> None:
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
temp_file_path = os.path.join(save_dir, 'dummy')
temp_file_path = os.path.join(save_dir, "dummy")
try:
with open(temp_file_path, 'w'):
with open(temp_file_path, "w"):
pass
except OSError as e:
print('| Unable to access checkpoint save directory: {}'.format(save_dir))
print("| Unable to access checkpoint save directory: {}".format(save_dir))
raise e
else:
os.remove(temp_file_path)
......@@ -7,15 +7,14 @@
Train a network across multiple GPUs.
"""
from collections import OrderedDict
import contextlib
from itertools import chain
import math
import os
import sys
from collections import OrderedDict
from itertools import chain
import torch
from fairseq import checkpoint_utils, distributed_utils, models, optim, utils
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
from fairseq.optim import lr_scheduler
......@@ -66,21 +65,21 @@ class Trainer(object):
def init_meters(self, args):
self.meters = OrderedDict()
self.meters['train_loss'] = AverageMeter()
self.meters['train_nll_loss'] = AverageMeter()
self.meters['valid_loss'] = AverageMeter()
self.meters['valid_nll_loss'] = AverageMeter()
self.meters['wps'] = TimeMeter() # words per second
self.meters['ups'] = TimeMeter() # updates per second
self.meters['wpb'] = AverageMeter() # words per batch
self.meters['bsz'] = AverageMeter() # sentences per batch
self.meters['gnorm'] = AverageMeter() # gradient norm
self.meters['clip'] = AverageMeter() # % of updates clipped
self.meters['oom'] = AverageMeter() # out of memory
self.meters["train_loss"] = AverageMeter()
self.meters["train_nll_loss"] = AverageMeter()
self.meters["valid_loss"] = AverageMeter()
self.meters["valid_nll_loss"] = AverageMeter()
self.meters["wps"] = TimeMeter() # words per second
self.meters["ups"] = TimeMeter() # updates per second
self.meters["wpb"] = AverageMeter() # words per batch
self.meters["bsz"] = AverageMeter() # sentences per batch
self.meters["gnorm"] = AverageMeter() # gradient norm
self.meters["clip"] = AverageMeter() # % of updates clipped
self.meters["oom"] = AverageMeter() # out of memory
if args.fp16:
self.meters['loss_scale'] = AverageMeter() # dynamic loss scale
self.meters['wall'] = TimeMeter() # wall time in seconds
self.meters['train_wall'] = StopwatchMeter() # train wall time in seconds
self.meters["loss_scale"] = AverageMeter() # dynamic loss scale
self.meters["wall"] = TimeMeter() # wall time in seconds
self.meters["train_wall"] = StopwatchMeter() # train wall time in seconds
@property
def criterion(self):
......@@ -102,7 +101,7 @@ class Trainer(object):
if self._wrapped_model is None:
if self.args.distributed_world_size > 1 and not self.args.use_bmuf:
self._wrapped_model = models.DistributedFairseqModel(
self.args, self._model,
self.args, self._model
)
else:
self._wrapped_model = self._model
......@@ -130,15 +129,19 @@ class Trainer(object):
if self.args.fp16:
if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
print('| WARNING: your device does NOT support faster training with --fp16, '
'please switch to FP32 which is likely to be faster')
print(
"| WARNING: your device does NOT support faster training with --fp16, "
"please switch to FP32 which is likely to be faster"
)
if self.args.memory_efficient_fp16:
self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(self.args, params)
self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(
self.args, params
)
else:
self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params)
else:
if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7:
print('| NOTICE: your device may support faster training with --fp16')
print("| NOTICE: your device may support faster training with --fp16")
self._optimizer = optim.build_optimizer(self.args, params)
if self.args.use_bmuf:
......@@ -152,11 +155,17 @@ class Trainer(object):
def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file."""
if distributed_utils.is_master(self.args): # only save one checkpoint
extra_state['train_meters'] = self.meters
extra_state["train_meters"] = self.meters
checkpoint_utils.save_state(
filename, self.args, self.get_model().state_dict(), self.get_criterion(),
self.optimizer, self.lr_scheduler, self.get_num_updates(),
self._optim_history, extra_state,
filename,
self.args,
self.get_model().state_dict(),
self.get_criterion(),
self.optimizer,
self.lr_scheduler,
self.get_num_updates(),
self._optim_history,
extra_state,
)
def load_checkpoint(
......@@ -172,6 +181,7 @@ class Trainer(object):
try:
from fairseq.fb_pathmgr import fb_pathmgr
bexists = fb_pathmgr.isfile(filename)
except (ModuleNotFoundError, ImportError):
bexists = os.path.exists(filename)
......@@ -181,18 +191,22 @@ class Trainer(object):
# load model parameters
try:
self.get_model().load_state_dict(state['model'], strict=True, args=self.args)
self.get_model().load_state_dict(
state["model"], strict=True, args=self.args
)
if utils.has_parameters(self.get_criterion()):
self.get_criterion().load_state_dict(state['criterion'], strict=True)
self.get_criterion().load_state_dict(
state["criterion"], strict=True
)
except Exception:
raise Exception(
'Cannot load model parameters from checkpoint {}; '
'please ensure that the architectures match.'.format(filename)
"Cannot load model parameters from checkpoint {}; "
"please ensure that the architectures match.".format(filename)
)
extra_state = state['extra_state']
self._optim_history = state['optimizer_history']
last_optim_state = state.get('last_optimizer_state', None)
extra_state = state["extra_state"]
self._optim_history = state["optimizer_history"]
last_optim_state = state.get("last_optimizer_state", None)
if last_optim_state is not None and not reset_optimizer:
# rebuild optimizer after loading model, since params may have changed
......@@ -200,41 +214,53 @@ class Trainer(object):
# only reload optimizer and lr_scheduler if they match
last_optim = self._optim_history[-1]
assert last_optim['criterion_name'] == self.get_criterion().__class__.__name__, \
'Criterion does not match; please reset the optimizer (--reset-optimizer).'
assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \
'Optimizer does not match; please reset the optimizer (--reset-optimizer).'
assert (
last_optim["criterion_name"] == self.get_criterion().__class__.__name__
), "Criterion does not match; please reset the optimizer (--reset-optimizer)."
assert (
last_optim["optimizer_name"] == self.optimizer.__class__.__name__
), "Optimizer does not match; please reset the optimizer (--reset-optimizer)."
if not reset_lr_scheduler:
self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])
self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"])
self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)
self.set_num_updates(last_optim['num_updates'])
self.set_num_updates(last_optim["num_updates"])
if extra_state is not None:
epoch = extra_state['train_iterator']['epoch']
print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(
filename, epoch, self.get_num_updates()))
epoch = extra_state["train_iterator"]["epoch"]
print(
"| loaded checkpoint {} (epoch {} @ {} updates)".format(
filename, epoch, self.get_num_updates()
)
)
self.lr_step(epoch)
if 'train_meters' in extra_state and not reset_meters:
self.meters.update(extra_state['train_meters'])
del extra_state['train_meters']
if "train_meters" in extra_state and not reset_meters:
self.meters.update(extra_state["train_meters"])
del extra_state["train_meters"]
# reset TimeMeters, since their start times don't make sense anymore
for meter in self.meters.values():
if isinstance(meter, TimeMeter):
meter.reset()
else:
print('| no existing checkpoint found {}'.format(filename))
print("| no existing checkpoint found {}".format(filename))
return extra_state
def get_train_iterator(self, epoch, combine=True, load_dataset=True, data_selector=None, shard_batch_itr=True):
def get_train_iterator(
self,
epoch,
combine=True,
load_dataset=True,
data_selector=None,
shard_batch_itr=True,
):
"""Return an EpochBatchIterator over the training set for a given epoch."""
if load_dataset:
print('| loading train data for epoch {}'.format(epoch))
print("| loading train data for epoch {}".format(epoch))
self.task.load_dataset(
self.args.train_subset,
epoch=epoch,
......@@ -246,8 +272,7 @@ class Trainer(object):
max_tokens=self.args.max_tokens,
max_sentences=self.args.max_sentences,
max_positions=utils.resolve_max_positions(
self.task.max_positions(),
self.model.max_positions(),
self.task.max_positions(), self.model.max_positions()
),
ignore_invalid_inputs=True,
required_batch_size_multiple=self.args.required_batch_size_multiple,
......@@ -269,7 +294,7 @@ class Trainer(object):
self.zero_grad()
if not dummy_batch:
self.meters['train_wall'].start()
self.meters["train_wall"].start()
# forward and backward pass
logging_outputs, sample_sizes, ooms = [], [], 0
......@@ -291,7 +316,7 @@ class Trainer(object):
"""
if (
self.args.distributed_world_size > 1
and hasattr(self.model, 'no_sync')
and hasattr(self.model, "no_sync")
and i < len(samples) - 1
):
return self.model.no_sync()
......@@ -302,8 +327,7 @@ class Trainer(object):
with maybe_no_sync():
# forward and backward
loss, sample_size, logging_output = self.task.train_step(
sample, self.model, self.criterion, self.optimizer,
ignore_grad
sample, self.model, self.criterion, self.optimizer, ignore_grad
)
if not ignore_grad:
......@@ -312,17 +336,21 @@ class Trainer(object):
if self.fast_stat_sync:
self._all_reduce_list[0] += sample_size
self._all_reduce_list[1] += logging_output.get('nsentences', 0.0)
self._all_reduce_list[2] += logging_output.get('loss', 0.0)
self._all_reduce_list[3] += logging_output.get('nll_loss', 0.0)
self._all_reduce_list[4] += logging_output.get('ntokens', 0.0)
self._all_reduce_list[1] += logging_output.get(
"nsentences", 0.0
)
self._all_reduce_list[2] += logging_output.get("loss", 0.0)
self._all_reduce_list[3] += logging_output.get("nll_loss", 0.0)
self._all_reduce_list[4] += logging_output.get("ntokens", 0.0)
except RuntimeError as e:
if 'out of memory' in str(e):
if "out of memory" in str(e):
self._log_oom(e)
if raise_oom:
raise e
print("| WARNING: attempting to recover from OOM in forward/backward pass",
file=sys.stderr)
print(
"| WARNING: attempting to recover from OOM in forward/backward pass",
file=sys.stderr,
)
ooms += 1
self.zero_grad()
else:
......@@ -331,7 +359,6 @@ class Trainer(object):
if self.fast_stat_sync:
self._all_reduce_list[5] += ooms
if ooms > 0 and self._oom_batch is not None:
self.handle_ooms(ooms)
......@@ -347,39 +374,36 @@ class Trainer(object):
# Normalize loss and nll_loss by "sample_size"
# and convert to log base 2
all_reduce_list_tensor[2:4].div_(
(
all_reduce_list_tensor[0:1] *
torch.log(torch.cuda.DoubleTensor([2]))
)
(all_reduce_list_tensor[0:1] * torch.log(torch.cuda.DoubleTensor([2])))
)
self._all_reduce_list = all_reduce_list_tensor.tolist()
logging_output = {}
[
sample_size,
logging_output['nsentences'],
logging_output['loss'],
logging_output['nll_loss'],
logging_output['ntokens'],
logging_output["nsentences"],
logging_output["loss"],
logging_output["nll_loss"],
logging_output["ntokens"],
ooms,
] = self._all_reduce_list
elif self._sync_stats():
logging_outputs, sample_sizes, ooms, prev_norms = \
zip(*distributed_utils.all_gather_list(
[logging_outputs, sample_sizes, ooms, self._prev_grad_norm],
))
logging_outputs, sample_sizes, ooms, prev_norms = zip(
*distributed_utils.all_gather_list(
[logging_outputs, sample_sizes, ooms, self._prev_grad_norm]
)
)
logging_outputs = list(chain.from_iterable(logging_outputs))
sample_sizes = list(chain.from_iterable(sample_sizes))
ooms = sum(ooms)
if not self.args.use_bmuf:
assert (
all(norm == prev_norms[0] for norm in prev_norms)
or all(math.isnan(norm) or math.isinf(norm) for norm in prev_norms)
), 'Fatal error: gradients are inconsistent between workers'
assert all(norm == prev_norms[0] for norm in prev_norms) or all(
math.isnan(norm) or math.isinf(norm) for norm in prev_norms
), "Fatal error: gradients are inconsistent between workers"
self.meters['oom'].update(ooms, len(samples))
self.meters["oom"].update(ooms, len(samples))
if ooms == self.args.distributed_world_size * len(samples):
print('| WARNING: OOM in all workers, skipping update')
print("| WARNING: OOM in all workers, skipping update")
self.zero_grad()
return None
......@@ -390,16 +414,20 @@ class Trainer(object):
)
sample_size = self.task.grad_denom(sample_sizes, self.get_criterion())
if not all(k in logging_output for k in ['ntokens', 'nsentences']):
raise Exception((
'Please update the {}.aggregate_logging_outputs() method to '
'return ntokens and nsentences'
).format(self.task.__class__.__name__))
if not all(k in logging_output for k in ["ntokens", "nsentences"]):
raise Exception(
(
"Please update the {}.aggregate_logging_outputs() method to "
"return ntokens and nsentences"
).format(self.task.__class__.__name__)
)
try:
# normalize grads by sample size
if sample_size > 0:
self.optimizer.multiply_grads(self.args.distributed_world_size / float(sample_size))
self.optimizer.multiply_grads(
self.args.distributed_world_size / float(sample_size)
)
# clip grads
grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)
......@@ -413,47 +441,57 @@ class Trainer(object):
self.task.update_step(self._num_updates)
# update meters
ntokens = logging_output.get('ntokens', 0)
nsentences = logging_output.get('nsentences', 0)
self.meters['wps'].update(ntokens)
self.meters['ups'].update(1.)
self.meters['wpb'].update(ntokens)
self.meters['bsz'].update(nsentences)
self.meters['gnorm'].update(grad_norm)
self.meters['clip'].update(
1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0.
ntokens = logging_output.get("ntokens", 0)
nsentences = logging_output.get("nsentences", 0)
self.meters["wps"].update(ntokens)
self.meters["ups"].update(1.0)
self.meters["wpb"].update(ntokens)
self.meters["bsz"].update(nsentences)
self.meters["gnorm"].update(grad_norm)
self.meters["clip"].update(
1.0
if grad_norm > self.args.clip_norm and self.args.clip_norm > 0
else 0.0
)
self.meters['train_loss'].update(logging_output.get('loss', 0), sample_size)
if 'train_acc' in self.meters:
self.meters['train_acc'].update(
logging_output.get('acc', 0), sample_size)
self.meters["train_loss"].update(logging_output.get("loss", 0), sample_size)
if "train_acc" in self.meters:
self.meters["train_acc"].update(
logging_output.get("acc", 0), sample_size
)
if 'nll_loss' in logging_output:
self.meters['train_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
if "nll_loss" in logging_output:
self.meters["train_nll_loss"].update(
logging_output.get("nll_loss", 0), ntokens
)
# clear CUDA cache to reduce memory fragmentation
if (self.args.empty_cache_freq > 0 and
((self.get_num_updates() + self.args.empty_cache_freq - 1) %
self.args.empty_cache_freq) == 0 and
torch.cuda.is_available() and
not self.args.cpu):
if (
self.args.empty_cache_freq > 0
and (
(self.get_num_updates() + self.args.empty_cache_freq - 1)
% self.args.empty_cache_freq
)
== 0
and torch.cuda.is_available()
and not self.args.cpu
):
torch.cuda.empty_cache()
except OverflowError as e:
print('| WARNING: overflow detected, ' + str(e))
print("| WARNING: overflow detected, " + str(e))
self.zero_grad()
logging_output = None
except RuntimeError as e:
if 'out of memory' in str(e):
if "out of memory" in str(e):
self._log_oom(e)
print('| ERROR: OOM during optimization, irrecoverable')
print("| ERROR: OOM during optimization, irrecoverable")
raise e
if self.args.fp16:
self.meters['loss_scale'].reset()
self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale)
self.meters["loss_scale"].reset()
self.meters["loss_scale"].update(self.optimizer.scaler.loss_scale)
self.clear_buffered_stats()
self.meters['train_wall'].stop()
self.meters["train_wall"].stop()
return logging_output
......@@ -475,10 +513,12 @@ class Trainer(object):
sample, self.model, self.criterion
)
except RuntimeError as e:
if 'out of memory' in str(e):
if "out of memory" in str(e):
self._log_oom(e)
if not raise_oom:
print('| WARNING: ran out of memory in validation step, retrying batch')
print(
"| WARNING: ran out of memory in validation step, retrying batch"
)
for p in self.model.parameters():
if p.grad is not None:
p.grad = None # free some memory
......@@ -492,9 +532,9 @@ class Trainer(object):
# gather logging outputs from all replicas
if self.args.distributed_world_size > 1:
logging_output, sample_size = zip(*distributed_utils.all_gather_list(
[logging_output, sample_size],
))
logging_output, sample_size = zip(
*distributed_utils.all_gather_list([logging_output, sample_size])
)
logging_output = list(logging_output)
sample_size = list(sample_size)
else:
......@@ -505,19 +545,18 @@ class Trainer(object):
logging_output = self.task.aggregate_logging_outputs(
logging_output, self.get_criterion()
)
sample_size = self.task.grad_denom(
sample_size, self.get_criterion()
)
sample_size = self.task.grad_denom(sample_size, self.get_criterion())
# update meters for validation
ntokens = logging_output.get('ntokens', 0)
self.meters['valid_loss'].update(logging_output.get('loss', 0), sample_size)
if 'valid_acc' in self.meters:
self.meters['valid_acc'].update(
logging_output.get('acc', 0), sample_size)
if 'nll_loss' in logging_output:
self.meters['valid_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
ntokens = logging_output.get("ntokens", 0)
self.meters["valid_loss"].update(logging_output.get("loss", 0), sample_size)
if "valid_acc" in self.meters:
self.meters["valid_acc"].update(logging_output.get("acc", 0), sample_size)
if "nll_loss" in logging_output:
self.meters["valid_nll_loss"].update(
logging_output.get("nll_loss", 0), ntokens
)
return logging_output
......@@ -604,19 +643,16 @@ class Trainer(object):
torch.cuda.manual_seed(seed)
def _sync_stats(self):
return (
self.args.distributed_world_size > 1 and
(
(not self.args.use_bmuf) or
(
self.args.use_bmuf
and (self.get_num_updates() + 1) % self.args.global_sync_iter == 0
)
return self.args.distributed_world_size > 1 and (
(not self.args.use_bmuf)
or (
self.args.use_bmuf
and (self.get_num_updates() + 1) % self.args.global_sync_iter == 0
)
)
def _log_oom(self, exc):
msg = '| OOM: Ran out of memory with exception: {}'.format(exc)
msg = "| OOM: Ran out of memory with exception: {}".format(exc)
# TODO: print should really go to logger, this print goes
# to stderr, which is buffered, which in many cases is not
# printed out if another exception happens.
......@@ -624,6 +660,5 @@ class Trainer(object):
print(msg, file=sys.stderr)
if torch.cuda.is_available() and hasattr(torch.cuda, "memory_summary"):
for device_idx in range(torch.cuda.device_count()):
print(torch.cuda.memory_summary(device=device_idx),
file=sys.stderr)
print(torch.cuda.memory_summary(device=device_idx), file=sys.stderr)
sys.stderr.flush()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment