Commit e79ceb15 authored by thomwolf's avatar thomwolf
Browse files

gpt-2 special tokens

parent 1f5fc95b
...@@ -547,7 +547,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -547,7 +547,7 @@ class GPT2Model(GPT2PreTrainedModel):
def __init__(self, config): def __init__(self, config):
super(GPT2Model, self).__init__(config) super(GPT2Model, self).__init__(config)
self.wte = nn.Embedding(config.vocab_size, config.n_embd) self.wte = nn.Embedding(config.total_tokens_embeddings, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd) self.wpe = nn.Embedding(config.n_positions, config.n_embd)
block = Block(config.n_ctx, config, scale=True) block = Block(config.n_ctx, config, scale=True)
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment