Unverified Commit c9d31d39 authored by Nopileos2's avatar Nopileos2 Committed by GitHub
Browse files

Changed dist to dist_train for loading parameters (#621)

Passing dist_train instead of dist as the to_cpu paremeter to the loading functions.
parent 8e055f02
......@@ -123,17 +123,17 @@ def main():
start_epoch = it = 0
last_epoch = -1
if args.pretrained_model is not None:
model.load_params_from_file(filename=args.pretrained_model, to_cpu=dist, logger=logger)
model.load_params_from_file(filename=args.pretrained_model, to_cpu=dist_train, logger=logger)
if args.ckpt is not None:
it, start_epoch = model.load_params_with_optimizer(args.ckpt, to_cpu=dist, optimizer=optimizer, logger=logger)
it, start_epoch = model.load_params_with_optimizer(args.ckpt, to_cpu=dist_train, optimizer=optimizer, logger=logger)
last_epoch = start_epoch + 1
else:
ckpt_list = glob.glob(str(ckpt_dir / '*checkpoint_epoch_*.pth'))
if len(ckpt_list) > 0:
ckpt_list.sort(key=os.path.getmtime)
it, start_epoch = model.load_params_with_optimizer(
ckpt_list[-1], to_cpu=dist, optimizer=optimizer, logger=logger
ckpt_list[-1], to_cpu=dist_train, optimizer=optimizer, logger=logger
)
last_epoch = start_epoch + 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment