"tests/git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "033cc4cfad07e8467ca6cb9e4401752c9c05d32c"
Unverified Commit f6cdafde authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

fix load weights (#8528)

* fix load weights

* delete line
parent f6f4da8d
......@@ -108,6 +108,7 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
continue
pointer = model
array = tf_weights[txt_name]
for m_name in name:
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
scope_names = re.split(r"_(\d+)", m_name)
......@@ -115,12 +116,30 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
scope_names = [m_name]
if scope_names[0] in ["kernel", "scale", "embedding"]:
pointer = getattr(pointer, "weight")
elif scope_names[0] == "self_attention":
pointer = getattr(pointer, "layer")
pointer = pointer[0]
elif scope_names[0] == "enc_dec_attention":
pointer = getattr(pointer, "layer")
pointer = pointer[1]
elif scope_names[0] == "dense_relu_dense":
pointer = getattr(pointer, "layer")
pointer = pointer[2]
elif scope_names[0] == "rms_norm":
if hasattr(pointer, "layer_norm"):
pointer = getattr(pointer, "layer_norm")
elif hasattr(pointer, "final_layer_norm"):
pointer = getattr(pointer, "final_layer_norm")
elif scope_names[0] == "scale":
pointer = getattr(pointer, "weight")
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
pointer = getattr(pointer, "bias")
elif scope_names[0] == "squad":
pointer = getattr(pointer, "classifier")
elif scope_names[0] == "decoder" and name[1] == "logits":
continue
elif scope_names[0] == "logits":
pointer = getattr(pointer, "lm_head")
else:
try:
pointer = getattr(pointer, scope_names[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment