Unverified Commit 247bed38 authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

GPTNeo: handle padded wte (#11079)



* GPTNeo: handle padded wte

* Switch to config.vocab_size

* apply review suggestion
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent 083ad7d4
...@@ -112,6 +112,10 @@ def load_tf_weights_in_gpt_neo(model, config, gpt_neo_checkpoint_path): ...@@ -112,6 +112,10 @@ def load_tf_weights_in_gpt_neo(model, config, gpt_neo_checkpoint_path):
if name[-1] == "w" and name[-2] in ["out_proj", "k_proj", "q_proj", "v_proj", "c_proj", "c_fc"]: if name[-1] == "w" and name[-2] in ["out_proj", "k_proj", "q_proj", "v_proj", "c_proj", "c_fc"]:
array = array.transpose() array = array.transpose()
if name == ["wte"]:
# if vocab is padded, then trim off the padding embeddings
array = array[: config.vocab_size]
try: try:
assert ( assert (
pointer.shape == array.shape pointer.shape == array.shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment