Commit 96d19aae authored by Jared Casper's avatar Jared Casper
Browse files

Don't require an even divide of layers in target model.

parent 4147bec2
......@@ -148,11 +148,6 @@ def save_checkpoint(queue, args):
# Transformer layers
#-------------------
if md.num_layers % args.target_pipeline_parallel_size != 0:
print("Source number of layers is not divisible by target pipeline parallel size")
exit(1)
layers_per_rank = md.num_layers // args.target_pipeline_parallel_size
assert layers_per_rank == len(models[0].language_model.encoder.layers)
for pp_rank in range(args.target_pipeline_parallel_size):
# For later pipeline parallel ranks, make the new models
if pp_rank > 0:
......@@ -160,7 +155,7 @@ def save_checkpoint(queue, args):
post_process = pp_rank == args.target_pipeline_parallel_size - 1
models = get_models(args.target_tensor_parallel_size, md.params_dtype, False, post_process)
for layer in range(layers_per_rank):
for layer in range(len(models[0].language_model.encoder.layers)):
# get full tensors
input_layernorm_weight = queue_get()
input_layernorm_bias = queue_get()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment