Commit 16263f96 authored by Lysandre's avatar Lysandre Committed by Lysandre Debut
Browse files

Headmasking

parent abb23a78
...@@ -224,7 +224,7 @@ class AlbertLayer(nn.Module): ...@@ -224,7 +224,7 @@ class AlbertLayer(nn.Module):
self.activation = ACT2FN[config.hidden_act] self.activation = ACT2FN[config.hidden_act]
def forward(self, hidden_states, attention_mask=None, head_mask=None): def forward(self, hidden_states, attention_mask=None, head_mask=None):
attention_output = self.attention(hidden_states, attention_mask) attention_output = self.attention(hidden_states, attention_mask, head_mask)
ffn_output = self.ffn(attention_output[0]) ffn_output = self.ffn(attention_output[0])
ffn_output = self.activation(ffn_output) ffn_output = self.activation(ffn_output)
ffn_output = self.ffn_output(ffn_output) ffn_output = self.ffn_output(ffn_output)
...@@ -245,8 +245,8 @@ class AlbertLayerGroup(nn.Module): ...@@ -245,8 +245,8 @@ class AlbertLayerGroup(nn.Module):
layer_hidden_states = () layer_hidden_states = ()
layer_attentions = () layer_attentions = ()
for albert_layer in self.albert_layers: for layer_index, albert_layer in enumerate(self.albert_layers):
layer_output = albert_layer(hidden_states, attention_mask, head_mask) layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index])
hidden_states = layer_output[0] hidden_states = layer_output[0]
if self.output_attentions: if self.output_attentions:
...@@ -283,7 +283,8 @@ class AlbertTransformer(nn.Module): ...@@ -283,7 +283,8 @@ class AlbertTransformer(nn.Module):
for layer_idx in range(self.config.num_hidden_layers): for layer_idx in range(self.config.num_hidden_layers):
group_idx = int(layer_idx / self.config.num_hidden_layers * self.config.num_hidden_groups) group_idx = int(layer_idx / self.config.num_hidden_layers * self.config.num_hidden_groups)
layer_group_output = self.albert_layer_groups[group_idx](hidden_states, attention_mask, head_mask) layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups)
layer_group_output = self.albert_layer_groups[group_idx](hidden_states, attention_mask, head_mask[group_idx*layers_per_group:(group_idx+1)*layers_per_group])
hidden_states = layer_group_output[0] hidden_states = layer_group_output[0]
...@@ -544,7 +545,7 @@ class AlbertForMaskedLM(AlbertPreTrainedModel): ...@@ -544,7 +545,7 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
masked_lm_labels=None): masked_lm_labels=None):
outputs = self.albert(input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None) outputs = self.albert(input_ids, attention_mask, token_type_ids, position_ids, head_mask)
sequence_outputs = outputs[0] sequence_outputs = outputs[0]
prediction_scores = self.predictions(sequence_outputs) prediction_scores = self.predictions(sequence_outputs)
......
...@@ -35,7 +35,6 @@ else: ...@@ -35,7 +35,6 @@ else:
class AlbertModelTest(CommonTestCases.CommonModelTester): class AlbertModelTest(CommonTestCases.CommonModelTester):
all_model_classes = (AlbertModel, AlbertForMaskedLM) if is_torch_available() else () all_model_classes = (AlbertModel, AlbertForMaskedLM) if is_torch_available() else ()
test_head_masking = False
class AlbertModelTester(object): class AlbertModelTester(object):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment