Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ebd45980
Commit
ebd45980
authored
Jan 08, 2020
by
Victor SANH
Committed by
Lysandre Debut
Jan 10, 2020
Browse files
Align with `run_squad` + fix some errors
parent
45634f87
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
122 additions
and
32 deletions
+122
-32
examples/distillation/run_squad_w_distillation.py
examples/distillation/run_squad_w_distillation.py
+122
-32
No files found.
examples/distillation/run_squad_w_distillation.py
View file @
ebd45980
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
""" This is the exact same script as `examples/run_squad.py` (as of 20
19, October 4
th) with an additional and optional step of distillation."""
""" This is the exact same script as `examples/run_squad.py` (as of 20
20, January 8
th) with an additional and optional step of distillation."""
import
argparse
import
argparse
import
glob
import
glob
...
@@ -60,6 +60,7 @@ try:
...
@@ -60,6 +60,7 @@ try:
except
ImportError
:
except
ImportError
:
from
tensorboardX
import
SummaryWriter
from
tensorboardX
import
SummaryWriter
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
ALL_MODELS
=
sum
(
ALL_MODELS
=
sum
(
...
@@ -114,11 +115,21 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
...
@@ -114,11 +115,21 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
scheduler
=
get_linear_schedule_with_warmup
(
scheduler
=
get_linear_schedule_with_warmup
(
optimizer
,
num_warmup_steps
=
args
.
warmup_steps
,
num_training_steps
=
t_total
optimizer
,
num_warmup_steps
=
args
.
warmup_steps
,
num_training_steps
=
t_total
)
)
if
args
.
fp16
:
# Check if saved optimizer or scheduler states exist
if
os
.
path
.
isfile
(
os
.
path
.
join
(
args
.
model_name_or_path
,
"optimizer.pt"
))
and
os
.
path
.
isfile
(
os
.
path
.
join
(
args
.
model_name_or_path
,
"scheduler.pt"
)
):
# Load in optimizer and scheduler states
optimizer
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
args
.
model_name_or_path
,
"optimizer.pt"
)))
scheduler
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
args
.
model_name_or_path
,
"scheduler.pt"
)))
if
args
.
fp16
:
try
:
try
:
from
apex
import
amp
from
apex
import
amp
except
ImportError
:
except
ImportError
:
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
)
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
)
model
,
optimizer
=
amp
.
initialize
(
model
,
optimizer
,
opt_level
=
args
.
fp16_opt_level
)
model
,
optimizer
=
amp
.
initialize
(
model
,
optimizer
,
opt_level
=
args
.
fp16_opt_level
)
# multi-gpu training (should be after apex fp16 initialization)
# multi-gpu training (should be after apex fp16 initialization)
...
@@ -145,18 +156,47 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
...
@@ -145,18 +156,47 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
logger
.
info
(
" Gradient Accumulation steps = %d"
,
args
.
gradient_accumulation_steps
)
logger
.
info
(
" Gradient Accumulation steps = %d"
,
args
.
gradient_accumulation_steps
)
logger
.
info
(
" Total optimization steps = %d"
,
t_total
)
logger
.
info
(
" Total optimization steps = %d"
,
t_total
)
global_step
=
0
global_step
=
1
epochs_trained
=
0
steps_trained_in_current_epoch
=
0
# Check if continuing training from a checkpoint
if
os
.
path
.
exists
(
args
.
model_name_or_path
):
try
:
# set global_step to gobal_step of last saved checkpoint from model path
checkpoint_suffix
=
args
.
model_name_or_path
.
split
(
"-"
)[
-
1
].
split
(
"/"
)[
0
]
global_step
=
int
(
checkpoint_suffix
)
epochs_trained
=
global_step
//
(
len
(
train_dataloader
)
//
args
.
gradient_accumulation_steps
)
steps_trained_in_current_epoch
=
global_step
%
(
len
(
train_dataloader
)
//
args
.
gradient_accumulation_steps
)
logger
.
info
(
" Continuing training from checkpoint, will skip to saved global_step"
)
logger
.
info
(
" Continuing training from epoch %d"
,
epochs_trained
)
logger
.
info
(
" Continuing training from global step %d"
,
global_step
)
logger
.
info
(
" Will skip the first %d steps in the first epoch"
,
steps_trained_in_current_epoch
)
except
ValueError
:
logger
.
info
(
" Starting fine-tuning."
)
tr_loss
,
logging_loss
=
0.0
,
0.0
tr_loss
,
logging_loss
=
0.0
,
0.0
model
.
zero_grad
()
model
.
zero_grad
()
train_iterator
=
trange
(
int
(
args
.
num_train_epochs
),
desc
=
"Epoch"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
])
train_iterator
=
trange
(
set_seed
(
args
)
# Added here for reproductibility
epochs_trained
,
int
(
args
.
num_train_epochs
),
desc
=
"Epoch"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
]
)
# Added here for reproductibility
set_seed
(
args
)
for
_
in
train_iterator
:
for
_
in
train_iterator
:
epoch_iterator
=
tqdm
(
train_dataloader
,
desc
=
"Iteration"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
])
epoch_iterator
=
tqdm
(
train_dataloader
,
desc
=
"Iteration"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
])
for
step
,
batch
in
enumerate
(
epoch_iterator
):
for
step
,
batch
in
enumerate
(
epoch_iterator
):
# Skip past any already trained steps if resuming training
if
steps_trained_in_current_epoch
>
0
:
steps_trained_in_current_epoch
-=
1
continue
model
.
train
()
model
.
train
()
if
teacher
is
not
None
:
if
teacher
is
not
None
:
teacher
.
eval
()
teacher
.
eval
()
batch
=
tuple
(
t
.
to
(
args
.
device
)
for
t
in
batch
)
batch
=
tuple
(
t
.
to
(
args
.
device
)
for
t
in
batch
)
inputs
=
{
inputs
=
{
"input_ids"
:
batch
[
0
],
"input_ids"
:
batch
[
0
],
"attention_mask"
:
batch
[
1
],
"attention_mask"
:
batch
[
1
],
...
@@ -167,6 +207,8 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
...
@@ -167,6 +207,8 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
inputs
[
"token_type_ids"
]
=
None
if
args
.
model_type
==
"xlm"
else
batch
[
2
]
inputs
[
"token_type_ids"
]
=
None
if
args
.
model_type
==
"xlm"
else
batch
[
2
]
if
args
.
model_type
in
[
"xlnet"
,
"xlm"
]:
if
args
.
model_type
in
[
"xlnet"
,
"xlm"
]:
inputs
.
update
({
"cls_index"
:
batch
[
5
],
"p_mask"
:
batch
[
6
]})
inputs
.
update
({
"cls_index"
:
batch
[
5
],
"p_mask"
:
batch
[
6
]})
if
args
.
version_2_with_negative
:
inputs
.
update
({
"is_impossible"
:
batch
[
7
]})
outputs
=
model
(
**
inputs
)
outputs
=
model
(
**
inputs
)
loss
,
start_logits_stu
,
end_logits_stu
=
outputs
loss
,
start_logits_stu
,
end_logits_stu
=
outputs
...
@@ -219,11 +261,10 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
...
@@ -219,11 +261,10 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
model
.
zero_grad
()
model
.
zero_grad
()
global_step
+=
1
global_step
+=
1
# Log metrics
if
args
.
local_rank
in
[
-
1
,
0
]
and
args
.
logging_steps
>
0
and
global_step
%
args
.
logging_steps
==
0
:
if
args
.
local_rank
in
[
-
1
,
0
]
and
args
.
logging_steps
>
0
and
global_step
%
args
.
logging_steps
==
0
:
# Log metrics
# Only evaluate when single GPU otherwise metrics may not average well
if
(
if
args
.
local_rank
==
-
1
and
args
.
evaluate_during_training
:
args
.
local_rank
==
-
1
and
args
.
evaluate_during_training
):
# Only evaluate when single GPU otherwise metrics may not average well
results
=
evaluate
(
args
,
model
,
tokenizer
)
results
=
evaluate
(
args
,
model
,
tokenizer
)
for
key
,
value
in
results
.
items
():
for
key
,
value
in
results
.
items
():
tb_writer
.
add_scalar
(
"eval_{}"
.
format
(
key
),
value
,
global_step
)
tb_writer
.
add_scalar
(
"eval_{}"
.
format
(
key
),
value
,
global_step
)
...
@@ -240,9 +281,15 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
...
@@ -240,9 +281,15 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
model
.
module
if
hasattr
(
model
,
"module"
)
else
model
model
.
module
if
hasattr
(
model
,
"module"
)
else
model
)
# Take care of distributed/parallel training
)
# Take care of distributed/parallel training
model_to_save
.
save_pretrained
(
output_dir
)
model_to_save
.
save_pretrained
(
output_dir
)
tokenizer
.
save_pretrained
(
output_dir
)
torch
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
torch
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
logger
.
info
(
"Saving model checkpoint to %s"
,
output_dir
)
logger
.
info
(
"Saving model checkpoint to %s"
,
output_dir
)
torch
.
save
(
optimizer
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
"optimizer.pt"
))
torch
.
save
(
scheduler
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
"scheduler.pt"
))
logger
.
info
(
"Saving optimizer and scheduler states to %s"
,
output_dir
)
if
args
.
max_steps
>
0
and
global_step
>
args
.
max_steps
:
if
args
.
max_steps
>
0
and
global_step
>
args
.
max_steps
:
epoch_iterator
.
close
()
epoch_iterator
.
close
()
break
break
...
@@ -263,18 +310,27 @@ def evaluate(args, model, tokenizer, prefix=""):
...
@@ -263,18 +310,27 @@ def evaluate(args, model, tokenizer, prefix=""):
os
.
makedirs
(
args
.
output_dir
)
os
.
makedirs
(
args
.
output_dir
)
args
.
eval_batch_size
=
args
.
per_gpu_eval_batch_size
*
max
(
1
,
args
.
n_gpu
)
args
.
eval_batch_size
=
args
.
per_gpu_eval_batch_size
*
max
(
1
,
args
.
n_gpu
)
# Note that DistributedSampler samples randomly
# Note that DistributedSampler samples randomly
eval_sampler
=
SequentialSampler
(
dataset
)
eval_sampler
=
SequentialSampler
(
dataset
)
eval_dataloader
=
DataLoader
(
dataset
,
sampler
=
eval_sampler
,
batch_size
=
args
.
eval_batch_size
)
eval_dataloader
=
DataLoader
(
dataset
,
sampler
=
eval_sampler
,
batch_size
=
args
.
eval_batch_size
)
# multi-gpu evaluate
if
args
.
n_gpu
>
1
and
not
isinstance
(
model
,
torch
.
nn
.
DataParallel
):
model
=
torch
.
nn
.
DataParallel
(
model
)
# Eval!
# Eval!
logger
.
info
(
"***** Running evaluation {} *****"
.
format
(
prefix
))
logger
.
info
(
"***** Running evaluation {} *****"
.
format
(
prefix
))
logger
.
info
(
" Num examples = %d"
,
len
(
dataset
))
logger
.
info
(
" Num examples = %d"
,
len
(
dataset
))
logger
.
info
(
" Batch size = %d"
,
args
.
eval_batch_size
)
logger
.
info
(
" Batch size = %d"
,
args
.
eval_batch_size
)
all_results
=
[]
all_results
=
[]
start_time
=
timeit
.
default_timer
()
for
batch
in
tqdm
(
eval_dataloader
,
desc
=
"Evaluating"
):
for
batch
in
tqdm
(
eval_dataloader
,
desc
=
"Evaluating"
):
model
.
eval
()
model
.
eval
()
batch
=
tuple
(
t
.
to
(
args
.
device
)
for
t
in
batch
)
batch
=
tuple
(
t
.
to
(
args
.
device
)
for
t
in
batch
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
inputs
=
{
"input_ids"
:
batch
[
0
],
"attention_mask"
:
batch
[
1
]}
inputs
=
{
"input_ids"
:
batch
[
0
],
"attention_mask"
:
batch
[
1
]}
if
args
.
model_type
!=
"distilbert"
:
if
args
.
model_type
!=
"distilbert"
:
...
@@ -282,6 +338,7 @@ def evaluate(args, model, tokenizer, prefix=""):
...
@@ -282,6 +338,7 @@ def evaluate(args, model, tokenizer, prefix=""):
example_indices
=
batch
[
3
]
example_indices
=
batch
[
3
]
if
args
.
model_type
in
[
"xlnet"
,
"xlm"
]:
if
args
.
model_type
in
[
"xlnet"
,
"xlm"
]:
inputs
.
update
({
"cls_index"
:
batch
[
4
],
"p_mask"
:
batch
[
5
]})
inputs
.
update
({
"cls_index"
:
batch
[
4
],
"p_mask"
:
batch
[
5
]})
outputs
=
model
(
**
inputs
)
outputs
=
model
(
**
inputs
)
for
i
,
example_index
in
enumerate
(
example_indices
):
for
i
,
example_index
in
enumerate
(
example_indices
):
...
@@ -314,9 +371,13 @@ def evaluate(args, model, tokenizer, prefix=""):
...
@@ -314,9 +371,13 @@ def evaluate(args, model, tokenizer, prefix=""):
all_results
.
append
(
result
)
all_results
.
append
(
result
)
evalTime
=
timeit
.
default_timer
()
-
start_time
logger
.
info
(
" Evaluation done in total %f secs (%f sec per example)"
,
evalTime
,
evalTime
/
len
(
dataset
))
# Compute predictions
# Compute predictions
output_prediction_file
=
os
.
path
.
join
(
args
.
output_dir
,
"predictions_{}.json"
.
format
(
prefix
))
output_prediction_file
=
os
.
path
.
join
(
args
.
output_dir
,
"predictions_{}.json"
.
format
(
prefix
))
output_nbest_file
=
os
.
path
.
join
(
args
.
output_dir
,
"nbest_predictions_{}.json"
.
format
(
prefix
))
output_nbest_file
=
os
.
path
.
join
(
args
.
output_dir
,
"nbest_predictions_{}.json"
.
format
(
prefix
))
if
args
.
version_2_with_negative
:
if
args
.
version_2_with_negative
:
output_null_log_odds_file
=
os
.
path
.
join
(
args
.
output_dir
,
"null_odds_{}.json"
.
format
(
prefix
))
output_null_log_odds_file
=
os
.
path
.
join
(
args
.
output_dir
,
"null_odds_{}.json"
.
format
(
prefix
))
else
:
else
:
...
@@ -333,7 +394,6 @@ def evaluate(args, model, tokenizer, prefix=""):
...
@@ -333,7 +394,6 @@ def evaluate(args, model, tokenizer, prefix=""):
output_prediction_file
,
output_prediction_file
,
output_nbest_file
,
output_nbest_file
,
output_null_log_odds_file
,
output_null_log_odds_file
,
args
.
predict_file
,
model
.
config
.
start_n_top
,
model
.
config
.
start_n_top
,
model
.
config
.
end_n_top
,
model
.
config
.
end_n_top
,
args
.
version_2_with_negative
,
args
.
version_2_with_negative
,
...
@@ -364,7 +424,8 @@ def evaluate(args, model, tokenizer, prefix=""):
...
@@ -364,7 +424,8 @@ def evaluate(args, model, tokenizer, prefix=""):
def
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
False
,
output_examples
=
False
):
def
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
False
,
output_examples
=
False
):
if
args
.
local_rank
not
in
[
-
1
,
0
]
and
not
evaluate
:
if
args
.
local_rank
not
in
[
-
1
,
0
]
and
not
evaluate
:
torch
.
distributed
.
barrier
()
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
torch
.
distributed
.
barrier
()
# Load data features from cache or dataset file
# Load data features from cache or dataset file
input_file
=
args
.
predict_file
if
evaluate
else
args
.
train_file
input_file
=
args
.
predict_file
if
evaluate
else
args
.
train_file
...
@@ -395,9 +456,9 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
...
@@ -395,9 +456,9 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
logger
.
info
(
"Creating features from dataset file at %s"
,
input_file
)
logger
.
info
(
"Creating features from dataset file at %s"
,
input_file
)
processor
=
SquadV2Processor
()
if
args
.
version_2_with_negative
else
SquadV1Processor
()
processor
=
SquadV2Processor
()
if
args
.
version_2_with_negative
else
SquadV1Processor
()
if
evaluate
:
if
evaluate
:
examples
=
processor
.
get_dev_examples
(
None
,
filename
=
args
.
predict_file
)
examples
=
processor
.
get_dev_examples
(
args
.
data_dir
,
filename
=
args
.
predict_file
)
else
:
else
:
examples
=
processor
.
get_train_examples
(
None
,
filename
=
args
.
train_file
)
examples
=
processor
.
get_train_examples
(
args
.
data_dir
,
filename
=
args
.
train_file
)
features
,
dataset
=
squad_convert_examples_to_features
(
features
,
dataset
=
squad_convert_examples_to_features
(
examples
=
examples
,
examples
=
examples
,
...
@@ -407,13 +468,16 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
...
@@ -407,13 +468,16 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
max_query_length
=
args
.
max_query_length
,
max_query_length
=
args
.
max_query_length
,
is_training
=
not
evaluate
,
is_training
=
not
evaluate
,
return_dataset
=
"pt"
,
return_dataset
=
"pt"
,
threads
=
args
.
threads
,
)
)
if
args
.
local_rank
in
[
-
1
,
0
]:
if
args
.
local_rank
in
[
-
1
,
0
]:
logger
.
info
(
"Saving features into cached file %s"
,
cached_features_file
)
logger
.
info
(
"Saving features into cached file %s"
,
cached_features_file
)
torch
.
save
({
"features"
:
features
,
"dataset"
:
dataset
,
"examples"
:
examples
},
cached_features_file
)
torch
.
save
({
"features"
:
features
,
"dataset"
:
dataset
,
"examples"
:
examples
},
cached_features_file
)
if
args
.
local_rank
==
0
and
not
evaluate
:
if
args
.
local_rank
==
0
and
not
evaluate
:
torch
.
distributed
.
barrier
()
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
torch
.
distributed
.
barrier
()
if
output_examples
:
if
output_examples
:
return
dataset
,
examples
,
features
return
dataset
,
examples
,
features
...
@@ -424,16 +488,6 @@ def main():
...
@@ -424,16 +488,6 @@ def main():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
# Required parameters
# Required parameters
parser
.
add_argument
(
"--train_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"SQuAD json for training. E.g., train-v1.1.json"
)
parser
.
add_argument
(
"--predict_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json"
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--model_type"
,
"--model_type"
,
default
=
None
,
default
=
None
,
...
@@ -480,6 +534,27 @@ def main():
...
@@ -480,6 +534,27 @@ def main():
)
)
# Other parameters
# Other parameters
parser
.
add_argument
(
"--data_dir"
,
default
=
None
,
type
=
str
,
help
=
"The input data dir. Should contain the .json files for the task."
+
"If no data dir or train/predict files are specified, will run with tensorflow_datasets."
,
)
parser
.
add_argument
(
"--train_file"
,
default
=
None
,
type
=
str
,
help
=
"The input training file. If a data dir is specified, will look for the file there"
+
"If no data dir or train/predict files are specified, will run with tensorflow_datasets."
,
)
parser
.
add_argument
(
"--predict_file"
,
default
=
None
,
type
=
str
,
help
=
"The input evaluation file. If a data dir is specified, will look for the file there"
+
"If no data dir or train/predict files are specified, will run with tensorflow_datasets."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--config_name"
,
default
=
""
,
type
=
str
,
help
=
"Pretrained config name or path if not the same as model_name"
"--config_name"
,
default
=
""
,
type
=
str
,
help
=
"Pretrained config name or path if not the same as model_name"
)
)
...
@@ -548,7 +623,7 @@ def main():
...
@@ -548,7 +623,7 @@ def main():
default
=
1
,
default
=
1
,
help
=
"Number of updates steps to accumulate before performing a backward/update pass."
,
help
=
"Number of updates steps to accumulate before performing a backward/update pass."
,
)
)
parser
.
add_argument
(
"--weight_decay"
,
default
=
0.0
,
type
=
float
,
help
=
"Weight deay if we apply some."
)
parser
.
add_argument
(
"--weight_decay"
,
default
=
0.0
,
type
=
float
,
help
=
"Weight de
c
ay if we apply some."
)
parser
.
add_argument
(
"--adam_epsilon"
,
default
=
1e-8
,
type
=
float
,
help
=
"Epsilon for Adam optimizer."
)
parser
.
add_argument
(
"--adam_epsilon"
,
default
=
1e-8
,
type
=
float
,
help
=
"Epsilon for Adam optimizer."
)
parser
.
add_argument
(
"--max_grad_norm"
,
default
=
1.0
,
type
=
float
,
help
=
"Max gradient norm."
)
parser
.
add_argument
(
"--max_grad_norm"
,
default
=
1.0
,
type
=
float
,
help
=
"Max gradient norm."
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -612,6 +687,8 @@ def main():
...
@@ -612,6 +687,8 @@ def main():
)
)
parser
.
add_argument
(
"--server_ip"
,
type
=
str
,
default
=
""
,
help
=
"Can be used for distant debugging."
)
parser
.
add_argument
(
"--server_ip"
,
type
=
str
,
default
=
""
,
help
=
"Can be used for distant debugging."
)
parser
.
add_argument
(
"--server_port"
,
type
=
str
,
default
=
""
,
help
=
"Can be used for distant debugging."
)
parser
.
add_argument
(
"--server_port"
,
type
=
str
,
default
=
""
,
help
=
"Can be used for distant debugging."
)
parser
.
add_argument
(
"--threads"
,
type
=
int
,
default
=
1
,
help
=
"multiple threads for converting example to features"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
(
if
(
...
@@ -666,7 +743,8 @@ def main():
...
@@ -666,7 +743,8 @@ def main():
# Load pretrained model and tokenizer
# Load pretrained model and tokenizer
if
args
.
local_rank
not
in
[
-
1
,
0
]:
if
args
.
local_rank
not
in
[
-
1
,
0
]:
torch
.
distributed
.
barrier
()
# Make sure only the first process in distributed training will download model & vocab
# Make sure only the first process in distributed training will download model & vocab
torch
.
distributed
.
barrier
()
args
.
model_type
=
args
.
model_type
.
lower
()
args
.
model_type
=
args
.
model_type
.
lower
()
config_class
,
model_class
,
tokenizer_class
=
MODEL_CLASSES
[
args
.
model_type
]
config_class
,
model_class
,
tokenizer_class
=
MODEL_CLASSES
[
args
.
model_type
]
...
@@ -703,12 +781,24 @@ def main():
...
@@ -703,12 +781,24 @@ def main():
teacher
=
None
teacher
=
None
if
args
.
local_rank
==
0
:
if
args
.
local_rank
==
0
:
torch
.
distributed
.
barrier
()
# Make sure only the first process in distributed training will download model & vocab
# Make sure only the first process in distributed training will download model & vocab
torch
.
distributed
.
barrier
()
model
.
to
(
args
.
device
)
model
.
to
(
args
.
device
)
logger
.
info
(
"Training/evaluation parameters %s"
,
args
)
logger
.
info
(
"Training/evaluation parameters %s"
,
args
)
# Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
# Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
# remove the need for this code, but it is still valid.
if
args
.
fp16
:
try
:
import
apex
apex
.
amp
.
register_half_function
(
torch
,
"einsum"
)
except
ImportError
:
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
)
# Training
# Training
if
args
.
do_train
:
if
args
.
do_train
:
train_dataset
=
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
False
,
output_examples
=
False
)
train_dataset
=
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
False
,
output_examples
=
False
)
...
@@ -734,15 +824,15 @@ def main():
...
@@ -734,15 +824,15 @@ def main():
torch
.
save
(
args
,
os
.
path
.
join
(
args
.
output_dir
,
"training_args.bin"
))
torch
.
save
(
args
,
os
.
path
.
join
(
args
.
output_dir
,
"training_args.bin"
))
# Load a trained model and vocabulary that you have fine-tuned
# Load a trained model and vocabulary that you have fine-tuned
model
=
model_class
.
from_pretrained
(
args
.
output_dir
,
cache_dir
=
args
.
cache_dir
if
args
.
cache_dir
else
None
)
model
=
model_class
.
from_pretrained
(
args
.
output_dir
)
tokenizer
=
tokenizer_class
.
from_pretrained
(
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
output_dir
,
do_lower_case
=
args
.
do_lower_case
)
args
.
output_dir
,
do_lower_case
=
args
.
do_lower_case
,
cache_dir
=
args
.
cache_dir
if
args
.
cache_dir
else
None
)
model
.
to
(
args
.
device
)
model
.
to
(
args
.
device
)
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
results
=
{}
results
=
{}
if
args
.
do_eval
and
args
.
local_rank
in
[
-
1
,
0
]:
if
args
.
do_eval
and
args
.
local_rank
in
[
-
1
,
0
]:
if
args
.
do_train
:
logger
.
info
(
"Loading checkpoints saved during training for evaluation"
)
checkpoints
=
[
args
.
output_dir
]
checkpoints
=
[
args
.
output_dir
]
if
args
.
eval_all_checkpoints
:
if
args
.
eval_all_checkpoints
:
checkpoints
=
list
(
checkpoints
=
list
(
...
@@ -755,7 +845,7 @@ def main():
...
@@ -755,7 +845,7 @@ def main():
for
checkpoint
in
checkpoints
:
for
checkpoint
in
checkpoints
:
# Reload the model
# Reload the model
global_step
=
checkpoint
.
split
(
"-"
)[
-
1
]
if
len
(
checkpoints
)
>
1
else
""
global_step
=
checkpoint
.
split
(
"-"
)[
-
1
]
if
len
(
checkpoints
)
>
1
else
""
model
=
model_class
.
from_pretrained
(
checkpoint
,
cache_dir
=
args
.
cache_dir
if
args
.
cache_dir
else
None
)
model
=
model_class
.
from_pretrained
(
checkpoint
)
model
.
to
(
args
.
device
)
model
.
to
(
args
.
device
)
# Evaluate
# Evaluate
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment