Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a76237da
Commit
a76237da
authored
Nov 12, 2019
by
Rajagopal Ananthanarayanan
Committed by
A. Unique TensorFlower
Nov 12, 2019
Browse files
Internal change
PiperOrigin-RevId: 280019807
parent
95dc9045
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
51 additions
and
23 deletions
+51
-23
official/benchmark/bert_benchmark.py
official/benchmark/bert_benchmark.py
+17
-0
official/nlp/bert/run_classifier.py
official/nlp/bert/run_classifier.py
+34
-23
No files found.
official/benchmark/bert_benchmark.py
View file @
a76237da
...
@@ -18,6 +18,7 @@ from __future__ import absolute_import
...
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
functools
import
json
import
json
import
math
import
math
import
os
import
os
...
@@ -31,6 +32,7 @@ import tensorflow as tf
...
@@ -31,6 +32,7 @@ import tensorflow as tf
from
official.benchmark
import
bert_benchmark_utils
as
benchmark_utils
from
official.benchmark
import
bert_benchmark_utils
as
benchmark_utils
from
official.nlp
import
bert_modeling
as
modeling
from
official.nlp
import
bert_modeling
as
modeling
from
official.nlp.bert
import
input_pipeline
from
official.nlp.bert
import
run_classifier
from
official.nlp.bert
import
run_classifier
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
distribution_utils
...
@@ -76,6 +78,19 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
...
@@ -76,6 +78,19 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
steps_per_loop
=
1
steps_per_loop
=
1
max_seq_length
=
input_meta_data
[
'max_seq_length'
]
train_input_fn
=
functools
.
partial
(
input_pipeline
.
create_classifier_dataset
,
FLAGS
.
train_data_path
,
seq_length
=
max_seq_length
,
batch_size
=
FLAGS
.
train_batch_size
)
eval_input_fn
=
functools
.
partial
(
input_pipeline
.
create_classifier_dataset
,
FLAGS
.
eval_data_path
,
seq_length
=
max_seq_length
,
batch_size
=
FLAGS
.
eval_batch_size
,
is_training
=
False
,
drop_remainder
=
False
)
run_classifier
.
run_bert_classifier
(
run_classifier
.
run_bert_classifier
(
strategy
,
strategy
,
bert_config
,
bert_config
,
...
@@ -88,6 +103,8 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
...
@@ -88,6 +103,8 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
warmup_steps
,
warmup_steps
,
FLAGS
.
learning_rate
,
FLAGS
.
learning_rate
,
FLAGS
.
init_checkpoint
,
FLAGS
.
init_checkpoint
,
train_input_fn
,
eval_input_fn
,
custom_callbacks
=
callbacks
)
custom_callbacks
=
callbacks
)
...
...
official/nlp/bert/run_classifier.py
View file @
a76237da
...
@@ -91,6 +91,8 @@ def run_bert_classifier(strategy,
...
@@ -91,6 +91,8 @@ def run_bert_classifier(strategy,
warmup_steps
,
warmup_steps
,
initial_lr
,
initial_lr
,
init_checkpoint
,
init_checkpoint
,
train_input_fn
,
eval_input_fn
,
custom_callbacks
=
None
,
custom_callbacks
=
None
,
run_eagerly
=
False
,
run_eagerly
=
False
,
use_keras_compile_fit
=
False
):
use_keras_compile_fit
=
False
):
...
@@ -98,19 +100,6 @@ def run_bert_classifier(strategy,
...
@@ -98,19 +100,6 @@ def run_bert_classifier(strategy,
max_seq_length
=
input_meta_data
[
'max_seq_length'
]
max_seq_length
=
input_meta_data
[
'max_seq_length'
]
num_classes
=
input_meta_data
[
'num_labels'
]
num_classes
=
input_meta_data
[
'num_labels'
]
train_input_fn
=
functools
.
partial
(
input_pipeline
.
create_classifier_dataset
,
FLAGS
.
train_data_path
,
seq_length
=
max_seq_length
,
batch_size
=
FLAGS
.
train_batch_size
)
eval_input_fn
=
functools
.
partial
(
input_pipeline
.
create_classifier_dataset
,
FLAGS
.
eval_data_path
,
seq_length
=
max_seq_length
,
batch_size
=
FLAGS
.
eval_batch_size
,
is_training
=
False
,
drop_remainder
=
False
)
def
_get_classifier_model
():
def
_get_classifier_model
():
"""Gets a classifier model."""
"""Gets a classifier model."""
classifier_model
,
core_model
=
(
classifier_model
,
core_model
=
(
...
@@ -153,7 +142,7 @@ def run_bert_classifier(strategy,
...
@@ -153,7 +142,7 @@ def run_bert_classifier(strategy,
if
use_keras_compile_fit
:
if
use_keras_compile_fit
:
# Start training using Keras compile/fit API.
# Start training using Keras compile/fit API.
logging
.
info
(
'Training using TF 2.0 Keras compile/fit API with '
logging
.
info
(
'Training using TF 2.0 Keras compile/fit API with '
'distr
u
but
ed
strategy.'
)
'distr
i
but
ion
strategy.'
)
return
run_keras_compile_fit
(
return
run_keras_compile_fit
(
model_dir
,
model_dir
,
strategy
,
strategy
,
...
@@ -170,7 +159,7 @@ def run_bert_classifier(strategy,
...
@@ -170,7 +159,7 @@ def run_bert_classifier(strategy,
# Use user-defined loop to start training.
# Use user-defined loop to start training.
logging
.
info
(
'Training using customized training loop TF 2.0 with '
logging
.
info
(
'Training using customized training loop TF 2.0 with '
'distr
u
but
ed
strategy.'
)
'distr
i
but
ion
strategy.'
)
return
model_training_utils
.
run_customized_training_loop
(
return
model_training_utils
.
run_customized_training_loop
(
strategy
=
strategy
,
strategy
=
strategy
,
model_fn
=
_get_classifier_model
,
model_fn
=
_get_classifier_model
,
...
@@ -237,7 +226,8 @@ def run_keras_compile_fit(model_dir,
...
@@ -237,7 +226,8 @@ def run_keras_compile_fit(model_dir,
def
export_classifier
(
model_export_path
,
input_meta_data
,
def
export_classifier
(
model_export_path
,
input_meta_data
,
restore_model_using_load_weights
):
restore_model_using_load_weights
,
bert_config
,
model_dir
):
"""Exports a trained model as a `SavedModel` for inference.
"""Exports a trained model as a `SavedModel` for inference.
Args:
Args:
...
@@ -249,15 +239,19 @@ def export_classifier(model_export_path, input_meta_data,
...
@@ -249,15 +239,19 @@ def export_classifier(model_export_path, input_meta_data,
tf.train.Checkpoint and another is using Keras model.save_weights().
tf.train.Checkpoint and another is using Keras model.save_weights().
Custom training loop implementation uses tf.train.Checkpoint API
Custom training loop implementation uses tf.train.Checkpoint API
and Keras ModelCheckpoint callback internally uses model.save_weights()
and Keras ModelCheckpoint callback internally uses model.save_weights()
API. Since these two API's cannot be used tog
h
ether, model loading logic
API. Since these two API's cannot be used together, model loading logic
must be take into account how model checkpoint was saved.
must be take into account how model checkpoint was saved.
bert_config: Bert configuration file to define core bert layers.
model_dir: The directory where the model weights and training/evaluation
summaries are stored.
Raises:
Raises:
Export path is not specified, got an empty string or None.
Export path is not specified, got an empty string or None.
"""
"""
if
not
model_export_path
:
if
not
model_export_path
:
raise
ValueError
(
'Export path is not specified: %s'
%
model_export_path
)
raise
ValueError
(
'Export path is not specified: %s'
%
model_export_path
)
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
if
not
model_dir
:
raise
ValueError
(
'Export path is not specified: %s'
%
model_dir
)
classifier_model
=
bert_models
.
classifier_model
(
classifier_model
=
bert_models
.
classifier_model
(
bert_config
,
tf
.
float32
,
input_meta_data
[
'num_labels'
],
bert_config
,
tf
.
float32
,
input_meta_data
[
'num_labels'
],
...
@@ -266,18 +260,20 @@ def export_classifier(model_export_path, input_meta_data,
...
@@ -266,18 +260,20 @@ def export_classifier(model_export_path, input_meta_data,
model_saving_utils
.
export_bert_model
(
model_saving_utils
.
export_bert_model
(
model_export_path
,
model_export_path
,
model
=
classifier_model
,
model
=
classifier_model
,
checkpoint_dir
=
FLAGS
.
model_dir
,
checkpoint_dir
=
model_dir
,
restore_model_using_load_weights
=
restore_model_using_load_weights
)
restore_model_using_load_weights
=
restore_model_using_load_weights
)
def
run_bert
(
strategy
,
input_meta_data
):
def
run_bert
(
strategy
,
input_meta_data
,
train_input_fn
,
eval_input_fn
):
"""Run BERT training."""
"""Run BERT training."""
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
if
FLAGS
.
mode
==
'export_only'
:
if
FLAGS
.
mode
==
'export_only'
:
# As Keras ModelCheckpoint callback used with Keras compile/fit() API
# As Keras ModelCheckpoint callback used with Keras compile/fit() API
# internally uses model.save_weights() to save checkpoints, we must
# internally uses model.save_weights() to save checkpoints, we must
# use model.load_weights() when Keras compile/fit() is used.
# use model.load_weights() when Keras compile/fit() is used.
export_classifier
(
FLAGS
.
model_export_path
,
input_meta_data
,
export_classifier
(
FLAGS
.
model_export_path
,
input_meta_data
,
FLAGS
.
use_keras_compile_fit
)
FLAGS
.
use_keras_compile_fit
,
bert_config
,
FLAGS
.
model_dir
)
return
return
if
FLAGS
.
mode
!=
'train_and_eval'
:
if
FLAGS
.
mode
!=
'train_and_eval'
:
...
@@ -285,7 +281,6 @@ def run_bert(strategy, input_meta_data):
...
@@ -285,7 +281,6 @@ def run_bert(strategy, input_meta_data):
# Enables XLA in Session Config. Should not be set for TPU.
# Enables XLA in Session Config. Should not be set for TPU.
keras_utils
.
set_config_v2
(
FLAGS
.
enable_xla
)
keras_utils
.
set_config_v2
(
FLAGS
.
enable_xla
)
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
epochs
=
FLAGS
.
num_train_epochs
epochs
=
FLAGS
.
num_train_epochs
train_data_size
=
input_meta_data
[
'train_data_size'
]
train_data_size
=
input_meta_data
[
'train_data_size'
]
steps_per_epoch
=
int
(
train_data_size
/
FLAGS
.
train_batch_size
)
steps_per_epoch
=
int
(
train_data_size
/
FLAGS
.
train_batch_size
)
...
@@ -308,6 +303,8 @@ def run_bert(strategy, input_meta_data):
...
@@ -308,6 +303,8 @@ def run_bert(strategy, input_meta_data):
warmup_steps
,
warmup_steps
,
FLAGS
.
learning_rate
,
FLAGS
.
learning_rate
,
FLAGS
.
init_checkpoint
,
FLAGS
.
init_checkpoint
,
train_input_fn
,
eval_input_fn
,
run_eagerly
=
FLAGS
.
run_eagerly
,
run_eagerly
=
FLAGS
.
run_eagerly
,
use_keras_compile_fit
=
FLAGS
.
use_keras_compile_fit
)
use_keras_compile_fit
=
FLAGS
.
use_keras_compile_fit
)
...
@@ -341,7 +338,21 @@ def main(_):
...
@@ -341,7 +338,21 @@ def main(_):
else
:
else
:
raise
ValueError
(
'The distribution strategy type is not supported: %s'
%
raise
ValueError
(
'The distribution strategy type is not supported: %s'
%
FLAGS
.
strategy_type
)
FLAGS
.
strategy_type
)
run_bert
(
strategy
,
input_meta_data
)
max_seq_length
=
input_meta_data
[
'max_seq_length'
]
train_input_fn
=
functools
.
partial
(
input_pipeline
.
create_classifier_dataset
,
FLAGS
.
train_data_path
,
seq_length
=
max_seq_length
,
batch_size
=
FLAGS
.
train_batch_size
)
eval_input_fn
=
functools
.
partial
(
input_pipeline
.
create_classifier_dataset
,
FLAGS
.
eval_data_path
,
seq_length
=
max_seq_length
,
batch_size
=
FLAGS
.
eval_batch_size
,
is_training
=
False
,
drop_remainder
=
False
)
run_bert
(
strategy
,
input_meta_data
,
train_input_fn
,
eval_input_fn
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment