Commit 73a2818d authored by vishnubanna's avatar vishnubanna
Browse files

PR1 darknet

parent 7beddae1
import tensorflow as tf
import tensorflow.keras as ks
import numpy as np
from absl.testing import parameterized
from official.vision.beta.projects.yolo.modeling.building_blocks import DarkTiny
class DarkTinyTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(("middle", 224, 224, 64, 2),
("last", 224, 224, 1024, 1))
def test_pass_through(self, width, height, filters, strides):
x = ks.Input(shape=(width, height, filters))
test_layer = DarkTiny(filters=filters, strides=strides)
outx = test_layer(x)
print(outx)
print(outx.shape.as_list())
self.assertEqual(width % strides, 0, msg="width % strides != 0")
self.assertEqual(height % strides, 0, msg="height % strides != 0")
self.assertAllEqual(
outx.shape.as_list(),
[None, width // strides, height // strides, filters])
return
@parameterized.named_parameters(("middle", 224, 224, 64, 2),
("last", 224, 224, 1024, 1))
def test_gradient_pass_though(self, width, height, filters, strides):
loss = ks.losses.MeanSquaredError()
optimizer = ks.optimizers.SGD()
test_layer = DarkTiny(filters=filters, strides=strides)
init = tf.random_normal_initializer()
x = tf.Variable(initial_value=init(shape=(1, width, height, filters),
dtype=tf.float32))
y = tf.Variable(initial_value=init(shape=(1, width // strides,
height // strides, filters),
dtype=tf.float32))
with tf.GradientTape() as tape:
x_hat = test_layer(x)
grad_loss = loss(x_hat, y)
grad = tape.gradient(grad_loss, test_layer.trainable_variables)
optimizer.apply_gradients(zip(grad, test_layer.trainable_variables))
self.assertNotIn(None, grad)
return
if __name__ == "__main__":
tf.test.main()
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow Model Garden Vision training driver."""
from absl import app
from absl import flags
import gin
from official.core import train_utils
# pylint: disable=unused-import
from official.vision.beta.projects.yolo.common import registry_imports
# pylint: enable=unused-import
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.modeling import performance
FLAGS = flags.FLAGS
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
if 'train' in FLAGS.mode:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment