"torchvision/vscode:/vscode.git/clone" did not exist on "20414024d730677da60ccb7841f5e96aec6e1c9e"
Unverified Commit 4026a049 authored by twaka's avatar twaka Committed by GitHub
Browse files

expand coverage of gpt2 model loading (#271)

parent 43710e8d
...@@ -228,10 +228,12 @@ class GPT2LMHeadModel(nn.Module): ...@@ -228,10 +228,12 @@ class GPT2LMHeadModel(nn.Module):
# GPT-2 ties the weights of the embedding layer and the final # GPT-2 ties the weights of the embedding layer and the final
# linear layer. # linear layer.
continue continue
if ".attn.bias" in name: if ".attn.bias" in name or ".attn.masked_bias" in name:
# Skip attention mask. # Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped. # NOTE: "c_attn.bias" should not be skipped.
continue continue
if not name.startswith("transformer."):
name = "transformer." + name name = "transformer." + name
# The HF's GPT-2 implementation uses Conv1D instead of Linear. # The HF's GPT-2 implementation uses Conv1D instead of Linear.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment