"docs/vscode:/vscode.git/clone" did not exist on "f3785a34041977f4d76fe29977998fbd2ee59df8"
Commit 10c2e8c1 authored by Wei Ho's avatar Wei Ho Committed by Facebook Github Bot
Browse files

Apply Black auto-formatting

Reviewed By: sujitoc

Differential Revision: D18738392

fbshipit-source-id: b7b7b75ef97946786c463c1887ef9a8676f030e6
parent cfc4b303
...@@ -3,25 +3,24 @@ ...@@ -3,25 +3,24 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from collections import OrderedDict
from typing import Union
import collections import collections
import logging import logging
import os import os
import re import re
import traceback
import shutil import shutil
import traceback
from collections import OrderedDict
from typing import Union
import torch import torch
from fairseq.models import FairseqDecoder, FairseqEncoder
from torch.serialization import default_restore_location from torch.serialization import default_restore_location
from fairseq.models import FairseqEncoder, FairseqDecoder
def save_checkpoint(args, trainer, epoch_itr, val_loss): def save_checkpoint(args, trainer, epoch_itr, val_loss):
from fairseq import distributed_utils, meters from fairseq import distributed_utils, meters
prev_best = getattr(save_checkpoint, 'best', val_loss) prev_best = getattr(save_checkpoint, "best", val_loss)
if val_loss is not None: if val_loss is not None:
best_function = max if args.maximize_best_checkpoint_metric else min best_function = max if args.maximize_best_checkpoint_metric else min
save_checkpoint.best = best_function(val_loss, prev_best) save_checkpoint.best = best_function(val_loss, prev_best)
...@@ -40,56 +39,59 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss): ...@@ -40,56 +39,59 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
updates = trainer.get_num_updates() updates = trainer.get_num_updates()
checkpoint_conds = collections.OrderedDict() checkpoint_conds = collections.OrderedDict()
checkpoint_conds['checkpoint{}.pt'.format(epoch)] = ( checkpoint_conds["checkpoint{}.pt".format(epoch)] = (
end_of_epoch and not args.no_epoch_checkpoints and end_of_epoch
epoch % args.save_interval == 0 and not args.no_epoch_checkpoints
and epoch % args.save_interval == 0
) )
checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = ( checkpoint_conds["checkpoint_{}_{}.pt".format(epoch, updates)] = (
not end_of_epoch and args.save_interval_updates > 0 and not end_of_epoch
updates % args.save_interval_updates == 0 and args.save_interval_updates > 0
and updates % args.save_interval_updates == 0
) )
checkpoint_conds['checkpoint_best.pt'] = ( checkpoint_conds["checkpoint_best.pt"] = val_loss is not None and (
val_loss is not None and not hasattr(save_checkpoint, "best")
(not hasattr(save_checkpoint, 'best') or is_better(val_loss, save_checkpoint.best)) or is_better(val_loss, save_checkpoint.best)
) )
checkpoint_conds['checkpoint_last.pt'] = not args.no_last_checkpoints checkpoint_conds["checkpoint_last.pt"] = not args.no_last_checkpoints
extra_state = { extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
'train_iterator': epoch_itr.state_dict(), if hasattr(save_checkpoint, "best"):
'val_loss': val_loss, extra_state.update({"best": save_checkpoint.best})
}
if hasattr(save_checkpoint, 'best'):
extra_state.update({'best': save_checkpoint.best})
checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond] checkpoints = [
os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
]
if len(checkpoints) > 0: if len(checkpoints) > 0:
trainer.save_checkpoint(checkpoints[0], extra_state) trainer.save_checkpoint(checkpoints[0], extra_state)
for cp in checkpoints[1:]: for cp in checkpoints[1:]:
try: try:
from fairseq.fb_pathmgr import fb_pathmgr from fairseq.fb_pathmgr import fb_pathmgr
fb_pathmgr.copy(checkpoints[0], cp, True) fb_pathmgr.copy(checkpoints[0], cp, True)
except (ModuleNotFoundError, ImportError): except (ModuleNotFoundError, ImportError):
shutil.copyfile(checkpoints[0], cp) shutil.copyfile(checkpoints[0], cp)
write_timer.stop() write_timer.stop()
print('| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)'.format( print(
checkpoints[0], epoch, updates, write_timer.sum)) "| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)".format(
checkpoints[0], epoch, updates, write_timer.sum
)
)
if not end_of_epoch and args.keep_interval_updates > 0: if not end_of_epoch and args.keep_interval_updates > 0:
# remove old checkpoints; checkpoints are sorted in descending order # remove old checkpoints; checkpoints are sorted in descending order
checkpoints = checkpoint_paths( checkpoints = checkpoint_paths(
args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt', args.save_dir, pattern=r"checkpoint_\d+_(\d+)\.pt"
) )
for old_chk in checkpoints[args.keep_interval_updates:]: for old_chk in checkpoints[args.keep_interval_updates :]:
if os.path.lexists(old_chk): if os.path.lexists(old_chk):
os.remove(old_chk) os.remove(old_chk)
if args.keep_last_epochs > 0: if args.keep_last_epochs > 0:
# remove old epoch checkpoints; checkpoints are sorted in descending order # remove old epoch checkpoints; checkpoints are sorted in descending order
checkpoints = checkpoint_paths( checkpoints = checkpoint_paths(args.save_dir, pattern=r"checkpoint(\d+)\.pt")
args.save_dir, pattern=r'checkpoint(\d+)\.pt', for old_chk in checkpoints[args.keep_last_epochs :]:
)
for old_chk in checkpoints[args.keep_last_epochs:]:
if os.path.lexists(old_chk): if os.path.lexists(old_chk):
os.remove(old_chk) os.remove(old_chk)
...@@ -105,8 +107,8 @@ def load_checkpoint(args, trainer, **passthrough_args): ...@@ -105,8 +107,8 @@ def load_checkpoint(args, trainer, **passthrough_args):
if args.distributed_rank == 0: if args.distributed_rank == 0:
os.makedirs(args.save_dir, exist_ok=True) os.makedirs(args.save_dir, exist_ok=True)
if args.restore_file == 'checkpoint_last.pt': if args.restore_file == "checkpoint_last.pt":
checkpoint_path = os.path.join(args.save_dir, 'checkpoint_last.pt') checkpoint_path = os.path.join(args.save_dir, "checkpoint_last.pt")
else: else:
checkpoint_path = args.restore_file checkpoint_path = args.restore_file
...@@ -120,26 +122,22 @@ def load_checkpoint(args, trainer, **passthrough_args): ...@@ -120,26 +122,22 @@ def load_checkpoint(args, trainer, **passthrough_args):
if ( if (
extra_state is not None extra_state is not None
and 'best' in extra_state and "best" in extra_state
and not args.reset_optimizer and not args.reset_optimizer
and not args.reset_meters and not args.reset_meters
): ):
save_checkpoint.best = extra_state['best'] save_checkpoint.best = extra_state["best"]
if extra_state is not None and not args.reset_dataloader: if extra_state is not None and not args.reset_dataloader:
# restore iterator from checkpoint # restore iterator from checkpoint
itr_state = extra_state['train_iterator'] itr_state = extra_state["train_iterator"]
epoch_itr = trainer.get_train_iterator( epoch_itr = trainer.get_train_iterator(
epoch=itr_state['epoch'], epoch=itr_state["epoch"], load_dataset=True, **passthrough_args
load_dataset=True,
**passthrough_args
) )
epoch_itr.load_state_dict(itr_state) epoch_itr.load_state_dict(itr_state)
else: else:
epoch_itr = trainer.get_train_iterator( epoch_itr = trainer.get_train_iterator(
epoch=0, epoch=0, load_dataset=True, **passthrough_args
load_dataset=True,
**passthrough_args
) )
trainer.lr_step(epoch_itr.epoch) trainer.lr_step(epoch_itr.epoch)
...@@ -151,16 +149,17 @@ def load_checkpoint_to_cpu(path, arg_overrides=None): ...@@ -151,16 +149,17 @@ def load_checkpoint_to_cpu(path, arg_overrides=None):
"""Loads a checkpoint to CPU (with upgrading for backward compatibility).""" """Loads a checkpoint to CPU (with upgrading for backward compatibility)."""
try: try:
from fairseq.fb_pathmgr import fb_pathmgr from fairseq.fb_pathmgr import fb_pathmgr
with fb_pathmgr.open(path, "rb") as f: with fb_pathmgr.open(path, "rb") as f:
state = torch.load( state = torch.load(
f, map_location=lambda s, l: default_restore_location(s, 'cpu'), f, map_location=lambda s, l: default_restore_location(s, "cpu")
) )
except (ModuleNotFoundError, ImportError): except (ModuleNotFoundError, ImportError):
# if path manager not found, continue with local file. # if path manager not found, continue with local file.
state = torch.load( state = torch.load(
path, map_location=lambda s, l: default_restore_location(s, 'cpu'), path, map_location=lambda s, l: default_restore_location(s, "cpu")
) )
args = state['args'] args = state["args"]
if arg_overrides is not None: if arg_overrides is not None:
for arg_name, arg_val in arg_overrides.items(): for arg_name, arg_val in arg_overrides.items():
setattr(args, arg_name, arg_val) setattr(args, arg_name, arg_val)
...@@ -187,21 +186,21 @@ def load_model_ensemble_and_task(filenames, arg_overrides=None, task=None): ...@@ -187,21 +186,21 @@ def load_model_ensemble_and_task(filenames, arg_overrides=None, task=None):
ensemble = [] ensemble = []
for filename in filenames: for filename in filenames:
if not os.path.exists(filename): if not os.path.exists(filename):
raise IOError('Model file not found: {}'.format(filename)) raise IOError("Model file not found: {}".format(filename))
state = load_checkpoint_to_cpu(filename, arg_overrides) state = load_checkpoint_to_cpu(filename, arg_overrides)
args = state['args'] args = state["args"]
if task is None: if task is None:
task = tasks.setup_task(args) task = tasks.setup_task(args)
# build model for ensemble # build model for ensemble
model = task.build_model(args) model = task.build_model(args)
model.load_state_dict(state['model'], strict=True, args=args) model.load_state_dict(state["model"], strict=True, args=args)
ensemble.append(model) ensemble.append(model)
return ensemble, args, task return ensemble, args, task
def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'): def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"):
"""Retrieves all checkpoints found in `path` directory. """Retrieves all checkpoints found in `path` directory.
Checkpoints are identified by matching filename to the specified pattern. If Checkpoints are identified by matching filename to the specified pattern. If
...@@ -244,34 +243,46 @@ def convert_state_dict_type(state_dict, ttype=torch.FloatTensor): ...@@ -244,34 +243,46 @@ def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
def save_state( def save_state(
filename, args, model_state_dict, criterion, optimizer, lr_scheduler, filename,
num_updates, optim_history=None, extra_state=None, args,
model_state_dict,
criterion,
optimizer,
lr_scheduler,
num_updates,
optim_history=None,
extra_state=None,
): ):
from fairseq import utils from fairseq import utils
if optim_history is None: if optim_history is None:
optim_history = [] optim_history = []
if extra_state is None: if extra_state is None:
extra_state = {} extra_state = {}
state_dict = { state_dict = {
'args': args, "args": args,
'model': model_state_dict if model_state_dict else {}, "model": model_state_dict if model_state_dict else {},
'optimizer_history': optim_history + [ "optimizer_history": optim_history
+ [
{ {
'criterion_name': criterion.__class__.__name__, "criterion_name": criterion.__class__.__name__,
'optimizer_name': optimizer.__class__.__name__, "optimizer_name": optimizer.__class__.__name__,
'lr_scheduler_state': lr_scheduler.state_dict(), "lr_scheduler_state": lr_scheduler.state_dict(),
'num_updates': num_updates, "num_updates": num_updates,
} }
], ],
'extra_state': extra_state, "extra_state": extra_state,
} }
if utils.has_parameters(criterion): if utils.has_parameters(criterion):
state_dict['criterion'] = criterion.state_dict() state_dict["criterion"] = criterion.state_dict()
if not args.no_save_optimizer_state: if not args.no_save_optimizer_state:
state_dict['last_optimizer_state'] = convert_state_dict_type(optimizer.state_dict()) state_dict["last_optimizer_state"] = convert_state_dict_type(
optimizer.state_dict()
)
try: try:
from fairseq.fb_pathmgr import fb_pathmgr from fairseq.fb_pathmgr import fb_pathmgr
with fb_pathmgr.open(filename, "wb") as f: with fb_pathmgr.open(filename, "wb") as f:
torch_persistent_save(state_dict, f) torch_persistent_save(state_dict, f)
except (ModuleNotFoundError, ImportError): except (ModuleNotFoundError, ImportError):
...@@ -284,65 +295,64 @@ def _upgrade_state_dict(state): ...@@ -284,65 +295,64 @@ def _upgrade_state_dict(state):
from fairseq import models, registry, tasks from fairseq import models, registry, tasks
# add optimizer_history # add optimizer_history
if 'optimizer_history' not in state: if "optimizer_history" not in state:
state['optimizer_history'] = [ state["optimizer_history"] = [
{ {"criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"]}
'criterion_name': 'CrossEntropyCriterion',
'best_loss': state['best_loss'],
},
] ]
state['last_optimizer_state'] = state['optimizer'] state["last_optimizer_state"] = state["optimizer"]
del state['optimizer'] del state["optimizer"]
del state['best_loss'] del state["best_loss"]
# move extra_state into sub-dictionary # move extra_state into sub-dictionary
if 'epoch' in state and 'extra_state' not in state: if "epoch" in state and "extra_state" not in state:
state['extra_state'] = { state["extra_state"] = {
'epoch': state['epoch'], "epoch": state["epoch"],
'batch_offset': state['batch_offset'], "batch_offset": state["batch_offset"],
'val_loss': state['val_loss'], "val_loss": state["val_loss"],
} }
del state['epoch'] del state["epoch"]
del state['batch_offset'] del state["batch_offset"]
del state['val_loss'] del state["val_loss"]
# reduce optimizer history's memory usage (only keep the last state) # reduce optimizer history's memory usage (only keep the last state)
if 'optimizer' in state['optimizer_history'][-1]: if "optimizer" in state["optimizer_history"][-1]:
state['last_optimizer_state'] = state['optimizer_history'][-1]['optimizer'] state["last_optimizer_state"] = state["optimizer_history"][-1]["optimizer"]
for optim_hist in state['optimizer_history']: for optim_hist in state["optimizer_history"]:
del optim_hist['optimizer'] del optim_hist["optimizer"]
# record the optimizer class name # record the optimizer class name
if 'optimizer_name' not in state['optimizer_history'][-1]: if "optimizer_name" not in state["optimizer_history"][-1]:
state['optimizer_history'][-1]['optimizer_name'] = 'FairseqNAG' state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG"
# move best_loss into lr_scheduler_state # move best_loss into lr_scheduler_state
if 'lr_scheduler_state' not in state['optimizer_history'][-1]: if "lr_scheduler_state" not in state["optimizer_history"][-1]:
state['optimizer_history'][-1]['lr_scheduler_state'] = { state["optimizer_history"][-1]["lr_scheduler_state"] = {
'best': state['optimizer_history'][-1]['best_loss'], "best": state["optimizer_history"][-1]["best_loss"]
} }
del state['optimizer_history'][-1]['best_loss'] del state["optimizer_history"][-1]["best_loss"]
# keep track of number of updates # keep track of number of updates
if 'num_updates' not in state['optimizer_history'][-1]: if "num_updates" not in state["optimizer_history"][-1]:
state['optimizer_history'][-1]['num_updates'] = 0 state["optimizer_history"][-1]["num_updates"] = 0
# old model checkpoints may not have separate source/target positions # old model checkpoints may not have separate source/target positions
if hasattr(state['args'], 'max_positions') and not hasattr(state['args'], 'max_source_positions'): if hasattr(state["args"], "max_positions") and not hasattr(
state['args'].max_source_positions = state['args'].max_positions state["args"], "max_source_positions"
state['args'].max_target_positions = state['args'].max_positions ):
state["args"].max_source_positions = state["args"].max_positions
state["args"].max_target_positions = state["args"].max_positions
# use stateful training data iterator # use stateful training data iterator
if 'train_iterator' not in state['extra_state']: if "train_iterator" not in state["extra_state"]:
state['extra_state']['train_iterator'] = { state["extra_state"]["train_iterator"] = {
'epoch': state['extra_state']['epoch'], "epoch": state["extra_state"]["epoch"],
'iterations_in_epoch': state['extra_state'].get('batch_offset', 0), "iterations_in_epoch": state["extra_state"].get("batch_offset", 0),
} }
# default to translation task # default to translation task
if not hasattr(state['args'], 'task'): if not hasattr(state["args"], "task"):
state['args'].task = 'translation' state["args"].task = "translation"
# set any missing default values in the task, model or other registries # set any missing default values in the task, model or other registries
registry.set_defaults(state['args'], tasks.TASK_REGISTRY[state['args'].task]) registry.set_defaults(state["args"], tasks.TASK_REGISTRY[state["args"].task])
registry.set_defaults(state['args'], models.ARCH_MODEL_REGISTRY[state['args'].arch]) registry.set_defaults(state["args"], models.ARCH_MODEL_REGISTRY[state["args"].arch])
for registry_name, REGISTRY in registry.REGISTRIES.items(): for registry_name, REGISTRY in registry.REGISTRIES.items():
choice = getattr(state['args'], registry_name, None) choice = getattr(state["args"], registry_name, None)
if choice is not None: if choice is not None:
cls = REGISTRY['registry'][choice] cls = REGISTRY["registry"][choice]
registry.set_defaults(state['args'], cls) registry.set_defaults(state["args"], cls)
return state return state
...@@ -362,26 +372,31 @@ def prune_state_dict(state_dict, args): ...@@ -362,26 +372,31 @@ def prune_state_dict(state_dict, args):
# args should not be none, but don't crash if it is. # args should not be none, but don't crash if it is.
return state_dict return state_dict
encoder_layers_to_keep = args.encoder_layers_to_keep if "encoder_layers_to_keep" in vars(args) else None encoder_layers_to_keep = (
decoder_layers_to_keep = args.decoder_layers_to_keep if "decoder_layers_to_keep" in vars(args) else None args.encoder_layers_to_keep if "encoder_layers_to_keep" in vars(args) else None
)
decoder_layers_to_keep = (
args.decoder_layers_to_keep if "decoder_layers_to_keep" in vars(args) else None
)
if not encoder_layers_to_keep and not decoder_layers_to_keep: if not encoder_layers_to_keep and not decoder_layers_to_keep:
return state_dict return state_dict
# apply pruning # apply pruning
print("| Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop") print(
"| Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop"
)
def create_pruning_pass(layers_to_keep, layer_name): def create_pruning_pass(layers_to_keep, layer_name):
keep_layers = sorted([int(layer_string) for layer_string in layers_to_keep.split(",")]) keep_layers = sorted(
[int(layer_string) for layer_string in layers_to_keep.split(",")]
)
mapping_dict = {} mapping_dict = {}
for i in range(len(keep_layers)): for i in range(len(keep_layers)):
mapping_dict[str(keep_layers[i])] = str(i) mapping_dict[str(keep_layers[i])] = str(i)
regex = re.compile("^{layer}.*\.layers\.(\d+)".format(layer=layer_name)) regex = re.compile("^{layer}.*\.layers\.(\d+)".format(layer=layer_name))
return { return {"substitution_regex": regex, "mapping_dict": mapping_dict}
"substitution_regex": regex,
"mapping_dict": mapping_dict
}
pruning_passes = [] pruning_passes = []
if encoder_layers_to_keep: if encoder_layers_to_keep:
...@@ -402,10 +417,18 @@ def prune_state_dict(state_dict, args): ...@@ -402,10 +417,18 @@ def prune_state_dict(state_dict, args):
original_layer_number = match.group(1) original_layer_number = match.group(1)
# figure out which mapping dict to replace from # figure out which mapping dict to replace from
for pruning_pass in pruning_passes: for pruning_pass in pruning_passes:
if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass["substitution_regex"].search(layer_name): if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass[
"substitution_regex"
].search(layer_name):
new_layer_number = pruning_pass["mapping_dict"][original_layer_number] new_layer_number = pruning_pass["mapping_dict"][original_layer_number]
substitution_match = pruning_pass["substitution_regex"].search(layer_name) substitution_match = pruning_pass["substitution_regex"].search(
new_state_key = layer_name[:substitution_match.start(1)] + new_layer_number + layer_name[substitution_match.end(1):] layer_name
)
new_state_key = (
layer_name[: substitution_match.start(1)]
+ new_layer_number
+ layer_name[substitution_match.end(1) :]
)
new_state_dict[new_state_key] = state_dict[layer_name] new_state_dict[new_state_key] = state_dict[layer_name]
# Since layers are now pruned, *_layers_to_keep are no longer needed. # Since layers are now pruned, *_layers_to_keep are no longer needed.
...@@ -428,7 +451,7 @@ def load_pretrained_component_from_model( ...@@ -428,7 +451,7 @@ def load_pretrained_component_from_model(
`checkpoint` file. `checkpoint` file.
""" """
if not os.path.exists(checkpoint): if not os.path.exists(checkpoint):
raise IOError('Model file not found: {}'.format(checkpoint)) raise IOError("Model file not found: {}".format(checkpoint))
state = load_checkpoint_to_cpu(checkpoint) state = load_checkpoint_to_cpu(checkpoint)
if isinstance(component, FairseqEncoder): if isinstance(component, FairseqEncoder):
component_type = "encoder" component_type = "encoder"
...@@ -443,7 +466,7 @@ def load_pretrained_component_from_model( ...@@ -443,7 +466,7 @@ def load_pretrained_component_from_model(
for key in state["model"].keys(): for key in state["model"].keys():
if key.startswith(component_type): if key.startswith(component_type):
# encoder.input_layers.0.0.weight --> input_layers.0.0.weight # encoder.input_layers.0.0.weight --> input_layers.0.0.weight
component_subkey = key[len(component_type) + 1:] component_subkey = key[len(component_type) + 1 :]
component_state_dict[component_subkey] = state["model"][key] component_state_dict[component_subkey] = state["model"][key]
component.load_state_dict(component_state_dict, strict=True) component.load_state_dict(component_state_dict, strict=True)
return component return component
...@@ -452,12 +475,12 @@ def load_pretrained_component_from_model( ...@@ -452,12 +475,12 @@ def load_pretrained_component_from_model(
def verify_checkpoint_directory(save_dir: str) -> None: def verify_checkpoint_directory(save_dir: str) -> None:
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
temp_file_path = os.path.join(save_dir, 'dummy') temp_file_path = os.path.join(save_dir, "dummy")
try: try:
with open(temp_file_path, 'w'): with open(temp_file_path, "w"):
pass pass
except OSError as e: except OSError as e:
print('| Unable to access checkpoint save directory: {}'.format(save_dir)) print("| Unable to access checkpoint save directory: {}".format(save_dir))
raise e raise e
else: else:
os.remove(temp_file_path) os.remove(temp_file_path)
...@@ -7,15 +7,14 @@ ...@@ -7,15 +7,14 @@
Train a network across multiple GPUs. Train a network across multiple GPUs.
""" """
from collections import OrderedDict
import contextlib import contextlib
from itertools import chain
import math import math
import os import os
import sys import sys
from collections import OrderedDict
from itertools import chain
import torch import torch
from fairseq import checkpoint_utils, distributed_utils, models, optim, utils from fairseq import checkpoint_utils, distributed_utils, models, optim, utils
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
from fairseq.optim import lr_scheduler from fairseq.optim import lr_scheduler
...@@ -66,21 +65,21 @@ class Trainer(object): ...@@ -66,21 +65,21 @@ class Trainer(object):
def init_meters(self, args): def init_meters(self, args):
self.meters = OrderedDict() self.meters = OrderedDict()
self.meters['train_loss'] = AverageMeter() self.meters["train_loss"] = AverageMeter()
self.meters['train_nll_loss'] = AverageMeter() self.meters["train_nll_loss"] = AverageMeter()
self.meters['valid_loss'] = AverageMeter() self.meters["valid_loss"] = AverageMeter()
self.meters['valid_nll_loss'] = AverageMeter() self.meters["valid_nll_loss"] = AverageMeter()
self.meters['wps'] = TimeMeter() # words per second self.meters["wps"] = TimeMeter() # words per second
self.meters['ups'] = TimeMeter() # updates per second self.meters["ups"] = TimeMeter() # updates per second
self.meters['wpb'] = AverageMeter() # words per batch self.meters["wpb"] = AverageMeter() # words per batch
self.meters['bsz'] = AverageMeter() # sentences per batch self.meters["bsz"] = AverageMeter() # sentences per batch
self.meters['gnorm'] = AverageMeter() # gradient norm self.meters["gnorm"] = AverageMeter() # gradient norm
self.meters['clip'] = AverageMeter() # % of updates clipped self.meters["clip"] = AverageMeter() # % of updates clipped
self.meters['oom'] = AverageMeter() # out of memory self.meters["oom"] = AverageMeter() # out of memory
if args.fp16: if args.fp16:
self.meters['loss_scale'] = AverageMeter() # dynamic loss scale self.meters["loss_scale"] = AverageMeter() # dynamic loss scale
self.meters['wall'] = TimeMeter() # wall time in seconds self.meters["wall"] = TimeMeter() # wall time in seconds
self.meters['train_wall'] = StopwatchMeter() # train wall time in seconds self.meters["train_wall"] = StopwatchMeter() # train wall time in seconds
@property @property
def criterion(self): def criterion(self):
...@@ -102,7 +101,7 @@ class Trainer(object): ...@@ -102,7 +101,7 @@ class Trainer(object):
if self._wrapped_model is None: if self._wrapped_model is None:
if self.args.distributed_world_size > 1 and not self.args.use_bmuf: if self.args.distributed_world_size > 1 and not self.args.use_bmuf:
self._wrapped_model = models.DistributedFairseqModel( self._wrapped_model = models.DistributedFairseqModel(
self.args, self._model, self.args, self._model
) )
else: else:
self._wrapped_model = self._model self._wrapped_model = self._model
...@@ -130,15 +129,19 @@ class Trainer(object): ...@@ -130,15 +129,19 @@ class Trainer(object):
if self.args.fp16: if self.args.fp16:
if self.cuda and torch.cuda.get_device_capability(0)[0] < 7: if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
print('| WARNING: your device does NOT support faster training with --fp16, ' print(
'please switch to FP32 which is likely to be faster') "| WARNING: your device does NOT support faster training with --fp16, "
"please switch to FP32 which is likely to be faster"
)
if self.args.memory_efficient_fp16: if self.args.memory_efficient_fp16:
self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(self.args, params) self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(
self.args, params
)
else: else:
self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params) self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params)
else: else:
if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7: if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7:
print('| NOTICE: your device may support faster training with --fp16') print("| NOTICE: your device may support faster training with --fp16")
self._optimizer = optim.build_optimizer(self.args, params) self._optimizer = optim.build_optimizer(self.args, params)
if self.args.use_bmuf: if self.args.use_bmuf:
...@@ -152,11 +155,17 @@ class Trainer(object): ...@@ -152,11 +155,17 @@ class Trainer(object):
def save_checkpoint(self, filename, extra_state): def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file.""" """Save all training state in a checkpoint file."""
if distributed_utils.is_master(self.args): # only save one checkpoint if distributed_utils.is_master(self.args): # only save one checkpoint
extra_state['train_meters'] = self.meters extra_state["train_meters"] = self.meters
checkpoint_utils.save_state( checkpoint_utils.save_state(
filename, self.args, self.get_model().state_dict(), self.get_criterion(), filename,
self.optimizer, self.lr_scheduler, self.get_num_updates(), self.args,
self._optim_history, extra_state, self.get_model().state_dict(),
self.get_criterion(),
self.optimizer,
self.lr_scheduler,
self.get_num_updates(),
self._optim_history,
extra_state,
) )
def load_checkpoint( def load_checkpoint(
...@@ -172,6 +181,7 @@ class Trainer(object): ...@@ -172,6 +181,7 @@ class Trainer(object):
try: try:
from fairseq.fb_pathmgr import fb_pathmgr from fairseq.fb_pathmgr import fb_pathmgr
bexists = fb_pathmgr.isfile(filename) bexists = fb_pathmgr.isfile(filename)
except (ModuleNotFoundError, ImportError): except (ModuleNotFoundError, ImportError):
bexists = os.path.exists(filename) bexists = os.path.exists(filename)
...@@ -181,18 +191,22 @@ class Trainer(object): ...@@ -181,18 +191,22 @@ class Trainer(object):
# load model parameters # load model parameters
try: try:
self.get_model().load_state_dict(state['model'], strict=True, args=self.args) self.get_model().load_state_dict(
state["model"], strict=True, args=self.args
)
if utils.has_parameters(self.get_criterion()): if utils.has_parameters(self.get_criterion()):
self.get_criterion().load_state_dict(state['criterion'], strict=True) self.get_criterion().load_state_dict(
state["criterion"], strict=True
)
except Exception: except Exception:
raise Exception( raise Exception(
'Cannot load model parameters from checkpoint {}; ' "Cannot load model parameters from checkpoint {}; "
'please ensure that the architectures match.'.format(filename) "please ensure that the architectures match.".format(filename)
) )
extra_state = state['extra_state'] extra_state = state["extra_state"]
self._optim_history = state['optimizer_history'] self._optim_history = state["optimizer_history"]
last_optim_state = state.get('last_optimizer_state', None) last_optim_state = state.get("last_optimizer_state", None)
if last_optim_state is not None and not reset_optimizer: if last_optim_state is not None and not reset_optimizer:
# rebuild optimizer after loading model, since params may have changed # rebuild optimizer after loading model, since params may have changed
...@@ -200,41 +214,53 @@ class Trainer(object): ...@@ -200,41 +214,53 @@ class Trainer(object):
# only reload optimizer and lr_scheduler if they match # only reload optimizer and lr_scheduler if they match
last_optim = self._optim_history[-1] last_optim = self._optim_history[-1]
assert last_optim['criterion_name'] == self.get_criterion().__class__.__name__, \ assert (
'Criterion does not match; please reset the optimizer (--reset-optimizer).' last_optim["criterion_name"] == self.get_criterion().__class__.__name__
assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \ ), "Criterion does not match; please reset the optimizer (--reset-optimizer)."
'Optimizer does not match; please reset the optimizer (--reset-optimizer).' assert (
last_optim["optimizer_name"] == self.optimizer.__class__.__name__
), "Optimizer does not match; please reset the optimizer (--reset-optimizer)."
if not reset_lr_scheduler: if not reset_lr_scheduler:
self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state']) self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"])
self.optimizer.load_state_dict(last_optim_state, optimizer_overrides) self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)
self.set_num_updates(last_optim['num_updates']) self.set_num_updates(last_optim["num_updates"])
if extra_state is not None: if extra_state is not None:
epoch = extra_state['train_iterator']['epoch'] epoch = extra_state["train_iterator"]["epoch"]
print('| loaded checkpoint {} (epoch {} @ {} updates)'.format( print(
filename, epoch, self.get_num_updates())) "| loaded checkpoint {} (epoch {} @ {} updates)".format(
filename, epoch, self.get_num_updates()
)
)
self.lr_step(epoch) self.lr_step(epoch)
if 'train_meters' in extra_state and not reset_meters: if "train_meters" in extra_state and not reset_meters:
self.meters.update(extra_state['train_meters']) self.meters.update(extra_state["train_meters"])
del extra_state['train_meters'] del extra_state["train_meters"]
# reset TimeMeters, since their start times don't make sense anymore # reset TimeMeters, since their start times don't make sense anymore
for meter in self.meters.values(): for meter in self.meters.values():
if isinstance(meter, TimeMeter): if isinstance(meter, TimeMeter):
meter.reset() meter.reset()
else: else:
print('| no existing checkpoint found {}'.format(filename)) print("| no existing checkpoint found {}".format(filename))
return extra_state return extra_state
def get_train_iterator(self, epoch, combine=True, load_dataset=True, data_selector=None, shard_batch_itr=True): def get_train_iterator(
self,
epoch,
combine=True,
load_dataset=True,
data_selector=None,
shard_batch_itr=True,
):
"""Return an EpochBatchIterator over the training set for a given epoch.""" """Return an EpochBatchIterator over the training set for a given epoch."""
if load_dataset: if load_dataset:
print('| loading train data for epoch {}'.format(epoch)) print("| loading train data for epoch {}".format(epoch))
self.task.load_dataset( self.task.load_dataset(
self.args.train_subset, self.args.train_subset,
epoch=epoch, epoch=epoch,
...@@ -246,8 +272,7 @@ class Trainer(object): ...@@ -246,8 +272,7 @@ class Trainer(object):
max_tokens=self.args.max_tokens, max_tokens=self.args.max_tokens,
max_sentences=self.args.max_sentences, max_sentences=self.args.max_sentences,
max_positions=utils.resolve_max_positions( max_positions=utils.resolve_max_positions(
self.task.max_positions(), self.task.max_positions(), self.model.max_positions()
self.model.max_positions(),
), ),
ignore_invalid_inputs=True, ignore_invalid_inputs=True,
required_batch_size_multiple=self.args.required_batch_size_multiple, required_batch_size_multiple=self.args.required_batch_size_multiple,
...@@ -269,7 +294,7 @@ class Trainer(object): ...@@ -269,7 +294,7 @@ class Trainer(object):
self.zero_grad() self.zero_grad()
if not dummy_batch: if not dummy_batch:
self.meters['train_wall'].start() self.meters["train_wall"].start()
# forward and backward pass # forward and backward pass
logging_outputs, sample_sizes, ooms = [], [], 0 logging_outputs, sample_sizes, ooms = [], [], 0
...@@ -291,7 +316,7 @@ class Trainer(object): ...@@ -291,7 +316,7 @@ class Trainer(object):
""" """
if ( if (
self.args.distributed_world_size > 1 self.args.distributed_world_size > 1
and hasattr(self.model, 'no_sync') and hasattr(self.model, "no_sync")
and i < len(samples) - 1 and i < len(samples) - 1
): ):
return self.model.no_sync() return self.model.no_sync()
...@@ -302,8 +327,7 @@ class Trainer(object): ...@@ -302,8 +327,7 @@ class Trainer(object):
with maybe_no_sync(): with maybe_no_sync():
# forward and backward # forward and backward
loss, sample_size, logging_output = self.task.train_step( loss, sample_size, logging_output = self.task.train_step(
sample, self.model, self.criterion, self.optimizer, sample, self.model, self.criterion, self.optimizer, ignore_grad
ignore_grad
) )
if not ignore_grad: if not ignore_grad:
...@@ -312,17 +336,21 @@ class Trainer(object): ...@@ -312,17 +336,21 @@ class Trainer(object):
if self.fast_stat_sync: if self.fast_stat_sync:
self._all_reduce_list[0] += sample_size self._all_reduce_list[0] += sample_size
self._all_reduce_list[1] += logging_output.get('nsentences', 0.0) self._all_reduce_list[1] += logging_output.get(
self._all_reduce_list[2] += logging_output.get('loss', 0.0) "nsentences", 0.0
self._all_reduce_list[3] += logging_output.get('nll_loss', 0.0) )
self._all_reduce_list[4] += logging_output.get('ntokens', 0.0) self._all_reduce_list[2] += logging_output.get("loss", 0.0)
self._all_reduce_list[3] += logging_output.get("nll_loss", 0.0)
self._all_reduce_list[4] += logging_output.get("ntokens", 0.0)
except RuntimeError as e: except RuntimeError as e:
if 'out of memory' in str(e): if "out of memory" in str(e):
self._log_oom(e) self._log_oom(e)
if raise_oom: if raise_oom:
raise e raise e
print("| WARNING: attempting to recover from OOM in forward/backward pass", print(
file=sys.stderr) "| WARNING: attempting to recover from OOM in forward/backward pass",
file=sys.stderr,
)
ooms += 1 ooms += 1
self.zero_grad() self.zero_grad()
else: else:
...@@ -331,7 +359,6 @@ class Trainer(object): ...@@ -331,7 +359,6 @@ class Trainer(object):
if self.fast_stat_sync: if self.fast_stat_sync:
self._all_reduce_list[5] += ooms self._all_reduce_list[5] += ooms
if ooms > 0 and self._oom_batch is not None: if ooms > 0 and self._oom_batch is not None:
self.handle_ooms(ooms) self.handle_ooms(ooms)
...@@ -347,39 +374,36 @@ class Trainer(object): ...@@ -347,39 +374,36 @@ class Trainer(object):
# Normalize loss and nll_loss by "sample_size" # Normalize loss and nll_loss by "sample_size"
# and convert to log base 2 # and convert to log base 2
all_reduce_list_tensor[2:4].div_( all_reduce_list_tensor[2:4].div_(
( (all_reduce_list_tensor[0:1] * torch.log(torch.cuda.DoubleTensor([2])))
all_reduce_list_tensor[0:1] *
torch.log(torch.cuda.DoubleTensor([2]))
)
) )
self._all_reduce_list = all_reduce_list_tensor.tolist() self._all_reduce_list = all_reduce_list_tensor.tolist()
logging_output = {} logging_output = {}
[ [
sample_size, sample_size,
logging_output['nsentences'], logging_output["nsentences"],
logging_output['loss'], logging_output["loss"],
logging_output['nll_loss'], logging_output["nll_loss"],
logging_output['ntokens'], logging_output["ntokens"],
ooms, ooms,
] = self._all_reduce_list ] = self._all_reduce_list
elif self._sync_stats(): elif self._sync_stats():
logging_outputs, sample_sizes, ooms, prev_norms = \ logging_outputs, sample_sizes, ooms, prev_norms = zip(
zip(*distributed_utils.all_gather_list( *distributed_utils.all_gather_list(
[logging_outputs, sample_sizes, ooms, self._prev_grad_norm], [logging_outputs, sample_sizes, ooms, self._prev_grad_norm]
)) )
)
logging_outputs = list(chain.from_iterable(logging_outputs)) logging_outputs = list(chain.from_iterable(logging_outputs))
sample_sizes = list(chain.from_iterable(sample_sizes)) sample_sizes = list(chain.from_iterable(sample_sizes))
ooms = sum(ooms) ooms = sum(ooms)
if not self.args.use_bmuf: if not self.args.use_bmuf:
assert ( assert all(norm == prev_norms[0] for norm in prev_norms) or all(
all(norm == prev_norms[0] for norm in prev_norms) math.isnan(norm) or math.isinf(norm) for norm in prev_norms
or all(math.isnan(norm) or math.isinf(norm) for norm in prev_norms) ), "Fatal error: gradients are inconsistent between workers"
), 'Fatal error: gradients are inconsistent between workers'
self.meters['oom'].update(ooms, len(samples)) self.meters["oom"].update(ooms, len(samples))
if ooms == self.args.distributed_world_size * len(samples): if ooms == self.args.distributed_world_size * len(samples):
print('| WARNING: OOM in all workers, skipping update') print("| WARNING: OOM in all workers, skipping update")
self.zero_grad() self.zero_grad()
return None return None
...@@ -390,16 +414,20 @@ class Trainer(object): ...@@ -390,16 +414,20 @@ class Trainer(object):
) )
sample_size = self.task.grad_denom(sample_sizes, self.get_criterion()) sample_size = self.task.grad_denom(sample_sizes, self.get_criterion())
if not all(k in logging_output for k in ['ntokens', 'nsentences']): if not all(k in logging_output for k in ["ntokens", "nsentences"]):
raise Exception(( raise Exception(
'Please update the {}.aggregate_logging_outputs() method to ' (
'return ntokens and nsentences' "Please update the {}.aggregate_logging_outputs() method to "
).format(self.task.__class__.__name__)) "return ntokens and nsentences"
).format(self.task.__class__.__name__)
)
try: try:
# normalize grads by sample size # normalize grads by sample size
if sample_size > 0: if sample_size > 0:
self.optimizer.multiply_grads(self.args.distributed_world_size / float(sample_size)) self.optimizer.multiply_grads(
self.args.distributed_world_size / float(sample_size)
)
# clip grads # clip grads
grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm) grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)
...@@ -413,47 +441,57 @@ class Trainer(object): ...@@ -413,47 +441,57 @@ class Trainer(object):
self.task.update_step(self._num_updates) self.task.update_step(self._num_updates)
# update meters # update meters
ntokens = logging_output.get('ntokens', 0) ntokens = logging_output.get("ntokens", 0)
nsentences = logging_output.get('nsentences', 0) nsentences = logging_output.get("nsentences", 0)
self.meters['wps'].update(ntokens) self.meters["wps"].update(ntokens)
self.meters['ups'].update(1.) self.meters["ups"].update(1.0)
self.meters['wpb'].update(ntokens) self.meters["wpb"].update(ntokens)
self.meters['bsz'].update(nsentences) self.meters["bsz"].update(nsentences)
self.meters['gnorm'].update(grad_norm) self.meters["gnorm"].update(grad_norm)
self.meters['clip'].update( self.meters["clip"].update(
1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0. 1.0
if grad_norm > self.args.clip_norm and self.args.clip_norm > 0
else 0.0
) )
self.meters['train_loss'].update(logging_output.get('loss', 0), sample_size) self.meters["train_loss"].update(logging_output.get("loss", 0), sample_size)
if 'train_acc' in self.meters: if "train_acc" in self.meters:
self.meters['train_acc'].update( self.meters["train_acc"].update(
logging_output.get('acc', 0), sample_size) logging_output.get("acc", 0), sample_size
)
if 'nll_loss' in logging_output: if "nll_loss" in logging_output:
self.meters['train_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens) self.meters["train_nll_loss"].update(
logging_output.get("nll_loss", 0), ntokens
)
# clear CUDA cache to reduce memory fragmentation # clear CUDA cache to reduce memory fragmentation
if (self.args.empty_cache_freq > 0 and if (
((self.get_num_updates() + self.args.empty_cache_freq - 1) % self.args.empty_cache_freq > 0
self.args.empty_cache_freq) == 0 and and (
torch.cuda.is_available() and (self.get_num_updates() + self.args.empty_cache_freq - 1)
not self.args.cpu): % self.args.empty_cache_freq
)
== 0
and torch.cuda.is_available()
and not self.args.cpu
):
torch.cuda.empty_cache() torch.cuda.empty_cache()
except OverflowError as e: except OverflowError as e:
print('| WARNING: overflow detected, ' + str(e)) print("| WARNING: overflow detected, " + str(e))
self.zero_grad() self.zero_grad()
logging_output = None logging_output = None
except RuntimeError as e: except RuntimeError as e:
if 'out of memory' in str(e): if "out of memory" in str(e):
self._log_oom(e) self._log_oom(e)
print('| ERROR: OOM during optimization, irrecoverable') print("| ERROR: OOM during optimization, irrecoverable")
raise e raise e
if self.args.fp16: if self.args.fp16:
self.meters['loss_scale'].reset() self.meters["loss_scale"].reset()
self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale) self.meters["loss_scale"].update(self.optimizer.scaler.loss_scale)
self.clear_buffered_stats() self.clear_buffered_stats()
self.meters['train_wall'].stop() self.meters["train_wall"].stop()
return logging_output return logging_output
...@@ -475,10 +513,12 @@ class Trainer(object): ...@@ -475,10 +513,12 @@ class Trainer(object):
sample, self.model, self.criterion sample, self.model, self.criterion
) )
except RuntimeError as e: except RuntimeError as e:
if 'out of memory' in str(e): if "out of memory" in str(e):
self._log_oom(e) self._log_oom(e)
if not raise_oom: if not raise_oom:
print('| WARNING: ran out of memory in validation step, retrying batch') print(
"| WARNING: ran out of memory in validation step, retrying batch"
)
for p in self.model.parameters(): for p in self.model.parameters():
if p.grad is not None: if p.grad is not None:
p.grad = None # free some memory p.grad = None # free some memory
...@@ -492,9 +532,9 @@ class Trainer(object): ...@@ -492,9 +532,9 @@ class Trainer(object):
# gather logging outputs from all replicas # gather logging outputs from all replicas
if self.args.distributed_world_size > 1: if self.args.distributed_world_size > 1:
logging_output, sample_size = zip(*distributed_utils.all_gather_list( logging_output, sample_size = zip(
[logging_output, sample_size], *distributed_utils.all_gather_list([logging_output, sample_size])
)) )
logging_output = list(logging_output) logging_output = list(logging_output)
sample_size = list(sample_size) sample_size = list(sample_size)
else: else:
...@@ -505,19 +545,18 @@ class Trainer(object): ...@@ -505,19 +545,18 @@ class Trainer(object):
logging_output = self.task.aggregate_logging_outputs( logging_output = self.task.aggregate_logging_outputs(
logging_output, self.get_criterion() logging_output, self.get_criterion()
) )
sample_size = self.task.grad_denom( sample_size = self.task.grad_denom(sample_size, self.get_criterion())
sample_size, self.get_criterion()
)
# update meters for validation # update meters for validation
ntokens = logging_output.get('ntokens', 0) ntokens = logging_output.get("ntokens", 0)
self.meters['valid_loss'].update(logging_output.get('loss', 0), sample_size) self.meters["valid_loss"].update(logging_output.get("loss", 0), sample_size)
if 'valid_acc' in self.meters: if "valid_acc" in self.meters:
self.meters['valid_acc'].update( self.meters["valid_acc"].update(logging_output.get("acc", 0), sample_size)
logging_output.get('acc', 0), sample_size)
if "nll_loss" in logging_output:
if 'nll_loss' in logging_output: self.meters["valid_nll_loss"].update(
self.meters['valid_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens) logging_output.get("nll_loss", 0), ntokens
)
return logging_output return logging_output
...@@ -604,19 +643,16 @@ class Trainer(object): ...@@ -604,19 +643,16 @@ class Trainer(object):
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
def _sync_stats(self): def _sync_stats(self):
return ( return self.args.distributed_world_size > 1 and (
self.args.distributed_world_size > 1 and (not self.args.use_bmuf)
( or (
(not self.args.use_bmuf) or self.args.use_bmuf
( and (self.get_num_updates() + 1) % self.args.global_sync_iter == 0
self.args.use_bmuf
and (self.get_num_updates() + 1) % self.args.global_sync_iter == 0
)
) )
) )
def _log_oom(self, exc): def _log_oom(self, exc):
msg = '| OOM: Ran out of memory with exception: {}'.format(exc) msg = "| OOM: Ran out of memory with exception: {}".format(exc)
# TODO: print should really go to logger, this print goes # TODO: print should really go to logger, this print goes
# to stderr, which is buffered, which in many cases is not # to stderr, which is buffered, which in many cases is not
# printed out if another exception happens. # printed out if another exception happens.
...@@ -624,6 +660,5 @@ class Trainer(object): ...@@ -624,6 +660,5 @@ class Trainer(object):
print(msg, file=sys.stderr) print(msg, file=sys.stderr)
if torch.cuda.is_available() and hasattr(torch.cuda, "memory_summary"): if torch.cuda.is_available() and hasattr(torch.cuda, "memory_summary"):
for device_idx in range(torch.cuda.device_count()): for device_idx in range(torch.cuda.device_count()):
print(torch.cuda.memory_summary(device=device_idx), print(torch.cuda.memory_summary(device=device_idx), file=sys.stderr)
file=sys.stderr)
sys.stderr.flush() sys.stderr.flush()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment