"comfy/git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "1938f5c5fe479996802c46d5c2233887e3598a40"
Unverified Commit f6cdafde authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

fix load weights (#8528)

* fix load weights

* delete line
parent f6f4da8d
...@@ -108,6 +108,7 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): ...@@ -108,6 +108,7 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
continue continue
pointer = model pointer = model
array = tf_weights[txt_name] array = tf_weights[txt_name]
for m_name in name: for m_name in name:
if re.fullmatch(r"[A-Za-z]+_\d+", m_name): if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
scope_names = re.split(r"_(\d+)", m_name) scope_names = re.split(r"_(\d+)", m_name)
...@@ -115,12 +116,30 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): ...@@ -115,12 +116,30 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
scope_names = [m_name] scope_names = [m_name]
if scope_names[0] in ["kernel", "scale", "embedding"]: if scope_names[0] in ["kernel", "scale", "embedding"]:
pointer = getattr(pointer, "weight") pointer = getattr(pointer, "weight")
elif scope_names[0] == "self_attention":
pointer = getattr(pointer, "layer")
pointer = pointer[0]
elif scope_names[0] == "enc_dec_attention":
pointer = getattr(pointer, "layer")
pointer = pointer[1]
elif scope_names[0] == "dense_relu_dense":
pointer = getattr(pointer, "layer")
pointer = pointer[2]
elif scope_names[0] == "rms_norm":
if hasattr(pointer, "layer_norm"):
pointer = getattr(pointer, "layer_norm")
elif hasattr(pointer, "final_layer_norm"):
pointer = getattr(pointer, "final_layer_norm")
elif scope_names[0] == "scale": elif scope_names[0] == "scale":
pointer = getattr(pointer, "weight") pointer = getattr(pointer, "weight")
elif scope_names[0] == "output_bias" or scope_names[0] == "beta": elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
pointer = getattr(pointer, "bias") pointer = getattr(pointer, "bias")
elif scope_names[0] == "squad": elif scope_names[0] == "squad":
pointer = getattr(pointer, "classifier") pointer = getattr(pointer, "classifier")
elif scope_names[0] == "decoder" and name[1] == "logits":
continue
elif scope_names[0] == "logits":
pointer = getattr(pointer, "lm_head")
else: else:
try: try:
pointer = getattr(pointer, scope_names[0]) pointer = getattr(pointer, scope_names[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment