Commit 98a5b9a0 authored by Jared Casper's avatar Jared Casper
Browse files

Improve comments around layer regex replacement.

parent 7cabbe67
......@@ -304,6 +304,7 @@ def main():
mpu.initialize.set_tensor_model_parallel_rank(0)
mpu.initialize.set_pipeline_model_parallel_world_size(args.pipeline_model_parallel_size)
# regex to parse out layer number from param name
layer_re = re.compile('layers\.([0-9]+)')
if args.pipeline_model_parallel_size > 1:
......@@ -314,7 +315,7 @@ def main():
for rank in range(args.pipeline_model_parallel_size):
mpu.initialize.set_pipeline_model_parallel_rank(rank)
model = get_model(model_type)
def repl_layer(m):
def update_layer_num(m):
# TODO! This assumes no interleaved pipeline execution
layer = int(m.group(1))
layer += rank * layers_per_part
......@@ -325,8 +326,10 @@ def main():
# See comment in MegatronModule.initialize_word_embeddings()
src_name = "language_model.embedding.word_embeddings.weight"
else:
src_name = re.sub(layer_re, repl_layer, dst_name)
print(f" > copying {src_name} to {dst_name} rank {rank}'s model")
# Translate destination layer number (0-N for each partition)
# to source layer number (single-model layer number)
src_name = re.sub(layer_re, update_layer_num, dst_name)
print(f" > copying {src_name} to {dst_name} in rank {rank}'s model")
partition_param.data.copy_(merged_params[src_name].data)
partitions.append(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment