Commit 977efdfb authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

added backwards compatibility to checkpointing.py.

parent ef63acd2
......@@ -91,19 +91,23 @@ def get_checkpoint_names(checkpoints_path, iteration, use_distributed_optimizer,
# Use both the tensor and pipeline MP rank. If using the distributed
# optimizer, then the optimizer's path must additionally include the
# data parallel rank.
common_path = os.path.join(
checkpoints_path,
directory,
"mp_rank_%02d_%03d" % (
mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank()))
model_name = os.path.join(common_path, "model_rng.pt")
if mpu.get_pipeline_model_parallel_world_size() == 1:
common_path = os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}'.format(
mpu.get_tensor_model_parallel_rank()))
else:
common_path = os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}_{:03d}'.format(
mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank()))
if use_distributed_optimizer:
model_name = os.path.join(common_path, "model_rng.pt")
optim_name = os.path.join(
common_path + "_%03d" % mpu.get_data_parallel_rank(),
"optim.pt")
else:
optim_name = os.path.join(common_path, "optim.pt")
model_name = optim_name = os.path.join(common_path, "model_optim_rng.pt")
return model_name, optim_name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment