Commit d9daad98 authored by Lysandre's avatar Lysandre Committed by Lysandre Debut
Browse files

Re-ordering of group_idx/layer_idx + Python 2 tests

parent 9d5c4954
...@@ -281,11 +281,17 @@ class AlbertTransformer(nn.Module): ...@@ -281,11 +281,17 @@ class AlbertTransformer(nn.Module):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states = (hidden_states,) all_hidden_states = (hidden_states,)
for layer_idx in range(self.config.num_hidden_layers): for i in range(self.config.num_hidden_layers):
group_idx = int(layer_idx / self.config.num_hidden_layers * self.config.num_hidden_groups) # Number of layers in a hidden group
layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups) layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups)
layer_group_output = self.albert_layer_groups[group_idx](hidden_states, attention_mask, head_mask[group_idx*layers_per_group:(group_idx+1)*layers_per_group])
# Index of the hidden group
group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
# Index of the layer inside the group
layer_idx = int(i - group_idx * layers_per_group)
layer_group_output = self.albert_layer_groups[group_idx](hidden_states, attention_mask, head_mask[group_idx*layers_per_group:(group_idx+1)*layers_per_group])
hidden_states = layer_group_output[0] hidden_states = layer_group_output[0]
if self.output_attentions: if self.output_attentions:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment