Unverified Commit 4026a049 authored by twaka's avatar twaka Committed by GitHub
Browse files

expand coverage of gpt2 model loading (#271)

parent 43710e8d
...@@ -228,11 +228,13 @@ class GPT2LMHeadModel(nn.Module): ...@@ -228,11 +228,13 @@ class GPT2LMHeadModel(nn.Module):
# GPT-2 ties the weights of the embedding layer and the final # GPT-2 ties the weights of the embedding layer and the final
# linear layer. # linear layer.
continue continue
if ".attn.bias" in name: if ".attn.bias" in name or ".attn.masked_bias" in name:
# Skip attention mask. # Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped. # NOTE: "c_attn.bias" should not be skipped.
continue continue
name = "transformer." + name
if not name.startswith("transformer."):
name = "transformer." + name
# The HF's GPT-2 implementation uses Conv1D instead of Linear. # The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights. # Because of this, we need to transpose the weights.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment