Commit fc1fbae4 authored by LysandreJik's avatar LysandreJik
Browse files

XLM can be pruned

parent 42e00cf9
...@@ -559,6 +559,12 @@ class XLMModel(XLMPreTrainedModel): ...@@ -559,6 +559,12 @@ class XLMModel(XLMPreTrainedModel):
self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config)) self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config))
self.layer_norm2.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps)) self.layer_norm2.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
if hasattr(config, "pruned_heads"):
pruned_heads = config.pruned_heads.copy().items()
for layer, heads in pruned_heads:
if self.attentions[int(layer)].n_heads == config.n_heads:
self.prune_heads({int(layer): list(map(int, heads))})
self.apply(self.init_weights) self.apply(self.init_weights)
def _resize_token_embeddings(self, new_num_tokens): def _resize_token_embeddings(self, new_num_tokens):
......
...@@ -269,7 +269,6 @@ class CommonTestCases: ...@@ -269,7 +269,6 @@ class CommonTestCases:
shutil.rmtree(directory) shutil.rmtree(directory)
def test_head_pruning_save_load_from_config_init(self): def test_head_pruning_save_load_from_config_init(self):
print(self.test_pruning)
if not self.test_pruning: if not self.test_pruning:
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment