Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
fb35d6be
Commit
fb35d6be
authored
Feb 24, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Feb 24, 2020
Browse files
Creates modeling/performance.py to include mix prediction related stuff
PiperOrigin-RevId: 297002741
parent
02af9bb5
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
112 additions
and
89 deletions
+112
-89
official/modeling/performance.py
official/modeling/performance.py
+56
-0
official/nlp/bert/bert_models.py
official/nlp/bert/bert_models.py
+2
-0
official/nlp/bert/common_flags.py
official/nlp/bert/common_flags.py
+8
-0
official/nlp/bert/run_classifier.py
official/nlp/bert/run_classifier.py
+7
-9
official/nlp/bert/run_pretraining.py
official/nlp/bert/run_pretraining.py
+8
-9
official/nlp/transformer/transformer_main.py
official/nlp/transformer/transformer_main.py
+12
-33
official/utils/flags/_performance.py
official/utils/flags/_performance.py
+3
-2
official/vision/image_classification/resnet_ctl_imagenet_main.py
...l/vision/image_classification/resnet_ctl_imagenet_main.py
+2
-10
official/vision/image_classification/resnet_imagenet_main.py
official/vision/image_classification/resnet_imagenet_main.py
+4
-11
official/vision/image_classification/resnet_runnable.py
official/vision/image_classification/resnet_runnable.py
+10
-15
No files found.
official/modeling/performance.py
0 → 100644
View file @
fb35d6be
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions and classes related to training performance."""
import
tensorflow
as
tf
def
configure_optimizer
(
optimizer
,
use_float16
=
False
,
use_graph_rewrite
=
False
,
loss_scale
=
"dynamic"
):
"""Configures optimizer object with performance options."""
if
use_float16
:
# Wraps optimizer with a LossScaleOptimizer. This is done automatically
# in compile() with the "mixed_float16" policy, but since we do not call
# compile(), we must wrap the optimizer manually.
optimizer
=
(
tf
.
keras
.
mixed_precision
.
experimental
.
LossScaleOptimizer
(
optimizer
,
loss_scale
=
loss_scale
))
if
use_graph_rewrite
:
# Note: the model dtype must be 'float32', which will ensure
# tf.ckeras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
optimizer
=
tf
.
train
.
experimental
.
enable_mixed_precision_graph_rewrite
(
optimizer
)
return
optimizer
def
set_mixed_precision_policy
(
dtype
,
loss_scale
=
None
):
"""Sets mix precision policy."""
if
dtype
==
tf
.
float16
:
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
Policy
(
'mixed_float16'
,
loss_scale
=
loss_scale
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
elif
dtype
==
tf
.
bfloat16
:
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
Policy
(
'mixed_bfloat16'
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
elif
dtype
==
tf
.
float32
:
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'float32'
)
else
:
raise
ValueError
(
"Unexpected dtype: %s"
%
dtype
)
official/nlp/bert/bert_models.py
View file @
fb35d6be
...
@@ -69,6 +69,8 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
...
@@ -69,6 +69,8 @@ class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
sentence_labels
):
sentence_labels
):
"""Implements call() for the layer."""
"""Implements call() for the layer."""
lm_label_weights
=
tf
.
cast
(
lm_label_weights
,
tf
.
float32
)
lm_label_weights
=
tf
.
cast
(
lm_label_weights
,
tf
.
float32
)
lm_output
=
tf
.
cast
(
lm_output
,
tf
.
float32
)
sentence_output
=
tf
.
cast
(
sentence_output
,
tf
.
float32
)
mask_label_loss
=
losses
.
weighted_sparse_categorical_crossentropy_loss
(
mask_label_loss
=
losses
.
weighted_sparse_categorical_crossentropy_loss
(
labels
=
lm_label_ids
,
predictions
=
lm_output
,
weights
=
lm_label_weights
)
labels
=
lm_label_ids
,
predictions
=
lm_output
,
weights
=
lm_label_weights
)
...
...
official/nlp/bert/common_flags.py
View file @
fb35d6be
...
@@ -88,9 +88,17 @@ def define_common_bert_flags():
...
@@ -88,9 +88,17 @@ def define_common_bert_flags():
)
)
def
dtype
():
return
flags_core
.
get_tf_dtype
(
flags
.
FLAGS
)
def
use_float16
():
def
use_float16
():
return
flags_core
.
get_tf_dtype
(
flags
.
FLAGS
)
==
tf
.
float16
return
flags_core
.
get_tf_dtype
(
flags
.
FLAGS
)
==
tf
.
float16
def
use_graph_rewrite
():
return
flags
.
FLAGS
.
fp16_implementation
==
'graph_rewrite'
def
get_loss_scale
():
def
get_loss_scale
():
return
flags_core
.
get_loss_scale
(
flags
.
FLAGS
,
default_for_fp16
=
'dynamic'
)
return
flags_core
.
get_loss_scale
(
flags
.
FLAGS
,
default_for_fp16
=
'dynamic'
)
official/nlp/bert/run_classifier.py
View file @
fb35d6be
...
@@ -27,6 +27,7 @@ from absl import logging
...
@@ -27,6 +27,7 @@ from absl import logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
model_training_utils
from
official.modeling
import
model_training_utils
from
official.modeling
import
performance
from
official.nlp
import
optimization
from
official.nlp
import
optimization
from
official.nlp.bert
import
bert_models
from
official.nlp.bert
import
bert_models
from
official.nlp.bert
import
common_flags
from
official.nlp.bert
import
common_flags
...
@@ -126,16 +127,12 @@ def run_bert_classifier(strategy,
...
@@ -126,16 +127,12 @@ def run_bert_classifier(strategy,
max_seq_length
,
max_seq_length
,
hub_module_url
=
FLAGS
.
hub_module_url
,
hub_module_url
=
FLAGS
.
hub_module_url
,
hub_module_trainable
=
FLAGS
.
hub_module_trainable
))
hub_module_trainable
=
FLAGS
.
hub_module_trainable
))
classifier_model
.
optimizer
=
optimization
.
create_optimizer
(
optimizer
=
optimization
.
create_optimizer
(
initial_lr
,
steps_per_epoch
*
epochs
,
warmup_steps
)
initial_lr
,
steps_per_epoch
*
epochs
,
warmup_steps
)
if
FLAGS
.
fp16_implementation
==
'graph_rewrite'
:
classifier_model
.
optimizer
=
performance
.
configure_optimizer
(
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
optimizer
,
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
use_float16
=
common_flags
.
use_float16
(),
# which will ensure tf.compat.v2.keras.mixed_precision and
use_graph_rewrite
=
common_flags
.
use_graph_rewrite
())
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
classifier_model
.
optimizer
=
tf
.
train
.
experimental
.
enable_mixed_precision_graph_rewrite
(
classifier_model
.
optimizer
)
return
classifier_model
,
core_model
return
classifier_model
,
core_model
# During distributed training, loss used for gradient computation is
# During distributed training, loss used for gradient computation is
...
@@ -302,6 +299,7 @@ def run_bert(strategy,
...
@@ -302,6 +299,7 @@ def run_bert(strategy,
raise
ValueError
(
'Unsupported mode is specified: %s'
%
FLAGS
.
mode
)
raise
ValueError
(
'Unsupported mode is specified: %s'
%
FLAGS
.
mode
)
# Enables XLA in Session Config. Should not be set for TPU.
# Enables XLA in Session Config. Should not be set for TPU.
keras_utils
.
set_config_v2
(
FLAGS
.
enable_xla
)
keras_utils
.
set_config_v2
(
FLAGS
.
enable_xla
)
performance
.
set_mixed_precision_policy
(
common_flags
.
dtype
())
epochs
=
FLAGS
.
num_train_epochs
epochs
=
FLAGS
.
num_train_epochs
train_data_size
=
input_meta_data
[
'train_data_size'
]
train_data_size
=
input_meta_data
[
'train_data_size'
]
...
...
official/nlp/bert/run_pretraining.py
View file @
fb35d6be
...
@@ -23,6 +23,7 @@ from absl import logging
...
@@ -23,6 +23,7 @@ from absl import logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
model_training_utils
from
official.modeling
import
model_training_utils
from
official.modeling
import
performance
from
official.nlp
import
optimization
from
official.nlp
import
optimization
from
official.nlp.bert
import
bert_models
from
official.nlp.bert
import
bert_models
from
official.nlp.bert
import
common_flags
from
official.nlp.bert
import
common_flags
...
@@ -102,16 +103,12 @@ def run_customized_training(strategy,
...
@@ -102,16 +103,12 @@ def run_customized_training(strategy,
"""Gets a pretraining model."""
"""Gets a pretraining model."""
pretrain_model
,
core_model
=
bert_models
.
pretrain_model
(
pretrain_model
,
core_model
=
bert_models
.
pretrain_model
(
bert_config
,
max_seq_length
,
max_predictions_per_seq
)
bert_config
,
max_seq_length
,
max_predictions_per_seq
)
pretrain_model
.
optimizer
=
optimization
.
create_optimizer
(
optimizer
=
optimization
.
create_optimizer
(
initial_lr
,
steps_per_epoch
*
epochs
,
warmup_steps
)
initial_lr
,
steps_per_epoch
*
epochs
,
warmup_steps
)
if
FLAGS
.
fp16_implementation
==
'graph_rewrite'
:
pretrain_model
.
optimizer
=
performance
.
configure_optimizer
(
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
optimizer
,
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
use_float16
=
common_flags
.
use_float16
(),
# which will ensure tf.compat.v2.keras.mixed_precision and
use_graph_rewrite
=
common_flags
.
use_graph_rewrite
())
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
pretrain_model
.
optimizer
=
tf
.
train
.
experimental
.
enable_mixed_precision_graph_rewrite
(
pretrain_model
.
optimizer
)
return
pretrain_model
,
core_model
return
pretrain_model
,
core_model
trained_model
=
model_training_utils
.
run_customized_training_loop
(
trained_model
=
model_training_utils
.
run_customized_training_loop
(
...
@@ -141,6 +138,8 @@ def run_bert_pretrain(strategy):
...
@@ -141,6 +138,8 @@ def run_bert_pretrain(strategy):
logging
.
info
(
'Training using customized training loop TF 2.0 with distrubuted'
logging
.
info
(
'Training using customized training loop TF 2.0 with distrubuted'
'strategy.'
)
'strategy.'
)
performance
.
set_mixed_precision_policy
(
common_flags
.
dtype
())
return
run_customized_training
(
return
run_customized_training
(
strategy
,
strategy
,
bert_config
,
bert_config
,
...
...
official/nlp/transformer/transformer_main.py
View file @
fb35d6be
...
@@ -17,7 +17,6 @@
...
@@ -17,7 +17,6 @@
See README for description of setting the training schedule and evaluating the
See README for description of setting the training schedule and evaluating the
BLEU score.
BLEU score.
"""
"""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
...
@@ -30,19 +29,19 @@ from absl import flags
...
@@ -30,19 +29,19 @@ from absl import flags
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.modeling
import
performance
from
official.nlp.transformer
import
compute_bleu
from
official.nlp.transformer
import
compute_bleu
from
official.nlp.transformer.utils
import
tokenizer
from
official.nlp.transformer
import
data_pipeline
from
official.nlp.transformer
import
data_pipeline
from
official.nlp.transformer
import
metrics
from
official.nlp.transformer
import
metrics
from
official.nlp.transformer
import
misc
from
official.nlp.transformer
import
misc
from
official.nlp.transformer
import
optimizer
from
official.nlp.transformer
import
optimizer
from
official.nlp.transformer
import
transformer
from
official.nlp.transformer
import
transformer
from
official.nlp.transformer
import
translate
from
official.nlp.transformer
import
translate
from
official.nlp.transformer.utils
import
tokenizer
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
from
official.utils.logs
import
logger
from
official.utils.logs
import
logger
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
INF
=
int
(
1e9
)
INF
=
int
(
1e9
)
BLEU_DIR
=
"bleu"
BLEU_DIR
=
"bleu"
...
@@ -180,21 +179,9 @@ class TransformerTask(object):
...
@@ -180,21 +179,9 @@ class TransformerTask(object):
else
:
else
:
logging
.
info
(
"Not using any distribution strategy."
)
logging
.
info
(
"Not using any distribution strategy."
)
if
params
[
"dtype"
]
==
tf
.
float16
:
performance
.
set_mixed_precision_policy
(
# TODO(reedwm): It's pretty ugly to set the global policy in a constructor
params
[
"dtype"
],
# like this. What if multiple instances of TransformerTask are created?
flags_core
.
get_loss_scale
(
flags_obj
,
default_for_fp16
=
"dynamic"
))
# We should have a better way in the tf.keras.mixed_precision API of doing
# this.
loss_scale
=
flags_core
.
get_loss_scale
(
flags_obj
,
default_for_fp16
=
"dynamic"
)
policy
=
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
Policy
(
"mixed_float16"
,
loss_scale
=
loss_scale
)
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
elif
params
[
"dtype"
]
==
tf
.
bfloat16
:
policy
=
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
Policy
(
"mixed_bfloat16"
)
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
@
property
@
property
def
use_tpu
(
self
):
def
use_tpu
(
self
):
...
@@ -434,8 +421,6 @@ class TransformerTask(object):
...
@@ -434,8 +421,6 @@ class TransformerTask(object):
def
_create_optimizer
(
self
):
def
_create_optimizer
(
self
):
"""Creates optimizer."""
"""Creates optimizer."""
params
=
self
.
params
params
=
self
.
params
# TODO(b/139414679): Explore the difference between using
# LearningRateSchedule and callback for GPU runs, and try to merge them.
lr_schedule
=
optimizer
.
LearningRateSchedule
(
lr_schedule
=
optimizer
.
LearningRateSchedule
(
params
[
"learning_rate"
],
params
[
"hidden_size"
],
params
[
"learning_rate"
],
params
[
"hidden_size"
],
params
[
"learning_rate_warmup_steps"
])
params
[
"learning_rate_warmup_steps"
])
...
@@ -445,18 +430,12 @@ class TransformerTask(object):
...
@@ -445,18 +430,12 @@ class TransformerTask(object):
params
[
"optimizer_adam_beta2"
],
params
[
"optimizer_adam_beta2"
],
epsilon
=
params
[
"optimizer_adam_epsilon"
])
epsilon
=
params
[
"optimizer_adam_epsilon"
])
if
params
[
"dtype"
]
==
tf
.
float16
:
opt
=
performance
.
configure_optimizer
(
opt
=
tf
.
keras
.
mixed_precision
.
experimental
.
LossScaleOptimizer
(
opt
,
opt
,
use_float16
=
params
[
"dtype"
]
==
tf
.
float16
,
use_graph_rewrite
=
self
.
flags_obj
.
fp16_implementation
==
"graph_rewrite"
,
loss_scale
=
flags_core
.
get_loss_scale
(
loss_scale
=
flags_core
.
get_loss_scale
(
self
.
flags_obj
,
default_for_fp16
=
"dynamic"
))
self
.
flags_obj
,
default_for_fp16
=
"dynamic"
))
if
self
.
flags_obj
.
fp16_implementation
==
"graph_rewrite"
:
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# which will ensure tf.compat.v2.keras.mixed_precision and
# tf.train.experimental.enable_mixed_precision_graph_rewrite do not double
# up.
opt
=
tf
.
train
.
experimental
.
enable_mixed_precision_graph_rewrite
(
opt
)
return
opt
return
opt
...
...
official/utils/flags/_performance.py
View file @
fb35d6be
...
@@ -43,14 +43,15 @@ def get_tf_dtype(flags_obj):
...
@@ -43,14 +43,15 @@ def get_tf_dtype(flags_obj):
def
get_loss_scale
(
flags_obj
,
default_for_fp16
):
def
get_loss_scale
(
flags_obj
,
default_for_fp16
):
dtype
=
get_tf_dtype
(
flags_obj
)
if
flags_obj
.
loss_scale
==
"dynamic"
:
if
flags_obj
.
loss_scale
==
"dynamic"
:
return
flags_obj
.
loss_scale
return
flags_obj
.
loss_scale
elif
flags_obj
.
loss_scale
is
not
None
:
elif
flags_obj
.
loss_scale
is
not
None
:
return
float
(
flags_obj
.
loss_scale
)
return
float
(
flags_obj
.
loss_scale
)
elif
flags_obj
.
dtype
==
"fp32"
:
elif
dtype
==
tf
.
float32
or
dtype
==
tf
.
bfloat16
:
return
1
# No loss scaling is needed for fp32
return
1
# No loss scaling is needed for fp32
else
:
else
:
assert
flags_obj
.
dtype
==
"fp
16
"
assert
dtype
==
tf
.
float
16
return
default_for_fp16
return
default_for_fp16
...
...
official/vision/image_classification/resnet_ctl_imagenet_main.py
View file @
fb35d6be
...
@@ -23,6 +23,7 @@ from absl import flags
...
@@ -23,6 +23,7 @@ from absl import flags
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
performance
from
official.staging.training
import
controller
from
official.staging.training
import
controller
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
from
official.utils.logs
import
logger
from
official.utils.logs
import
logger
...
@@ -110,16 +111,7 @@ def run(flags_obj):
...
@@ -110,16 +111,7 @@ def run(flags_obj):
keras_utils
.
set_session_config
(
keras_utils
.
set_session_config
(
enable_eager
=
flags_obj
.
enable_eager
,
enable_eager
=
flags_obj
.
enable_eager
,
enable_xla
=
flags_obj
.
enable_xla
)
enable_xla
=
flags_obj
.
enable_xla
)
performance
.
set_mixed_precision_policy
(
flags_core
.
get_tf_dtype
(
flags_obj
))
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
if
dtype
==
tf
.
float16
:
policy
=
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
Policy
(
'mixed_float16'
)
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
elif
dtype
==
tf
.
bfloat16
:
policy
=
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
Policy
(
'mixed_bfloat16'
)
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
# This only affects GPU.
# This only affects GPU.
common
.
set_cudnn_batchnorm_mode
()
common
.
set_cudnn_batchnorm_mode
()
...
...
official/vision/image_classification/resnet_imagenet_main.py
View file @
fb35d6be
...
@@ -28,6 +28,7 @@ import tensorflow as tf
...
@@ -28,6 +28,7 @@ import tensorflow as tf
import
tensorflow_model_optimization
as
tfmot
import
tensorflow_model_optimization
as
tfmot
from
official.benchmark.models
import
trivial_model
from
official.benchmark.models
import
trivial_model
from
official.modeling
import
performance
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
from
official.utils.logs
import
logger
from
official.utils.logs
import
logger
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
distribution_utils
...
@@ -65,17 +66,9 @@ def run(flags_obj):
...
@@ -65,17 +66,9 @@ def run(flags_obj):
common
.
set_cudnn_batchnorm_mode
()
common
.
set_cudnn_batchnorm_mode
()
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
if
dtype
==
tf
.
float16
:
performance
.
set_mixed_precision_policy
(
loss_scale
=
flags_core
.
get_loss_scale
(
flags_obj
,
default_for_fp16
=
128
)
flags_core
.
get_tf_dtype
(
flags_obj
),
policy
=
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
Policy
(
flags_core
.
get_loss_scale
(
flags_obj
,
default_for_fp16
=
128
))
'mixed_float16'
,
loss_scale
=
loss_scale
)
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
if
not
keras_utils
.
is_v2_0
():
raise
ValueError
(
'--dtype=fp16 is not supported in TensorFlow 1.'
)
elif
dtype
==
tf
.
bfloat16
:
policy
=
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
Policy
(
'mixed_bfloat16'
)
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
data_format
=
flags_obj
.
data_format
data_format
=
flags_obj
.
data_format
if
data_format
is
None
:
if
data_format
is
None
:
...
...
official/vision/image_classification/resnet_runnable.py
View file @
fb35d6be
...
@@ -20,6 +20,7 @@ from __future__ import print_function
...
@@ -20,6 +20,7 @@ from __future__ import print_function
import
tensorflow.compat.v2
as
tf
import
tensorflow.compat.v2
as
tf
from
official.modeling
import
performance
from
official.staging.training
import
standard_runnable
from
official.staging.training
import
standard_runnable
from
official.staging.training
import
utils
from
official.staging.training
import
utils
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
...
@@ -85,21 +86,15 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
...
@@ -85,21 +86,15 @@ class ResnetRunnable(standard_runnable.StandardTrainable,
# Make sure iterations variable is created inside scope.
# Make sure iterations variable is created inside scope.
self
.
global_step
=
self
.
optimizer
.
iterations
self
.
global_step
=
self
.
optimizer
.
iterations
if
self
.
dtype
==
tf
.
float16
:
use_graph_rewrite
=
flags_obj
.
fp16_implementation
==
'graph_rewrite'
loss_scale
=
flags_core
.
get_loss_scale
(
flags_obj
,
default_for_fp16
=
128
)
if
use_graph_rewrite
and
not
flags_obj
.
use_tf_function
:
self
.
optimizer
=
(
tf
.
keras
.
mixed_precision
.
experimental
.
LossScaleOptimizer
(
self
.
optimizer
,
loss_scale
))
elif
flags_obj
.
fp16_implementation
==
'graph_rewrite'
:
# `dtype` is still float32 in this case. We built the graph in float32
# and let the graph rewrite change parts of it float16.
if
not
flags_obj
.
use_tf_function
:
raise
ValueError
(
'--fp16_implementation=graph_rewrite requires '
raise
ValueError
(
'--fp16_implementation=graph_rewrite requires '
'--use_tf_function to be true'
)
'--use_tf_function to be true'
)
loss_scale
=
flags_core
.
get_loss_scale
(
flags_obj
,
default_for_fp16
=
128
)
self
.
optimizer
=
performance
.
configure_optimizer
(
self
.
optimizer
=
(
self
.
optimizer
,
tf
.
train
.
experimental
.
enable_mixed_precision_graph_rewrite
(
use_float16
=
self
.
dtype
==
tf
.
float16
,
self
.
optimizer
,
loss_scale
))
use_graph_rewrite
=
use_graph_rewrite
,
loss_scale
=
flags_core
.
get_loss_scale
(
flags_obj
,
default_for_fp16
=
128
))
self
.
train_loss
=
tf
.
keras
.
metrics
.
Mean
(
'train_loss'
,
dtype
=
tf
.
float32
)
self
.
train_loss
=
tf
.
keras
.
metrics
.
Mean
(
'train_loss'
,
dtype
=
tf
.
float32
)
self
.
train_accuracy
=
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
self
.
train_accuracy
=
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment