Commit 72105ef0 authored by Jared Casper's avatar Jared Casper
Browse files

Fix bug in merge_mp_partitions for handling recent checkpoints.

parent c601d751
...@@ -240,6 +240,11 @@ def main(): ...@@ -240,6 +240,11 @@ def main():
tokenizer = rebuild_tokenizer(args) tokenizer = rebuild_tokenizer(args)
mpu.initialize.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size) mpu.initialize.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
for rank in range(args.tensor_model_parallel_size): for rank in range(args.tensor_model_parallel_size):
# Reset these since load_checkpoint asserts they are 0, but we are loading
# multiple checkpoints in the same process and they get set each time
args.consumed_train_samples = 0
args.consumed_valid_samples = 0
mpu.initialize.set_tensor_model_parallel_rank(rank) mpu.initialize.set_tensor_model_parallel_rank(rank)
checkpoint_name, iteration = get_parallel_checkpoint_name(args.load) checkpoint_name, iteration = get_parallel_checkpoint_name(args.load)
model_ = get_model(model_type) model_ = get_model(model_type)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment