Commit 25a31953 authored by Lysandre's avatar Lysandre Committed by Lysandre Debut
Browse files

Output Attentions + output hidden states

parent ce9eade2
......@@ -105,6 +105,7 @@ class AlbertAttention(BertSelfAttention):
def __init__(self, config):
super(AlbertAttention, self).__init__(config)
self.output_attentions = config.output_attentions
self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.attention_head_size = config.hidden_size // config.num_attention_heads
......@@ -177,7 +178,7 @@ class AlbertAttention(BertSelfAttention):
projected_context_layer = torch.einsum("bfnd,ndh->bfh", context_layer, w) + b
projected_context_layer_dropout = self.dropout(projected_context_layer)
layernormed_context_layer = self.LayerNorm(input_ids + projected_context_layer_dropout)
return layernormed_context_layer
return (layernormed_context_layer, attention_probs) if self.output_attentions else (layernormed_context_layer,)
class AlbertLayer(nn.Module):
......@@ -193,25 +194,45 @@ class AlbertLayer(nn.Module):
def forward(self, hidden_states, attention_mask=None, head_mask=None):
attention_output = self.attention(hidden_states, attention_mask)
ffn_output = self.ffn(attention_output)
ffn_output = self.ffn(attention_output[0])
ffn_output = self.activation(ffn_output)
ffn_output = self.ffn_output(ffn_output)
hidden_states = self.full_layer_layer_norm(ffn_output + attention_output)
hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0])
return hidden_states
return (hidden_states,) + attention_output[1:] # add attentions if we output them
class AlbertLayerGroup(nn.Module):
def __init__(self, config):
super(AlbertLayerGroup, self).__init__()
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.albert_layers = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)])
def forward(self, hidden_states, attention_mask=None, head_mask=None):
layer_hidden_states = ()
layer_attentions = ()
for albert_layer in self.albert_layers:
hidden_states = albert_layer(hidden_states, attention_mask, head_mask)
if self.output_hidden_states:
layer_hidden_states = layer_hidden_states + (hidden_states,)
layer_output = albert_layer(hidden_states, attention_mask, head_mask)
hidden_states = layer_output[0]
if self.output_attentions:
layer_attentions = layer_attentions + (layer_output[1],)
if self.output_hidden_states:
layer_hidden_states = layer_hidden_states + (hidden_states,)
return hidden_states
outputs = (hidden_states,)
if self.output_hidden_states:
outputs = outputs + (layer_hidden_states,)
if self.output_attentions:
outputs = outputs + (layer_attentions,)
return outputs # last-layer hidden state, (layer hidden states), (layer attentions)
class AlbertTransformer(nn.Module):
......@@ -227,11 +248,30 @@ class AlbertTransformer(nn.Module):
def forward(self, hidden_states, attention_mask=None, head_mask=None):
hidden_states = self.embedding_hidden_mapping_in(hidden_states)
all_attentions = ()
if self.output_hidden_states:
all_hidden_states = (hidden_states,)
for layer_idx in range(self.config.num_hidden_layers):
group_idx = int(layer_idx / self.config.num_hidden_layers * self.config.num_hidden_groups)
hidden_states = self.albert_layer_groups[group_idx](hidden_states, attention_mask, head_mask)
layer_group_output = self.albert_layer_groups[group_idx](hidden_states, attention_mask, head_mask)
hidden_states = layer_group_output[0]
if self.output_attentions:
all_attentions = all_attentions + layer_group_output[1]
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
return (hidden_states,)
outputs = (hidden_states,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions)
ALBERT_START_DOCSTRING = r""" The ALBERT model was proposed in
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment