"tests/experimental/vscode:/vscode.git/clone" did not exist on "7d7edf6d37576fb6eda65db6db43fda54a7f06ba"
Commit 96d19aae authored by Jared Casper's avatar Jared Casper
Browse files

Don't require an even divide of layers in target model.

parent 4147bec2
...@@ -148,11 +148,6 @@ def save_checkpoint(queue, args): ...@@ -148,11 +148,6 @@ def save_checkpoint(queue, args):
# Transformer layers # Transformer layers
#------------------- #-------------------
if md.num_layers % args.target_pipeline_parallel_size != 0:
print("Source number of layers is not divisible by target pipeline parallel size")
exit(1)
layers_per_rank = md.num_layers // args.target_pipeline_parallel_size
assert layers_per_rank == len(models[0].language_model.encoder.layers)
for pp_rank in range(args.target_pipeline_parallel_size): for pp_rank in range(args.target_pipeline_parallel_size):
# For later pipeline parallel ranks, make the new models # For later pipeline parallel ranks, make the new models
if pp_rank > 0: if pp_rank > 0:
...@@ -160,7 +155,7 @@ def save_checkpoint(queue, args): ...@@ -160,7 +155,7 @@ def save_checkpoint(queue, args):
post_process = pp_rank == args.target_pipeline_parallel_size - 1 post_process = pp_rank == args.target_pipeline_parallel_size - 1
models = get_models(args.target_tensor_parallel_size, md.params_dtype, False, post_process) models = get_models(args.target_tensor_parallel_size, md.params_dtype, False, post_process)
for layer in range(layers_per_rank): for layer in range(len(models[0].language_model.encoder.layers)):
# get full tensors # get full tensors
input_layernorm_weight = queue_get() input_layernorm_weight = queue_get()
input_layernorm_bias = queue_get() input_layernorm_bias = queue_get()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment