Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
90d360a7
Commit
90d360a7
authored
Nov 01, 2018
by
VictorSanh
Browse files
WIP
parent
f8e7c95d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
150 additions
and
20 deletions
+150
-20
run_classifier_pytorch.py
run_classifier_pytorch.py
+150
-20
No files found.
run_classifier_pytorch.py
View file @
90d360a7
...
@@ -115,7 +115,11 @@ parser.add_argument("--iterations_per_loop",
...
@@ -115,7 +115,11 @@ parser.add_argument("--iterations_per_loop",
default
=
1000
,
default
=
1000
,
type
=
int
,
type
=
int
,
help
=
"How many steps to make in each estimator call."
)
help
=
"How many steps to make in each estimator call."
)
parser
.
add_argument
(
"--use_gpu"
,
default
=
True
,
type
=
bool
,
help
=
"Whether to use GPU"
)
### BEGIN - TO DELETE EVENTUALLY --> NO SENSE IN PYTORCH ###
### BEGIN - TO DELETE EVENTUALLY --> NO SENSE IN PYTORCH ###
parser
.
add_argument
(
"--use_tpu"
,
parser
.
add_argument
(
"--use_tpu"
,
default
=
False
,
default
=
False
,
...
@@ -416,25 +420,18 @@ def input_fn_builder(features, seq_length, is_training, drop_remainder):
...
@@ -416,25 +420,18 @@ def input_fn_builder(features, seq_length, is_training, drop_remainder):
batch_size
=
params
[
"batch_size"
]
batch_size
=
params
[
"batch_size"
]
num_examples
=
len
(
features
)
num_examples
=
len
(
features
)
# This is for demo purposes and does NOT scale to large data sets. We do
device
=
torch
.
device
(
"cuda"
)
if
args
.
use_gpu
else
torch
.
device
(
"cpu"
)
# not use Dataset.from_generator() because that uses tf.py_func which is
d
=
{
"input_ids"
:
# not TPU compatible. The right way to load data is with TFRecordReader.
torch
.
IntTensor
(
all_input_ids
,
device
=
device
),
#Requires_grad=False by default
d
=
tf
.
data
.
Dataset
.
from_tensor_slices
({
"input_mask"
:
"input_ids"
:
torch
.
IntTensor
(
all_input_mask
,
device
=
device
),
torch
.
Tensor
(
all_input_ids
,
size
=
[
num_examples
,
seq_length
],
"segment_ids"
:
dtype
=
torch
.
int32
,
requires_grad
=
False
),
torch
.
IntTensor
(
all_segment_ids
,
device
=
device
),
"input_mask"
:
"label_ids"
:
torch
.
Tensor
(
all_input_mask
,
size
=
[
num_examples
,
seq_length
],
torch
.
IntTensor
(
all_label_ids
,
device
=
device
)
dtype
=
torch
.
int32
,
requires_grad
=
False
),
}
"segment_ids"
:
torch
.
Tensor
(
all_segment_ids
,
size
=
[
num_examples
,
seq_length
],
dtype
=
torch
.
int32
,
requires_grad
=
False
),
"label_ids"
:
torch
.
Tensor
(
all_label_ids
,
size
=
[
num_examples
],
dtype
=
torch
.
int32
,
requires_grad
=
False
)
})
if
is_training
:
if
is_training
:
d
=
d
.
repeat
()
d
=
d
.
repeat
()
d
=
d
.
shuffle
(
buffer_size
=
100
)
d
=
d
.
shuffle
(
buffer_size
=
100
)
...
@@ -443,3 +440,136 @@ def input_fn_builder(features, seq_length, is_training, drop_remainder):
...
@@ -443,3 +440,136 @@ def input_fn_builder(features, seq_length, is_training, drop_remainder):
return
d
return
d
return
input_fn
return
input_fn
def
main
(
_
):
processors
=
{
"cola"
:
ColaProcessor
,
"mnli"
:
MnliProcessor
,
"mrpc"
:
MrpcProcessor
,
}
if
not
args
.
do_train
and
not
args
.
do_eval
:
raise
ValueError
(
"At least one of `do_train` or `do_eval` must be True."
)
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
args
.
bert_config_file
)
if
args
.
max_seq_length
>
bert_config
.
max_position_embeddings
:
raise
ValueError
(
"Cannot use sequence length %d because the BERT model "
"was only trained up to sequence length %d"
%
(
args
.
max_seq_length
,
bert_config
.
max_position_embeddings
))
if
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
):
raise
ConfigurationError
(
f
"Output directory (
{
args
.
output_dir
}
) already exists and is "
f
"not empty."
)
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
task_name
=
args
.
task_name
.
lower
()
if
task_name
not
in
processors
:
raise
ValueError
(
"Task not found: %s"
%
(
task_name
))
processor
=
processors
[
task_name
]()
label_list
=
processor
.
get_labels
()
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
args
.
vocab_file
,
do_lower_case
=
args
.
do_lower_case
)
# tpu_cluster_resolver = None
# if FLAGS.use_tpu and FLAGS.tpu_name:
# tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
# FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
# is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
# run_config = tf.contrib.tpu.RunConfig(
# cluster=tpu_cluster_resolver,
# master=FLAGS.master,
# model_dir=FLAGS.output_dir,
# save_checkpoints_steps=FLAGS.save_checkpoints_steps,
# tpu_config=tf.contrib.tpu.TPUConfig(
# iterations_per_loop=FLAGS.iterations_per_loop,
# num_shards=FLAGS.num_tpu_cores,
# per_host_input_for_training=is_per_host))
train_examples
=
None
num_train_steps
=
None
num_warmup_steps
=
None
if
args
.
do_train
:
train_examples
=
processor
.
get_train_examples
(
args
.
data_dir
)
num_train_steps
=
int
(
len
(
train_examples
)
/
args
.
train_batch_size
*
args
.
num_train_epochs
)
num_warmup_steps
=
int
(
num_train_steps
*
args
.
warmup_proportion
)
model_fn
=
model_fn_builder
(
bert_config
=
bert_config
,
num_labels
=
len
(
label_list
),
init_checkpoint
=
args
.
init_checkpoint
,
learning_rate
=
args
.
learning_rate
,
num_train_steps
=
num_train_steps
,
num_warmup_steps
=
num_warmup_steps
,
use_gpu
=
args
.
use_gpu
,
use_one_hot_embeddings
=
args
.
use_gpu
)
### TO DO - to check when model_fn is written)
# If TPU is not available, this will fall back to normal Estimator on CPU
# or GPU. - TO DO
for
batch
in
estimator
=
tf
.
contrib
.
tpu
.
TPUEstimator
(
use_tpu
=
args
.
use_tpu
,
model_fn
=
model_fn
,
config
=
run_config
,
train_batch_size
=
args
.
train_batch_size
,
eval_batch_size
=
args
.
eval_batch_size
)
if
args
.
do_train
:
train_features
=
convert_examples_to_features
(
train_examples
,
label_list
,
args
.
max_seq_length
,
tokenizer
)
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
" Num examples = %d"
,
len
(
train_examples
))
logger
.
info
(
" Batch size = %d"
,
args
.
train_batch_size
)
logger
.
info
(
" Num steps = %d"
,
num_train_steps
)
train_input_fn
=
input_fn_builder
(
features
=
train_features
,
seq_length
=
args
.
max_seq_length
,
is_training
=
True
,
drop_remainder
=
True
)
estimator
.
train
(
input_fn
=
train_input_fn
,
max_steps
=
num_train_steps
)
if
args
.
do_eval
:
eval_examples
=
processor
.
get_dev_examples
(
args
.
data_dir
)
eval_features
=
convert_examples_to_features
(
eval_examples
,
label_list
,
args
.
max_seq_length
,
tokenizer
)
tf
.
logging
.
info
(
"***** Running evaluation *****"
)
tf
.
logging
.
info
(
" Num examples = %d"
,
len
(
eval_examples
))
tf
.
logging
.
info
(
" Batch size = %d"
,
args
.
eval_batch_size
)
# This tells the estimator to run through the entire set.
eval_steps
=
None
# However, if running eval on the TPU, you will need to specify the
# number of steps.
if
args
.
use_tpu
:
# Eval will be slightly WRONG on the TPU because it will truncate
# the last batch.
eval_steps
=
int
(
len
(
eval_examples
)
/
args
.
eval_batch_size
)
eval_drop_remainder
=
True
if
args
.
use_tpu
else
False
eval_input_fn
=
input_fn_builder
(
features
=
eval_features
,
seq_length
=
args
.
max_seq_length
,
is_training
=
False
,
drop_remainder
=
eval_drop_remainder
)
result
=
estimator
.
evaluate
(
input_fn
=
eval_input_fn
,
steps
=
eval_steps
)
output_eval_file
=
os
.
path
.
join
(
args
.
output_dir
,
"eval_results.txt"
)
with
tf
.
gfile
.
GFile
(
output_eval_file
,
"w"
)
as
writer
:
tf
.
logging
.
info
(
"***** Eval results *****"
)
for
key
in
sorted
(
result
.
keys
()):
tf
.
logging
.
info
(
" %s = %s"
,
key
,
str
(
result
[
key
]))
writer
.
write
(
"%s = %s
\n
"
%
(
key
,
str
(
result
[
key
])))
if
__name__
==
"__main__"
:
main
()
return
None
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment