Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a690edab
Commit
a690edab
authored
Aug 20, 2019
by
thomwolf
Browse files
various fix and clean up on run_lm_finetuning
parent
f94f1c60
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
113 additions
and
102 deletions
+113
-102
.gitignore
.gitignore
+4
-1
examples/run_lm_finetuning.py
examples/run_lm_finetuning.py
+109
-50
examples/utils_lm.py
examples/utils_lm.py
+0
-51
No files found.
.gitignore
View file @
a690edab
...
...
@@ -127,4 +127,7 @@ proc_data
# examples
runs
examples/runs
\ No newline at end of file
examples/runs
# data
data
\ No newline at end of file
examples/run_
generative
_finetuning.py
→
examples/run_
lm
_finetuning.py
View file @
a690edab
...
...
@@ -25,33 +25,75 @@ import argparse
import
glob
import
logging
import
os
import
pickle
import
random
import
numpy
as
np
import
torch
from
torch.utils.data
import
(
DataLoader
,
SequentialSampler
,)
from
torch.utils.data
import
DataLoader
,
Dataset
,
SequentialSampler
from
torch.utils.data.distributed
import
DistributedSampler
from
tensorboardX
import
SummaryWriter
from
tqdm
import
tqdm
,
trange
from
pytorch_transformers
import
(
WEIGHTS_NAME
,
GPT2Config
,
GPT2LMHeadModel
,
GPT2Tokenizer
,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
OpenAIGPTConfig
,
OpenAIGPTLMHeadModel
,
OpenAIGPTTokenizer
,
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
,
Bert
Config
,
BertForMaskedLM
,
BertTokenizer
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
RobertaConfig
,
RobertaForMaskedLM
,
RobertaTokenizer
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
pytorch_transformers
import
AdamW
,
WarmupLinearSchedule
logger
=
logging
.
getLogger
(
__name__
)
from
pytorch_transformers
import
(
WEIGHTS_NAME
,
AdamW
,
WarmupLinearSchedule
,
BertConfig
,
BertForMaskedLM
,
BertTokenizer
,
GPT2
Config
,
GPT2LMHeadModel
,
GPT2Tokenizer
,
OpenAIGPTConfig
,
OpenAIGPTLMHeadModel
,
OpenAIGPTTokenizer
,
RobertaConfig
,
RobertaForMaskedLM
,
RobertaTokenizer
)
from
utils_lm
import
WikiTextDataset
logger
=
logging
.
getLogger
(
__name__
)
MODEL_CLASSES
=
{
'gpt2'
:
(
GPT2Config
,
GPT2LMHeadModel
,
GPT2Tokenizer
),
'openai-gpt'
:
(
OpenAIGPTConfig
,
OpenAIGPTLMHeadModel
,
OpenAIGPTTokenizer
),
"
bert
"
:
(
BertConfig
,
BertForMaskedLM
,
BertTokenizer
),
"
roberta
"
:
(
RobertaConfig
,
RobertaForMaskedLM
,
RobertaTokenizer
)
'
bert
'
:
(
BertConfig
,
BertForMaskedLM
,
BertTokenizer
),
'
roberta
'
:
(
RobertaConfig
,
RobertaForMaskedLM
,
RobertaTokenizer
)
}
class
TextDataset
(
Dataset
):
def
__init__
(
self
,
tokenizer
,
file_path
=
'train'
,
block_size
=
512
):
assert
os
.
path
.
isfile
(
file_path
)
directory
,
filename
=
os
.
path
.
split
(
file_path
)
cached_features_file
=
os
.
path
.
join
(
directory
,
f
'cached_lm_
{
block_size
}
_
{
filename
}
'
)
if
os
.
path
.
exists
(
cached_features_file
):
logger
.
info
(
"Loading features from cached file %s"
,
cached_features_file
)
with
open
(
cached_features_file
,
'rb'
)
as
handle
:
self
.
examples
=
pickle
.
load
(
handle
)
else
:
logger
.
info
(
"Creating features from dataset file at %s"
,
directory
)
self
.
examples
=
[]
with
open
(
file_path
,
encoding
=
"utf-8"
)
as
f
:
text
=
f
.
read
()
tokenized_text
=
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
tokenize
(
text
))
while
len
(
tokenized_text
)
>=
block_size
:
# Truncate in block of block_size
self
.
examples
.
append
(
tokenized_text
[:
block_size
])
tokenized_text
=
tokenized_text
[
block_size
:]
# Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
# If your dataset is small, first you should loook for a bigger one :-) and second you
# can change this behavior by adding (model specific) padding.
logger
.
info
(
"Saving features into cached file %s"
,
cached_features_file
)
with
open
(
cached_features_file
,
'wb'
)
as
handle
:
pickle
.
dump
(
self
.
examples
,
handle
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
def
__len__
(
self
):
return
len
(
self
.
examples
)
def
__getitem__
(
self
,
item
):
return
torch
.
tensor
(
self
.
examples
[
item
])
def
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
False
):
dataset
=
TextDataset
(
tokenizer
,
file_path
=
args
.
eval_data_file
if
evaluate
else
args
.
train_data_file
,
block_size
=
args
.
block_size
)
return
dataset
def
set_seed
(
args
):
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
...
...
@@ -59,20 +101,27 @@ def set_seed(args):
if
args
.
n_gpu
>
0
:
torch
.
cuda
.
manual_seed_all
(
args
.
seed
)
# Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original
def
mask_tokens
(
inputs
,
tokenizer
,
args
):
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
labels
=
inputs
.
clone
()
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
masked_indices
=
torch
.
bernoulli
(
torch
.
full
(
labels
.
shape
,
args
.
mlm_probability
)).
byte
()
labels
[
~
masked_indices
.
bool
()]
=
-
1
# We only compute loss on masked tokens
labels
[
~
masked_indices
]
=
-
1
# We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced
=
torch
.
bernoulli
(
torch
.
full
(
labels
.
shape
,
0.8
)).
byte
()
&
masked_indices
inputs
[
indices_replaced
]
=
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
mask_token
)
# 10% of the time, we replace masked input tokens with random word
indices_random
=
torch
.
bernoulli
(
torch
.
full
(
labels
.
shape
,
0.5
)).
byte
()
&
masked_indices
&
~
indices_replaced
random_words
=
torch
.
randint
(
len
(
tokenizer
),
labels
.
shape
,
dtype
=
torch
.
long
)
inputs
[
indices_random
]
=
random_words
[
indices_random
]
inputs
[
indices_replaced
.
bool
()]
=
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
mask_token
)
# 80% of the time, replace masked input tokens with [MASK]
indices_random
=
(
torch
.
bernoulli
(
torch
.
full
(
labels
.
shape
,
0.5
)).
byte
()
&
masked_indices
&
~
indices_replaced
).
bool
()
random_words
=
torch
.
randint
(
args
.
num_embeddings
,
labels
.
shape
,
dtype
=
torch
.
long
)
inputs
[
indices_random
]
=
random_words
[
indices_random
]
# 10% of the time, replace masked input tokens with random word
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return
inputs
,
labels
def
train
(
args
,
train_dataset
,
model
,
tokenizer
):
""" Train the model """
if
args
.
local_rank
in
[
-
1
,
0
]:
...
...
@@ -146,13 +195,15 @@ def train(args, train_dataset, model, tokenizer):
if
args
.
fp16
:
with
amp
.
scale_loss
(
loss
,
optimizer
)
as
scaled_loss
:
scaled_loss
.
backward
()
torch
.
nn
.
utils
.
clip_grad_norm_
(
amp
.
master_params
(
optimizer
),
args
.
max_grad_norm
)
else
:
loss
.
backward
()
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
args
.
max_grad_norm
)
tr_loss
+=
loss
.
item
()
if
(
step
+
1
)
%
args
.
gradient_accumulation_steps
==
0
:
if
args
.
fp16
:
torch
.
nn
.
utils
.
clip_grad_norm_
(
amp
.
master_params
(
optimizer
),
args
.
max_grad_norm
)
else
:
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
args
.
max_grad_norm
)
optimizer
.
step
()
scheduler
.
step
()
# Update learning rate schedule
model
.
zero_grad
()
...
...
@@ -240,24 +291,22 @@ def evaluate(args, model, tokenizer, prefix=""):
return
results
def
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
False
):
dataset
=
WikiTextDataset
(
args
,
tokenizer
,
file
=
"test"
if
evaluate
else
"train"
,
directory
=
args
.
data_dir
)
return
dataset
def
main
():
parser
=
argparse
.
ArgumentParser
()
## Required parameters
parser
.
add_argument
(
"--data_
dir
"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The input
data dir. Should contain the .tsv
file
s
(
or other data files) for the task
."
)
parser
.
add_argument
(
"--
train_
data_
file
"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The input
training data
file (
a text file)
."
)
parser
.
add_argument
(
"--output_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The output directory where the model predictions and checkpoints will be written."
)
## Other parameters
parser
.
add_argument
(
"--model_name"
,
default
=
"bert"
,
type
=
str
,
parser
.
add_argument
(
"--eval_data_file"
,
default
=
None
,
type
=
str
,
help
=
"An optional input evaluation data file to evaluate the perplexity on (a text file)."
)
parser
.
add_argument
(
"--model_type"
,
default
=
"bert"
,
type
=
str
,
help
=
"The model architecture to be fine-tuned."
)
parser
.
add_argument
(
"--model_
checkpoint
"
,
default
=
"bert-base-cased"
,
type
=
str
,
parser
.
add_argument
(
"--model_
name_or_path
"
,
default
=
"bert-base-cased"
,
type
=
str
,
help
=
"The model checkpoint for weights initialization."
)
parser
.
add_argument
(
"--mlm"
,
action
=
'store_true'
,
...
...
@@ -266,20 +315,21 @@ def main():
help
=
"Ratio of tokens to mask for masked language modeling loss"
)
parser
.
add_argument
(
"--config_name"
,
default
=
""
,
type
=
str
,
help
=
"
P
retrained config name or path if not the same as model_name"
)
help
=
"
Optional p
retrained config name or path if not the same as model_name
_or_path
"
)
parser
.
add_argument
(
"--tokenizer_name"
,
default
=
""
,
type
=
str
,
help
=
"
P
retrained tokenizer name or path if not the same as model_name"
)
help
=
"
Optional p
retrained tokenizer name or path if not the same as model_name
_or_path
"
)
parser
.
add_argument
(
"--cache_dir"
,
default
=
""
,
type
=
str
,
help
=
"Where do you want to store the pre-trained models downloaded from s3"
)
parser
.
add_argument
(
"--max_seq_length"
,
default
=
128
,
type
=
int
,
help
=
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
)
help
=
"Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)"
)
parser
.
add_argument
(
"--block_size"
,
default
=-
1
,
type
=
int
,
help
=
"Optional input sequence length after tokenization."
"The training dataset will be truncated in block of this size for training."
"Default to the model max input length."
)
parser
.
add_argument
(
"--do_train"
,
action
=
'store_true'
,
help
=
"Whether to run training."
)
parser
.
add_argument
(
"--do_eval"
,
action
=
'store_true'
,
help
=
"Whether to run eval on the dev set."
)
parser
.
add_argument
(
"--evaluate_during_training"
,
action
=
'store_true'
,
help
=
"Ru
l
evaluation during training at each logging step."
)
help
=
"Ru
n
evaluation during training at each logging step."
)
parser
.
add_argument
(
"--do_lower_case"
,
action
=
'store_true'
,
help
=
"Set this flag if you are using an uncased model."
)
...
...
@@ -309,7 +359,7 @@ def main():
parser
.
add_argument
(
'--save_steps'
,
type
=
int
,
default
=
50
,
help
=
"Save checkpoint every X updates steps."
)
parser
.
add_argument
(
"--eval_all_checkpoints"
,
action
=
'store_true'
,
help
=
"Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number"
)
help
=
"Evaluate all checkpoints starting with the same prefix as model_name
_or_path
ending and ending with step number"
)
parser
.
add_argument
(
"--no_cuda"
,
action
=
'store_true'
,
help
=
"Avoid using CUDA when available"
)
parser
.
add_argument
(
'--overwrite_output_dir'
,
action
=
'store_true'
,
...
...
@@ -330,9 +380,12 @@ def main():
parser
.
add_argument
(
'--server_port'
,
type
=
str
,
default
=
''
,
help
=
"For distant debugging."
)
args
=
parser
.
parse_args
()
if
args
.
model_
nam
e
in
[
"bert"
,
"roberta"
]
and
not
args
.
mlm
:
if
args
.
model_
typ
e
in
[
"bert"
,
"roberta"
]
and
not
args
.
mlm
:
raise
ValueError
(
"BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
"flag (masked language modeling)."
)
if
args
.
eval_data_file
is
None
and
args
.
do_eval
:
raise
ValueError
(
"Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
"or remove the --do_eval argument."
)
if
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
)
and
args
.
do_train
and
not
args
.
overwrite_output_dir
:
raise
ValueError
(
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome."
.
format
(
args
.
output_dir
))
...
...
@@ -368,30 +421,36 @@ def main():
# Load pretrained model and tokenizer
if
args
.
local_rank
not
in
[
-
1
,
0
]:
torch
.
distributed
.
barrier
()
# Make sure only the first process in distributed training will download model & vocab
config_class
,
model_class
,
tokenizer_class
=
MODEL_CLASSES
[
args
.
model_name
]
config
=
config_class
.
from_pretrained
(
args
.
config_name
if
args
.
config_name
else
args
.
model_checkpoint
)
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
tokenizer_name
if
args
.
tokenizer_name
else
args
.
model_checkpoint
,
do_lower_case
=
args
.
do_lower_case
)
model
=
model_class
.
from_pretrained
(
args
.
model_checkpoint
,
from_tf
=
bool
(
'.ckpt'
in
args
.
model_checkpoint
),
config
=
config
)
args
.
num_embeddings
=
config
.
vocab_size
# We need this to create the model at next line (number of embeddings to use)
torch
.
distributed
.
barrier
()
# Barrier to make sure only the first process in distributed training download model & vocab
config_class
,
model_class
,
tokenizer_class
=
MODEL_CLASSES
[
args
.
model_type
]
config
=
config_class
.
from_pretrained
(
args
.
config_name
if
args
.
config_name
else
args
.
model_name_or_path
)
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
tokenizer_name
if
args
.
tokenizer_name
else
args
.
model_name_or_path
,
do_lower_case
=
args
.
do_lower_case
)
if
args
.
block_size
<=
0
:
args
.
block_size
=
tokenizer
.
max_len
# Our input block size will be the max possible for the model
model
=
model_class
.
from_pretrained
(
args
.
model_name_or_path
,
from_tf
=
bool
(
'.ckpt'
in
args
.
model_name_or_path
),
config
=
config
)
model
.
to
(
args
.
device
)
if
args
.
local_rank
==
0
:
torch
.
distributed
.
barrier
()
# Make sure only the first process in distributed training will download model & vocab
model
.
to
(
args
.
device
)
torch
.
distributed
.
barrier
()
# End of barrier to make sure only the first process in distributed training download model & vocab
logger
.
info
(
"Training/evaluation parameters %s"
,
args
)
# Training
if
args
.
do_train
:
if
args
.
local_rank
not
in
[
-
1
,
0
]:
torch
.
distributed
.
barrier
()
# Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache
train_dataset
=
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
False
)
if
args
.
local_rank
==
0
:
torch
.
distributed
.
barrier
()
global_step
,
tr_loss
=
train
(
args
,
train_dataset
,
model
,
tokenizer
)
logger
.
info
(
" global_step = %s, average loss = %s"
,
global_step
,
tr_loss
)
# Saving best-practices: if you use
defaults names for the model
, you can reload
i
t using from_pretrained()
# Saving best-practices: if you use
save_pretrained for the model and tokenizer
, you can reload t
hem
using from_pretrained()
if
args
.
do_train
and
(
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
):
# Create output directory if needed
if
not
os
.
path
.
exists
(
args
.
output_dir
)
and
args
.
local_rank
in
[
-
1
,
0
]:
...
...
@@ -409,7 +468,7 @@ def main():
# Load a trained model and vocabulary that you have fine-tuned
model
=
model_class
.
from_pretrained
(
args
.
output_dir
)
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
output_dir
)
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
output_dir
,
do_lower_case
=
args
.
do_lower_case
)
model
.
to
(
args
.
device
)
...
...
examples/utils_lm.py
deleted
100644 → 0
View file @
f94f1c60
from
torch.utils.data
import
Dataset
,
DataLoader
import
os
import
random
import
torch
import
torch.nn.functional
as
F
import
logging
import
pickle
logger
=
logging
.
getLogger
(
__name__
)
class
WikiTextDataset
(
Dataset
):
def
__init__
(
self
,
args
,
tokenizer
,
file
=
'train'
,
directory
=
'wikitext'
,
max_context_length
=
512
,
cache
=
None
):
if
args
.
local_rank
not
in
[
-
1
,
0
]:
torch
.
distributed
.
barrier
()
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
cached_features_file
=
os
.
path
.
join
(
args
.
data_dir
,
f
'cached_lm_
{
file
}
_
{
args
.
max_seq_length
}
'
)
if
os
.
path
.
exists
(
cached_features_file
):
logger
.
info
(
"Loading features from cached file %s"
,
cached_features_file
)
with
open
(
cached_features_file
,
'rb'
)
as
handle
:
self
.
examples
=
pickle
.
load
(
handle
)
else
:
logger
.
info
(
"Creating features from dataset file at %s"
,
args
.
data_dir
)
self
.
max_context_length
=
max_context_length
self
.
examples
=
[]
with
open
(
os
.
path
.
join
(
directory
,
f
"wiki.
{
file
}
.raw"
),
encoding
=
"utf-8"
)
as
f
:
text
=
f
.
read
()
tokenized_text
=
tokenizer
.
convert_tokens_to_ids
(
tokenizer
.
tokenize
(
text
))
while
len
(
tokenized_text
)
>
max_context_length
:
self
.
examples
.
append
(
tokenized_text
[:
max_context_length
])
tokenized_text
=
tokenized_text
[
max_context_length
:]
if
args
.
local_rank
in
[
-
1
,
0
]:
logger
.
info
(
"Saving features into cached file %s"
,
cached_features_file
)
with
open
(
cached_features_file
,
'wb'
)
as
handle
:
pickle
.
dump
(
self
.
examples
,
handle
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
if
args
.
local_rank
==
0
:
torch
.
distributed
.
barrier
()
def
__len__
(
self
):
return
len
(
self
.
examples
)
def
__getitem__
(
self
,
item
):
return
torch
.
tensor
(
self
.
examples
[
item
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment