Unverified Commit 33cd4be5 authored by Zhengqiang Yin's avatar Zhengqiang Yin Committed by GitHub
Browse files

fix megatron bert convert state dict naming (#15820)

parent 9a2995ee
...@@ -155,6 +155,7 @@ def convert_megatron_checkpoint(args, input_state_dict, config): ...@@ -155,6 +155,7 @@ def convert_megatron_checkpoint(args, input_state_dict, config):
# The simple map of names for "automated" rules. # The simple map of names for "automated" rules.
megatron_to_transformers = { megatron_to_transformers = {
"attention.dense": ".attention.output.dense.", "attention.dense": ".attention.output.dense.",
"self_attention.dense": ".attention.output.dense.",
"mlp.dense_h_to_4h": ".intermediate.dense.", "mlp.dense_h_to_4h": ".intermediate.dense.",
"mlp.dense_4h_to_h": ".output.dense.", "mlp.dense_4h_to_h": ".output.dense.",
} }
...@@ -188,7 +189,9 @@ def convert_megatron_checkpoint(args, input_state_dict, config): ...@@ -188,7 +189,9 @@ def convert_megatron_checkpoint(args, input_state_dict, config):
output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = val output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = val
# Transpose the QKV matrix. # Transpose the QKV matrix.
elif op_name == "attention.query_key_value" and weight_or_bias == "weight": elif (
op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value"
) and weight_or_bias == "weight":
# Make sure the QKV pointer is nil. # Make sure the QKV pointer is nil.
assert attention_qkv_weight is None, "" assert attention_qkv_weight is None, ""
...@@ -198,7 +201,9 @@ def convert_megatron_checkpoint(args, input_state_dict, config): ...@@ -198,7 +201,9 @@ def convert_megatron_checkpoint(args, input_state_dict, config):
attention_qkv_weight = out_val attention_qkv_weight = out_val
# Transpose the bias. # Transpose the bias.
elif op_name == "attention.query_key_value" and weight_or_bias == "bias": elif (
op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value"
) and weight_or_bias == "bias":
# Make sure we read the weight tensor. # Make sure we read the weight tensor.
assert attention_qkv_weight is not None, "" assert attention_qkv_weight is not None, ""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment