"...resnet50_tensorflow.git" did not exist on "95d1b29870d7e2ae0455eebf34674f1afd10c507"
Commit db80d57a authored by Shining Sun's avatar Shining Sun
Browse files

Added some alternative code. About to do dataset

parent 51fc02ae
...@@ -105,7 +105,8 @@ def get_optimizer(): ...@@ -105,7 +105,8 @@ def get_optimizer():
learning_rate = BASE_LEARNING_RATE * FLAGS.batch_size / 256 learning_rate = BASE_LEARNING_RATE * FLAGS.batch_size / 256
optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9) optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
else: else:
optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9) # optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9)
optimizer = gradient_descent_v2.SGD(learning_rate=0.1, momentum=0.9)
return optimizer return optimizer
......
...@@ -23,6 +23,7 @@ from absl import flags ...@@ -23,6 +23,7 @@ from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import imagenet_main from official.resnet import imagenet_main
from official.resnet import imagenet_preprocessing
from official.resnet import resnet_run_loop from official.resnet import resnet_run_loop
from official.resnet.keras import keras_common from official.resnet.keras import keras_common
from official.resnet.keras import resnet50 from official.resnet.keras import resnet50
...@@ -30,6 +31,8 @@ from official.utils.flags import core as flags_core ...@@ -30,6 +31,8 @@ from official.utils.flags import core as flags_core
from official.utils.logs import logger from official.utils.logs import logger
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
# import os
# os.environ['TF2_BEHAVIOR'] = 'enabled'
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
(1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80) (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
...@@ -69,11 +72,25 @@ def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch, batc ...@@ -69,11 +72,25 @@ def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch, batc
def parse_record_keras(raw_record, is_training, dtype): def parse_record_keras(raw_record, is_training, dtype):
"""Adjust the shape of label.""" """Adjust the shape of label."""
image_buffer, label, bbox = imagenet_main._parse_example_proto(raw_record)
image = imagenet_preprocessing.preprocess_image(
image_buffer=image_buffer,
bbox=bbox,
output_height=imagenet_main.DEFAULT_IMAGE_SIZE,
output_width=imagenet_main.DEFAULT_IMAGE_SIZE,
num_channels=imagenet_main.NUM_CHANNELS,
is_training=is_training)
image = tf.cast(image, dtype)
label = tf.sparse_to_dense(label, (imagenet_main.NUM_CLASSES,), 1)
"""
image, label = imagenet_main.parse_record(raw_record, is_training, dtype) image, label = imagenet_main.parse_record(raw_record, is_training, dtype)
# Subtract one so that labels are in [0, 1000), and cast to float32 for # Subtract one so that labels are in [0, 1000), and cast to float32 for
# Keras model. # Keras model.
label = tf.cast(tf.cast(tf.reshape(label, shape=[1]), dtype=tf.int32) - 1, label = tf.cast(tf.cast(tf.reshape(label, shape=[1]), dtype=tf.int32) - 1,
dtype=tf.float32) dtype=tf.float32)
"""
return image, label return image, label
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment