Commit 98a5b9a0 authored by Jared Casper's avatar Jared Casper
Browse files

Improve comments around layer regex replacement.

parent 7cabbe67
...@@ -304,6 +304,7 @@ def main(): ...@@ -304,6 +304,7 @@ def main():
mpu.initialize.set_tensor_model_parallel_rank(0) mpu.initialize.set_tensor_model_parallel_rank(0)
mpu.initialize.set_pipeline_model_parallel_world_size(args.pipeline_model_parallel_size) mpu.initialize.set_pipeline_model_parallel_world_size(args.pipeline_model_parallel_size)
# regex to parse out layer number from param name
layer_re = re.compile('layers\.([0-9]+)') layer_re = re.compile('layers\.([0-9]+)')
if args.pipeline_model_parallel_size > 1: if args.pipeline_model_parallel_size > 1:
...@@ -314,7 +315,7 @@ def main(): ...@@ -314,7 +315,7 @@ def main():
for rank in range(args.pipeline_model_parallel_size): for rank in range(args.pipeline_model_parallel_size):
mpu.initialize.set_pipeline_model_parallel_rank(rank) mpu.initialize.set_pipeline_model_parallel_rank(rank)
model = get_model(model_type) model = get_model(model_type)
def repl_layer(m): def update_layer_num(m):
# TODO! This assumes no interleaved pipeline execution # TODO! This assumes no interleaved pipeline execution
layer = int(m.group(1)) layer = int(m.group(1))
layer += rank * layers_per_part layer += rank * layers_per_part
...@@ -325,8 +326,10 @@ def main(): ...@@ -325,8 +326,10 @@ def main():
# See comment in MegatronModule.initialize_word_embeddings() # See comment in MegatronModule.initialize_word_embeddings()
src_name = "language_model.embedding.word_embeddings.weight" src_name = "language_model.embedding.word_embeddings.weight"
else: else:
src_name = re.sub(layer_re, repl_layer, dst_name) # Translate destination layer number (0-N for each partition)
print(f" > copying {src_name} to {dst_name} rank {rank}'s model") # to source layer number (single-model layer number)
src_name = re.sub(layer_re, update_layer_num, dst_name)
print(f" > copying {src_name} to {dst_name} in rank {rank}'s model")
partition_param.data.copy_(merged_params[src_name].data) partition_param.data.copy_(merged_params[src_name].data)
partitions.append(model) partitions.append(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment