@@ -355,14 +359,30 @@ class PreTrainedModel(nn.Module):
returnmodel_embeds
definit_weights(self):
""" Initialize and prunes weights if needed. """
# Initialize weights
self.apply(self._init_weights)
# Prune heads if needed
ifself.config.pruned_heads:
self.prune_heads(self.config.pruned_heads)
defprune_heads(self,heads_to_prune):
""" Prunes heads of the base model.
Arguments:
heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
"""
base_model=getattr(self,self.base_model_prefix,self)# get the base model if needed
# save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads