Unverified Commit 9419f144 authored by larekrow's avatar larekrow Committed by GitHub
Browse files

Fix convert_opt_original_pytorch_checkpoint_to_pytorch.py typo (#22526)

`load_checkpoint()` silently fails because `".qkj_proj." in key` is always `False`, but will eventually cause an error at `model.load_state_dict(state_dict)`.
parent a55a822a
...@@ -55,9 +55,9 @@ def load_checkpoint(checkpoint_path): ...@@ -55,9 +55,9 @@ def load_checkpoint(checkpoint_path):
keys = list(sd.keys()) keys = list(sd.keys())
for key in keys: for key in keys:
if ".qkj_proj." in key: if ".qkv_proj." in key:
value = sd[key] value = sd[key]
# We split QKV in seperate Q,K,V # We split QKV in separate Q,K,V
q_name = key.replace(".qkv_proj.", ".q_proj.") q_name = key.replace(".qkv_proj.", ".q_proj.")
k_name = key.replace(".qkv_proj.", ".k_proj.") k_name = key.replace(".qkv_proj.", ".k_proj.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment