"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "132852203a02e320049457316a63cffb64968aa1"
Unverified Commit 247bed38 authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

GPTNeo: handle padded wte (#11079)



* GPTNeo: handle padded wte

* Switch to config.vocab_size

* apply review suggestion
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent 083ad7d4
......@@ -112,6 +112,10 @@ def load_tf_weights_in_gpt_neo(model, config, gpt_neo_checkpoint_path):
if name[-1] == "w" and name[-2] in ["out_proj", "k_proj", "q_proj", "v_proj", "c_proj", "c_fc"]:
array = array.transpose()
if name == ["wte"]:
# if vocab is padded, then trim off the padding embeddings
array = array[: config.vocab_size]
try:
assert (
pointer.shape == array.shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment