Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
edfe91c3
Commit
edfe91c3
authored
Jun 19, 2019
by
thomwolf
Browse files
first version bertology ok
parent
7766ce66
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
8 deletions
+16
-8
examples/bertology.py
examples/bertology.py
+16
-8
No files found.
examples/bertology.py
View file @
edfe91c3
...
...
@@ -25,17 +25,20 @@ def entropy(p):
plogp
[
p
==
0
]
=
0
return
-
plogp
.
sum
(
dim
=-
1
)
def
print_1d_tensor
(
tensor
,
prefix
=
""
):
if
tensor
.
dtype
!=
torch
.
long
:
logger
.
info
(
prefix
+
"
\t
"
.
join
(
f
"
{
x
:.
5
f
}
"
for
x
in
tensor
.
cpu
().
data
))
else
:
logger
.
info
(
prefix
+
"
\t
"
.
join
(
f
"
{
x
:
d
}
"
for
x
in
tensor
.
cpu
().
data
))
def
print_2d_tensor
(
tensor
):
logger
.
info
(
"lv, h >
\t
"
+
"
\t
"
.
join
(
f
"
{
x
+
1
}
"
for
x
in
range
(
len
(
tensor
))))
for
row
in
range
(
len
(
tensor
)):
print_1d_tensor
(
tensor
[
row
],
prefix
=
f
"layer
{
row
+
1
}
:
\t
"
)
def
compute_heads_importance
(
args
,
model
,
eval_dataloader
,
compute_entropy
=
True
,
compute_importance
=
True
,
head_mask
=
None
):
""" Example on how to use model outputs to compute:
- head attention entropy (activated by setting output_attentions=True when we created the model
...
...
@@ -54,7 +57,7 @@ def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True,
batch
=
tuple
(
t
.
to
(
args
.
device
)
for
t
in
batch
)
input_ids
,
input_mask
,
segment_ids
,
label_ids
=
batch
# Do a forward pass (not
in
torch.no_grad() since we need gradients for importance score - see below)
# Do a forward pass (not
with
torch.no_grad() since we need gradients for importance score - see below)
all_attentions
,
logits
=
model
(
input_ids
,
token_type_ids
=
segment_ids
,
attention_mask
=
input_mask
,
head_mask
=
head_mask
)
if
compute_entropy
:
...
...
@@ -103,6 +106,7 @@ def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True,
return
attn_entropy
,
head_importance
,
preds
,
labels
def
run_model
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--model_name_or_path'
,
type
=
str
,
default
=
'bert-base-cased-finetuned-mrpc'
,
help
=
'pretrained model name or path to local checkpoint'
)
...
...
@@ -212,7 +216,7 @@ def run_model():
eval_data
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_label_ids
)
if
args
.
data_subset
>
0
:
eval_data
=
Subset
(
eval_data
,
list
(
range
(
args
.
data_subset
)))
eval_data
=
Subset
(
eval_data
,
list
(
range
(
min
(
args
.
data_subset
,
len
(
eval_data
))
)))
eval_sampler
=
SequentialSampler
(
eval_data
)
if
args
.
local_rank
==
-
1
else
DistributedSampler
(
eval_data
)
eval_dataloader
=
DataLoader
(
eval_data
,
sampler
=
eval_sampler
,
batch_size
=
args
.
batch_size
)
...
...
@@ -246,14 +250,14 @@ def run_model():
logger
.
info
(
"Pruning: original score: %f, threshold: %f"
,
original_score
,
original_score
*
args
.
masking_threshold
)
new_head_mask
=
torch
.
ones_like
(
head_importance
)
num_to_mask
=
int
(
new_head_mask
.
numel
()
*
args
.
masking_amount
)
num_to_mask
=
max
(
1
,
int
(
new_head_mask
.
numel
()
*
args
.
masking_amount
)
)
current_score
=
original_score
while
current_score
>=
original_score
*
args
.
masking_threshold
:
head_mask
=
new_head_mask
# save current head mask
# heads from
mo
st important to
lea
st - keep only not-masked heads
head_importance
=
head_importance
.
view
(
-
1
)[
head_mask
.
view
(
-
1
).
nonzero
()][:,
0
]
current_heads_to_mask
=
head_importance
.
sort
()[
1
]
head_mask
=
new_head_mask
.
clone
()
# save current head mask
# heads from
lea
st important to
mo
st - keep only not-masked heads
head_importance
[
head_mask
==
0.0
]
=
float
(
'Inf'
)
current_heads_to_mask
=
head_importance
.
view
(
-
1
).
sort
()[
1
]
if
len
(
current_heads_to_mask
)
<=
num_to_mask
:
break
...
...
@@ -261,7 +265,7 @@ def run_model():
# mask heads
current_heads_to_mask
=
current_heads_to_mask
[:
num_to_mask
]
logger
.
info
(
"Heads to mask: %s"
,
str
(
current_heads_to_mask
.
tolist
()))
new_head_mask
=
head_mask
.
view
(
-
1
)
new_head_mask
=
new_
head_mask
.
view
(
-
1
)
new_head_mask
[
current_heads_to_mask
]
=
0.0
new_head_mask
=
new_head_mask
.
view_as
(
head_mask
)
print_2d_tensor
(
new_head_mask
)
...
...
@@ -272,6 +276,10 @@ def run_model():
current_score
=
compute_metrics
(
task_name
,
preds
,
labels
)[
args
.
metric_name
]
logger
.
info
(
"Masking: current score: %f, remaning heads %d (%.1f percents)"
,
current_score
,
new_head_mask
.
sum
(),
new_head_mask
.
sum
()
/
new_head_mask
.
numel
()
*
100
)
logger
.
info
(
"Final head mask"
)
print_2d_tensor
(
head_mask
)
np
.
save
(
os
.
path
.
join
(
args
.
output_dir
,
'head_mask.npy'
),
head_mask
.
detach
().
cpu
().
numpy
())
# Try pruning and test time speedup
# Pruning is like masking but we actually remove the masked weights
before_time
=
datetime
.
now
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment