Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0e1e8128
Commit
0e1e8128
authored
Jun 19, 2019
by
thomwolf
Browse files
more logging
parent
909d4f1a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
2 deletions
+3
-2
examples/bertology.py
examples/bertology.py
+3
-2
No files found.
examples/bertology.py
View file @
0e1e8128
...
...
@@ -227,7 +227,7 @@ def run_model():
_
,
head_importance
,
preds
,
labels
=
compute_heads_importance
(
args
,
model
,
eval_dataloader
,
compute_entropy
=
False
)
preds
=
np
.
argmax
(
preds
,
axis
=
1
)
if
args
.
output_mode
==
"classification"
else
np
.
squeeze
(
preds
)
original_score
=
compute_metrics
(
task_name
,
preds
,
labels
)[
args
.
metric_name
]
logger
.
info
(
"Pruning: original score: %f"
,
original_score
)
logger
.
info
(
"Pruning: original score:
%f, threshold:
%f"
,
original_score
,
original_score
*
args
.
masking_threshold
)
new_head_mask
=
torch
.
ones_like
(
head_importance
)
num_to_mask
=
int
(
new_head_mask
.
numel
()
*
args
.
masking_amount
)
...
...
@@ -245,6 +245,7 @@ def run_model():
# mask heads
heads_to_mask
=
heads_to_mask
[
-
num_to_mask
:]
logger
.
info
(
"Heads to mask: %s"
,
str
(
heads_to_mask
.
tolist
()))
new_head_mask
=
head_mask
.
view
(
-
1
)
new_head_mask
[
heads_to_mask
]
=
0.0
new_head_mask
=
new_head_mask
.
view_as
(
head_importance
)
...
...
@@ -254,7 +255,7 @@ def run_model():
_
,
head_importance
,
preds
,
labels
=
compute_heads_importance
(
args
,
model
,
eval_dataloader
,
compute_entropy
=
False
,
head_mask
=
new_head_mask
)
preds
=
np
.
argmax
(
preds
,
axis
=
1
)
if
args
.
output_mode
==
"classification"
else
np
.
squeeze
(
preds
)
current_score
=
compute_metrics
(
task_name
,
preds
,
labels
)[
args
.
metric_name
]
logger
.
info
(
"Masking: current score: %f, remaning heads %.1f percents"
,
current_score
,
head_mask
.
sum
()
/
head_mask
.
numel
()
*
100
)
logger
.
info
(
"Masking: current score: %f, remaning heads
%d (
%.1f percents
)
"
,
current_score
,
new_head_mask
.
sum
(),
new_
head_mask
.
sum
()
/
new_
head_mask
.
numel
()
*
100
)
# Try pruning and test time speedup
# Pruning is like masking but we actually remove the masked weights
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment