Unverified Commit 98d5b727 authored by Thomas Wang's avatar Thomas Wang Committed by GitHub
Browse files

Update OPT conversion script to work for OPT-IML (#21519)

parent fe616f35
......@@ -53,6 +53,27 @@ def load_checkpoint(checkpoint_path):
if old_key in sd:
sd[new_key] = sd.pop(old_key)
keys = list(sd.keys())
for key in keys:
if ".qkj_proj." in key:
value = sd[key]
# We split QKV in seperate Q,K,V
q_name = key.replace(".qkv_proj.", ".q_proj.")
k_name = key.replace(".qkv_proj.", ".k_proj.")
v_name = key.replace(".qkv_proj.", ".v_proj.")
depth = value.shape[0]
assert depth % 3 == 0
# `SequeuceParallelTransformerBlock` has QKV weight is separated in K,V,Q despite the naming:
# https://cs.github.com/facebookresearch/metaseq/blob/51871bd73cd04c038f239ea2a26db1d7f6b37927/metaseq/modules/sequence_parallel_transformer_layer.py#L97
k, v, q = torch.split(value, depth // 3, dim=0)
sd[q_name] = q
sd[k_name] = k
sd[v_name] = v
del sd[key]
return sd
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment