Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e4b46d86
"...git@developer.sourcefind.cn:dcuai/dlexamples.git" did not exist on "82496fd438242f3904c61d2f2254913eaeb4b8e9"
Commit
e4b46d86
authored
Jun 19, 2019
by
thomwolf
Browse files
update head pruning
parent
0f40e8d6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
12 deletions
+18
-12
examples/bertology.py
examples/bertology.py
+18
-12
No files found.
examples/bertology.py
View file @
e4b46d86
...
@@ -92,7 +92,13 @@ def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True,
...
@@ -92,7 +92,13 @@ def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True,
# Normalize
# Normalize
attn_entropy
/=
tot_tokens
attn_entropy
/=
tot_tokens
head_importance
/=
tot_tokens
head_importance
/=
tot_tokens
if
args
.
normalize_importance
:
# Layerwise importance normalization
if
not
args
.
dont_normalize_importance_by_layer
:
exponent
=
2
norm_by_layer
=
torch
.
pow
(
torch
.
pow
(
head_importance
,
exponent
).
sum
(
-
1
),
1
/
exponent
)
head_importance
/=
norm_by_layer
.
unsqueeze
(
-
1
)
+
1e-20
if
not
args
.
dont_normalize_global_importance
:
head_importance
=
(
head_importance
-
head_importance
.
min
())
/
(
head_importance
.
max
()
-
head_importance
.
min
())
head_importance
=
(
head_importance
-
head_importance
.
min
())
/
(
head_importance
.
max
()
-
head_importance
.
min
())
return
attn_entropy
,
head_importance
,
preds
,
labels
return
attn_entropy
,
head_importance
,
preds
,
labels
...
@@ -106,7 +112,8 @@ def run_model():
...
@@ -106,7 +112,8 @@ def run_model():
parser
.
add_argument
(
"--data_subset"
,
type
=
int
,
default
=-
1
,
help
=
"If > 0: limit the data to a subset of data_subset instances."
)
parser
.
add_argument
(
"--data_subset"
,
type
=
int
,
default
=-
1
,
help
=
"If > 0: limit the data to a subset of data_subset instances."
)
parser
.
add_argument
(
"--overwrite_output_dir"
,
action
=
'store_true'
,
help
=
"Whether to overwrite data in output directory"
)
parser
.
add_argument
(
"--overwrite_output_dir"
,
action
=
'store_true'
,
help
=
"Whether to overwrite data in output directory"
)
parser
.
add_argument
(
"--normalize_importance"
,
action
=
'store_true'
,
help
=
"Whether to normalize importance score between 0 and 1"
)
parser
.
add_argument
(
"--dont_normalize_importance_by_layer"
,
action
=
'store_true'
,
help
=
"Don't normalize importance score by layers"
)
parser
.
add_argument
(
"--dont_normalize_global_importance"
,
action
=
'store_true'
,
help
=
"Don't normalize all importance scores between 0 and 1"
)
parser
.
add_argument
(
"--try_masking"
,
action
=
'store_true'
,
help
=
"Whether to try to mask head until a threshold of accuracy."
)
parser
.
add_argument
(
"--try_masking"
,
action
=
'store_true'
,
help
=
"Whether to try to mask head until a threshold of accuracy."
)
parser
.
add_argument
(
"--masking_threshold"
,
default
=
0.9
,
type
=
float
,
help
=
"masking threshold in term of metrics"
parser
.
add_argument
(
"--masking_threshold"
,
default
=
0.9
,
type
=
float
,
help
=
"masking threshold in term of metrics"
...
@@ -243,21 +250,20 @@ def run_model():
...
@@ -243,21 +250,20 @@ def run_model():
current_score
=
original_score
current_score
=
original_score
while
current_score
>=
original_score
*
args
.
masking_threshold
:
while
current_score
>=
original_score
*
args
.
masking_threshold
:
head_mask
=
new_head_mask
head_mask
=
new_head_mask
# save current head mask
# heads from most important to least
# heads from most important to least - keep only not-masked heads
heads_to_mask
=
head_importance
.
view
(
-
1
).
sort
(
descending
=
True
)[
1
]
head_importance
=
head_importance
.
view
(
-
1
)[
head_mask
.
view
(
-
1
).
nonzero
()][:,
0
]
# keep only not-masked heads
current_heads_to_mask
=
head_importance
.
sort
()[
1
]
heads_to_mask
=
heads_to_mask
[
head_mask
.
view
(
-
1
).
nonzero
()][:,
0
]
if
len
(
heads_to_mask
)
<=
num_to_mask
:
if
len
(
current_
heads_to_mask
)
<=
num_to_mask
:
break
break
# mask heads
# mask heads
heads_to_mask
=
heads_to_mask
[
-
num_to_mask
:
]
current_
heads_to_mask
=
current_
heads_to_mask
[
:
num_to_mask
]
logger
.
info
(
"Heads to mask: %s"
,
str
(
heads_to_mask
.
tolist
()))
logger
.
info
(
"Heads to mask: %s"
,
str
(
current_
heads_to_mask
.
tolist
()))
new_head_mask
=
head_mask
.
view
(
-
1
)
new_head_mask
=
head_mask
.
view
(
-
1
)
new_head_mask
[
heads_to_mask
]
=
0.0
new_head_mask
[
current_
heads_to_mask
]
=
0.0
new_head_mask
=
new_head_mask
.
view_as
(
head_
importance
)
new_head_mask
=
new_head_mask
.
view_as
(
head_
mask
)
print_2d_tensor
(
new_head_mask
)
print_2d_tensor
(
new_head_mask
)
# Compute metric and head importance again
# Compute metric and head importance again
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment