Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b80684b2
Commit
b80684b2
authored
Feb 08, 2019
by
thomwolf
Browse files
fixing run openai gpt example
parent
80607874
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
22 deletions
+26
-22
examples/run_openai_gpt.py
examples/run_openai_gpt.py
+26
-22
No files found.
examples/run_openai_gpt.py
View file @
b80684b2
...
@@ -31,7 +31,9 @@ import torch
...
@@ -31,7 +31,9 @@ import torch
from
torch.utils.data
import
(
DataLoader
,
RandomSampler
,
SequentialSampler
,
from
torch.utils.data
import
(
DataLoader
,
RandomSampler
,
SequentialSampler
,
TensorDataset
)
TensorDataset
)
from
pytorch_pretrained_bert
import
OpenAIGPTDoubleHeadsModel
,
OpenAIGPTTokenizer
,
OpenAIAdam
from
pytorch_pretrained_bert
import
OpenAIGPTDoubleHeadsModel
,
OpenAIGPTTokenizer
,
OpenAIAdam
,
cached_path
ROCSTORIES_URL
=
"https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz"
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
...
@@ -63,7 +65,7 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d
...
@@ -63,7 +65,7 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d
n_batch
=
len
(
dataset
)
n_batch
=
len
(
dataset
)
input_ids
=
np
.
zeros
((
n_batch
,
2
,
input_len
),
dtype
=
np
.
int64
)
input_ids
=
np
.
zeros
((
n_batch
,
2
,
input_len
),
dtype
=
np
.
int64
)
mc_token_mask
=
np
.
zeros
((
n_batch
,
2
,
input_len
),
dtype
=
np
.
int64
)
mc_token_mask
=
np
.
zeros
((
n_batch
,
2
,
input_len
),
dtype
=
np
.
int64
)
lm_labels
=
np
.
full
((
n_batch
,
2
,
input_len
),
-
1
,
dtype
=
np
.
int64
)
lm_labels
=
np
.
full
((
n_batch
,
2
,
input_len
),
fill_value
=
-
1
,
dtype
=
np
.
int64
)
mc_labels
=
np
.
zeros
((
n_batch
,),
dtype
=
np
.
int64
)
mc_labels
=
np
.
zeros
((
n_batch
,),
dtype
=
np
.
int64
)
for
i
,
(
story
,
cont1
,
cont2
,
mc_label
),
in
enumerate
(
dataset
):
for
i
,
(
story
,
cont1
,
cont2
,
mc_label
),
in
enumerate
(
dataset
):
with_cont1
=
[
start_token
]
+
story
[:
cap_length
]
+
[
delimiter_token
]
+
cont1
[:
cap_length
]
+
[
clf_token
]
with_cont1
=
[
start_token
]
+
story
[:
cap_length
]
+
[
delimiter_token
]
+
cont1
[:
cap_length
]
+
[
clf_token
]
...
@@ -71,6 +73,7 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d
...
@@ -71,6 +73,7 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d
input_ids
[
i
,
0
,
:
len
(
with_cont1
)]
=
with_cont1
input_ids
[
i
,
0
,
:
len
(
with_cont1
)]
=
with_cont1
input_ids
[
i
,
1
,
:
len
(
with_cont2
)]
=
with_cont2
input_ids
[
i
,
1
,
:
len
(
with_cont2
)]
=
with_cont2
mc_token_mask
[
i
,
0
,
len
(
with_cont1
)
-
1
]
=
1
mc_token_mask
[
i
,
0
,
len
(
with_cont1
)
-
1
]
=
1
mc_token_mask
[
i
,
1
,
len
(
with_cont2
)
-
1
]
=
1
lm_labels
[
i
,
0
,
:
len
(
with_cont1
)
-
1
]
=
with_cont1
[
1
:]
lm_labels
[
i
,
0
,
:
len
(
with_cont1
)
-
1
]
=
with_cont1
[
1
:]
lm_labels
[
i
,
1
,
:
len
(
with_cont2
)
-
1
]
=
with_cont2
[
1
:]
lm_labels
[
i
,
1
,
:
len
(
with_cont2
)
-
1
]
=
with_cont2
[
1
:]
mc_labels
[
i
]
=
mc_label
mc_labels
[
i
]
=
mc_label
...
@@ -86,8 +89,8 @@ def main():
...
@@ -86,8 +89,8 @@ def main():
parser
.
add_argument
(
"--do_eval"
,
action
=
'store_true'
,
help
=
"Whether to run eval on the dev set."
)
parser
.
add_argument
(
"--do_eval"
,
action
=
'store_true'
,
help
=
"Whether to run eval on the dev set."
)
parser
.
add_argument
(
"--output_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,
parser
.
add_argument
(
"--output_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The output directory where the model predictions and checkpoints will be written."
)
help
=
"The output directory where the model predictions and checkpoints will be written."
)
parser
.
add_argument
(
'--train_dataset'
,
type
=
str
,
default
=
'
cloze_test_val__spring2016 - cloze_test_ALL_val.tsv
'
)
parser
.
add_argument
(
'--train_dataset'
,
type
=
str
,
default
=
''
)
parser
.
add_argument
(
'--eval_dataset'
,
type
=
str
,
default
=
'
test_spring2016.tsv
'
)
parser
.
add_argument
(
'--eval_dataset'
,
type
=
str
,
default
=
''
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
42
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
42
)
parser
.
add_argument
(
'--num_train_epochs'
,
type
=
int
,
default
=
3
)
parser
.
add_argument
(
'--num_train_epochs'
,
type
=
int
,
default
=
3
)
parser
.
add_argument
(
'--train_batch_size'
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
'--train_batch_size'
,
type
=
int
,
default
=
8
)
...
@@ -97,7 +100,7 @@ def main():
...
@@ -97,7 +100,7 @@ def main():
parser
.
add_argument
(
'--warmup_proportion'
,
type
=
float
,
default
=
0.002
)
parser
.
add_argument
(
'--warmup_proportion'
,
type
=
float
,
default
=
0.002
)
parser
.
add_argument
(
'--lr_schedule'
,
type
=
str
,
default
=
'warmup_linear'
)
parser
.
add_argument
(
'--lr_schedule'
,
type
=
str
,
default
=
'warmup_linear'
)
parser
.
add_argument
(
'--weight_decay'
,
type
=
float
,
default
=
0.01
)
parser
.
add_argument
(
'--weight_decay'
,
type
=
float
,
default
=
0.01
)
parser
.
add_argument
(
'--lm_coef'
,
type
=
float
,
default
=
0.
5
)
parser
.
add_argument
(
'--lm_coef'
,
type
=
float
,
default
=
0.
9
)
parser
.
add_argument
(
'--n_valid'
,
type
=
int
,
default
=
374
)
parser
.
add_argument
(
'--n_valid'
,
type
=
int
,
default
=
374
)
parser
.
add_argument
(
'--server_ip'
,
type
=
str
,
default
=
''
,
help
=
"Can be used for distant debugging."
)
parser
.
add_argument
(
'--server_ip'
,
type
=
str
,
default
=
''
,
help
=
"Can be used for distant debugging."
)
...
@@ -137,6 +140,8 @@ def main():
...
@@ -137,6 +140,8 @@ def main():
model
.
to
(
device
)
model
.
to
(
device
)
# Load and encode the datasets
# Load and encode the datasets
if
not
args
.
train_dataset
and
not
args
.
eval_dataset
:
roc_stories
=
cached_path
(
ROCSTORIES_URL
)
def
tokenize_and_encode
(
obj
):
def
tokenize_and_encode
(
obj
):
""" Tokenize and encode a nested object """
""" Tokenize and encode a nested object """
if
isinstance
(
obj
,
str
):
if
isinstance
(
obj
,
str
):
...
@@ -144,7 +149,6 @@ def main():
...
@@ -144,7 +149,6 @@ def main():
elif
isinstance
(
obj
,
int
):
elif
isinstance
(
obj
,
int
):
return
obj
return
obj
return
list
(
tokenize_and_encode
(
o
)
for
o
in
obj
)
return
list
(
tokenize_and_encode
(
o
)
for
o
in
obj
)
logger
.
info
(
"Encoding dataset..."
)
logger
.
info
(
"Encoding dataset..."
)
train_dataset
=
load_rocstories_dataset
(
args
.
train_dataset
)
train_dataset
=
load_rocstories_dataset
(
args
.
train_dataset
)
eval_dataset
=
load_rocstories_dataset
(
args
.
eval_dataset
)
eval_dataset
=
load_rocstories_dataset
(
args
.
eval_dataset
)
...
@@ -152,13 +156,13 @@ def main():
...
@@ -152,13 +156,13 @@ def main():
encoded_datasets
=
tokenize_and_encode
(
datasets
)
encoded_datasets
=
tokenize_and_encode
(
datasets
)
# Compute the mex input length for the Transformer
# Compute the mex input length for the Transformer
input_length
=
max
(
len
(
story
)
+
max
(
len
(
cont1
),
len
(
cont2
))
+
3
\
max_length
=
model
.
config
.
n_positions
//
2
-
2
input_length
=
max
(
len
(
story
[:
max_length
])
+
max
(
len
(
cont1
[:
max_length
]),
len
(
cont2
[:
max_length
]))
+
3
\
for
dataset
in
encoded_datasets
for
story
,
cont1
,
cont2
,
_
in
dataset
)
for
dataset
in
encoded_datasets
for
story
,
cont1
,
cont2
,
_
in
dataset
)
input_length
=
min
(
input_length
,
model
.
config
.
n_positions
)
# Max size of input for the pre-trained model
input_length
=
min
(
input_length
,
model
.
config
.
n_positions
)
# Max size of input for the pre-trained model
max_sub_part_length
=
input_length
//
2
-
2
# Prepare inputs tensors and dataloaders
# Prepare inputs tensors and dataloaders
tensor_datasets
=
pre_process_datasets
(
encoded_datasets
,
input_length
,
max_
sub_part_
length
,
*
special_tokens_ids
)
tensor_datasets
=
pre_process_datasets
(
encoded_datasets
,
input_length
,
max_length
,
*
special_tokens_ids
)
train_tensor_dataset
,
eval_tensor_dataset
=
tensor_datasets
[
0
],
tensor_datasets
[
1
]
train_tensor_dataset
,
eval_tensor_dataset
=
tensor_datasets
[
0
],
tensor_datasets
[
1
]
train_data
=
TensorDataset
(
*
train_tensor_dataset
)
train_data
=
TensorDataset
(
*
train_tensor_dataset
)
...
@@ -176,7 +180,7 @@ def main():
...
@@ -176,7 +180,7 @@ def main():
{
'params'
:
[
p
for
n
,
p
in
param_optimizer
if
not
any
(
nd
in
n
for
nd
in
no_decay
)],
'weight_decay'
:
0.01
},
{
'params'
:
[
p
for
n
,
p
in
param_optimizer
if
not
any
(
nd
in
n
for
nd
in
no_decay
)],
'weight_decay'
:
0.01
},
{
'params'
:
[
p
for
n
,
p
in
param_optimizer
if
any
(
nd
in
n
for
nd
in
no_decay
)],
'weight_decay'
:
0.0
}
{
'params'
:
[
p
for
n
,
p
in
param_optimizer
if
any
(
nd
in
n
for
nd
in
no_decay
)],
'weight_decay'
:
0.0
}
]
]
num_train_optimization_steps
=
len
(
train_data
)
//
args
.
train_batch_size
num_train_optimization_steps
=
len
(
train_data
)
*
args
.
num_train_epochs
//
args
.
train_batch_size
optimizer
=
OpenAIAdam
(
optimizer_grouped_parameters
,
optimizer
=
OpenAIAdam
(
optimizer_grouped_parameters
,
lr
=
args
.
learning_rate
,
lr
=
args
.
learning_rate
,
warmup
=
args
.
warmup_proportion
,
warmup
=
args
.
warmup_proportion
,
...
@@ -185,12 +189,11 @@ def main():
...
@@ -185,12 +189,11 @@ def main():
t_total
=
num_train_optimization_steps
)
t_total
=
num_train_optimization_steps
)
if
args
.
do_train
:
if
args
.
do_train
:
nb_tr_steps
=
0
nb_tr_steps
,
tr_loss
,
exp_average_loss
=
0
,
0
,
None
tr_loss
=
0
model
.
train
()
model
.
train
()
for
_
in
trange
(
int
(
args
.
num_train_epochs
),
desc
=
"Epoch"
):
for
_
in
trange
(
int
(
args
.
num_train_epochs
),
desc
=
"Epoch"
):
tr_loss
=
0
tr_loss
=
0
nb_tr_examples
,
nb_tr_steps
=
0
,
0
nb_tr_steps
=
0
tqdm_bar
=
tqdm
(
train_dataloader
,
desc
=
"Training"
)
tqdm_bar
=
tqdm
(
train_dataloader
,
desc
=
"Training"
)
for
step
,
batch
in
enumerate
(
tqdm_bar
):
for
step
,
batch
in
enumerate
(
tqdm_bar
):
batch
=
tuple
(
t
.
to
(
device
)
for
t
in
batch
)
batch
=
tuple
(
t
.
to
(
device
)
for
t
in
batch
)
...
@@ -200,21 +203,22 @@ def main():
...
@@ -200,21 +203,22 @@ def main():
loss
.
backward
()
loss
.
backward
()
optimizer
.
step
()
optimizer
.
step
()
tr_loss
+=
loss
.
item
()
tr_loss
+=
loss
.
item
()
nb_tr_examples
+=
input_ids
.
size
(
0
)
exp_average_loss
=
loss
.
item
()
if
exp_average_loss
is
None
else
0.7
*
exp_average_loss
+
0.3
*
loss
.
item
(
)
nb_tr_steps
+=
1
nb_tr_steps
+=
1
tqdm_bar
.
desc
=
"Training loss: {:.2e}
"
.
format
(
tr_loss
/
nb_tr_steps
)
tqdm_bar
.
desc
=
"Training loss: {:.2e}
lr: {:.2e}"
.
format
(
exp_average_loss
,
optimizer
.
get_lr
()[
0
]
)
# Save a trained model
# Save a trained model
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
output_model_file
=
os
.
path
.
join
(
args
.
output_dir
,
"pytorch_model.bin"
)
if
args
.
do_train
:
if
args
.
do_train
:
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
output_model_file
=
os
.
path
.
join
(
args
.
output_dir
,
"pytorch_model.bin"
)
config
=
model
.
config
torch
.
save
(
model_to_save
.
state_dict
(),
output_model_file
)
torch
.
save
(
model_to_save
.
state_dict
(),
output_model_file
)
# Load a trained model that you have fine-tuned
# Load a trained model that you have fine-tuned
model_state_dict
=
torch
.
load
(
output_model_file
)
model_state_dict
=
torch
.
load
(
output_model_file
)
model
=
OpenAIGPTDoubleHeadsModel
.
from_pretrained
(
args
.
model_name
,
state_dict
=
model_state_dict
,
model
=
OpenAIGPTDoubleHeadsModel
(
config
)
num_special_tokens
=
len
(
special_tokens
)
)
model
.
load_state_dict
(
model_state_dict
)
model
.
to
(
device
)
model
.
to
(
device
)
if
args
.
do_eval
:
if
args
.
do_eval
:
model
.
eval
()
model
.
eval
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment