Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1dc9b3c7
Commit
1dc9b3c7
authored
Apr 22, 2020
by
Julien Chaumond
Browse files
Fixes #3877
parent
dd9d483d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
49 additions
and
49 deletions
+49
-49
docs/source/bertology.rst
docs/source/bertology.rst
+1
-1
examples/run_bertology.py
examples/run_bertology.py
+48
-48
No files found.
docs/source/bertology.rst
View file @
1dc9b3c7
examples/run_bertology.py
View file @
1dc9b3c7
...
...
@@ -30,10 +30,17 @@ from torch.utils.data import DataLoader, SequentialSampler, Subset
from
torch.utils.data.distributed
import
DistributedSampler
from
tqdm
import
tqdm
from
run_glue
import
ALL_MODELS
,
MODEL_CLASSES
,
load_and_cache_examples
,
set_seed
from
transformers
import
glue_compute_metrics
as
compute_metrics
from
transformers
import
glue_output_modes
as
output_modes
from
transformers
import
glue_processors
as
processors
from
transformers
import
(
AutoConfig
,
AutoModelForSequenceClassification
,
AutoTokenizer
,
DefaultDataCollator
,
GlueDataset
,
glue_compute_metrics
,
glue_output_modes
,
glue_processors
,
set_seed
,
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -64,7 +71,7 @@ def compute_heads_importance(
- head importance scores according to http://arxiv.org/abs/1905.10650
"""
# Prepare our tensors
n_layers
,
n_heads
=
model
.
bert
.
config
.
num_hidden_layers
,
model
.
bert
.
config
.
num_attention_heads
n_layers
,
n_heads
=
model
.
config
.
num_hidden_layers
,
model
.
config
.
num_attention_heads
head_importance
=
torch
.
zeros
(
n_layers
,
n_heads
).
to
(
args
.
device
)
attn_entropy
=
torch
.
zeros
(
n_layers
,
n_heads
).
to
(
args
.
device
)
...
...
@@ -75,14 +82,12 @@ def compute_heads_importance(
labels
=
None
tot_tokens
=
0.0
for
step
,
batch
in
enumerate
(
tqdm
(
eval_dataloader
,
desc
=
"Iteration"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
])):
batch
=
tuple
(
t
.
to
(
args
.
device
)
for
t
in
batch
)
input_ids
,
input_mask
,
segment_ids
,
label_ids
=
batch
for
step
,
inputs
in
enumerate
(
tqdm
(
eval_dataloader
,
desc
=
"Iteration"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
])):
for
k
,
v
in
inputs
.
items
():
inputs
[
k
]
=
v
.
to
(
args
.
device
)
# Do a forward pass (not with torch.no_grad() since we need gradients for importance score - see below)
outputs
=
model
(
input_ids
,
token_type_ids
=
segment_ids
,
attention_mask
=
input_mask
,
labels
=
label_ids
,
head_mask
=
head_mask
)
outputs
=
model
(
**
inputs
,
head_mask
=
head_mask
)
loss
,
logits
,
all_attentions
=
(
outputs
[
0
],
outputs
[
1
],
...
...
@@ -92,7 +97,7 @@ def compute_heads_importance(
if
compute_entropy
:
for
layer
,
attn
in
enumerate
(
all_attentions
):
masked_entropy
=
entropy
(
attn
.
detach
())
*
input_mask
.
float
().
unsqueeze
(
1
)
masked_entropy
=
entropy
(
attn
.
detach
())
*
input
s
[
"attention
_mask
"
]
.
float
().
unsqueeze
(
1
)
attn_entropy
[
layer
]
+=
masked_entropy
.
sum
(
-
1
).
sum
(
0
).
detach
()
if
compute_importance
:
...
...
@@ -101,12 +106,12 @@ def compute_heads_importance(
# Also store our logits/labels if we want to compute metrics afterwards
if
preds
is
None
:
preds
=
logits
.
detach
().
cpu
().
numpy
()
labels
=
label
_ids
.
detach
().
cpu
().
numpy
()
labels
=
inputs
[
"
label
s"
]
.
detach
().
cpu
().
numpy
()
else
:
preds
=
np
.
append
(
preds
,
logits
.
detach
().
cpu
().
numpy
(),
axis
=
0
)
labels
=
np
.
append
(
labels
,
label
_ids
.
detach
().
cpu
().
numpy
(),
axis
=
0
)
labels
=
np
.
append
(
labels
,
inputs
[
"
label
s"
]
.
detach
().
cpu
().
numpy
(),
axis
=
0
)
tot_tokens
+=
input_mask
.
float
().
detach
().
sum
().
data
tot_tokens
+=
input
s
[
"attention
_mask
"
]
.
float
().
detach
().
sum
().
data
# Normalize
attn_entropy
/=
tot_tokens
...
...
@@ -145,7 +150,7 @@ def mask_heads(args, model, eval_dataloader):
"""
_
,
head_importance
,
preds
,
labels
=
compute_heads_importance
(
args
,
model
,
eval_dataloader
,
compute_entropy
=
False
)
preds
=
np
.
argmax
(
preds
,
axis
=
1
)
if
args
.
output_mode
==
"classification"
else
np
.
squeeze
(
preds
)
original_score
=
compute_metrics
(
args
.
task_name
,
preds
,
labels
)[
args
.
metric_name
]
original_score
=
glue_
compute_metrics
(
args
.
task_name
,
preds
,
labels
)[
args
.
metric_name
]
logger
.
info
(
"Pruning: original score: %f, threshold: %f"
,
original_score
,
original_score
*
args
.
masking_threshold
)
new_head_mask
=
torch
.
ones_like
(
head_importance
)
...
...
@@ -174,7 +179,7 @@ def mask_heads(args, model, eval_dataloader):
args
,
model
,
eval_dataloader
,
compute_entropy
=
False
,
head_mask
=
new_head_mask
)
preds
=
np
.
argmax
(
preds
,
axis
=
1
)
if
args
.
output_mode
==
"classification"
else
np
.
squeeze
(
preds
)
current_score
=
compute_metrics
(
args
.
task_name
,
preds
,
labels
)[
args
.
metric_name
]
current_score
=
glue_
compute_metrics
(
args
.
task_name
,
preds
,
labels
)[
args
.
metric_name
]
logger
.
info
(
"Masking: current score: %f, remaning heads %d (%.1f percents)"
,
current_score
,
...
...
@@ -200,7 +205,7 @@ def prune_heads(args, model, eval_dataloader, head_mask):
args
,
model
,
eval_dataloader
,
compute_entropy
=
False
,
compute_importance
=
False
,
head_mask
=
head_mask
)
preds
=
np
.
argmax
(
preds
,
axis
=
1
)
if
args
.
output_mode
==
"classification"
else
np
.
squeeze
(
preds
)
score_masking
=
compute_metrics
(
args
.
task_name
,
preds
,
labels
)[
args
.
metric_name
]
score_masking
=
glue_
compute_metrics
(
args
.
task_name
,
preds
,
labels
)[
args
.
metric_name
]
original_time
=
datetime
.
now
()
-
before_time
original_num_params
=
sum
(
p
.
numel
()
for
p
in
model
.
parameters
())
...
...
@@ -214,7 +219,7 @@ def prune_heads(args, model, eval_dataloader, head_mask):
args
,
model
,
eval_dataloader
,
compute_entropy
=
False
,
compute_importance
=
False
,
head_mask
=
None
)
preds
=
np
.
argmax
(
preds
,
axis
=
1
)
if
args
.
output_mode
==
"classification"
else
np
.
squeeze
(
preds
)
score_pruning
=
compute_metrics
(
args
.
task_name
,
preds
,
labels
)[
args
.
metric_name
]
score_pruning
=
glue_
compute_metrics
(
args
.
task_name
,
preds
,
labels
)[
args
.
metric_name
]
new_time
=
datetime
.
now
()
-
before_time
logger
.
info
(
...
...
@@ -242,14 +247,14 @@ def main():
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to pre
-
trained model or
shortcut name selected in the list: "
+
", "
.
join
(
ALL_MODELS
)
,
help
=
"Path to pretrained model or
model identifier from huggingface.co/models"
,
)
parser
.
add_argument
(
"--task_name"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The name of the task to train selected in the list: "
+
", "
.
join
(
processors
.
keys
()),
help
=
"The name of the task to train selected in the list: "
+
", "
.
join
(
glue_
processors
.
keys
()),
)
parser
.
add_argument
(
"--output_dir"
,
...
...
@@ -274,7 +279,7 @@ def main():
)
parser
.
add_argument
(
"--cache_dir"
,
default
=
""
,
default
=
None
,
type
=
str
,
help
=
"Where do you want to store the pre-trained models downloaded from s3"
,
)
...
...
@@ -350,48 +355,40 @@ def main():
logger
.
info
(
"device: {} n_gpu: {}, distributed: {}"
.
format
(
args
.
device
,
args
.
n_gpu
,
bool
(
args
.
local_rank
!=
-
1
)))
# Set seeds
set_seed
(
args
)
set_seed
(
args
.
seed
)
# Prepare GLUE task
args
.
task_name
=
args
.
task_name
.
lower
()
if
args
.
task_name
not
in
processors
:
if
args
.
task_name
not
in
glue_
processors
:
raise
ValueError
(
"Task not found: %s"
%
(
args
.
task_name
))
processor
=
processors
[
args
.
task_name
]()
args
.
output_mode
=
output_modes
[
args
.
task_name
]
processor
=
glue_
processors
[
args
.
task_name
]()
args
.
output_mode
=
glue_
output_modes
[
args
.
task_name
]
label_list
=
processor
.
get_labels
()
num_labels
=
len
(
label_list
)
# Load pretrained model and tokenizer
if
args
.
local_rank
not
in
[
-
1
,
0
]:
torch
.
distributed
.
barrier
()
# Make sure only the first process in distributed training will download model & vocab
#
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
args
.
model_type
=
""
for
key
in
MODEL_CLASSES
:
if
key
in
args
.
model_name_or_path
.
lower
():
args
.
model_type
=
key
# take the first match in model types
break
config_class
,
model_class
,
tokenizer_class
=
MODEL_CLASSES
[
args
.
model_type
]
config
=
config_class
.
from_pretrained
(
config
=
AutoConfig
.
from_pretrained
(
args
.
config_name
if
args
.
config_name
else
args
.
model_name_or_path
,
num_labels
=
num_labels
,
finetuning_task
=
args
.
task_name
,
output_attentions
=
True
,
cache_dir
=
args
.
cache_dir
if
args
.
cache_dir
else
None
,
cache_dir
=
args
.
cache_dir
,
)
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
tokenizer_name
if
args
.
tokenizer_name
else
args
.
model_name_or_path
,
cache_dir
=
args
.
cache_dir
if
args
.
cache_dir
else
None
,
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
tokenizer_name
if
args
.
tokenizer_name
else
args
.
model_name_or_path
,
cache_dir
=
args
.
cache_dir
,
)
model
=
model_class
.
from_pretrained
(
model
=
AutoModelForSequenceClassification
.
from_pretrained
(
args
.
model_name_or_path
,
from_tf
=
bool
(
".ckpt"
in
args
.
model_name_or_path
),
config
=
config
,
cache_dir
=
args
.
cache_dir
if
args
.
cache_dir
else
None
,
cache_dir
=
args
.
cache_dir
,
)
if
args
.
local_rank
==
0
:
torch
.
distributed
.
barrier
()
# Make sure only the first process in distributed training will download model & vocab
# Distributed and parallel training
model
.
to
(
args
.
device
)
if
args
.
local_rank
!=
-
1
:
...
...
@@ -402,15 +399,18 @@ def main():
model
=
torch
.
nn
.
DataParallel
(
model
)
# Print/save training arguments
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
torch
.
save
(
args
,
os
.
path
.
join
(
args
.
output_dir
,
"run_args.bin"
))
logger
.
info
(
"Training/evaluation parameters %s"
,
args
)
# Prepare dataset for the GLUE task
eval_data
=
load_and_cache_examples
(
args
,
args
.
task_name
,
tokenizer
,
evaluate
=
True
)
eval_data
set
=
GlueDataset
(
args
,
tokenizer
=
tokenizer
,
evaluate
=
True
,
local_rank
=
args
.
local_rank
)
if
args
.
data_subset
>
0
:
eval_data
=
Subset
(
eval_data
,
list
(
range
(
min
(
args
.
data_subset
,
len
(
eval_data
)))))
eval_sampler
=
SequentialSampler
(
eval_data
)
if
args
.
local_rank
==
-
1
else
DistributedSampler
(
eval_data
)
eval_dataloader
=
DataLoader
(
eval_data
,
sampler
=
eval_sampler
,
batch_size
=
args
.
batch_size
)
eval_dataset
=
Subset
(
eval_dataset
,
list
(
range
(
min
(
args
.
data_subset
,
len
(
eval_dataset
)))))
eval_sampler
=
SequentialSampler
(
eval_dataset
)
if
args
.
local_rank
==
-
1
else
DistributedSampler
(
eval_dataset
)
eval_dataloader
=
DataLoader
(
eval_dataset
,
sampler
=
eval_sampler
,
batch_size
=
args
.
batch_size
,
collate_fn
=
DefaultDataCollator
().
collate_batch
)
# Compute head entropy and importance score
compute_heads_importance
(
args
,
model
,
eval_dataloader
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment