Commit 0efc4ab6 authored by thomwolf's avatar thomwolf
Browse files

adding dropout to GPT-2 and embedding dropout to GPT

parent ea9dbea9
...@@ -113,6 +113,9 @@ class GPT2Config(object): ...@@ -113,6 +113,9 @@ class GPT2Config(object):
n_embd=768, n_embd=768,
n_layer=12, n_layer=12,
n_head=12, n_head=12,
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
initializer_range=0.02, initializer_range=0.02,
predict_special_tokens=True predict_special_tokens=True
...@@ -129,6 +132,11 @@ class GPT2Config(object): ...@@ -129,6 +132,11 @@ class GPT2Config(object):
n_head: Number of attention heads for each attention layer in n_head: Number of attention heads for each attention layer in
the Transformer encoder. the Transformer encoder.
layer_norm_epsilon: epsilon to use in the layer norm layers layer_norm_epsilon: epsilon to use in the layer norm layers
resid_pdrop: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attn_pdrop: The dropout ratio for the attention
probabilities.
embd_pdrop: The dropout ratio for the embeddings.
initializer_range: The sttdev of the truncated_normal_initializer for initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices. initializing all weight matrices.
predict_special_tokens: should we predict special tokens (when the model has a LM head) predict_special_tokens: should we predict special tokens (when the model has a LM head)
...@@ -147,6 +155,9 @@ class GPT2Config(object): ...@@ -147,6 +155,9 @@ class GPT2Config(object):
self.n_embd = n_embd self.n_embd = n_embd
self.n_layer = n_layer self.n_layer = n_layer
self.n_head = n_head self.n_head = n_head
self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop
self.attn_pdrop = attn_pdrop
self.layer_norm_epsilon = layer_norm_epsilon self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.predict_special_tokens = predict_special_tokens self.predict_special_tokens = predict_special_tokens
...@@ -221,6 +232,8 @@ class Attention(nn.Module): ...@@ -221,6 +232,8 @@ class Attention(nn.Module):
self.scale = scale self.scale = scale
self.c_attn = Conv1D(n_state * 3, nx) self.c_attn = Conv1D(n_state * 3, nx)
self.c_proj = Conv1D(n_state, nx) self.c_proj = Conv1D(n_state, nx)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
def _attn(self, q, k, v): def _attn(self, q, k, v):
w = torch.matmul(q, k) w = torch.matmul(q, k)
...@@ -231,6 +244,7 @@ class Attention(nn.Module): ...@@ -231,6 +244,7 @@ class Attention(nn.Module):
w = w * b - 1e4 * (1 - b) w = w * b - 1e4 * (1 - b)
w = nn.Softmax(dim=-1)(w) w = nn.Softmax(dim=-1)(w)
w = self.attn_dropout(w)
return torch.matmul(w, v) return torch.matmul(w, v)
def merge_heads(self, x): def merge_heads(self, x):
...@@ -260,6 +274,7 @@ class Attention(nn.Module): ...@@ -260,6 +274,7 @@ class Attention(nn.Module):
a = self._attn(query, key, value) a = self._attn(query, key, value)
a = self.merge_heads(a) a = self.merge_heads(a)
a = self.c_proj(a) a = self.c_proj(a)
a = self.resid_dropout(a)
return a, present return a, present
...@@ -270,11 +285,12 @@ class MLP(nn.Module): ...@@ -270,11 +285,12 @@ class MLP(nn.Module):
self.c_fc = Conv1D(n_state, nx) self.c_fc = Conv1D(n_state, nx)
self.c_proj = Conv1D(nx, n_state) self.c_proj = Conv1D(nx, n_state)
self.act = gelu self.act = gelu
self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, x): def forward(self, x):
h = self.act(self.c_fc(x)) h = self.act(self.c_fc(x))
h2 = self.c_proj(h) h2 = self.c_proj(h)
return h2 return self.dropout(h2)
class Block(nn.Module): class Block(nn.Module):
...@@ -323,6 +339,7 @@ class GPT2MultipleChoiceHead(nn.Module): ...@@ -323,6 +339,7 @@ class GPT2MultipleChoiceHead(nn.Module):
def __init__(self, config): def __init__(self, config):
super(GPT2MultipleChoiceHead, self).__init__() super(GPT2MultipleChoiceHead, self).__init__()
self.n_embd = config.n_embd self.n_embd = config.n_embd
self.dropout = nn.Dropout2d(config.resid_pdrop) # To reproduce the noise_shape parameter of TF implementation
self.linear = nn.Linear(config.n_embd, 1) self.linear = nn.Linear(config.n_embd, 1)
nn.init.normal_(self.linear.weight, std=0.02) nn.init.normal_(self.linear.weight, std=0.02)
...@@ -552,6 +569,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -552,6 +569,7 @@ class GPT2Model(GPT2PreTrainedModel):
super(GPT2Model, self).__init__(config) super(GPT2Model, self).__init__(config)
self.wte = nn.Embedding(config.total_tokens_embeddings, config.n_embd) self.wte = nn.Embedding(config.total_tokens_embeddings, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd) self.wpe = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop)
block = Block(config.n_ctx, config, scale=True) block = Block(config.n_ctx, config, scale=True)
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
...@@ -594,6 +612,8 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -594,6 +612,8 @@ class GPT2Model(GPT2PreTrainedModel):
else: else:
token_type_embeds = 0 token_type_embeds = 0
hidden_states = inputs_embeds + position_embeds + token_type_embeds hidden_states = inputs_embeds + position_embeds + token_type_embeds
hidden_states = self.drop(hidden_states)
presents = [] presents = []
for block, layer_past in zip(self.h, past): for block, layer_past in zip(self.h, past):
hidden_states, present = block(hidden_states, layer_past) hidden_states, present = block(hidden_states, layer_past)
......
...@@ -383,7 +383,6 @@ class OpenAIGPTMultipleChoiceHead(nn.Module): ...@@ -383,7 +383,6 @@ class OpenAIGPTMultipleChoiceHead(nn.Module):
def __init__(self, config): def __init__(self, config):
super(OpenAIGPTMultipleChoiceHead, self).__init__() super(OpenAIGPTMultipleChoiceHead, self).__init__()
self.n_embd = config.n_embd self.n_embd = config.n_embd
# self.multiple_choice_token = multiple_choice_token
self.dropout = nn.Dropout2d(config.resid_pdrop) # To reproduce the noise_shape parameter of TF implementation self.dropout = nn.Dropout2d(config.resid_pdrop) # To reproduce the noise_shape parameter of TF implementation
self.linear = nn.Linear(config.n_embd, 1) self.linear = nn.Linear(config.n_embd, 1)
...@@ -651,9 +650,9 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -651,9 +650,9 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
token_type_embeds = self.tokens_embed(token_type_ids) token_type_embeds = self.tokens_embed(token_type_ids)
else: else:
token_type_embeds = 0 token_type_embeds = 0
# Add the position information to the input embeddings
# h = e.sum(dim=2)
hidden_states = inputs_embeds + position_embeds + token_type_embeds hidden_states = inputs_embeds + position_embeds + token_type_embeds
hidden_states = self.drop(hidden_states)
all_attentions = [] all_attentions = []
for block in self.h: for block in self.h:
if self.output_attentions: if self.output_attentions:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment