Commit e211785a authored by thomwolf's avatar thomwolf
Browse files

extract attention weights from GPT

parent db98a4a4
...@@ -253,7 +253,7 @@ class Conv1D(nn.Module): ...@@ -253,7 +253,7 @@ class Conv1D(nn.Module):
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, nx, n_ctx, config, scale=False): def __init__(self, nx, n_ctx, config, scale=False, output_attentions=False):
super(Attention, self).__init__() super(Attention, self).__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd) n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem] # [switch nx => n_state from Block to Attention to keep identical to TF implem]
...@@ -262,6 +262,7 @@ class Attention(nn.Module): ...@@ -262,6 +262,7 @@ class Attention(nn.Module):
self.n_head = config.n_head self.n_head = config.n_head
self.split_size = n_state self.split_size = n_state
self.scale = scale self.scale = scale
self.output_attentions = output_attentions
self.c_attn = Conv1D(n_state * 3, 1, nx) self.c_attn = Conv1D(n_state * 3, 1, nx)
self.c_proj = Conv1D(n_state, 1, nx) self.c_proj = Conv1D(n_state, 1, nx)
self.attn_dropout = nn.Dropout(config.attn_pdrop) self.attn_dropout = nn.Dropout(config.attn_pdrop)
...@@ -278,6 +279,8 @@ class Attention(nn.Module): ...@@ -278,6 +279,8 @@ class Attention(nn.Module):
w = nn.Softmax(dim=-1)(w) w = nn.Softmax(dim=-1)(w)
w = self.attn_dropout(w) w = self.attn_dropout(w)
if self.output_attentions:
return w, torch.matmul(w, v)
return torch.matmul(w, v) return torch.matmul(w, v)
def merge_heads(self, x): def merge_heads(self, x):
...@@ -300,9 +303,13 @@ class Attention(nn.Module): ...@@ -300,9 +303,13 @@ class Attention(nn.Module):
key = self.split_heads(key, k=True) key = self.split_heads(key, k=True)
value = self.split_heads(value) value = self.split_heads(value)
a = self._attn(query, key, value) a = self._attn(query, key, value)
if self.output_attentions:
attentions, a = a
a = self.merge_heads(a) a = self.merge_heads(a)
a = self.c_proj(a) a = self.c_proj(a)
a = self.resid_dropout(a) a = self.resid_dropout(a)
if self.output_attentions:
return attentions, a
return a return a
...@@ -322,19 +329,24 @@ class MLP(nn.Module): ...@@ -322,19 +329,24 @@ class MLP(nn.Module):
class Block(nn.Module): class Block(nn.Module):
def __init__(self, n_ctx, config, scale=False): def __init__(self, n_ctx, config, scale=False, output_attentions=False):
super(Block, self).__init__() super(Block, self).__init__()
nx = config.n_embd nx = config.n_embd
self.attn = Attention(nx, n_ctx, config, scale) self.output_attentions = output_attentions
self.attn = Attention(nx, n_ctx, config, scale, output_attentions)
self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon) self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.mlp = MLP(4 * nx, config) self.mlp = MLP(4 * nx, config)
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon) self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
def forward(self, x): def forward(self, x):
a = self.attn(x) a = self.attn(x)
if self.output_attentions:
attentions, a = a
n = self.ln_1(x + a) n = self.ln_1(x + a)
m = self.mlp(n) m = self.mlp(n)
h = self.ln_2(n + m) h = self.ln_2(n + m)
if self.output_attentions:
return attentions, h
return h return h
...@@ -591,12 +603,13 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -591,12 +603,13 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(OpenAIGPTModel, self).__init__(config) super(OpenAIGPTModel, self).__init__(config)
self.output_attentions = output_attentions
self.tokens_embed = nn.Embedding(config.total_tokens_embeddings, config.n_embd) self.tokens_embed = nn.Embedding(config.total_tokens_embeddings, config.n_embd)
self.positions_embed = nn.Embedding(config.n_positions, config.n_embd) self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop) self.drop = nn.Dropout(config.embd_pdrop)
block = Block(config.n_ctx, config, scale=True) block = Block(config.n_ctx, config, scale=True, output_attentions=output_attentions)
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
self.apply(self.init_weights) self.apply(self.init_weights)
...@@ -639,9 +652,16 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -639,9 +652,16 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
# Add the position information to the input embeddings # Add the position information to the input embeddings
# h = e.sum(dim=2) # h = e.sum(dim=2)
hidden_states = inputs_embeds + position_embeds + token_type_embeds hidden_states = inputs_embeds + position_embeds + token_type_embeds
all_attentions = []
for block in self.h: for block in self.h:
if self.output_attentions:
attentions, hidden_states = block(hidden_states)
all_attentions.append(attentions)
else:
hidden_states = block(hidden_states) hidden_states = block(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),) output_shape = input_shape + (hidden_states.size(-1),)
if self.output_attentions:
return all_attentions, hidden_states.view(*output_shape)
return hidden_states.view(*output_shape) return hidden_states.view(*output_shape)
...@@ -701,9 +721,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -701,9 +721,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(OpenAIGPTLMHeadModel, self).__init__(config) super(OpenAIGPTLMHeadModel, self).__init__(config)
self.transformer = OpenAIGPTModel(config) self.transformer = OpenAIGPTModel(config, output_attentions=output_attentions)
self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config) self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
self.apply(self.init_weights) self.apply(self.init_weights)
...@@ -716,6 +736,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -716,6 +736,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None): def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None):
hidden_states = self.transformer(input_ids, position_ids, token_type_ids) hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
if self.transformer.output_attentions:
all_attentions, hidden_states = hidden_states
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
if lm_labels is not None: if lm_labels is not None:
# Shift so that tokens < n predict n # Shift so that tokens < n predict n
...@@ -726,6 +748,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -726,6 +748,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)) shift_labels.view(-1))
return loss return loss
if self.transformer.output_attentions:
return all_attentions, lm_logits
return lm_logits return lm_logits
...@@ -790,9 +814,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -790,9 +814,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(OpenAIGPTDoubleHeadsModel, self).__init__(config) super(OpenAIGPTDoubleHeadsModel, self).__init__(config)
self.transformer = OpenAIGPTModel(config) self.transformer = OpenAIGPTModel(config, output_attentions=output_attentions)
self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config) self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
self.multiple_choice_head = OpenAIGPTMultipleChoiceHead(config) self.multiple_choice_head = OpenAIGPTMultipleChoiceHead(config)
self.apply(self.init_weights) self.apply(self.init_weights)
...@@ -806,6 +830,8 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -806,6 +830,8 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None): def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None):
hidden_states = self.transformer(input_ids, position_ids, token_type_ids) hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
if self.transformer.output_attentions:
all_attentions, hidden_states = hidden_states
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids) mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
losses = [] losses = []
...@@ -819,4 +845,6 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -819,4 +845,6 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))) losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
if losses: if losses:
return losses return losses
if self.transformer.output_attentions:
return all_attentions, lm_logits, mc_logits
return lm_logits, mc_logits return lm_logits, mc_logits
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment