Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
f409e4d0
Commit
f409e4d0
authored
Nov 13, 2020
by
Reed Wanderman-Milne
Committed by
A. Unique TensorFlower
Nov 13, 2020
Browse files
Internal change
PiperOrigin-RevId: 342390260
parent
ec4b78f3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
18 additions
and
14 deletions
+18
-14
official/nlp/bert/model_training_utils_test.py
official/nlp/bert/model_training_utils_test.py
+3
-5
official/nlp/bert/run_classifier.py
official/nlp/bert/run_classifier.py
+5
-3
official/nlp/bert/run_pretraining.py
official/nlp/bert/run_pretraining.py
+4
-2
official/nlp/bert/run_squad_helper.py
official/nlp/bert/run_squad_helper.py
+6
-4
No files found.
official/nlp/bert/model_training_utils_test.py
View file @
f409e4d0
...
@@ -107,9 +107,8 @@ def create_model_fn(input_shape, num_classes, use_float16=False):
...
@@ -107,9 +107,8 @@ def create_model_fn(input_shape, num_classes, use_float16=False):
tf
.
reduce_mean
(
input_layer
),
name
=
'mean_input'
,
aggregation
=
'mean'
)
tf
.
reduce_mean
(
input_layer
),
name
=
'mean_input'
,
aggregation
=
'mean'
)
model
.
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
learning_rate
=
0.1
,
momentum
=
0.9
)
model
.
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
learning_rate
=
0.1
,
momentum
=
0.9
)
if
use_float16
:
if
use_float16
:
model
.
optimizer
=
(
model
.
optimizer
=
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
(
tf
.
keras
.
mixed_precision
.
experimental
.
LossScaleOptimizer
(
model
.
optimizer
)
model
.
optimizer
,
loss_scale
=
'dynamic'
))
return
model
,
sub_model
return
model
,
sub_model
return
_model_fn
return
_model_fn
...
@@ -198,8 +197,7 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -198,8 +197,7 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
@
combinations
.
generate
(
eager_gpu_strategy_combinations
())
@
combinations
.
generate
(
eager_gpu_strategy_combinations
())
def
test_train_eager_mixed_precision
(
self
,
distribution
):
def
test_train_eager_mixed_precision
(
self
,
distribution
):
model_dir
=
self
.
create_tempdir
().
full_path
model_dir
=
self
.
create_tempdir
().
full_path
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
Policy
(
'mixed_float16'
)
tf
.
keras
.
mixed_precision
.
set_global_policy
(
'mixed_float16'
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
self
.
_model_fn
=
create_model_fn
(
self
.
_model_fn
=
create_model_fn
(
input_shape
=
[
128
],
num_classes
=
3
,
use_float16
=
True
)
input_shape
=
[
128
],
num_classes
=
3
,
use_float16
=
True
)
self
.
run_training
(
self
.
run_training
(
...
...
official/nlp/bert/run_classifier.py
View file @
f409e4d0
...
@@ -151,7 +151,8 @@ def run_bert_classifier(strategy,
...
@@ -151,7 +151,8 @@ def run_bert_classifier(strategy,
classifier_model
.
optimizer
=
performance
.
configure_optimizer
(
classifier_model
.
optimizer
=
performance
.
configure_optimizer
(
optimizer
,
optimizer
,
use_float16
=
common_flags
.
use_float16
(),
use_float16
=
common_flags
.
use_float16
(),
use_graph_rewrite
=
common_flags
.
use_graph_rewrite
())
use_graph_rewrite
=
common_flags
.
use_graph_rewrite
(),
use_experimental_api
=
False
)
return
classifier_model
,
core_model
return
classifier_model
,
core_model
# tf.keras.losses objects accept optional sample_weight arguments (eg. coming
# tf.keras.losses objects accept optional sample_weight arguments (eg. coming
...
@@ -348,7 +349,7 @@ def export_classifier(model_export_path, input_meta_data, bert_config,
...
@@ -348,7 +349,7 @@ def export_classifier(model_export_path, input_meta_data, bert_config,
raise
ValueError
(
'Export path is not specified: %s'
%
model_dir
)
raise
ValueError
(
'Export path is not specified: %s'
%
model_dir
)
# Export uses float32 for now, even if training uses mixed precision.
# Export uses float32 for now, even if training uses mixed precision.
tf
.
keras
.
mixed_precision
.
experimental
.
set
_policy
(
'float32'
)
tf
.
keras
.
mixed_precision
.
set_global
_policy
(
'float32'
)
classifier_model
=
bert_models
.
classifier_model
(
classifier_model
=
bert_models
.
classifier_model
(
bert_config
,
bert_config
,
input_meta_data
.
get
(
'num_labels'
,
1
),
input_meta_data
.
get
(
'num_labels'
,
1
),
...
@@ -370,7 +371,8 @@ def run_bert(strategy,
...
@@ -370,7 +371,8 @@ def run_bert(strategy,
"""Run BERT training."""
"""Run BERT training."""
# Enables XLA in Session Config. Should not be set for TPU.
# Enables XLA in Session Config. Should not be set for TPU.
keras_utils
.
set_session_config
(
FLAGS
.
enable_xla
)
keras_utils
.
set_session_config
(
FLAGS
.
enable_xla
)
performance
.
set_mixed_precision_policy
(
common_flags
.
dtype
())
performance
.
set_mixed_precision_policy
(
common_flags
.
dtype
(),
use_experimental_api
=
False
)
epochs
=
FLAGS
.
num_train_epochs
*
FLAGS
.
num_eval_per_epoch
epochs
=
FLAGS
.
num_train_epochs
*
FLAGS
.
num_eval_per_epoch
train_data_size
=
(
train_data_size
=
(
...
...
official/nlp/bert/run_pretraining.py
View file @
f409e4d0
...
@@ -126,7 +126,8 @@ def run_customized_training(strategy,
...
@@ -126,7 +126,8 @@ def run_customized_training(strategy,
pretrain_model
.
optimizer
=
performance
.
configure_optimizer
(
pretrain_model
.
optimizer
=
performance
.
configure_optimizer
(
optimizer
,
optimizer
,
use_float16
=
common_flags
.
use_float16
(),
use_float16
=
common_flags
.
use_float16
(),
use_graph_rewrite
=
common_flags
.
use_graph_rewrite
())
use_graph_rewrite
=
common_flags
.
use_graph_rewrite
(),
use_experimental_api
=
False
)
return
pretrain_model
,
core_model
return
pretrain_model
,
core_model
trained_model
=
model_training_utils
.
run_customized_training_loop
(
trained_model
=
model_training_utils
.
run_customized_training_loop
(
...
@@ -162,7 +163,8 @@ def run_bert_pretrain(strategy, custom_callbacks=None):
...
@@ -162,7 +163,8 @@ def run_bert_pretrain(strategy, custom_callbacks=None):
logging
.
info
(
'Training using customized training loop TF 2.0 with distributed'
logging
.
info
(
'Training using customized training loop TF 2.0 with distributed'
'strategy.'
)
'strategy.'
)
performance
.
set_mixed_precision_policy
(
common_flags
.
dtype
())
performance
.
set_mixed_precision_policy
(
common_flags
.
dtype
(),
use_experimental_api
=
False
)
# Only when explicit_allreduce = True, post_allreduce_callbacks and
# Only when explicit_allreduce = True, post_allreduce_callbacks and
# allreduce_bytes_per_pack will take effect. optimizer.apply_gradients() no
# allreduce_bytes_per_pack will take effect. optimizer.apply_gradients() no
...
...
official/nlp/bert/run_squad_helper.py
View file @
f409e4d0
...
@@ -160,7 +160,7 @@ def get_squad_model_to_predict(strategy, bert_config, checkpoint_path,
...
@@ -160,7 +160,7 @@ def get_squad_model_to_predict(strategy, bert_config, checkpoint_path,
"""Gets a squad model to make predictions."""
"""Gets a squad model to make predictions."""
with
strategy
.
scope
():
with
strategy
.
scope
():
# Prediction always uses float32, even if training uses mixed precision.
# Prediction always uses float32, even if training uses mixed precision.
tf
.
keras
.
mixed_precision
.
experimental
.
set
_policy
(
'float32'
)
tf
.
keras
.
mixed_precision
.
set_global
_policy
(
'float32'
)
squad_model
,
_
=
bert_models
.
squad_model
(
squad_model
,
_
=
bert_models
.
squad_model
(
bert_config
,
bert_config
,
input_meta_data
[
'max_seq_length'
],
input_meta_data
[
'max_seq_length'
],
...
@@ -225,7 +225,8 @@ def train_squad(strategy,
...
@@ -225,7 +225,8 @@ def train_squad(strategy,
' strategy.'
)
' strategy.'
)
# Enables XLA in Session Config. Should not be set for TPU.
# Enables XLA in Session Config. Should not be set for TPU.
keras_utils
.
set_session_config
(
FLAGS
.
enable_xla
)
keras_utils
.
set_session_config
(
FLAGS
.
enable_xla
)
performance
.
set_mixed_precision_policy
(
common_flags
.
dtype
())
performance
.
set_mixed_precision_policy
(
common_flags
.
dtype
(),
use_experimental_api
=
False
)
epochs
=
FLAGS
.
num_train_epochs
epochs
=
FLAGS
.
num_train_epochs
num_train_examples
=
input_meta_data
[
'train_data_size'
]
num_train_examples
=
input_meta_data
[
'train_data_size'
]
...
@@ -253,7 +254,8 @@ def train_squad(strategy,
...
@@ -253,7 +254,8 @@ def train_squad(strategy,
squad_model
.
optimizer
=
performance
.
configure_optimizer
(
squad_model
.
optimizer
=
performance
.
configure_optimizer
(
optimizer
,
optimizer
,
use_float16
=
common_flags
.
use_float16
(),
use_float16
=
common_flags
.
use_float16
(),
use_graph_rewrite
=
common_flags
.
use_graph_rewrite
())
use_graph_rewrite
=
common_flags
.
use_graph_rewrite
(),
use_experimental_api
=
False
)
return
squad_model
,
core_model
return
squad_model
,
core_model
# Only when explicit_allreduce = True, post_allreduce_callbacks and
# Only when explicit_allreduce = True, post_allreduce_callbacks and
...
@@ -465,7 +467,7 @@ def export_squad(model_export_path, input_meta_data, bert_config):
...
@@ -465,7 +467,7 @@ def export_squad(model_export_path, input_meta_data, bert_config):
if
not
model_export_path
:
if
not
model_export_path
:
raise
ValueError
(
'Export path is not specified: %s'
%
model_export_path
)
raise
ValueError
(
'Export path is not specified: %s'
%
model_export_path
)
# Export uses float32 for now, even if training uses mixed precision.
# Export uses float32 for now, even if training uses mixed precision.
tf
.
keras
.
mixed_precision
.
experimental
.
set
_policy
(
'float32'
)
tf
.
keras
.
mixed_precision
.
set_global
_policy
(
'float32'
)
squad_model
,
_
=
bert_models
.
squad_model
(
bert_config
,
squad_model
,
_
=
bert_models
.
squad_model
(
bert_config
,
input_meta_data
[
'max_seq_length'
])
input_meta_data
[
'max_seq_length'
])
model_saving_utils
.
export_bert_model
(
model_saving_utils
.
export_bert_model
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment