Unverified Commit aac4c799 authored by janEbert's avatar janEbert Committed by GitHub
Browse files

Fix non-deterministic Megatron-LM checkpoint name (#24674)

Fix non-deterministic checkpoint name

`os.listdir`'s order is not deterministic, which is a problem when
querying the first listed file as in the code (`os.listdir(...)[0]`).

This can return a checkpoint name such as `distrib_optim.pt`, which does
not include desired information such as the saved arguments originally
given to Megatron-LM.
parent 33aafc26
......@@ -291,8 +291,10 @@ def get_megatron_sharded_states(args, tp_size, pp_size, pp_rank):
tp_state_dicts = []
for i in range(tp_size):
sub_dir_name = f"mp_rank_{i:02d}" if pp_size == 1 else f"mp_rank_{i:02d}_{pp_rank:03d}"
checkpoint_name = os.listdir(os.path.join(args.load_path, sub_dir_name))[0]
for checkpoint_name in ["model_optim_rng.pt", "model_rng.pt"]:
checkpoint_path = os.path.join(args.load_path, sub_dir_name, checkpoint_name)
if os.path.isfile(checkpoint_path):
break
state_dict = torch.load(checkpoint_path, map_location="cpu")
tp_state_dicts.append(state_dict)
return tp_state_dicts
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment