Commit 0e1e8128 authored by thomwolf's avatar thomwolf
Browse files

more logging

parent 909d4f1a
......@@ -227,7 +227,7 @@ def run_model():
_, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False)
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
original_score = compute_metrics(task_name, preds, labels)[args.metric_name]
logger.info("Pruning: original score: %f", original_score)
logger.info("Pruning: original score: %f, threshold: %f", original_score, original_score * args.masking_threshold)
new_head_mask = torch.ones_like(head_importance)
num_to_mask = int(new_head_mask.numel() * args.masking_amount)
......@@ -245,6 +245,7 @@ def run_model():
# mask heads
heads_to_mask = heads_to_mask[-num_to_mask:]
logger.info("Heads to mask: %s", str(heads_to_mask.tolist()))
new_head_mask = head_mask.view(-1)
new_head_mask[heads_to_mask] = 0.0
new_head_mask = new_head_mask.view_as(head_importance)
......@@ -254,7 +255,7 @@ def run_model():
_, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False, head_mask=new_head_mask)
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
current_score = compute_metrics(task_name, preds, labels)[args.metric_name]
logger.info("Masking: current score: %f, remaning heads %.1f percents", current_score, head_mask.sum()/head_mask.numel() * 100)
logger.info("Masking: current score: %f, remaning heads %d (%.1f percents)", current_score, new_head_mask.sum(), new_head_mask.sum()/new_head_mask.numel() * 100)
# Try pruning and test time speedup
# Pruning is like masking but we actually remove the masked weights
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment