Commit 275179a0 authored by thomwolf's avatar thomwolf
Browse files

output attentions in GPT-2

parent 366a3b02
...@@ -223,7 +223,7 @@ class Conv1D(nn.Module): ...@@ -223,7 +223,7 @@ class Conv1D(nn.Module):
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, nx, n_ctx, config, scale=False): def __init__(self, nx, n_ctx, config, scale=False, output_attentions=False):
super(Attention, self).__init__() super(Attention, self).__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd) n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem] # [switch nx => n_state from Block to Attention to keep identical to TF implem]
...@@ -232,6 +232,7 @@ class Attention(nn.Module): ...@@ -232,6 +232,7 @@ class Attention(nn.Module):
self.n_head = config.n_head self.n_head = config.n_head
self.split_size = n_state self.split_size = n_state
self.scale = scale self.scale = scale
self.output_attentions = output_attentions
self.c_attn = Conv1D(n_state * 3, nx) self.c_attn = Conv1D(n_state * 3, nx)
self.c_proj = Conv1D(n_state, nx) self.c_proj = Conv1D(n_state, nx)
self.attn_dropout = nn.Dropout(config.attn_pdrop) self.attn_dropout = nn.Dropout(config.attn_pdrop)
...@@ -247,6 +248,8 @@ class Attention(nn.Module): ...@@ -247,6 +248,8 @@ class Attention(nn.Module):
w = nn.Softmax(dim=-1)(w) w = nn.Softmax(dim=-1)(w)
w = self.attn_dropout(w) w = self.attn_dropout(w)
if self.output_attentions:
return w, torch.matmul(w, v)
return torch.matmul(w, v) return torch.matmul(w, v)
def merge_heads(self, x): def merge_heads(self, x):
...@@ -274,9 +277,13 @@ class Attention(nn.Module): ...@@ -274,9 +277,13 @@ class Attention(nn.Module):
value = torch.cat((past_value, value), dim=-2) value = torch.cat((past_value, value), dim=-2)
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
a = self._attn(query, key, value) a = self._attn(query, key, value)
if self.output_attentions:
attentions, a = a
a = self.merge_heads(a) a = self.merge_heads(a)
a = self.c_proj(a) a = self.c_proj(a)
a = self.resid_dropout(a) a = self.resid_dropout(a)
if self.output_attentions:
return attentions, a, present
return a, present return a, present
...@@ -296,19 +303,26 @@ class MLP(nn.Module): ...@@ -296,19 +303,26 @@ class MLP(nn.Module):
class Block(nn.Module): class Block(nn.Module):
def __init__(self, n_ctx, config, scale=False): def __init__(self, n_ctx, config, scale=False, output_attentions=False):
super(Block, self).__init__() super(Block, self).__init__()
nx = config.n_embd nx = config.n_embd
self.output_attentions = output_attentions
self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon) self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.attn = Attention(nx, n_ctx, config, scale) self.attn = Attention(nx, n_ctx, config, scale, output_attentions)
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon) self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.mlp = MLP(4 * nx, config) self.mlp = MLP(4 * nx, config)
def forward(self, x, layer_past=None): def forward(self, x, layer_past=None):
a, present = self.attn(self.ln_1(x), layer_past=layer_past) output_attn = self.attn(self.ln_1(x), layer_past=layer_past)
if self.output_attentions:
attentions, a, present = output_attn
else:
a, present = output_attn
x = x + a x = x + a
m = self.mlp(self.ln_2(x)) m = self.mlp(self.ln_2(x))
x = x + m x = x + m
if self.output_attentions:
return attentions, x, present
return x, present return x, present
...@@ -567,12 +581,13 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -567,12 +581,13 @@ class GPT2Model(GPT2PreTrainedModel):
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(GPT2Model, self).__init__(config) super(GPT2Model, self).__init__(config)
self.output_attentions = output_attentions
self.wte = nn.Embedding(config.total_tokens_embeddings, config.n_embd) self.wte = nn.Embedding(config.total_tokens_embeddings, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd) self.wpe = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop) self.drop = nn.Dropout(config.embd_pdrop)
block = Block(config.n_ctx, config, scale=True) block = Block(config.n_ctx, config, scale=True, output_attentions=output_attentions)
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
...@@ -617,11 +632,18 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -617,11 +632,18 @@ class GPT2Model(GPT2PreTrainedModel):
hidden_states = self.drop(hidden_states) hidden_states = self.drop(hidden_states)
presents = [] presents = []
all_attentions = []
for block, layer_past in zip(self.h, past): for block, layer_past in zip(self.h, past):
hidden_states, present = block(hidden_states, layer_past) if self.output_attentions:
attentions, hidden_states, present = block(hidden_states, layer_past)
all_attentions.append(attentions)
else:
hidden_states, present = block(hidden_states, layer_past)
presents.append(present) presents.append(present)
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),) output_shape = input_shape + (hidden_states.size(-1),)
if self.output_attentions:
return all_attentions, hidden_states.view(*output_shape), presents
return hidden_states.view(*output_shape), presents return hidden_states.view(*output_shape), presents
...@@ -669,9 +691,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -669,9 +691,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(GPT2LMHeadModel, self).__init__(config) super(GPT2LMHeadModel, self).__init__(config)
self.transformer = GPT2Model(config) self.transformer = GPT2Model(config, output_attentions=output_attentions)
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config) self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
self.apply(self.init_weights) self.apply(self.init_weights)
...@@ -684,7 +706,11 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -684,7 +706,11 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens) self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens)
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None): def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None):
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past) transformer_output = self.transformer(input_ids, position_ids, token_type_ids, past)
if self.transformer.output_attentions:
all_attentions, hidden_states, presents = transformer_output
else:
hidden_states, presents = transformer_output
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
if lm_labels is not None: if lm_labels is not None:
# Shift so that tokens < n predict n # Shift so that tokens < n predict n
...@@ -695,6 +721,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -695,6 +721,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)) shift_labels.view(-1))
return loss return loss
if self.transformer.output_attentions:
return all_attentions, lm_logits, presents
return lm_logits, presents return lm_logits, presents
...@@ -747,9 +775,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -747,9 +775,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(GPT2DoubleHeadsModel, self).__init__(config) super(GPT2DoubleHeadsModel, self).__init__(config)
self.transformer = GPT2Model(config) self.transformer = GPT2Model(config, output_attentions=output_attentions)
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config) self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
self.multiple_choice_head = GPT2MultipleChoiceHead(config) self.multiple_choice_head = GPT2MultipleChoiceHead(config)
self.apply(self.init_weights) self.apply(self.init_weights)
...@@ -763,7 +791,11 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -763,7 +791,11 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens) self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens)
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None): def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None):
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past) transformer_output = self.transformer(input_ids, position_ids, token_type_ids, past)
if self.transformer.output_attentions:
all_attentions, hidden_states, presents = transformer_output
else:
hidden_states, presents = transformer_output
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids) mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
losses = [] losses = []
...@@ -777,4 +809,6 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -777,4 +809,6 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))) losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
if losses: if losses:
return losses return losses
if self.transformer.output_attentions:
return all_attentions, lm_logits, mc_logits, presents
return lm_logits, mc_logits, presents return lm_logits, mc_logits, presents
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment