Unverified Commit 2376ef4a authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

support Wandb (#29)

* add wandb support

* code clean
parent 854b8890
......@@ -2,15 +2,18 @@
[ -z "${n_gpu}" ] && n_gpu=$(nvidia-smi -L | wc -l)
export NCCL_ASYNC_ERROR_HANDLING=1
export OMP_NUM_THREADS=1
run_name=bert_example
save_dir="./save/${run_name}"
mkdir -p ${save_dir}
python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port=$MASTER_PORT $(which unicore-train) ./example_data --user-dir . --valid-subset valid \
--num-workers 0 --ddp-backend=c10d \
--task bert --loss masked_lm --arch bert_base \
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-6 --clip-norm 1.0 \
--lr-scheduler polynomial_decay --lr 1e-4 --warmup-updates 100 --total-num-update 10000 --batch-size 4 \
--update-freq 1 --seed 1 \
--fp16 --fp16-init-scale 4 --fp16-scale-window 256 --tensorboard-logdir ./tsb/ \
--fp16 --fp16-init-scale 4 --fp16-scale-window 256 --tensorboard-logdir $save_dir/tsb \
--max-update 10000 --log-interval 100 --log-format simple \
--save-interval-updates 5000 --validate-interval-updates 5000 --keep-interval-updates 30 --no-epoch-checkpoints \
--save-dir ./save
--save-interval-updates 1000 --validate-interval-updates 1000 --keep-interval-updates 30 --no-epoch-checkpoints \
--save-dir $save_dir
......@@ -33,7 +33,9 @@ def progress_bar(
epoch: Optional[int] = None,
prefix: Optional[str] = None,
tensorboard_logdir: Optional[str] = None,
wandb_project: Optional[str] = None,
default_log_format: str = "tqdm",
args=None,
):
if log_format is None:
log_format = default_log_format
......@@ -52,44 +54,13 @@ def progress_bar(
raise ValueError("Unknown log format: {}".format(log_format))
if tensorboard_logdir:
try:
# [FB only] custom wrapper for TensorBoard
import palaas # noqa
from .fb_tbmf_wrapper import FbTbmfWrapper
bar = FbTbmfWrapper(bar, log_interval)
except ImportError:
bar = TensorboardProgressBarWrapper(bar, tensorboard_logdir)
bar = TensorboardProgressBarWrapper(
bar, tensorboard_logdir, wandb_project, args
)
return bar
def build_progress_bar(
args,
iterator,
epoch: Optional[int] = None,
prefix: Optional[str] = None,
default: str = "tqdm",
no_progress_bar: str = "none",
):
"""Legacy wrapper that takes an argparse.Namespace."""
if getattr(args, "no_progress_bar", False):
default = no_progress_bar
if getattr(args, "distributed_rank", 0) == 0:
tensorboard_logdir = getattr(args, "tensorboard_logdir", None)
else:
tensorboard_logdir = None
return progress_bar(
iterator,
log_format=args.log_format,
log_interval=args.log_interval,
epoch=epoch,
prefix=prefix,
tensorboard_logdir=tensorboard_logdir,
default_log_format=default,
)
def format_stat(stat):
if isinstance(stat, Number):
stat = "{:g}".format(stat)
......@@ -306,10 +277,23 @@ except ImportError:
except ImportError:
SummaryWriter = None
try:
_wandb_inited = False
import wandb
wandb_available = True
except ImportError:
wandb_available = False
def _close_writers():
for w in _tensorboard_writers.values():
w.close()
if _wandb_inited:
try:
wandb.finish()
except:
pass
atexit.register(_close_writers)
......@@ -318,7 +302,7 @@ atexit.register(_close_writers)
class TensorboardProgressBarWrapper(BaseProgressBar):
"""Log to tensorboard."""
def __init__(self, wrapped_bar, tensorboard_logdir):
def __init__(self, wrapped_bar, tensorboard_logdir, wandb_project, args):
self.wrapped_bar = wrapped_bar
self.tensorboard_logdir = tensorboard_logdir
......@@ -326,6 +310,17 @@ class TensorboardProgressBarWrapper(BaseProgressBar):
logger.warning(
"tensorboard not found, please install with: pip install tensorboard"
)
global _wandb_inited
if not _wandb_inited and wandb_project and wandb_available:
wandb_name = args.wandb_name or wandb.util.generate_id()
wandb.init(
project=wandb_project,
name=wandb_name,
config=vars(args),
id=wandb_name,
resume="allow",
)
_wandb_inited = True
def _writer(self, key):
if SummaryWriter is None:
......@@ -362,9 +357,15 @@ class TensorboardProgressBarWrapper(BaseProgressBar):
step = stats["num_updates"]
for key in stats.keys() - {"num_updates"}:
if isinstance(stats[key], AverageMeter):
writer.add_scalar(key, stats[key].val, step)
val = stats[key].val
elif isinstance(stats[key], Number):
writer.add_scalar(key, stats[key], step)
val = stats[key]
elif torch.is_tensor(stats[key]) and stats[key].numel() == 1:
writer.add_scalar(key, stats[key].item(), step)
val = stats[key].item()
else:
val = None
if val:
writer.add_scalar(key, val, step)
if _wandb_inited:
wandb.log({"{}_{}".format(tag, key): val}, step=step)
writer.flush()
......@@ -10,8 +10,15 @@ import torch
from typing import Callable, List, Optional
# this import is for backward compatibility
from unicore.utils import csv_str_list, eval_bool, eval_str_dict, eval_str_list, import_user_module # noqa
from unicore.utils import (
csv_str_list,
eval_bool,
eval_str_dict,
eval_str_list,
import_user_module,
) # noqa
def get_training_parser(default_task="translation"):
......@@ -137,7 +144,7 @@ def parse_args_and_arch(
args.no_seed_provided = True
else:
args.no_seed_provided = False
args.validate_with_ema = getattr(args, "validate_with_ema", False)
# Apply architecture configuration.
if hasattr(args, "arch") and args.arch in ARCH_CONFIG_REGISTRY:
......@@ -149,11 +156,11 @@ def parse_args_and_arch(
return args
def get_parser(desc, default_task='test'):
def get_parser(desc, default_task="test"):
# Before creating the true parser, we need to import optional user module
# in order to eagerly import custom tasks, optimizers, architectures, etc.
usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
usr_parser.add_argument('--user-dir', default=None)
usr_parser.add_argument("--user-dir", default=None)
usr_args, _ = usr_parser.parse_known_args()
import_user_module(usr_args)
......@@ -167,6 +174,10 @@ def get_parser(desc, default_task='test'):
parser.add_argument('--tensorboard-logdir', metavar='DIR', default='',
help='path to save logs for tensorboard, should match --logdir '
'of running tensorboard (default: no tensorboard logging)')
parser.add_argument('--wandb-project', metavar='DIR', default='',
help='name of wandb project, empty for no wandb logging, for wandb login, use env WANDB_API_KEY')
parser.add_argument('--wandb-name', metavar='DIR', default='',
help='wandb run/id name, empty for no wandb logging, for wandb login, use env WANDB_API_KEY')
parser.add_argument('--seed', default=1, type=int, metavar='N',
help='pseudo random number generator seed')
parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA')
......@@ -216,7 +227,7 @@ def get_parser(desc, default_task='test'):
def add_dataset_args(parser, train=False, gen=False):
group = parser.add_argument_group('Dataset and data loading')
group = parser.add_argument_group("Dataset and data loading")
# fmt: off
group.add_argument('--num-workers', default=1, type=int, metavar='N',
help='how many subprocesses to use for data loading')
......@@ -256,7 +267,7 @@ def add_dataset_args(parser, train=False, gen=False):
def add_distributed_training_args(parser):
group = parser.add_argument_group('Distributed training')
group = parser.add_argument_group("Distributed training")
# fmt: off
group.add_argument('--distributed-world-size', type=int, metavar='N',
default=max(1, torch.cuda.device_count()),
......@@ -301,7 +312,7 @@ def add_distributed_training_args(parser):
def add_optimization_args(parser):
group = parser.add_argument_group('Optimization')
group = parser.add_argument_group("Optimization")
# fmt: off
group.add_argument('--max-epoch', '--me', default=0, type=int, metavar='N',
help='force stop training at specified epoch')
......@@ -327,7 +338,7 @@ def add_optimization_args(parser):
def add_checkpoint_args(parser):
group = parser.add_argument_group('Checkpointing')
group = parser.add_argument_group("Checkpointing")
# fmt: off
group.add_argument('--save-dir', metavar='DIR', default='checkpoints',
help='path to save checkpoints')
......@@ -397,7 +408,7 @@ def add_common_eval_args(group):
def add_model_args(parser):
group = parser.add_argument_group('Model configuration')
group = parser.add_argument_group("Model configuration")
# fmt: off
# Model definitions can be found under unicore/models/
......
......@@ -40,7 +40,6 @@ logger = logging.getLogger("unicore_cli.train")
def main(args) -> None:
utils.import_user_module(args)
utils.set_jit_fusion_options()
......@@ -84,17 +83,17 @@ def main(args) -> None:
logger.info(
"num. model params: {:,} (num. trained: {:,})".format(
sum(getattr(p, "_orig_size", p).numel() for p in model.parameters()),
sum(getattr(p, "_orig_size", p).numel() for p in model.parameters() if p.requires_grad),
sum(
getattr(p, "_orig_size", p).numel()
for p in model.parameters()
if p.requires_grad
),
)
)
# Build trainer
trainer = Trainer(args, task, model, loss)
logger.info(
"training on {} devices (GPUs)".format(
args.distributed_world_size
)
)
logger.info("training on {} devices (GPUs)".format(args.distributed_world_size))
logger.info(
"batch size per device = {}".format(
args.batch_size,
......@@ -123,7 +122,9 @@ def main(args) -> None:
break
# train for one epoch
valid_losses, should_stop = train(args, trainer, task, epoch_itr, ckp_copy_thread)
valid_losses, should_stop = train(
args, trainer, task, epoch_itr, ckp_copy_thread
)
if should_stop:
break
......@@ -194,11 +195,13 @@ def train(
log_interval=args.log_interval,
epoch=epoch_itr.epoch,
tensorboard_logdir=(
args.tensorboard_logdir
if distributed_utils.is_master(args)
else None
args.tensorboard_logdir if distributed_utils.is_master(args) else None
),
wandb_project=(
args.wandb_project if distributed_utils.is_master(args) else None
),
default_log_format=("tqdm" if not args.no_progress_bar else "simple"),
args=args,
)
trainer.begin_epoch(epoch_itr.epoch)
......@@ -267,10 +270,7 @@ def validate_and_save(
)
training_time_hours = trainer.cumulative_training_time() / (60 * 60)
if (
args.stop_time_hours > 0
and training_time_hours > args.stop_time_hours
):
if args.stop_time_hours > 0 and training_time_hours > args.stop_time_hours:
should_stop = True
logger.info(
f"Stopping training due to "
......@@ -279,7 +279,11 @@ def validate_and_save(
)
do_save = (
(end_of_epoch and epoch_itr.epoch % args.save_interval == 0 and not args.no_epoch_checkpoints)
(
end_of_epoch
and epoch_itr.epoch % args.save_interval == 0
and not args.no_epoch_checkpoints
)
or should_stop
or (
args.save_interval_updates > 0
......@@ -290,7 +294,11 @@ def validate_and_save(
)
do_validate = (
(not end_of_epoch and do_save) # validate during mid-epoch saves
or (end_of_epoch and epoch_itr.epoch % args.validate_interval == 0 and not args.no_epoch_checkpoints)
or (
end_of_epoch
and epoch_itr.epoch % args.validate_interval == 0
and not args.no_epoch_checkpoints
)
or should_stop
or (
args.validate_interval_updates > 0
......@@ -309,7 +317,12 @@ def validate_and_save(
# Save checkpoint
checkpoint_utils.save_checkpoint(
args, trainer, epoch_itr, valid_losses[0], ckp_copy_thread, do_save=(do_save or should_stop),
args,
trainer,
epoch_itr,
valid_losses[0],
ckp_copy_thread,
do_save=(do_save or should_stop),
)
return valid_losses, should_stop
......@@ -377,11 +390,12 @@ def validate(
return valid_losses
def get_valid_stats(
args, trainer: Trainer, stats: Dict[str, Any]
) -> Dict[str, Any]:
def get_valid_stats(args, trainer: Trainer, stats: Dict[str, Any]) -> Dict[str, Any]:
stats["num_updates"] = trainer.get_num_updates()
if hasattr(checkpoint_utils.save_checkpoint, "best") and args.best_checkpoint_metric in stats:
if (
hasattr(checkpoint_utils.save_checkpoint, "best")
and args.best_checkpoint_metric in stats
):
key = "best_{0}".format(args.best_checkpoint_metric)
best_function = max if args.maximize_best_checkpoint_metric else min
stats[key] = best_function(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment