Commit 275179a0 authored by thomwolf's avatar thomwolf
Browse files

output attentions in GPT-2

parent 366a3b02
......@@ -223,7 +223,7 @@ class Conv1D(nn.Module):
class Attention(nn.Module):
def __init__(self, nx, n_ctx, config, scale=False):
def __init__(self, nx, n_ctx, config, scale=False, output_attentions=False):
super(Attention, self).__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
......@@ -232,6 +232,7 @@ class Attention(nn.Module):
self.n_head = config.n_head
self.split_size = n_state
self.scale = scale
self.output_attentions = output_attentions
self.c_attn = Conv1D(n_state * 3, nx)
self.c_proj = Conv1D(n_state, nx)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
......@@ -247,6 +248,8 @@ class Attention(nn.Module):
w = nn.Softmax(dim=-1)(w)
w = self.attn_dropout(w)
if self.output_attentions:
return w, torch.matmul(w, v)
return torch.matmul(w, v)
def merge_heads(self, x):
......@@ -274,9 +277,13 @@ class Attention(nn.Module):
value = torch.cat((past_value, value), dim=-2)
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
a = self._attn(query, key, value)
if self.output_attentions:
attentions, a = a
a = self.merge_heads(a)
a = self.c_proj(a)
a = self.resid_dropout(a)
if self.output_attentions:
return attentions, a, present
return a, present
......@@ -296,19 +303,26 @@ class MLP(nn.Module):
class Block(nn.Module):
def __init__(self, n_ctx, config, scale=False):
def __init__(self, n_ctx, config, scale=False, output_attentions=False):
super(Block, self).__init__()
nx = config.n_embd
self.output_attentions = output_attentions
self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.attn = Attention(nx, n_ctx, config, scale)
self.attn = Attention(nx, n_ctx, config, scale, output_attentions)
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.mlp = MLP(4 * nx, config)
def forward(self, x, layer_past=None):
a, present = self.attn(self.ln_1(x), layer_past=layer_past)
output_attn = self.attn(self.ln_1(x), layer_past=layer_past)
if self.output_attentions:
attentions, a, present = output_attn
else:
a, present = output_attn
x = x + a
m = self.mlp(self.ln_2(x))
x = x + m
if self.output_attentions:
return attentions, x, present
return x, present
......@@ -567,12 +581,13 @@ class GPT2Model(GPT2PreTrainedModel):
```
"""
def __init__(self, config):
def __init__(self, config, output_attentions=False):
super(GPT2Model, self).__init__(config)
self.output_attentions = output_attentions
self.wte = nn.Embedding(config.total_tokens_embeddings, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop)
block = Block(config.n_ctx, config, scale=True)
block = Block(config.n_ctx, config, scale=True, output_attentions=output_attentions)
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
......@@ -617,11 +632,18 @@ class GPT2Model(GPT2PreTrainedModel):
hidden_states = self.drop(hidden_states)
presents = []
all_attentions = []
for block, layer_past in zip(self.h, past):
if self.output_attentions:
attentions, hidden_states, present = block(hidden_states, layer_past)
all_attentions.append(attentions)
else:
hidden_states, present = block(hidden_states, layer_past)
presents.append(present)
hidden_states = self.ln_f(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
if self.output_attentions:
return all_attentions, hidden_states.view(*output_shape), presents
return hidden_states.view(*output_shape), presents
......@@ -669,9 +691,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
```
"""
def __init__(self, config):
def __init__(self, config, output_attentions=False):
super(GPT2LMHeadModel, self).__init__(config)
self.transformer = GPT2Model(config)
self.transformer = GPT2Model(config, output_attentions=output_attentions)
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
self.apply(self.init_weights)
......@@ -684,7 +706,11 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens)
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None):
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
transformer_output = self.transformer(input_ids, position_ids, token_type_ids, past)
if self.transformer.output_attentions:
all_attentions, hidden_states, presents = transformer_output
else:
hidden_states, presents = transformer_output
lm_logits = self.lm_head(hidden_states)
if lm_labels is not None:
# Shift so that tokens < n predict n
......@@ -695,6 +721,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1))
return loss
if self.transformer.output_attentions:
return all_attentions, lm_logits, presents
return lm_logits, presents
......@@ -747,9 +775,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
```
"""
def __init__(self, config):
def __init__(self, config, output_attentions=False):
super(GPT2DoubleHeadsModel, self).__init__(config)
self.transformer = GPT2Model(config)
self.transformer = GPT2Model(config, output_attentions=output_attentions)
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
self.multiple_choice_head = GPT2MultipleChoiceHead(config)
self.apply(self.init_weights)
......@@ -763,7 +791,11 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens)
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None):
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
transformer_output = self.transformer(input_ids, position_ids, token_type_ids, past)
if self.transformer.output_attentions:
all_attentions, hidden_states, presents = transformer_output
else:
hidden_states, presents = transformer_output
lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
losses = []
......@@ -777,4 +809,6 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
if losses:
return losses
if self.transformer.output_attentions:
return all_attentions, lm_logits, mc_logits, presents
return lm_logits, mc_logits, presents
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment