"vscode:/vscode.git/clone" did not exist on "c69ea5efc4eac65b183e8d07b1bf91d20bbe0c8c"
Commit 12290c0d authored by Lysandre's avatar Lysandre Committed by Lysandre Debut
Browse files

Handles multi layer and multi groups

parent 139affaa
...@@ -136,7 +136,6 @@ class AlbertModel(BertModel): ...@@ -136,7 +136,6 @@ class AlbertModel(BertModel):
head_mask=head_mask) head_mask=head_mask)
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
print(sequence_output.shape, sequence_output[:, 0].shape, self.pooler(sequence_output[:, 0]).shape)
pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0])) pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0]))
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
...@@ -260,7 +259,6 @@ class AlbertLayer(nn.Module): ...@@ -260,7 +259,6 @@ class AlbertLayer(nn.Module):
self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size) self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states, attention_mask=None, head_mask=None): def forward(self, hidden_states, attention_mask=None, head_mask=None):
for _ in range(self.config.inner_group_num):
attention_output = self.attention(hidden_states, attention_mask)[0] attention_output = self.attention(hidden_states, attention_mask)[0]
ffn_output = self.ffn(attention_output) ffn_output = self.ffn(attention_output)
ffn_output = gelu_new(ffn_output) ffn_output = gelu_new(ffn_output)
...@@ -303,16 +301,16 @@ class AlbertTransformer(nn.Module): ...@@ -303,16 +301,16 @@ class AlbertTransformer(nn.Module):
return (hidden_states,) return (hidden_states,)
model_size = 'base' # model_size = 'base'
hidden_groups = 1 # hidden_groups = 1
inner_groups = 1 # inner_groups = 2
config = AlbertConfig.from_json_file("/home/hf/google-research/albert/config_{}-{}-hg-{}-ig.json".format(model_size, hidden_groups, inner_groups)) # config = AlbertConfig.from_json_file("/home/hf/google-research/albert/config_{}-{}-hg-{}-ig.json".format(model_size, hidden_groups, inner_groups))
model = AlbertModel(config) # model = AlbertModel(config)
print(model) # # print(model)
model = load_tf_weights_in_albert(model, config, "/home/hf/transformers/albert-{}-{}-hg-{}-ig/albert-{}-{}-hg-{}-ig".format(model_size, hidden_groups, inner_groups, model_size, hidden_groups, inner_groups)) # model = load_tf_weights_in_albert(model, config, "/home/hf/transformers/albert-{}-{}-hg-{}-ig/albert-{}-{}-hg-{}-ig".format(model_size, hidden_groups, inner_groups, model_size, hidden_groups, inner_groups))
model.eval() # # model.eval()
print(sum(p.numel() for p in model.parameters() if p.requires_grad)) # # print(sum(p.numel() for p in model.parameters() if p.requires_grad))
# input_ids = [[31, 51, 99, 88, 54, 34, 23, 23, 12], [15, 5, 0, 88, 54, 34, 23, 23, 12]] # input_ids = [[31, 51, 99, 88, 54, 34, 23, 23, 12], [15, 5, 0, 88, 54, 34, 23, 23, 12]]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment