Commit 7766ce66 authored by thomwolf's avatar thomwolf
Browse files

update bertology

parent 7f00a36e
...@@ -281,9 +281,11 @@ def run_model(): ...@@ -281,9 +281,11 @@ def run_model():
score_masking = compute_metrics(task_name, preds, labels)[args.metric_name] score_masking = compute_metrics(task_name, preds, labels)[args.metric_name]
original_time = datetime.now() - before_time original_time = datetime.now() - before_time
original_num_params = sum(p.numel() for p in model.parameters())
heads_to_prune = dict((layer, (1 - head_mask[layer].long()).nonzero().tolist()) for layer in range(len(head_mask))) heads_to_prune = dict((layer, (1 - head_mask[layer].long()).nonzero().tolist()) for layer in range(len(head_mask)))
assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item() assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item()
model.bert.prune_heads(heads_to_prune) model.bert.prune_heads(heads_to_prune)
pruned_num_params = sum(p.numel() for p in model.parameters())
before_time = datetime.now() before_time = datetime.now()
_, _, preds, labels = compute_heads_importance(args, model, eval_dataloader, _, _, preds, labels = compute_heads_importance(args, model, eval_dataloader,
...@@ -292,6 +294,7 @@ def run_model(): ...@@ -292,6 +294,7 @@ def run_model():
score_pruning = compute_metrics(task_name, preds, labels)[args.metric_name] score_pruning = compute_metrics(task_name, preds, labels)[args.metric_name]
new_time = datetime.now() - before_time new_time = datetime.now() - before_time
logger.info("Pruning: original num of params: %.2e, after pruning %.2e (%.1f percents)", original_num_params, pruned_num_params, pruned_num_params/original_num_params * 100)
logger.info("Pruning: score with masking: %f score with pruning: %f", score_masking, score_pruning) logger.info("Pruning: score with masking: %f score with pruning: %f", score_masking, score_pruning)
logger.info("Pruning: speed ratio (new timing / original timing): %f percents", original_time/new_time * 100) logger.info("Pruning: speed ratio (new timing / original timing): %f percents", original_time/new_time * 100)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment