Unverified Commit 4026a049 authored by twaka's avatar twaka Committed by GitHub
Browse files

expand coverage of gpt2 model loading (#271)

parent 43710e8d
......@@ -228,11 +228,13 @@ class GPT2LMHeadModel(nn.Module):
# GPT-2 ties the weights of the embedding layer and the final
# linear layer.
continue
if ".attn.bias" in name:
if ".attn.bias" in name or ".attn.masked_bias" in name:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
name = "transformer." + name
if not name.startswith("transformer."):
name = "transformer." + name
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment