"...resnet50_tensorflow.git" did not exist on "44fa1d377c81371a85256db57563d3e2016c7730"
Commit de5e5682 authored by VictorSanh's avatar VictorSanh
Browse files

add output_attentions for BertModel

parent 275179a0
...@@ -275,7 +275,7 @@ class BertEmbeddings(nn.Module): ...@@ -275,7 +275,7 @@ class BertEmbeddings(nn.Module):
class BertSelfAttention(nn.Module): class BertSelfAttention(nn.Module):
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(BertSelfAttention, self).__init__() super(BertSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0: if config.hidden_size % config.num_attention_heads != 0:
raise ValueError( raise ValueError(
...@@ -291,6 +291,8 @@ class BertSelfAttention(nn.Module): ...@@ -291,6 +291,8 @@ class BertSelfAttention(nn.Module):
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.output_attentions = output_attentions
def transpose_for_scores(self, x): def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape) x = x.view(*new_x_shape)
...@@ -322,7 +324,10 @@ class BertSelfAttention(nn.Module): ...@@ -322,7 +324,10 @@ class BertSelfAttention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(*new_context_layer_shape)
return context_layer if self.output_attentions:
return attention_probs, context_layer
else:
return context_layer
class BertSelfOutput(nn.Module): class BertSelfOutput(nn.Module):
...@@ -381,33 +386,43 @@ class BertOutput(nn.Module): ...@@ -381,33 +386,43 @@ class BertOutput(nn.Module):
class BertLayer(nn.Module): class BertLayer(nn.Module):
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(BertLayer, self).__init__() super(BertLayer, self).__init__()
self.attention = BertAttention(config) self.attention = BertAttention(config)
self.intermediate = BertIntermediate(config) self.intermediate = BertIntermediate(config)
self.output = BertOutput(config) self.output = BertOutput(config)
self.output_attentions = output_attentions
def forward(self, hidden_states, attention_mask): def forward(self, hidden_states, attention_mask):
attention_output = self.attention(hidden_states, attention_mask) attention_output = self.attention(hidden_states, attention_mask)
intermediate_output = self.intermediate(attention_output) intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output) layer_output = self.output(intermediate_output, attention_output)
if self.output_attentions:
return attention_output, layer_output
return layer_output return layer_output
class BertEncoder(nn.Module): class BertEncoder(nn.Module):
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(BertEncoder, self).__init__() super(BertEncoder, self).__init__()
layer = BertLayer(config) layer = BertLayer(config, output_attentions=output_attentions)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
self.output_attentions = output_attentions
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
all_encoder_layers = [] all_encoder_layers = []
all_attentions = []
for layer_module in self.layer: for layer_module in self.layer:
hidden_states = layer_module(hidden_states, attention_mask) hidden_states = layer_module(hidden_states, attention_mask)
if self.output_attentions:
attentions, hidden_states = hidden_states
all_attentions.append(attentions)
if output_all_encoded_layers: if output_all_encoded_layers:
all_encoder_layers.append(hidden_states) all_encoder_layers.append(hidden_states)
if not output_all_encoded_layers: if not output_all_encoded_layers:
all_encoder_layers.append(hidden_states) all_encoder_layers.append(hidden_states)
if self.output_attentions:
return all_attentions, all_encoder_layers
return all_encoder_layers return all_encoder_layers
...@@ -699,12 +714,13 @@ class BertModel(BertPreTrainedModel): ...@@ -699,12 +714,13 @@ class BertModel(BertPreTrainedModel):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(BertModel, self).__init__(config) super(BertModel, self).__init__(config)
self.embeddings = BertEmbeddings(config) self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config) self.encoder = BertEncoder(config, output_attentions=output_attentions)
self.pooler = BertPooler(config) self.pooler = BertPooler(config)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
self.output_attentions = output_attentions
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
if attention_mask is None: if attention_mask is None:
...@@ -731,10 +747,14 @@ class BertModel(BertPreTrainedModel): ...@@ -731,10 +747,14 @@ class BertModel(BertPreTrainedModel):
encoded_layers = self.encoder(embedding_output, encoded_layers = self.encoder(embedding_output,
extended_attention_mask, extended_attention_mask,
output_all_encoded_layers=output_all_encoded_layers) output_all_encoded_layers=output_all_encoded_layers)
if self.output_attentions:
all_attentions, encoded_layers = encoded_layers
sequence_output = encoded_layers[-1] sequence_output = encoded_layers[-1]
pooled_output = self.pooler(sequence_output) pooled_output = self.pooler(sequence_output)
if not output_all_encoded_layers: if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1] encoded_layers = encoded_layers[-1]
if self.output_attentions:
return all_attentions, encoded_layers, pooled_output
return encoded_layers, pooled_output return encoded_layers, pooled_output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment