Commit c85b5db6 authored by LysandreJik's avatar LysandreJik
Browse files

Conditional append/init + fixed warning

parent 5c2b94c8
......@@ -379,11 +379,15 @@ class PreTrainedModel(nn.Module):
for head in heads:
if head not in self.config.pruned_heads[int(layer)]:
self.config.pruned_heads[int(layer)].append(head)
if int(layer) in to_be_pruned:
to_be_pruned[int(layer)].append(head)
else:
logger.warning("Tried to remove head " + head +
" of layer " + layer +
" but it was already removed. The current removed heads are " + heads_to_prune)
to_be_pruned[int(layer)] = [head]
else:
logger.warning("Tried to remove head " + str(head) +
" of layer " + str(layer) +
" but it was already removed. The current removed heads are " + str(heads_to_prune))
base_model._prune_heads(to_be_pruned)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment