Commit edfe91c3 authored by thomwolf's avatar thomwolf
Browse files

first version bertology ok

parent 7766ce66
...@@ -25,17 +25,20 @@ def entropy(p): ...@@ -25,17 +25,20 @@ def entropy(p):
plogp[p == 0] = 0 plogp[p == 0] = 0
return -plogp.sum(dim=-1) return -plogp.sum(dim=-1)
def print_1d_tensor(tensor, prefix=""): def print_1d_tensor(tensor, prefix=""):
if tensor.dtype != torch.long: if tensor.dtype != torch.long:
logger.info(prefix + "\t".join(f"{x:.5f}" for x in tensor.cpu().data)) logger.info(prefix + "\t".join(f"{x:.5f}" for x in tensor.cpu().data))
else: else:
logger.info(prefix + "\t".join(f"{x:d}" for x in tensor.cpu().data)) logger.info(prefix + "\t".join(f"{x:d}" for x in tensor.cpu().data))
def print_2d_tensor(tensor): def print_2d_tensor(tensor):
logger.info("lv, h >\t" + "\t".join(f"{x + 1}" for x in range(len(tensor)))) logger.info("lv, h >\t" + "\t".join(f"{x + 1}" for x in range(len(tensor))))
for row in range(len(tensor)): for row in range(len(tensor)):
print_1d_tensor(tensor[row], prefix=f"layer {row + 1}:\t") print_1d_tensor(tensor[row], prefix=f"layer {row + 1}:\t")
def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None): def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None):
""" Example on how to use model outputs to compute: """ Example on how to use model outputs to compute:
- head attention entropy (activated by setting output_attentions=True when we created the model - head attention entropy (activated by setting output_attentions=True when we created the model
...@@ -54,7 +57,7 @@ def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True, ...@@ -54,7 +57,7 @@ def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True,
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch input_ids, input_mask, segment_ids, label_ids = batch
# Do a forward pass (not in torch.no_grad() since we need gradients for importance score - see below) # Do a forward pass (not with torch.no_grad() since we need gradients for importance score - see below)
all_attentions, logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, head_mask=head_mask) all_attentions, logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, head_mask=head_mask)
if compute_entropy: if compute_entropy:
...@@ -103,6 +106,7 @@ def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True, ...@@ -103,6 +106,7 @@ def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True,
return attn_entropy, head_importance, preds, labels return attn_entropy, head_importance, preds, labels
def run_model(): def run_model():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model_name_or_path', type=str, default='bert-base-cased-finetuned-mrpc', help='pretrained model name or path to local checkpoint') parser.add_argument('--model_name_or_path', type=str, default='bert-base-cased-finetuned-mrpc', help='pretrained model name or path to local checkpoint')
...@@ -212,7 +216,7 @@ def run_model(): ...@@ -212,7 +216,7 @@ def run_model():
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if args.data_subset > 0: if args.data_subset > 0:
eval_data = Subset(eval_data, list(range(args.data_subset))) eval_data = Subset(eval_data, list(range(min(args.data_subset, len(eval_data)))))
eval_sampler = SequentialSampler(eval_data) if args.local_rank == -1 else DistributedSampler(eval_data) eval_sampler = SequentialSampler(eval_data) if args.local_rank == -1 else DistributedSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size)
...@@ -246,14 +250,14 @@ def run_model(): ...@@ -246,14 +250,14 @@ def run_model():
logger.info("Pruning: original score: %f, threshold: %f", original_score, original_score * args.masking_threshold) logger.info("Pruning: original score: %f, threshold: %f", original_score, original_score * args.masking_threshold)
new_head_mask = torch.ones_like(head_importance) new_head_mask = torch.ones_like(head_importance)
num_to_mask = int(new_head_mask.numel() * args.masking_amount) num_to_mask = max(1, int(new_head_mask.numel() * args.masking_amount))
current_score = original_score current_score = original_score
while current_score >= original_score * args.masking_threshold: while current_score >= original_score * args.masking_threshold:
head_mask = new_head_mask # save current head mask head_mask = new_head_mask.clone() # save current head mask
# heads from most important to least - keep only not-masked heads # heads from least important to most - keep only not-masked heads
head_importance = head_importance.view(-1)[head_mask.view(-1).nonzero()][:, 0] head_importance[head_mask == 0.0] = float('Inf')
current_heads_to_mask = head_importance.sort()[1] current_heads_to_mask = head_importance.view(-1).sort()[1]
if len(current_heads_to_mask) <= num_to_mask: if len(current_heads_to_mask) <= num_to_mask:
break break
...@@ -261,7 +265,7 @@ def run_model(): ...@@ -261,7 +265,7 @@ def run_model():
# mask heads # mask heads
current_heads_to_mask = current_heads_to_mask[:num_to_mask] current_heads_to_mask = current_heads_to_mask[:num_to_mask]
logger.info("Heads to mask: %s", str(current_heads_to_mask.tolist())) logger.info("Heads to mask: %s", str(current_heads_to_mask.tolist()))
new_head_mask = head_mask.view(-1) new_head_mask = new_head_mask.view(-1)
new_head_mask[current_heads_to_mask] = 0.0 new_head_mask[current_heads_to_mask] = 0.0
new_head_mask = new_head_mask.view_as(head_mask) new_head_mask = new_head_mask.view_as(head_mask)
print_2d_tensor(new_head_mask) print_2d_tensor(new_head_mask)
...@@ -272,6 +276,10 @@ def run_model(): ...@@ -272,6 +276,10 @@ def run_model():
current_score = compute_metrics(task_name, preds, labels)[args.metric_name] current_score = compute_metrics(task_name, preds, labels)[args.metric_name]
logger.info("Masking: current score: %f, remaning heads %d (%.1f percents)", current_score, new_head_mask.sum(), new_head_mask.sum()/new_head_mask.numel() * 100) logger.info("Masking: current score: %f, remaning heads %d (%.1f percents)", current_score, new_head_mask.sum(), new_head_mask.sum()/new_head_mask.numel() * 100)
logger.info("Final head mask")
print_2d_tensor(head_mask)
np.save(os.path.join(args.output_dir, 'head_mask.npy'), head_mask.detach().cpu().numpy())
# Try pruning and test time speedup # Try pruning and test time speedup
# Pruning is like masking but we actually remove the masked weights # Pruning is like masking but we actually remove the masked weights
before_time = datetime.now() before_time = datetime.now()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment