Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
edfe91c3
Commit
edfe91c3
authored
Jun 19, 2019
by
thomwolf
Browse files
first version bertology ok
parent
7766ce66
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
8 deletions
+16
-8
examples/bertology.py
examples/bertology.py
+16
-8
No files found.
examples/bertology.py
View file @
edfe91c3
...
@@ -25,17 +25,20 @@ def entropy(p):
...
@@ -25,17 +25,20 @@ def entropy(p):
plogp
[
p
==
0
]
=
0
plogp
[
p
==
0
]
=
0
return
-
plogp
.
sum
(
dim
=-
1
)
return
-
plogp
.
sum
(
dim
=-
1
)
def
print_1d_tensor
(
tensor
,
prefix
=
""
):
def
print_1d_tensor
(
tensor
,
prefix
=
""
):
if
tensor
.
dtype
!=
torch
.
long
:
if
tensor
.
dtype
!=
torch
.
long
:
logger
.
info
(
prefix
+
"
\t
"
.
join
(
f
"
{
x
:.
5
f
}
"
for
x
in
tensor
.
cpu
().
data
))
logger
.
info
(
prefix
+
"
\t
"
.
join
(
f
"
{
x
:.
5
f
}
"
for
x
in
tensor
.
cpu
().
data
))
else
:
else
:
logger
.
info
(
prefix
+
"
\t
"
.
join
(
f
"
{
x
:
d
}
"
for
x
in
tensor
.
cpu
().
data
))
logger
.
info
(
prefix
+
"
\t
"
.
join
(
f
"
{
x
:
d
}
"
for
x
in
tensor
.
cpu
().
data
))
def
print_2d_tensor
(
tensor
):
def
print_2d_tensor
(
tensor
):
logger
.
info
(
"lv, h >
\t
"
+
"
\t
"
.
join
(
f
"
{
x
+
1
}
"
for
x
in
range
(
len
(
tensor
))))
logger
.
info
(
"lv, h >
\t
"
+
"
\t
"
.
join
(
f
"
{
x
+
1
}
"
for
x
in
range
(
len
(
tensor
))))
for
row
in
range
(
len
(
tensor
)):
for
row
in
range
(
len
(
tensor
)):
print_1d_tensor
(
tensor
[
row
],
prefix
=
f
"layer
{
row
+
1
}
:
\t
"
)
print_1d_tensor
(
tensor
[
row
],
prefix
=
f
"layer
{
row
+
1
}
:
\t
"
)
def
compute_heads_importance
(
args
,
model
,
eval_dataloader
,
compute_entropy
=
True
,
compute_importance
=
True
,
head_mask
=
None
):
def
compute_heads_importance
(
args
,
model
,
eval_dataloader
,
compute_entropy
=
True
,
compute_importance
=
True
,
head_mask
=
None
):
""" Example on how to use model outputs to compute:
""" Example on how to use model outputs to compute:
- head attention entropy (activated by setting output_attentions=True when we created the model
- head attention entropy (activated by setting output_attentions=True when we created the model
...
@@ -54,7 +57,7 @@ def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True,
...
@@ -54,7 +57,7 @@ def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True,
batch
=
tuple
(
t
.
to
(
args
.
device
)
for
t
in
batch
)
batch
=
tuple
(
t
.
to
(
args
.
device
)
for
t
in
batch
)
input_ids
,
input_mask
,
segment_ids
,
label_ids
=
batch
input_ids
,
input_mask
,
segment_ids
,
label_ids
=
batch
# Do a forward pass (not
in
torch.no_grad() since we need gradients for importance score - see below)
# Do a forward pass (not
with
torch.no_grad() since we need gradients for importance score - see below)
all_attentions
,
logits
=
model
(
input_ids
,
token_type_ids
=
segment_ids
,
attention_mask
=
input_mask
,
head_mask
=
head_mask
)
all_attentions
,
logits
=
model
(
input_ids
,
token_type_ids
=
segment_ids
,
attention_mask
=
input_mask
,
head_mask
=
head_mask
)
if
compute_entropy
:
if
compute_entropy
:
...
@@ -103,6 +106,7 @@ def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True,
...
@@ -103,6 +106,7 @@ def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True,
return
attn_entropy
,
head_importance
,
preds
,
labels
return
attn_entropy
,
head_importance
,
preds
,
labels
def
run_model
():
def
run_model
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--model_name_or_path'
,
type
=
str
,
default
=
'bert-base-cased-finetuned-mrpc'
,
help
=
'pretrained model name or path to local checkpoint'
)
parser
.
add_argument
(
'--model_name_or_path'
,
type
=
str
,
default
=
'bert-base-cased-finetuned-mrpc'
,
help
=
'pretrained model name or path to local checkpoint'
)
...
@@ -212,7 +216,7 @@ def run_model():
...
@@ -212,7 +216,7 @@ def run_model():
eval_data
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_label_ids
)
eval_data
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_label_ids
)
if
args
.
data_subset
>
0
:
if
args
.
data_subset
>
0
:
eval_data
=
Subset
(
eval_data
,
list
(
range
(
args
.
data_subset
)))
eval_data
=
Subset
(
eval_data
,
list
(
range
(
min
(
args
.
data_subset
,
len
(
eval_data
))
)))
eval_sampler
=
SequentialSampler
(
eval_data
)
if
args
.
local_rank
==
-
1
else
DistributedSampler
(
eval_data
)
eval_sampler
=
SequentialSampler
(
eval_data
)
if
args
.
local_rank
==
-
1
else
DistributedSampler
(
eval_data
)
eval_dataloader
=
DataLoader
(
eval_data
,
sampler
=
eval_sampler
,
batch_size
=
args
.
batch_size
)
eval_dataloader
=
DataLoader
(
eval_data
,
sampler
=
eval_sampler
,
batch_size
=
args
.
batch_size
)
...
@@ -246,14 +250,14 @@ def run_model():
...
@@ -246,14 +250,14 @@ def run_model():
logger
.
info
(
"Pruning: original score: %f, threshold: %f"
,
original_score
,
original_score
*
args
.
masking_threshold
)
logger
.
info
(
"Pruning: original score: %f, threshold: %f"
,
original_score
,
original_score
*
args
.
masking_threshold
)
new_head_mask
=
torch
.
ones_like
(
head_importance
)
new_head_mask
=
torch
.
ones_like
(
head_importance
)
num_to_mask
=
int
(
new_head_mask
.
numel
()
*
args
.
masking_amount
)
num_to_mask
=
max
(
1
,
int
(
new_head_mask
.
numel
()
*
args
.
masking_amount
)
)
current_score
=
original_score
current_score
=
original_score
while
current_score
>=
original_score
*
args
.
masking_threshold
:
while
current_score
>=
original_score
*
args
.
masking_threshold
:
head_mask
=
new_head_mask
# save current head mask
head_mask
=
new_head_mask
.
clone
()
# save current head mask
# heads from
mo
st important to
lea
st - keep only not-masked heads
# heads from
lea
st important to
mo
st - keep only not-masked heads
head_importance
=
head_importance
.
view
(
-
1
)[
head_mask
.
view
(
-
1
).
nonzero
()][:,
0
]
head_importance
[
head_mask
==
0.0
]
=
float
(
'Inf'
)
current_heads_to_mask
=
head_importance
.
sort
()[
1
]
current_heads_to_mask
=
head_importance
.
view
(
-
1
).
sort
()[
1
]
if
len
(
current_heads_to_mask
)
<=
num_to_mask
:
if
len
(
current_heads_to_mask
)
<=
num_to_mask
:
break
break
...
@@ -261,7 +265,7 @@ def run_model():
...
@@ -261,7 +265,7 @@ def run_model():
# mask heads
# mask heads
current_heads_to_mask
=
current_heads_to_mask
[:
num_to_mask
]
current_heads_to_mask
=
current_heads_to_mask
[:
num_to_mask
]
logger
.
info
(
"Heads to mask: %s"
,
str
(
current_heads_to_mask
.
tolist
()))
logger
.
info
(
"Heads to mask: %s"
,
str
(
current_heads_to_mask
.
tolist
()))
new_head_mask
=
head_mask
.
view
(
-
1
)
new_head_mask
=
new_
head_mask
.
view
(
-
1
)
new_head_mask
[
current_heads_to_mask
]
=
0.0
new_head_mask
[
current_heads_to_mask
]
=
0.0
new_head_mask
=
new_head_mask
.
view_as
(
head_mask
)
new_head_mask
=
new_head_mask
.
view_as
(
head_mask
)
print_2d_tensor
(
new_head_mask
)
print_2d_tensor
(
new_head_mask
)
...
@@ -272,6 +276,10 @@ def run_model():
...
@@ -272,6 +276,10 @@ def run_model():
current_score
=
compute_metrics
(
task_name
,
preds
,
labels
)[
args
.
metric_name
]
current_score
=
compute_metrics
(
task_name
,
preds
,
labels
)[
args
.
metric_name
]
logger
.
info
(
"Masking: current score: %f, remaning heads %d (%.1f percents)"
,
current_score
,
new_head_mask
.
sum
(),
new_head_mask
.
sum
()
/
new_head_mask
.
numel
()
*
100
)
logger
.
info
(
"Masking: current score: %f, remaning heads %d (%.1f percents)"
,
current_score
,
new_head_mask
.
sum
(),
new_head_mask
.
sum
()
/
new_head_mask
.
numel
()
*
100
)
logger
.
info
(
"Final head mask"
)
print_2d_tensor
(
head_mask
)
np
.
save
(
os
.
path
.
join
(
args
.
output_dir
,
'head_mask.npy'
),
head_mask
.
detach
().
cpu
().
numpy
())
# Try pruning and test time speedup
# Try pruning and test time speedup
# Pruning is like masking but we actually remove the masked weights
# Pruning is like masking but we actually remove the masked weights
before_time
=
datetime
.
now
()
before_time
=
datetime
.
now
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment