Commit 872e38ea authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

Merge branch 'merge_bugfix' into 'main'

Fix bug in merge_mp_partitions for handling recent checkpoints.

See merge request ADLR/megatron-lm!226
parents c601d751 72105ef0
......@@ -240,6 +240,11 @@ def main():
tokenizer = rebuild_tokenizer(args)
mpu.initialize.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
for rank in range(args.tensor_model_parallel_size):
# Reset these since load_checkpoint asserts they are 0, but we are loading
# multiple checkpoints in the same process and they get set each time
args.consumed_train_samples = 0
args.consumed_valid_samples = 0
mpu.initialize.set_tensor_model_parallel_rank(rank)
checkpoint_name, iteration = get_parallel_checkpoint_name(args.load)
model_ = get_model(model_type)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment