Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3cdb38a7
Commit
3cdb38a7
authored
Jan 08, 2020
by
Victor SANH
Committed by
Lysandre Debut
Jan 10, 2020
Browse files
indents
parent
ebd45980
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
20 deletions
+20
-20
examples/distillation/run_squad_w_distillation.py
examples/distillation/run_squad_w_distillation.py
+20
-20
No files found.
examples/distillation/run_squad_w_distillation.py
View file @
3cdb38a7
...
@@ -123,8 +123,8 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
...
@@ -123,8 +123,8 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
# Load in optimizer and scheduler states
# Load in optimizer and scheduler states
optimizer
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
args
.
model_name_or_path
,
"optimizer.pt"
)))
optimizer
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
args
.
model_name_or_path
,
"optimizer.pt"
)))
scheduler
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
args
.
model_name_or_path
,
"scheduler.pt"
)))
scheduler
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
args
.
model_name_or_path
,
"scheduler.pt"
)))
if
args
.
fp16
:
if
args
.
fp16
:
try
:
try
:
from
apex
import
amp
from
apex
import
amp
except
ImportError
:
except
ImportError
:
...
@@ -157,7 +157,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
...
@@ -157,7 +157,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
logger
.
info
(
" Total optimization steps = %d"
,
t_total
)
logger
.
info
(
" Total optimization steps = %d"
,
t_total
)
global_step
=
1
global_step
=
1
epochs_trained
=
0
epochs_trained
=
0
steps_trained_in_current_epoch
=
0
steps_trained_in_current_epoch
=
0
# Check if continuing training from a checkpoint
# Check if continuing training from a checkpoint
if
os
.
path
.
exists
(
args
.
model_name_or_path
):
if
os
.
path
.
exists
(
args
.
model_name_or_path
):
...
@@ -178,10 +178,10 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
...
@@ -178,10 +178,10 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
tr_loss
,
logging_loss
=
0.0
,
0.0
tr_loss
,
logging_loss
=
0.0
,
0.0
model
.
zero_grad
()
model
.
zero_grad
()
train_iterator
=
trange
(
train_iterator
=
trange
(
epochs_trained
,
int
(
args
.
num_train_epochs
),
desc
=
"Epoch"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
]
epochs_trained
,
int
(
args
.
num_train_epochs
),
desc
=
"Epoch"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
]
)
)
# Added here for reproductibility
# Added here for reproductibility
set_seed
(
args
)
set_seed
(
args
)
for
_
in
train_iterator
:
for
_
in
train_iterator
:
epoch_iterator
=
tqdm
(
train_dataloader
,
desc
=
"Iteration"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
])
epoch_iterator
=
tqdm
(
train_dataloader
,
desc
=
"Iteration"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
])
...
@@ -207,7 +207,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
...
@@ -207,7 +207,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
inputs
[
"token_type_ids"
]
=
None
if
args
.
model_type
==
"xlm"
else
batch
[
2
]
inputs
[
"token_type_ids"
]
=
None
if
args
.
model_type
==
"xlm"
else
batch
[
2
]
if
args
.
model_type
in
[
"xlnet"
,
"xlm"
]:
if
args
.
model_type
in
[
"xlnet"
,
"xlm"
]:
inputs
.
update
({
"cls_index"
:
batch
[
5
],
"p_mask"
:
batch
[
6
]})
inputs
.
update
({
"cls_index"
:
batch
[
5
],
"p_mask"
:
batch
[
6
]})
if
args
.
version_2_with_negative
:
if
args
.
version_2_with_negative
:
inputs
.
update
({
"is_impossible"
:
batch
[
7
]})
inputs
.
update
({
"is_impossible"
:
batch
[
7
]})
outputs
=
model
(
**
inputs
)
outputs
=
model
(
**
inputs
)
loss
,
start_logits_stu
,
end_logits_stu
=
outputs
loss
,
start_logits_stu
,
end_logits_stu
=
outputs
...
@@ -261,7 +261,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
...
@@ -261,7 +261,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
model
.
zero_grad
()
model
.
zero_grad
()
global_step
+=
1
global_step
+=
1
# Log metrics
# Log metrics
if
args
.
local_rank
in
[
-
1
,
0
]
and
args
.
logging_steps
>
0
and
global_step
%
args
.
logging_steps
==
0
:
if
args
.
local_rank
in
[
-
1
,
0
]
and
args
.
logging_steps
>
0
and
global_step
%
args
.
logging_steps
==
0
:
# Only evaluate when single GPU otherwise metrics may not average well
# Only evaluate when single GPU otherwise metrics may not average well
if
args
.
local_rank
==
-
1
and
args
.
evaluate_during_training
:
if
args
.
local_rank
==
-
1
and
args
.
evaluate_during_training
:
...
@@ -281,7 +281,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
...
@@ -281,7 +281,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
model
.
module
if
hasattr
(
model
,
"module"
)
else
model
model
.
module
if
hasattr
(
model
,
"module"
)
else
model
)
# Take care of distributed/parallel training
)
# Take care of distributed/parallel training
model_to_save
.
save_pretrained
(
output_dir
)
model_to_save
.
save_pretrained
(
output_dir
)
tokenizer
.
save_pretrained
(
output_dir
)
tokenizer
.
save_pretrained
(
output_dir
)
torch
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
torch
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
logger
.
info
(
"Saving model checkpoint to %s"
,
output_dir
)
logger
.
info
(
"Saving model checkpoint to %s"
,
output_dir
)
...
@@ -325,7 +325,7 @@ def evaluate(args, model, tokenizer, prefix=""):
...
@@ -325,7 +325,7 @@ def evaluate(args, model, tokenizer, prefix=""):
logger
.
info
(
" Batch size = %d"
,
args
.
eval_batch_size
)
logger
.
info
(
" Batch size = %d"
,
args
.
eval_batch_size
)
all_results
=
[]
all_results
=
[]
start_time
=
timeit
.
default_timer
()
start_time
=
timeit
.
default_timer
()
for
batch
in
tqdm
(
eval_dataloader
,
desc
=
"Evaluating"
):
for
batch
in
tqdm
(
eval_dataloader
,
desc
=
"Evaluating"
):
model
.
eval
()
model
.
eval
()
...
@@ -425,7 +425,7 @@ def evaluate(args, model, tokenizer, prefix=""):
...
@@ -425,7 +425,7 @@ def evaluate(args, model, tokenizer, prefix=""):
def
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
False
,
output_examples
=
False
):
def
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
False
,
output_examples
=
False
):
if
args
.
local_rank
not
in
[
-
1
,
0
]
and
not
evaluate
:
if
args
.
local_rank
not
in
[
-
1
,
0
]
and
not
evaluate
:
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
# Load data features from cache or dataset file
# Load data features from cache or dataset file
input_file
=
args
.
predict_file
if
evaluate
else
args
.
train_file
input_file
=
args
.
predict_file
if
evaluate
else
args
.
train_file
...
@@ -468,7 +468,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
...
@@ -468,7 +468,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
max_query_length
=
args
.
max_query_length
,
max_query_length
=
args
.
max_query_length
,
is_training
=
not
evaluate
,
is_training
=
not
evaluate
,
return_dataset
=
"pt"
,
return_dataset
=
"pt"
,
threads
=
args
.
threads
,
threads
=
args
.
threads
,
)
)
if
args
.
local_rank
in
[
-
1
,
0
]:
if
args
.
local_rank
in
[
-
1
,
0
]:
...
@@ -476,7 +476,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
...
@@ -476,7 +476,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
torch
.
save
({
"features"
:
features
,
"dataset"
:
dataset
,
"examples"
:
examples
},
cached_features_file
)
torch
.
save
({
"features"
:
features
,
"dataset"
:
dataset
,
"examples"
:
examples
},
cached_features_file
)
if
args
.
local_rank
==
0
and
not
evaluate
:
if
args
.
local_rank
==
0
and
not
evaluate
:
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
if
output_examples
:
if
output_examples
:
...
@@ -541,11 +541,11 @@ def main():
...
@@ -541,11 +541,11 @@ def main():
help
=
"The input data dir. Should contain the .json files for the task."
help
=
"The input data dir. Should contain the .json files for the task."
+
"If no data dir or train/predict files are specified, will run with tensorflow_datasets."
,
+
"If no data dir or train/predict files are specified, will run with tensorflow_datasets."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--train_file"
,
"--train_file"
,
default
=
None
,
default
=
None
,
type
=
str
,
type
=
str
,
help
=
"The input training file. If a data dir is specified, will look for the file there"
help
=
"The input training file. If a data dir is specified, will look for the file there"
+
"If no data dir or train/predict files are specified, will run with tensorflow_datasets."
,
+
"If no data dir or train/predict files are specified, will run with tensorflow_datasets."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -688,7 +688,7 @@ def main():
...
@@ -688,7 +688,7 @@ def main():
parser
.
add_argument
(
"--server_ip"
,
type
=
str
,
default
=
""
,
help
=
"Can be used for distant debugging."
)
parser
.
add_argument
(
"--server_ip"
,
type
=
str
,
default
=
""
,
help
=
"Can be used for distant debugging."
)
parser
.
add_argument
(
"--server_port"
,
type
=
str
,
default
=
""
,
help
=
"Can be used for distant debugging."
)
parser
.
add_argument
(
"--server_port"
,
type
=
str
,
default
=
""
,
help
=
"Can be used for distant debugging."
)
parser
.
add_argument
(
"--threads"
,
type
=
int
,
default
=
1
,
help
=
"multiple threads for converting example to features"
)
parser
.
add_argument
(
"--threads"
,
type
=
int
,
default
=
1
,
help
=
"multiple threads for converting example to features"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
(
if
(
...
@@ -743,7 +743,7 @@ def main():
...
@@ -743,7 +743,7 @@ def main():
# Load pretrained model and tokenizer
# Load pretrained model and tokenizer
if
args
.
local_rank
not
in
[
-
1
,
0
]:
if
args
.
local_rank
not
in
[
-
1
,
0
]:
# Make sure only the first process in distributed training will download model & vocab
# Make sure only the first process in distributed training will download model & vocab
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
args
.
model_type
=
args
.
model_type
.
lower
()
args
.
model_type
=
args
.
model_type
.
lower
()
...
@@ -781,7 +781,7 @@ def main():
...
@@ -781,7 +781,7 @@ def main():
teacher
=
None
teacher
=
None
if
args
.
local_rank
==
0
:
if
args
.
local_rank
==
0
:
# Make sure only the first process in distributed training will download model & vocab
# Make sure only the first process in distributed training will download model & vocab
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
model
.
to
(
args
.
device
)
model
.
to
(
args
.
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment