Unverified Commit 172f42c5 authored by tju_skywalker's avatar tju_skywalker Committed by GitHub
Browse files

save space when converting hf model to megatron model. (#25950)

* fix convert megatron model too large

* fix convert megatron model too large
parent b8def689
...@@ -737,7 +737,7 @@ def convert_checkpoint_from_transformers_to_megatron(args): ...@@ -737,7 +737,7 @@ def convert_checkpoint_from_transformers_to_megatron(args):
word_emb_dict = get_element_from_dict_by_path( word_emb_dict = get_element_from_dict_by_path(
output_state_dict[i], "model.language_model.embedding.word_embeddings" output_state_dict[i], "model.language_model.embedding.word_embeddings"
) )
word_emb_dict["weight"] = out_word_embed[i] word_emb_dict["weight"] = out_word_embed[i].clone()
# Transformer layers # Transformer layers
print("converting transformer layers") print("converting transformer layers")
...@@ -845,7 +845,7 @@ def convert_checkpoint_from_transformers_to_megatron(args): ...@@ -845,7 +845,7 @@ def convert_checkpoint_from_transformers_to_megatron(args):
for i in range(args.target_tensor_model_parallel_size): for i in range(args.target_tensor_model_parallel_size):
params_dict = get_element_from_dict_by_path(output_state_dict[i], "model.language_model.encoder") params_dict = get_element_from_dict_by_path(output_state_dict[i], "model.language_model.encoder")
params_dict[layer_name] = ( params_dict[layer_name] = (
params[i] if (op_name + "." + weight_or_bias in tensor_parallel_params) else params params[i].clone() if (op_name + "." + weight_or_bias in tensor_parallel_params) else params
) )
if pp_rank == args.target_pipeline_model_parallel_size - 1: if pp_rank == args.target_pipeline_model_parallel_size - 1:
...@@ -860,7 +860,7 @@ def convert_checkpoint_from_transformers_to_megatron(args): ...@@ -860,7 +860,7 @@ def convert_checkpoint_from_transformers_to_megatron(args):
# add the LM head # add the LM head
for i in range(args.target_tensor_model_parallel_size): for i in range(args.target_tensor_model_parallel_size):
params_dict = get_element_from_dict_by_path(output_state_dict[i], "model.word_embeddings_for_head") params_dict = get_element_from_dict_by_path(output_state_dict[i], "model.word_embeddings_for_head")
params_dict["weight"] = out_word_embed[i] params_dict["weight"] = out_word_embed[i].clone()
# saving the state dict as per the tp_rank and pp_rank # saving the state dict as per the tp_rank and pp_rank
for tp_rank in range(args.target_tensor_model_parallel_size): for tp_rank in range(args.target_tensor_model_parallel_size):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment