Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0aaedcc0
Commit
0aaedcc0
authored
Nov 27, 2018
by
Li Li
Browse files
Bug fix in examples;correct t_total for distributed training;run prediction for full dataset
parent
32167cdf
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
42 additions
and
22 deletions
+42
-22
examples/run_classifier.py
examples/run_classifier.py
+11
-9
examples/run_squad.py
examples/run_squad.py
+31
-13
No files found.
examples/run_classifier.py
View file @
0aaedcc0
...
@@ -33,6 +33,7 @@ from torch.utils.data.distributed import DistributedSampler
...
@@ -33,6 +33,7 @@ from torch.utils.data.distributed import DistributedSampler
from
pytorch_pretrained_bert.tokenization
import
BertTokenizer
from
pytorch_pretrained_bert.tokenization
import
BertTokenizer
from
pytorch_pretrained_bert.modeling
import
BertForSequenceClassification
from
pytorch_pretrained_bert.modeling
import
BertForSequenceClassification
from
pytorch_pretrained_bert.optimization
import
BertAdam
from
pytorch_pretrained_bert.optimization
import
BertAdam
from
pytorch_pretrained_bert.file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
...
@@ -155,8 +156,8 @@ class MnliProcessor(DataProcessor):
...
@@ -155,8 +156,8 @@ class MnliProcessor(DataProcessor):
if
i
==
0
:
if
i
==
0
:
continue
continue
guid
=
"%s-%s"
%
(
set_type
,
line
[
0
])
guid
=
"%s-%s"
%
(
set_type
,
line
[
0
])
text_a
=
line
[
8
]
)
text_a
=
line
[
8
]
text_b
=
line
[
9
]
)
text_b
=
line
[
9
]
label
=
line
[
-
1
]
label
=
line
[
-
1
]
examples
.
append
(
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
...
@@ -482,7 +483,7 @@ def main():
...
@@ -482,7 +483,7 @@ def main():
len
(
train_examples
)
/
args
.
train_batch_size
/
args
.
gradient_accumulation_steps
*
args
.
num_train_epochs
)
len
(
train_examples
)
/
args
.
train_batch_size
/
args
.
gradient_accumulation_steps
*
args
.
num_train_epochs
)
# Prepare model
# Prepare model
model
=
BertForSequenceClassification
.
from_pretrained
(
args
.
bert_model
,
len
(
label_list
),
model
=
BertForSequenceClassification
.
from_pretrained
(
args
.
bert_model
,
cache_dir
=
PYTORCH_PRETRAINED_BERT_CACHE
/
'distributed_{}'
.
format
(
args
.
local_rank
))
cache_dir
=
PYTORCH_PRETRAINED_BERT_CACHE
/
'distributed_{}'
.
format
(
args
.
local_rank
))
if
args
.
fp16
:
if
args
.
fp16
:
model
.
half
()
model
.
half
()
...
@@ -507,10 +508,13 @@ def main():
...
@@ -507,10 +508,13 @@ def main():
{
'params'
:
[
p
for
n
,
p
in
param_optimizer
if
not
any
(
nd
in
n
for
nd
in
no_decay
)],
'weight_decay_rate'
:
0.01
},
{
'params'
:
[
p
for
n
,
p
in
param_optimizer
if
not
any
(
nd
in
n
for
nd
in
no_decay
)],
'weight_decay_rate'
:
0.01
},
{
'params'
:
[
p
for
n
,
p
in
param_optimizer
if
any
(
nd
in
n
for
nd
in
no_decay
)],
'weight_decay_rate'
:
0.0
}
{
'params'
:
[
p
for
n
,
p
in
param_optimizer
if
any
(
nd
in
n
for
nd
in
no_decay
)],
'weight_decay_rate'
:
0.0
}
]
]
t_total
=
num_train_steps
if
args
.
local_rank
!=
-
1
:
t_total
=
t_total
//
torch
.
distributed
.
get_world_size
()
optimizer
=
BertAdam
(
optimizer_grouped_parameters
,
optimizer
=
BertAdam
(
optimizer_grouped_parameters
,
lr
=
args
.
learning_rate
,
lr
=
args
.
learning_rate
,
warmup
=
args
.
warmup_proportion
,
warmup
=
args
.
warmup_proportion
,
t_total
=
num_train_steps
)
t_total
=
t_total
)
global_step
=
0
global_step
=
0
if
args
.
do_train
:
if
args
.
do_train
:
...
@@ -571,7 +575,7 @@ def main():
...
@@ -571,7 +575,7 @@ def main():
model
.
zero_grad
()
model
.
zero_grad
()
global_step
+=
1
global_step
+=
1
if
args
.
do_eval
:
if
args
.
do_eval
and
(
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
)
:
eval_examples
=
processor
.
get_dev_examples
(
args
.
data_dir
)
eval_examples
=
processor
.
get_dev_examples
(
args
.
data_dir
)
eval_features
=
convert_examples_to_features
(
eval_features
=
convert_examples_to_features
(
eval_examples
,
label_list
,
args
.
max_seq_length
,
tokenizer
)
eval_examples
,
label_list
,
args
.
max_seq_length
,
tokenizer
)
...
@@ -583,10 +587,8 @@ def main():
...
@@ -583,10 +587,8 @@ def main():
all_segment_ids
=
torch
.
tensor
([
f
.
segment_ids
for
f
in
eval_features
],
dtype
=
torch
.
long
)
all_segment_ids
=
torch
.
tensor
([
f
.
segment_ids
for
f
in
eval_features
],
dtype
=
torch
.
long
)
all_label_ids
=
torch
.
tensor
([
f
.
label_id
for
f
in
eval_features
],
dtype
=
torch
.
long
)
all_label_ids
=
torch
.
tensor
([
f
.
label_id
for
f
in
eval_features
],
dtype
=
torch
.
long
)
eval_data
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_label_ids
)
eval_data
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_label_ids
)
if
args
.
local_rank
==
-
1
:
# Run prediction for full data
eval_sampler
=
SequentialSampler
(
eval_data
)
eval_sampler
=
SequentialSampler
(
eval_data
)
else
:
eval_sampler
=
DistributedSampler
(
eval_data
)
eval_dataloader
=
DataLoader
(
eval_data
,
sampler
=
eval_sampler
,
batch_size
=
args
.
eval_batch_size
)
eval_dataloader
=
DataLoader
(
eval_data
,
sampler
=
eval_sampler
,
batch_size
=
args
.
eval_batch_size
)
model
.
eval
()
model
.
eval
()
...
...
examples/run_squad.py
View file @
0aaedcc0
...
@@ -25,6 +25,7 @@ import json
...
@@ -25,6 +25,7 @@ import json
import
math
import
math
import
os
import
os
import
random
import
random
import
pickle
from
tqdm
import
tqdm
,
trange
from
tqdm
import
tqdm
,
trange
import
numpy
as
np
import
numpy
as
np
...
@@ -35,6 +36,7 @@ from torch.utils.data.distributed import DistributedSampler
...
@@ -35,6 +36,7 @@ from torch.utils.data.distributed import DistributedSampler
from
pytorch_pretrained_bert.tokenization
import
whitespace_tokenize
,
BasicTokenizer
,
BertTokenizer
from
pytorch_pretrained_bert.tokenization
import
whitespace_tokenize
,
BasicTokenizer
,
BertTokenizer
from
pytorch_pretrained_bert.modeling
import
BertForQuestionAnswering
from
pytorch_pretrained_bert.modeling
import
BertForQuestionAnswering
from
pytorch_pretrained_bert.optimization
import
BertAdam
from
pytorch_pretrained_bert.optimization
import
BertAdam
from
pytorch_pretrained_bert.file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
...
@@ -749,6 +751,10 @@ def main():
...
@@ -749,6 +751,10 @@ def main():
type
=
int
,
type
=
int
,
default
=
1
,
default
=
1
,
help
=
"Number of updates steps to accumulate before performing a backward/update pass."
)
help
=
"Number of updates steps to accumulate before performing a backward/update pass."
)
parser
.
add_argument
(
"--do_lower_case"
,
default
=
True
,
action
=
'store_true'
,
help
=
"Whether to lower case the input text. True for uncased models, False for cased models."
)
parser
.
add_argument
(
"--local_rank"
,
parser
.
add_argument
(
"--local_rank"
,
type
=
int
,
type
=
int
,
default
=-
1
,
default
=-
1
,
...
@@ -845,13 +851,23 @@ def main():
...
@@ -845,13 +851,23 @@ def main():
{
'params'
:
[
p
for
n
,
p
in
param_optimizer
if
not
any
(
nd
in
n
for
nd
in
no_decay
)],
'weight_decay_rate'
:
0.01
},
{
'params'
:
[
p
for
n
,
p
in
param_optimizer
if
not
any
(
nd
in
n
for
nd
in
no_decay
)],
'weight_decay_rate'
:
0.01
},
{
'params'
:
[
p
for
n
,
p
in
param_optimizer
if
any
(
nd
in
n
for
nd
in
no_decay
)],
'weight_decay_rate'
:
0.0
}
{
'params'
:
[
p
for
n
,
p
in
param_optimizer
if
any
(
nd
in
n
for
nd
in
no_decay
)],
'weight_decay_rate'
:
0.0
}
]
]
t_total
=
num_train_steps
if
args
.
local_rank
!=
-
1
:
t_total
=
t_total
//
torch
.
distributed
.
get_world_size
()
optimizer
=
BertAdam
(
optimizer_grouped_parameters
,
optimizer
=
BertAdam
(
optimizer_grouped_parameters
,
lr
=
args
.
learning_rate
,
lr
=
args
.
learning_rate
,
warmup
=
args
.
warmup_proportion
,
warmup
=
args
.
warmup_proportion
,
t_total
=
num_train_steps
)
t_total
=
t_total
)
global_step
=
0
global_step
=
0
if
args
.
do_train
:
if
args
.
do_train
:
cached_train_features_file
=
args
.
train_file
+
'_{0}_{1}_{2}_{3}'
.
format
(
args
.
bert_model
,
str
(
args
.
max_seq_length
),
str
(
args
.
doc_stride
),
str
(
args
.
max_query_length
))
train_features
=
None
try
:
with
open
(
cached_train_features_file
,
"rb"
)
as
reader
:
train_features
=
pickle
.
load
(
reader
)
except
:
train_features
=
convert_examples_to_features
(
train_features
=
convert_examples_to_features
(
examples
=
train_examples
,
examples
=
train_examples
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
...
@@ -859,6 +875,10 @@ def main():
...
@@ -859,6 +875,10 @@ def main():
doc_stride
=
args
.
doc_stride
,
doc_stride
=
args
.
doc_stride
,
max_query_length
=
args
.
max_query_length
,
max_query_length
=
args
.
max_query_length
,
is_training
=
True
)
is_training
=
True
)
if
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
:
logger
.
info
(
" Saving train features into cached file %s"
,
cached_train_features_file
)
with
open
(
cached_train_features_file
,
"wb"
)
as
writer
:
train_features
=
pickle
.
dump
(
train_features
,
writer
)
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
" Num orig examples = %d"
,
len
(
train_examples
))
logger
.
info
(
" Num orig examples = %d"
,
len
(
train_examples
))
logger
.
info
(
" Num split examples = %d"
,
len
(
train_features
))
logger
.
info
(
" Num split examples = %d"
,
len
(
train_features
))
...
@@ -913,7 +933,7 @@ def main():
...
@@ -913,7 +933,7 @@ def main():
model
.
zero_grad
()
model
.
zero_grad
()
global_step
+=
1
global_step
+=
1
if
args
.
do_predict
:
if
args
.
do_predict
and
(
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
)
:
eval_examples
=
read_squad_examples
(
eval_examples
=
read_squad_examples
(
input_file
=
args
.
predict_file
,
is_training
=
False
)
input_file
=
args
.
predict_file
,
is_training
=
False
)
eval_features
=
convert_examples_to_features
(
eval_features
=
convert_examples_to_features
(
...
@@ -934,10 +954,8 @@ def main():
...
@@ -934,10 +954,8 @@ def main():
all_segment_ids
=
torch
.
tensor
([
f
.
segment_ids
for
f
in
eval_features
],
dtype
=
torch
.
long
)
all_segment_ids
=
torch
.
tensor
([
f
.
segment_ids
for
f
in
eval_features
],
dtype
=
torch
.
long
)
all_example_index
=
torch
.
arange
(
all_input_ids
.
size
(
0
),
dtype
=
torch
.
long
)
all_example_index
=
torch
.
arange
(
all_input_ids
.
size
(
0
),
dtype
=
torch
.
long
)
eval_data
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_example_index
)
eval_data
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_example_index
)
if
args
.
local_rank
==
-
1
:
# Run prediction for full data
eval_sampler
=
SequentialSampler
(
eval_data
)
eval_sampler
=
SequentialSampler
(
eval_data
)
else
:
eval_sampler
=
DistributedSampler
(
eval_data
)
eval_dataloader
=
DataLoader
(
eval_data
,
sampler
=
eval_sampler
,
batch_size
=
args
.
predict_batch_size
)
eval_dataloader
=
DataLoader
(
eval_data
,
sampler
=
eval_sampler
,
batch_size
=
args
.
predict_batch_size
)
model
.
eval
()
model
.
eval
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment