Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9b219f04
"vscode:/vscode.git/clone" did not exist on "8a18e73e18cb6bb846b5d5a11e9a7ff91caedda8"
Commit
9b219f04
authored
Jun 04, 2020
by
Tianqi Liu
Committed by
A. Unique TensorFlower
Jun 04, 2020
Browse files
Internal cleanup.
PiperOrigin-RevId: 314822016
parent
22f76623
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
30 additions
and
15 deletions
+30
-15
official/nlp/bert/run_classifier.py
official/nlp/bert/run_classifier.py
+24
-10
official/nlp/data/classifier_data_lib.py
official/nlp/data/classifier_data_lib.py
+4
-4
official/nlp/data/create_finetuning_data.py
official/nlp/data/create_finetuning_data.py
+2
-1
No files found.
official/nlp/bert/run_classifier.py
View file @
9b219f04
...
@@ -54,6 +54,12 @@ flags.DEFINE_string(
...
@@ -54,6 +54,12 @@ flags.DEFINE_string(
'to be used for training and evaluation.'
)
'to be used for training and evaluation.'
)
flags
.
DEFINE_string
(
'predict_checkpoint_path'
,
None
,
flags
.
DEFINE_string
(
'predict_checkpoint_path'
,
None
,
'Path to the checkpoint for predictions.'
)
'Path to the checkpoint for predictions.'
)
flags
.
DEFINE_integer
(
'num_eval_per_epoch'
,
1
,
'Number of evaluations per epoch. The purpose of this flag is to provide '
'more granular evaluation scores and checkpoints. For example, if original '
'data has N samples and num_eval_per_epoch is n, then each epoch will be '
'evaluated every N/n samples.'
)
flags
.
DEFINE_integer
(
'train_batch_size'
,
32
,
'Batch size for training.'
)
flags
.
DEFINE_integer
(
'train_batch_size'
,
32
,
'Batch size for training.'
)
flags
.
DEFINE_integer
(
'eval_batch_size'
,
32
,
'Batch size for evaluation.'
)
flags
.
DEFINE_integer
(
'eval_batch_size'
,
32
,
'Batch size for evaluation.'
)
...
@@ -92,8 +98,11 @@ def get_regression_loss_fn():
...
@@ -92,8 +98,11 @@ def get_regression_loss_fn():
return
regression_loss_fn
return
regression_loss_fn
def
get_dataset_fn
(
input_file_pattern
,
max_seq_length
,
global_batch_size
,
def
get_dataset_fn
(
input_file_pattern
,
is_training
,
label_type
=
tf
.
int64
):
max_seq_length
,
global_batch_size
,
is_training
,
label_type
=
tf
.
int64
):
"""Gets a closure to create a dataset."""
"""Gets a closure to create a dataset."""
def
_dataset_fn
(
ctx
=
None
):
def
_dataset_fn
(
ctx
=
None
):
...
@@ -151,17 +160,21 @@ def run_bert_classifier(strategy,
...
@@ -151,17 +160,21 @@ def run_bert_classifier(strategy,
use_graph_rewrite
=
common_flags
.
use_graph_rewrite
())
use_graph_rewrite
=
common_flags
.
use_graph_rewrite
())
return
classifier_model
,
core_model
return
classifier_model
,
core_model
loss_fn
=
(
get_regression_loss_fn
()
if
is_regression
loss_fn
=
(
else
get_loss_fn
(
num_classes
))
get_regression_loss_fn
()
if
is_regression
else
get_loss_fn
(
num_classes
))
# Defines evaluation metrics function, which will create metrics in the
# Defines evaluation metrics function, which will create metrics in the
# correct device and strategy scope.
# correct device and strategy scope.
if
is_regression
:
if
is_regression
:
metric_fn
=
functools
.
partial
(
tf
.
keras
.
metrics
.
MeanSquaredError
,
metric_fn
=
functools
.
partial
(
'mean_squared_error'
,
dtype
=
tf
.
float32
)
tf
.
keras
.
metrics
.
MeanSquaredError
,
'mean_squared_error'
,
dtype
=
tf
.
float32
)
else
:
else
:
metric_fn
=
functools
.
partial
(
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
,
metric_fn
=
functools
.
partial
(
'accuracy'
,
dtype
=
tf
.
float32
)
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
,
'accuracy'
,
dtype
=
tf
.
float32
)
# Start training using Keras compile/fit API.
# Start training using Keras compile/fit API.
logging
.
info
(
'Training using TF 2.x Keras compile/fit API with '
logging
.
info
(
'Training using TF 2.x Keras compile/fit API with '
...
@@ -349,8 +362,9 @@ def run_bert(strategy,
...
@@ -349,8 +362,9 @@ def run_bert(strategy,
keras_utils
.
set_session_config
(
FLAGS
.
enable_xla
)
keras_utils
.
set_session_config
(
FLAGS
.
enable_xla
)
performance
.
set_mixed_precision_policy
(
common_flags
.
dtype
())
performance
.
set_mixed_precision_policy
(
common_flags
.
dtype
())
epochs
=
FLAGS
.
num_train_epochs
epochs
=
FLAGS
.
num_train_epochs
*
FLAGS
.
num_eval_per_epoch
train_data_size
=
input_meta_data
[
'train_data_size'
]
train_data_size
=
(
input_meta_data
[
'train_data_size'
]
//
FLAGS
.
num_eval_per_epoch
)
steps_per_epoch
=
int
(
train_data_size
/
FLAGS
.
train_batch_size
)
steps_per_epoch
=
int
(
train_data_size
/
FLAGS
.
train_batch_size
)
warmup_steps
=
int
(
epochs
*
train_data_size
*
0.1
/
FLAGS
.
train_batch_size
)
warmup_steps
=
int
(
epochs
*
train_data_size
*
0.1
/
FLAGS
.
train_batch_size
)
eval_steps
=
int
(
eval_steps
=
int
(
...
...
official/nlp/data/classifier_data_lib.py
View file @
9b219f04
...
@@ -127,15 +127,14 @@ class XnliProcessor(DataProcessor):
...
@@ -127,15 +127,14 @@ class XnliProcessor(DataProcessor):
"""See base class."""
"""See base class."""
lines
=
[]
lines
=
[]
for
language
in
self
.
languages
:
for
language
in
self
.
languages
:
# Skips the header.
lines
.
extend
(
lines
.
extend
(
self
.
_read_tsv
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"multinli"
,
os
.
path
.
join
(
data_dir
,
"multinli"
,
"multinli.train.%s.tsv"
%
language
)))
"multinli.train.%s.tsv"
%
language
))
[
1
:]
)
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
(
i
,
line
)
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"train-%d"
%
i
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
...
@@ -825,7 +824,8 @@ def generate_tf_record_from_data_file(processor,
...
@@ -825,7 +824,8 @@ def generate_tf_record_from_data_file(processor,
eval_data_output_path: Output to which processed tf record for evaluation
eval_data_output_path: Output to which processed tf record for evaluation
will be saved.
will be saved.
test_data_output_path: Output to which processed tf record for testing
test_data_output_path: Output to which processed tf record for testing
will be saved. Must be a pattern template with {} if processor is XNLI.
will be saved. Must be a pattern template with {} if processor has
language specific test data.
max_seq_length: Maximum sequence length of the to be generated
max_seq_length: Maximum sequence length of the to be generated
training/eval data.
training/eval data.
...
...
official/nlp/data/create_finetuning_data.py
View file @
9b219f04
...
@@ -99,7 +99,8 @@ flags.DEFINE_string(
...
@@ -99,7 +99,8 @@ flags.DEFINE_string(
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
"test_data_output_path"
,
None
,
"test_data_output_path"
,
None
,
"The path in which generated test input data will be written as tf"
"The path in which generated test input data will be written as tf"
" records. If None, do not generate test data."
)
" records. If None, do not generate test data. Must be a pattern template"
" as test_{}.tfrecords if processor has language specific test data."
)
flags
.
DEFINE_string
(
"meta_data_file_path"
,
None
,
flags
.
DEFINE_string
(
"meta_data_file_path"
,
None
,
"The path in which input meta data will be written."
)
"The path in which input meta data will be written."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment