"...git@developer.sourcefind.cn:dcuai/dlexamples.git" did not exist on "82496fd438242f3904c61d2f2254913eaeb4b8e9"
Commit e4b46d86 authored by thomwolf's avatar thomwolf
Browse files

update head pruning

parent 0f40e8d6
...@@ -92,7 +92,13 @@ def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True, ...@@ -92,7 +92,13 @@ def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True,
# Normalize # Normalize
attn_entropy /= tot_tokens attn_entropy /= tot_tokens
head_importance /= tot_tokens head_importance /= tot_tokens
if args.normalize_importance: # Layerwise importance normalization
if not args.dont_normalize_importance_by_layer:
exponent = 2
norm_by_layer = torch.pow(torch.pow(head_importance, exponent).sum(-1), 1/exponent)
head_importance /= norm_by_layer.unsqueeze(-1) + 1e-20
if not args.dont_normalize_global_importance:
head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min()) head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min())
return attn_entropy, head_importance, preds, labels return attn_entropy, head_importance, preds, labels
...@@ -106,7 +112,8 @@ def run_model(): ...@@ -106,7 +112,8 @@ def run_model():
parser.add_argument("--data_subset", type=int, default=-1, help="If > 0: limit the data to a subset of data_subset instances.") parser.add_argument("--data_subset", type=int, default=-1, help="If > 0: limit the data to a subset of data_subset instances.")
parser.add_argument("--overwrite_output_dir", action='store_true', help="Whether to overwrite data in output directory") parser.add_argument("--overwrite_output_dir", action='store_true', help="Whether to overwrite data in output directory")
parser.add_argument("--normalize_importance", action='store_true', help="Whether to normalize importance score between 0 and 1") parser.add_argument("--dont_normalize_importance_by_layer", action='store_true', help="Don't normalize importance score by layers")
parser.add_argument("--dont_normalize_global_importance", action='store_true', help="Don't normalize all importance scores between 0 and 1")
parser.add_argument("--try_masking", action='store_true', help="Whether to try to mask head until a threshold of accuracy.") parser.add_argument("--try_masking", action='store_true', help="Whether to try to mask head until a threshold of accuracy.")
parser.add_argument("--masking_threshold", default=0.9, type=float, help="masking threshold in term of metrics" parser.add_argument("--masking_threshold", default=0.9, type=float, help="masking threshold in term of metrics"
...@@ -243,21 +250,20 @@ def run_model(): ...@@ -243,21 +250,20 @@ def run_model():
current_score = original_score current_score = original_score
while current_score >= original_score * args.masking_threshold: while current_score >= original_score * args.masking_threshold:
head_mask = new_head_mask head_mask = new_head_mask # save current head mask
# heads from most important to least # heads from most important to least - keep only not-masked heads
heads_to_mask = head_importance.view(-1).sort(descending=True)[1] head_importance = head_importance.view(-1)[head_mask.view(-1).nonzero()][:, 0]
# keep only not-masked heads current_heads_to_mask = head_importance.sort()[1]
heads_to_mask = heads_to_mask[head_mask.view(-1).nonzero()][:, 0]
if len(heads_to_mask) <= num_to_mask: if len(current_heads_to_mask) <= num_to_mask:
break break
# mask heads # mask heads
heads_to_mask = heads_to_mask[-num_to_mask:] current_heads_to_mask = current_heads_to_mask[:num_to_mask]
logger.info("Heads to mask: %s", str(heads_to_mask.tolist())) logger.info("Heads to mask: %s", str(current_heads_to_mask.tolist()))
new_head_mask = head_mask.view(-1) new_head_mask = head_mask.view(-1)
new_head_mask[heads_to_mask] = 0.0 new_head_mask[current_heads_to_mask] = 0.0
new_head_mask = new_head_mask.view_as(head_importance) new_head_mask = new_head_mask.view_as(head_mask)
print_2d_tensor(new_head_mask) print_2d_tensor(new_head_mask)
# Compute metric and head importance again # Compute metric and head importance again
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment