@@ -170,6 +197,9 @@ class PreTrainedModel(nn.Module):
ifself.config.pruned_heads:
self.prune_heads(self.config.pruned_heads)
# Tie weights if needed
self.tie_weights()
defprune_heads(self,heads_to_prune):
""" Prunes heads of the base model.
...
...
@@ -178,14 +208,12 @@ class PreTrainedModel(nn.Module):
heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
"""
base_model=getattr(self,self.base_model_prefix,self)# get the base model if needed
# save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads