Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b80684b2
Commit
b80684b2
authored
Feb 08, 2019
by
thomwolf
Browse files
fixing run openai gpt example
parent
80607874
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
22 deletions
+26
-22
examples/run_openai_gpt.py
examples/run_openai_gpt.py
+26
-22
No files found.
examples/run_openai_gpt.py
View file @
b80684b2
...
...
@@ -31,7 +31,9 @@ import torch
from
torch.utils.data
import
(
DataLoader
,
RandomSampler
,
SequentialSampler
,
TensorDataset
)
from
pytorch_pretrained_bert
import
OpenAIGPTDoubleHeadsModel
,
OpenAIGPTTokenizer
,
OpenAIAdam
from
pytorch_pretrained_bert
import
OpenAIGPTDoubleHeadsModel
,
OpenAIGPTTokenizer
,
OpenAIAdam
,
cached_path
ROCSTORIES_URL
=
"https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz"
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
...
...
@@ -63,7 +65,7 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d
n_batch
=
len
(
dataset
)
input_ids
=
np
.
zeros
((
n_batch
,
2
,
input_len
),
dtype
=
np
.
int64
)
mc_token_mask
=
np
.
zeros
((
n_batch
,
2
,
input_len
),
dtype
=
np
.
int64
)
lm_labels
=
np
.
full
((
n_batch
,
2
,
input_len
),
-
1
,
dtype
=
np
.
int64
)
lm_labels
=
np
.
full
((
n_batch
,
2
,
input_len
),
fill_value
=
-
1
,
dtype
=
np
.
int64
)
mc_labels
=
np
.
zeros
((
n_batch
,),
dtype
=
np
.
int64
)
for
i
,
(
story
,
cont1
,
cont2
,
mc_label
),
in
enumerate
(
dataset
):
with_cont1
=
[
start_token
]
+
story
[:
cap_length
]
+
[
delimiter_token
]
+
cont1
[:
cap_length
]
+
[
clf_token
]
...
...
@@ -71,6 +73,7 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d
input_ids
[
i
,
0
,
:
len
(
with_cont1
)]
=
with_cont1
input_ids
[
i
,
1
,
:
len
(
with_cont2
)]
=
with_cont2
mc_token_mask
[
i
,
0
,
len
(
with_cont1
)
-
1
]
=
1
mc_token_mask
[
i
,
1
,
len
(
with_cont2
)
-
1
]
=
1
lm_labels
[
i
,
0
,
:
len
(
with_cont1
)
-
1
]
=
with_cont1
[
1
:]
lm_labels
[
i
,
1
,
:
len
(
with_cont2
)
-
1
]
=
with_cont2
[
1
:]
mc_labels
[
i
]
=
mc_label
...
...
@@ -86,8 +89,8 @@ def main():
parser
.
add_argument
(
"--do_eval"
,
action
=
'store_true'
,
help
=
"Whether to run eval on the dev set."
)
parser
.
add_argument
(
"--output_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The output directory where the model predictions and checkpoints will be written."
)
parser
.
add_argument
(
'--train_dataset'
,
type
=
str
,
default
=
'
cloze_test_val__spring2016 - cloze_test_ALL_val.tsv
'
)
parser
.
add_argument
(
'--eval_dataset'
,
type
=
str
,
default
=
'
test_spring2016.tsv
'
)
parser
.
add_argument
(
'--train_dataset'
,
type
=
str
,
default
=
''
)
parser
.
add_argument
(
'--eval_dataset'
,
type
=
str
,
default
=
''
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
42
)
parser
.
add_argument
(
'--num_train_epochs'
,
type
=
int
,
default
=
3
)
parser
.
add_argument
(
'--train_batch_size'
,
type
=
int
,
default
=
8
)
...
...
@@ -97,7 +100,7 @@ def main():
parser
.
add_argument
(
'--warmup_proportion'
,
type
=
float
,
default
=
0.002
)
parser
.
add_argument
(
'--lr_schedule'
,
type
=
str
,
default
=
'warmup_linear'
)
parser
.
add_argument
(
'--weight_decay'
,
type
=
float
,
default
=
0.01
)
parser
.
add_argument
(
'--lm_coef'
,
type
=
float
,
default
=
0.
5
)
parser
.
add_argument
(
'--lm_coef'
,
type
=
float
,
default
=
0.
9
)
parser
.
add_argument
(
'--n_valid'
,
type
=
int
,
default
=
374
)
parser
.
add_argument
(
'--server_ip'
,
type
=
str
,
default
=
''
,
help
=
"Can be used for distant debugging."
)
...
...
@@ -137,6 +140,8 @@ def main():
model
.
to
(
device
)
# Load and encode the datasets
if
not
args
.
train_dataset
and
not
args
.
eval_dataset
:
roc_stories
=
cached_path
(
ROCSTORIES_URL
)
def
tokenize_and_encode
(
obj
):
""" Tokenize and encode a nested object """
if
isinstance
(
obj
,
str
):
...
...
@@ -144,7 +149,6 @@ def main():
elif
isinstance
(
obj
,
int
):
return
obj
return
list
(
tokenize_and_encode
(
o
)
for
o
in
obj
)
logger
.
info
(
"Encoding dataset..."
)
train_dataset
=
load_rocstories_dataset
(
args
.
train_dataset
)
eval_dataset
=
load_rocstories_dataset
(
args
.
eval_dataset
)
...
...
@@ -152,13 +156,13 @@ def main():
encoded_datasets
=
tokenize_and_encode
(
datasets
)
# Compute the mex input length for the Transformer
input_length
=
max
(
len
(
story
)
+
max
(
len
(
cont1
),
len
(
cont2
))
+
3
\
max_length
=
model
.
config
.
n_positions
//
2
-
2
input_length
=
max
(
len
(
story
[:
max_length
])
+
max
(
len
(
cont1
[:
max_length
]),
len
(
cont2
[:
max_length
]))
+
3
\
for
dataset
in
encoded_datasets
for
story
,
cont1
,
cont2
,
_
in
dataset
)
input_length
=
min
(
input_length
,
model
.
config
.
n_positions
)
# Max size of input for the pre-trained model
max_sub_part_length
=
input_length
//
2
-
2
# Prepare inputs tensors and dataloaders
tensor_datasets
=
pre_process_datasets
(
encoded_datasets
,
input_length
,
max_
sub_part_
length
,
*
special_tokens_ids
)
tensor_datasets
=
pre_process_datasets
(
encoded_datasets
,
input_length
,
max_length
,
*
special_tokens_ids
)
train_tensor_dataset
,
eval_tensor_dataset
=
tensor_datasets
[
0
],
tensor_datasets
[
1
]
train_data
=
TensorDataset
(
*
train_tensor_dataset
)
...
...
@@ -176,7 +180,7 @@ def main():
{
'params'
:
[
p
for
n
,
p
in
param_optimizer
if
not
any
(
nd
in
n
for
nd
in
no_decay
)],
'weight_decay'
:
0.01
},
{
'params'
:
[
p
for
n
,
p
in
param_optimizer
if
any
(
nd
in
n
for
nd
in
no_decay
)],
'weight_decay'
:
0.0
}
]
num_train_optimization_steps
=
len
(
train_data
)
//
args
.
train_batch_size
num_train_optimization_steps
=
len
(
train_data
)
*
args
.
num_train_epochs
//
args
.
train_batch_size
optimizer
=
OpenAIAdam
(
optimizer_grouped_parameters
,
lr
=
args
.
learning_rate
,
warmup
=
args
.
warmup_proportion
,
...
...
@@ -185,12 +189,11 @@ def main():
t_total
=
num_train_optimization_steps
)
if
args
.
do_train
:
nb_tr_steps
=
0
tr_loss
=
0
nb_tr_steps
,
tr_loss
,
exp_average_loss
=
0
,
0
,
None
model
.
train
()
for
_
in
trange
(
int
(
args
.
num_train_epochs
),
desc
=
"Epoch"
):
tr_loss
=
0
nb_tr_examples
,
nb_tr_steps
=
0
,
0
nb_tr_steps
=
0
tqdm_bar
=
tqdm
(
train_dataloader
,
desc
=
"Training"
)
for
step
,
batch
in
enumerate
(
tqdm_bar
):
batch
=
tuple
(
t
.
to
(
device
)
for
t
in
batch
)
...
...
@@ -200,21 +203,22 @@ def main():
loss
.
backward
()
optimizer
.
step
()
tr_loss
+=
loss
.
item
()
nb_tr_examples
+=
input_ids
.
size
(
0
)
exp_average_loss
=
loss
.
item
()
if
exp_average_loss
is
None
else
0.7
*
exp_average_loss
+
0.3
*
loss
.
item
(
)
nb_tr_steps
+=
1
tqdm_bar
.
desc
=
"Training loss: {:.2e}
"
.
format
(
tr_loss
/
nb_tr_steps
)
tqdm_bar
.
desc
=
"Training loss: {:.2e}
lr: {:.2e}"
.
format
(
exp_average_loss
,
optimizer
.
get_lr
()[
0
]
)
# Save a trained model
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
output_model_file
=
os
.
path
.
join
(
args
.
output_dir
,
"pytorch_model.bin"
)
if
args
.
do_train
:
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
output_model_file
=
os
.
path
.
join
(
args
.
output_dir
,
"pytorch_model.bin"
)
config
=
model
.
config
torch
.
save
(
model_to_save
.
state_dict
(),
output_model_file
)
# Load a trained model that you have fine-tuned
model_state_dict
=
torch
.
load
(
output_model_file
)
model
=
OpenAIGPTDoubleHeadsModel
.
from_pretrained
(
args
.
model_name
,
state_dict
=
model_state_dict
,
num_special_tokens
=
len
(
special_tokens
)
)
model
.
to
(
device
)
# Load a trained model that you have fine-tuned
model_state_dict
=
torch
.
load
(
output_model_file
)
model
=
OpenAIGPTDoubleHeadsModel
(
config
)
model
.
load_state_dict
(
model_state_dict
)
model
.
to
(
device
)
if
args
.
do_eval
:
model
.
eval
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment