Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
56d4ba8d
Commit
56d4ba8d
authored
Jan 17, 2020
by
Julien Chaumond
Browse files
[run_lm_finetuning] Train from scratch
parent
c7f79815
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
93 additions
and
56 deletions
+93
-56
examples/run_lm_finetuning.py
examples/run_lm_finetuning.py
+93
-56
No files found.
examples/run_lm_finetuning.py
View file @
56d4ba8d
...
@@ -28,7 +28,7 @@ import pickle
...
@@ -28,7 +28,7 @@ import pickle
import
random
import
random
import
re
import
re
import
shutil
import
shutil
from
typing
import
Tuple
from
typing
import
Dict
,
List
,
Tuple
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -54,6 +54,7 @@ from transformers import (
...
@@ -54,6 +54,7 @@ from transformers import (
OpenAIGPTConfig
,
OpenAIGPTConfig
,
OpenAIGPTLMHeadModel
,
OpenAIGPTLMHeadModel
,
OpenAIGPTTokenizer
,
OpenAIGPTTokenizer
,
PreTrainedModel
,
PreTrainedTokenizer
,
PreTrainedTokenizer
,
RobertaConfig
,
RobertaConfig
,
RobertaForMaskedLM
,
RobertaForMaskedLM
,
...
@@ -82,11 +83,11 @@ MODEL_CLASSES = {
...
@@ -82,11 +83,11 @@ MODEL_CLASSES = {
class
TextDataset
(
Dataset
):
class
TextDataset
(
Dataset
):
def
__init__
(
self
,
tokenizer
,
args
,
file_path
=
"train"
,
block_size
=
512
):
def
__init__
(
self
,
tokenizer
:
PreTrainedTokenizer
,
args
,
file_path
=
"train"
,
block_size
=
512
):
assert
os
.
path
.
isfile
(
file_path
)
assert
os
.
path
.
isfile
(
file_path
)
directory
,
filename
=
os
.
path
.
split
(
file_path
)
directory
,
filename
=
os
.
path
.
split
(
file_path
)
cached_features_file
=
os
.
path
.
join
(
cached_features_file
=
os
.
path
.
join
(
directory
,
args
.
model_
name_or_path
+
"_cached_lm_"
+
str
(
block_size
)
+
"_"
+
filename
directory
,
args
.
model_
type
+
"_cached_lm_"
+
str
(
block_size
)
+
"_"
+
filename
)
)
if
os
.
path
.
exists
(
cached_features_file
)
and
not
args
.
overwrite_cache
:
if
os
.
path
.
exists
(
cached_features_file
)
and
not
args
.
overwrite_cache
:
...
@@ -120,13 +121,12 @@ class TextDataset(Dataset):
...
@@ -120,13 +121,12 @@ class TextDataset(Dataset):
def
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
False
):
def
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
False
):
dataset
=
TextDataset
(
return
TextDataset
(
tokenizer
,
tokenizer
,
args
,
args
,
file_path
=
args
.
eval_data_file
if
evaluate
else
args
.
train_data_file
,
file_path
=
args
.
eval_data_file
if
evaluate
else
args
.
train_data_file
,
block_size
=
args
.
block_size
,
block_size
=
args
.
block_size
,
)
)
return
dataset
def
set_seed
(
args
):
def
set_seed
(
args
):
...
@@ -137,18 +137,11 @@ def set_seed(args):
...
@@ -137,18 +137,11 @@ def set_seed(args):
torch
.
cuda
.
manual_seed_all
(
args
.
seed
)
torch
.
cuda
.
manual_seed_all
(
args
.
seed
)
def
_rotate_checkpoints
(
args
,
checkpoint_prefix
,
use_mtime
=
False
):
def
_sorted_checkpoints
(
args
,
checkpoint_prefix
=
"checkpoint"
,
use_mtime
=
False
)
->
List
[
str
]:
if
not
args
.
save_total_limit
:
ordering_and_checkpoint_path
=
[]
return
if
args
.
save_total_limit
<=
0
:
return
# Check if we should delete older checkpoint(s)
glob_checkpoints
=
glob
.
glob
(
os
.
path
.
join
(
args
.
output_dir
,
"{}-*"
.
format
(
checkpoint_prefix
)))
glob_checkpoints
=
glob
.
glob
(
os
.
path
.
join
(
args
.
output_dir
,
"{}-*"
.
format
(
checkpoint_prefix
)))
if
len
(
glob_checkpoints
)
<=
args
.
save_total_limit
:
return
ordering_and_checkpoint_path
=
[]
for
path
in
glob_checkpoints
:
for
path
in
glob_checkpoints
:
if
use_mtime
:
if
use_mtime
:
ordering_and_checkpoint_path
.
append
((
os
.
path
.
getmtime
(
path
),
path
))
ordering_and_checkpoint_path
.
append
((
os
.
path
.
getmtime
(
path
),
path
))
...
@@ -159,6 +152,20 @@ def _rotate_checkpoints(args, checkpoint_prefix, use_mtime=False):
...
@@ -159,6 +152,20 @@ def _rotate_checkpoints(args, checkpoint_prefix, use_mtime=False):
checkpoints_sorted
=
sorted
(
ordering_and_checkpoint_path
)
checkpoints_sorted
=
sorted
(
ordering_and_checkpoint_path
)
checkpoints_sorted
=
[
checkpoint
[
1
]
for
checkpoint
in
checkpoints_sorted
]
checkpoints_sorted
=
[
checkpoint
[
1
]
for
checkpoint
in
checkpoints_sorted
]
return
checkpoints_sorted
def
_rotate_checkpoints
(
args
,
checkpoint_prefix
=
"checkpoint"
,
use_mtime
=
False
)
->
None
:
if
not
args
.
save_total_limit
:
return
if
args
.
save_total_limit
<=
0
:
return
# Check if we should delete older checkpoint(s)
checkpoints_sorted
=
_sorted_checkpoints
(
args
,
checkpoint_prefix
,
use_mtime
)
if
len
(
checkpoints_sorted
)
<=
args
.
save_total_limit
:
return
number_of_checkpoints_to_delete
=
max
(
0
,
len
(
checkpoints_sorted
)
-
args
.
save_total_limit
)
number_of_checkpoints_to_delete
=
max
(
0
,
len
(
checkpoints_sorted
)
-
args
.
save_total_limit
)
checkpoints_to_be_deleted
=
checkpoints_sorted
[:
number_of_checkpoints_to_delete
]
checkpoints_to_be_deleted
=
checkpoints_sorted
[:
number_of_checkpoints_to_delete
]
for
checkpoint
in
checkpoints_to_be_deleted
:
for
checkpoint
in
checkpoints_to_be_deleted
:
...
@@ -191,7 +198,7 @@ def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, args) -> T
...
@@ -191,7 +198,7 @@ def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, args) -> T
return
inputs
,
labels
return
inputs
,
labels
def
train
(
args
,
train_dataset
,
model
,
tokenizer
)
:
def
train
(
args
,
train_dataset
,
model
:
PreTrainedModel
,
tokenizer
:
PreTrainedTokenizer
)
->
Tuple
[
int
,
float
]
:
""" Train the model """
""" Train the model """
if
args
.
local_rank
in
[
-
1
,
0
]:
if
args
.
local_rank
in
[
-
1
,
0
]:
tb_writer
=
SummaryWriter
()
tb_writer
=
SummaryWriter
()
...
@@ -221,7 +228,7 @@ def train(args, train_dataset, model, tokenizer):
...
@@ -221,7 +228,7 @@ def train(args, train_dataset, model, tokenizer):
)
)
# Check if saved optimizer or scheduler states exist
# Check if saved optimizer or scheduler states exist
if
os
.
path
.
isfile
(
os
.
path
.
join
(
args
.
model_name_or_path
,
"optimizer.pt"
))
and
os
.
path
.
isfile
(
if
args
.
model_name_or_path
and
os
.
path
.
isfile
(
os
.
path
.
join
(
args
.
model_name_or_path
,
"optimizer.pt"
))
and
os
.
path
.
isfile
(
os
.
path
.
join
(
args
.
model_name_or_path
,
"scheduler.pt"
)
os
.
path
.
join
(
args
.
model_name_or_path
,
"scheduler.pt"
)
):
):
# Load in optimizer and scheduler states
# Load in optimizer and scheduler states
...
@@ -263,7 +270,7 @@ def train(args, train_dataset, model, tokenizer):
...
@@ -263,7 +270,7 @@ def train(args, train_dataset, model, tokenizer):
epochs_trained
=
0
epochs_trained
=
0
steps_trained_in_current_epoch
=
0
steps_trained_in_current_epoch
=
0
# Check if continuing training from a checkpoint
# Check if continuing training from a checkpoint
if
os
.
path
.
exists
(
args
.
model_name_or_path
):
if
args
.
model_name_or_path
and
os
.
path
.
exists
(
args
.
model_name_or_path
):
try
:
try
:
# set global_step to gobal_step of last saved checkpoint from model path
# set global_step to gobal_step of last saved checkpoint from model path
checkpoint_suffix
=
args
.
model_name_or_path
.
split
(
"-"
)[
-
1
].
split
(
"/"
)[
0
]
checkpoint_suffix
=
args
.
model_name_or_path
.
split
(
"-"
)[
-
1
].
split
(
"/"
)[
0
]
...
@@ -342,8 +349,7 @@ def train(args, train_dataset, model, tokenizer):
...
@@ -342,8 +349,7 @@ def train(args, train_dataset, model, tokenizer):
checkpoint_prefix
=
"checkpoint"
checkpoint_prefix
=
"checkpoint"
# Save model checkpoint
# Save model checkpoint
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"{}-{}"
.
format
(
checkpoint_prefix
,
global_step
))
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"{}-{}"
.
format
(
checkpoint_prefix
,
global_step
))
if
not
os
.
path
.
exists
(
output_dir
):
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
os
.
makedirs
(
output_dir
)
model_to_save
=
(
model_to_save
=
(
model
.
module
if
hasattr
(
model
,
"module"
)
else
model
model
.
module
if
hasattr
(
model
,
"module"
)
else
model
)
# Take care of distributed/parallel training
)
# Take care of distributed/parallel training
...
@@ -372,14 +378,14 @@ def train(args, train_dataset, model, tokenizer):
...
@@ -372,14 +378,14 @@ def train(args, train_dataset, model, tokenizer):
return
global_step
,
tr_loss
/
global_step
return
global_step
,
tr_loss
/
global_step
def
evaluate
(
args
,
model
,
t
okenizer
,
prefix
=
""
):
def
evaluate
(
args
,
model
:
PreTrainedModel
,
tokenizer
:
PreTrainedT
okenizer
,
prefix
=
""
)
->
Dict
:
# Loop to handle MNLI double evaluation (matched, mis-matched)
# Loop to handle MNLI double evaluation (matched, mis-matched)
eval_output_dir
=
args
.
output_dir
eval_output_dir
=
args
.
output_dir
eval_dataset
=
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
True
)
eval_dataset
=
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
True
)
if
not
os
.
path
.
exists
(
eval_output_dir
)
and
args
.
local_rank
in
[
-
1
,
0
]:
if
args
.
local_rank
in
[
-
1
,
0
]:
os
.
makedirs
(
eval_output_dir
)
os
.
makedirs
(
eval_output_dir
,
exist_ok
=
True
)
args
.
eval_batch_size
=
args
.
per_gpu_eval_batch_size
*
max
(
1
,
args
.
n_gpu
)
args
.
eval_batch_size
=
args
.
per_gpu_eval_batch_size
*
max
(
1
,
args
.
n_gpu
)
# Note that DistributedSampler samples randomly
# Note that DistributedSampler samples randomly
...
@@ -433,11 +439,16 @@ def main():
...
@@ -433,11 +439,16 @@ def main():
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--output_dir"
,
"--output_dir"
,
default
=
None
,
type
=
str
,
type
=
str
,
required
=
True
,
required
=
True
,
help
=
"The output directory where the model predictions and checkpoints will be written."
,
help
=
"The output directory where the model predictions and checkpoints will be written."
,
)
)
parser
.
add_argument
(
"--model_type"
,
type
=
str
,
required
=
True
,
help
=
"The model architecture to be trained or fine-tuned."
,
)
parser
.
add_argument
(
"--should_continue"
,
action
=
"store_true"
,
help
=
"Whether to continue from latest checkpoint in output_dir"
)
# Other parameters
# Other parameters
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -447,12 +458,11 @@ def main():
...
@@ -447,12 +458,11 @@ def main():
help
=
"An optional input evaluation data file to evaluate the perplexity on (a text file)."
,
help
=
"An optional input evaluation data file to evaluate the perplexity on (a text file)."
,
)
)
parser
.
add_argument
(
"--model_type"
,
default
=
"bert"
,
type
=
str
,
help
=
"The model architecture to be fine-tuned."
)
parser
.
add_argument
(
parser
.
add_argument
(
"--model_name_or_path"
,
"--model_name_or_path"
,
default
=
"bert-base-cased"
,
default
=
None
,
type
=
str
,
type
=
str
,
help
=
"The model checkpoint for weights initialization."
,
help
=
"The model checkpoint for weights initialization.
Leave None if you want to train a model from scratch.
"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -464,19 +474,25 @@ def main():
...
@@ -464,19 +474,25 @@ def main():
parser
.
add_argument
(
parser
.
add_argument
(
"--config_name"
,
"--config_name"
,
default
=
""
,
default
=
None
,
type
=
str
,
type
=
str
,
help
=
"Optional pretrained config name or path if not the same as model_name_or_path"
,
help
=
"Optional pretrained config name or path if not the same as model_name_or_path
. If both are None, initialize a new config.
"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--tokenizer_name"
,
"--tokenizer_name"
,
default
=
None
,
type
=
str
,
help
=
"Optional pretrained tokenizer name or path if not the same as model_name_or_path. If both are None, initialize a new tokenizer."
,
)
parser
.
add_argument
(
"--tokenizer_init_args"
,
default
=
""
,
default
=
""
,
type
=
str
,
type
=
str
,
help
=
"
Optional pretrained tokenizer name or path if not the same as model_name_or_path
"
,
help
=
"
If instantiating a new tokenizer, comma-separated list of input args to feed the constructor.
"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--cache_dir"
,
"--cache_dir"
,
default
=
""
,
default
=
None
,
type
=
str
,
type
=
str
,
help
=
"Optional directory to store the pre-trained models downloaded from s3 (instead of the default one)"
,
help
=
"Optional directory to store the pre-trained models downloaded from s3 (instead of the default one)"
,
)
)
...
@@ -493,9 +509,6 @@ def main():
...
@@ -493,9 +509,6 @@ def main():
parser
.
add_argument
(
parser
.
add_argument
(
"--evaluate_during_training"
,
action
=
"store_true"
,
help
=
"Run evaluation during training at each logging step."
"--evaluate_during_training"
,
action
=
"store_true"
,
help
=
"Run evaluation during training at each logging step."
)
)
parser
.
add_argument
(
"--do_lower_case"
,
action
=
"store_true"
,
help
=
"Set this flag if you are using an uncased model."
)
parser
.
add_argument
(
"--per_gpu_train_batch_size"
,
default
=
4
,
type
=
int
,
help
=
"Batch size per GPU/CPU for training."
)
parser
.
add_argument
(
"--per_gpu_train_batch_size"
,
default
=
4
,
type
=
int
,
help
=
"Batch size per GPU/CPU for training."
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -563,7 +576,7 @@ def main():
...
@@ -563,7 +576,7 @@ def main():
if
args
.
model_type
in
[
"bert"
,
"roberta"
,
"distilbert"
,
"camembert"
]
and
not
args
.
mlm
:
if
args
.
model_type
in
[
"bert"
,
"roberta"
,
"distilbert"
,
"camembert"
]
and
not
args
.
mlm
:
raise
ValueError
(
raise
ValueError
(
"BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
"BERT and RoBERTa
-like models
do not have LM heads but masked LM heads. They must be run using the --mlm "
"flag (masked language modeling)."
"flag (masked language modeling)."
)
)
if
args
.
eval_data_file
is
None
and
args
.
do_eval
:
if
args
.
eval_data_file
is
None
and
args
.
do_eval
:
...
@@ -571,6 +584,14 @@ def main():
...
@@ -571,6 +584,14 @@ def main():
"Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
"Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
"or remove the --do_eval argument."
"or remove the --do_eval argument."
)
)
if
args
.
should_continue
:
sorted_checkpoints
=
_sorted_checkpoints
(
args
)
if
len
(
sorted_checkpoints
)
==
0
:
raise
ValueError
(
"Used --should_continue but no checkpoint was found in --output_dir."
)
else
:
args
.
model_name_or_path
=
sorted_checkpoints
[
-
1
]
if
(
if
(
os
.
path
.
exists
(
args
.
output_dir
)
os
.
path
.
exists
(
args
.
output_dir
)
...
@@ -627,26 +648,42 @@ def main():
...
@@ -627,26 +648,42 @@ def main():
torch
.
distributed
.
barrier
()
# Barrier to make sure only the first process in distributed training download model & vocab
torch
.
distributed
.
barrier
()
# Barrier to make sure only the first process in distributed training download model & vocab
config_class
,
model_class
,
tokenizer_class
=
MODEL_CLASSES
[
args
.
model_type
]
config_class
,
model_class
,
tokenizer_class
=
MODEL_CLASSES
[
args
.
model_type
]
config
=
config_class
.
from_pretrained
(
args
.
config_name
if
args
.
config_name
else
args
.
model_name_or_path
,
if
args
.
config_name
:
cache_dir
=
args
.
cache_dir
if
args
.
cache_dir
else
None
,
config
=
config_class
.
from_pretrained
(
args
.
config_name
,
cache_dir
=
args
.
cache_dir
)
)
elif
args
.
model_name_or_path
:
tokenizer
=
tokenizer_class
.
from_pretrained
(
config
=
config_class
.
from_pretrained
(
args
.
model_name_or_path
,
cache_dir
=
args
.
cache_dir
)
args
.
tokenizer_name
if
args
.
tokenizer_name
else
args
.
model_name_or_path
,
else
:
do_lower_case
=
args
.
do_lower_case
,
config
=
config_class
()
cache_dir
=
args
.
cache_dir
if
args
.
cache_dir
else
None
,
)
if
args
.
tokenizer_name
:
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
tokenizer_name
,
cache_dir
=
args
.
cache_dir
)
elif
args
.
model_name_or_path
:
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
model_name_or_path
,
cache_dir
=
args
.
cache_dir
)
else
:
logger
.
warning
(
"You are instantiating a new {} tokenizer from scratch. Are you sure this is what you meant to do?"
"To specifiy a pretrained tokenizer name, use --tokenizer_name"
.
format
(
tokenizer_class
.
__name__
)
)
tokenizer
=
tokenizer_class
(
*
args
.
tokenizer_init_args
.
split
(
","
))
if
args
.
block_size
<=
0
:
if
args
.
block_size
<=
0
:
args
.
block_size
=
(
args
.
block_size
=
tokenizer
.
max_len_single_sentence
tokenizer
.
max_len_single_sentence
# Our input block size will be the max possible for the model
)
# Our input block size will be the max possible for the model
else
:
args
.
block_size
=
min
(
args
.
block_size
,
tokenizer
.
max_len_single_sentence
)
args
.
block_size
=
min
(
args
.
block_size
,
tokenizer
.
max_len_single_sentence
)
model
=
model_class
.
from_pretrained
(
args
.
model_name_or_path
,
if
args
.
model_name_or_path
:
from_tf
=
bool
(
".ckpt"
in
args
.
model_name_or_path
),
model
=
model_class
.
from_pretrained
(
config
=
config
,
args
.
model_name_or_path
,
cache_dir
=
args
.
cache_dir
if
args
.
cache_dir
else
None
,
from_tf
=
bool
(
".ckpt"
in
args
.
model_name_or_path
),
)
config
=
config
,
cache_dir
=
args
.
cache_dir
,
)
else
:
logger
.
info
(
"Training new model from scratch"
)
model
=
model_class
(
config
=
config
)
model
.
to
(
args
.
device
)
model
.
to
(
args
.
device
)
if
args
.
local_rank
==
0
:
if
args
.
local_rank
==
0
:
...
@@ -670,8 +707,8 @@ def main():
...
@@ -670,8 +707,8 @@ def main():
# Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
# Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
if
args
.
do_train
and
(
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
):
if
args
.
do_train
and
(
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
):
# Create output directory if needed
# Create output directory if needed
if
not
os
.
path
.
exists
(
args
.
output_dir
)
and
args
.
local_rank
in
[
-
1
,
0
]:
if
args
.
local_rank
in
[
-
1
,
0
]:
os
.
makedirs
(
args
.
output_dir
)
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
logger
.
info
(
"Saving model checkpoint to %s"
,
args
.
output_dir
)
logger
.
info
(
"Saving model checkpoint to %s"
,
args
.
output_dir
)
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
...
@@ -687,7 +724,7 @@ def main():
...
@@ -687,7 +724,7 @@ def main():
# Load a trained model and vocabulary that you have fine-tuned
# Load a trained model and vocabulary that you have fine-tuned
model
=
model_class
.
from_pretrained
(
args
.
output_dir
)
model
=
model_class
.
from_pretrained
(
args
.
output_dir
)
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
output_dir
,
do_lower_case
=
args
.
do_lower_case
)
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
output_dir
)
model
.
to
(
args
.
device
)
model
.
to
(
args
.
device
)
# Evaluation
# Evaluation
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment