Commit eb275559 authored by vishnubanna's avatar vishnubanna
Browse files

Darknet review additions

parent 0bfc786c
...@@ -13,7 +13,7 @@ class DarkNet(hyperparams.Config): ...@@ -13,7 +13,7 @@ class DarkNet(hyperparams.Config):
model_id: str = "darknet53" model_id: str = "darknet53"
# # we could not get this to work # we could not get this to work
# @dataclasses.dataclass @dataclasses.dataclass
# class Backbone(backbones.Backbone): class Backbone(backbones.Backbone):
# darknet: DarkNet = DarkNet() darknet: DarkNet = DarkNet()
...@@ -9,7 +9,7 @@ from official.vision.beta.projects.yolo.modeling import building_blocks as nn_bl ...@@ -9,7 +9,7 @@ from official.vision.beta.projects.yolo.modeling import building_blocks as nn_bl
# builder required classes # builder required classes
class BlockConfig(object): class BlockConfig(object):
def __init__(self, layer, stack, reps, bottleneck, filters, kernel_size, def __init__(self, layer, stack, reps, bottleneck, filters, pool_size, kernel_size,
strides, padding, activation, route, output_name, is_output): strides, padding, activation, route, output_name, is_output):
''' '''
get layer config to make code more readable get layer config to make code more readable
...@@ -28,13 +28,13 @@ class BlockConfig(object): ...@@ -28,13 +28,13 @@ class BlockConfig(object):
self.bottleneck = bottleneck self.bottleneck = bottleneck
self.filters = filters self.filters = filters
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.pool_size = pool_size
self.strides = strides self.strides = strides
self.padding = padding self.padding = padding
self.activation = activation self.activation = activation
self.route = route self.route = route
self.output_name = output_name self.output_name = output_name
self.is_output = is_output self.is_output = is_output
return
def build_block_specs(config): def build_block_specs(config):
...@@ -44,47 +44,50 @@ def build_block_specs(config): ...@@ -44,47 +44,50 @@ def build_block_specs(config):
return specs return specs
def darkconv_config_todict(config, kwargs):
dictvals = {
"filters": config.filters,
"kernel_size": config.kernel_size,
"strides": config.strides,
"padding": config.padding
}
dictvals.update(kwargs)
return dictvals
def darktiny_config_todict(config, kwargs):
dictvals = {"filters": config.filters, "strides": config.strides}
dictvals.update(kwargs)
return dictvals
def maxpool_config_todict(config, kwargs):
return {
"pool_size": config.kernel_size,
"strides": config.strides,
"padding": config.padding,
"name": kwargs["name"]
}
class layer_registry(object):
class layer_factory(object):
"""
class for quick look up of default layers used by darknet to
connect, introduce or exit a level. Used in place of an if condition
or switch to make adding new layers easier and to reduce redundant code
"""
def __init__(self): def __init__(self):
self._layer_dict = { self._layer_dict = {
"DarkTiny": (nn_blocks.DarkTiny, darktiny_config_todict), "DarkTiny": (nn_blocks.DarkTiny, self.darktiny_config_todict),
"DarkConv": (nn_blocks.DarkConv, darkconv_config_todict), "DarkConv": (nn_blocks.DarkConv, self.darkconv_config_todict),
"MaxPool": (tf.keras.layers.MaxPool2D, maxpool_config_todict) "MaxPool": (tf.keras.layers.MaxPool2D, self.maxpool_config_todict)
} }
return return
def _get_layer(self, key): def darkconv_config_todict(config, kwargs):
return self._layer_dict[key] dictvals = {
"filters": config.filters,
"kernel_size": config.kernel_size,
"strides": config.strides,
"padding": config.padding
}
dictvals.update(kwargs)
return dictvals
def darktiny_config_todict(config, kwargs):
dictvals = {"filters": config.filters, "strides": config.strides}
dictvals.update(kwargs)
return dictvals
def maxpool_config_todict(config, kwargs):
return {
"pool_size": config.pool_size,
"strides": config.strides,
"padding": config.padding,
"name": kwargs["name"]
}
def __call__(self, config, kwargs): def __call__(self, config, kwargs):
layer, get_param_dict = self._get_layer(config.layer) layer, get_param_dict = self._layer_dict[key]
param_dict = get_param_dict(config, kwargs) param_dict = get_param_dict(config, kwargs)
return layer(**param_dict) return layer(**param_dict)
...@@ -92,7 +95,7 @@ class layer_registry(object): ...@@ -92,7 +95,7 @@ class layer_registry(object):
# model configs # model configs
LISTNAMES = [ LISTNAMES = [
"default_layer_name", "level_type", "number_of_layers_in_level", "default_layer_name", "level_type", "number_of_layers_in_level",
"bottleneck", "filters", "kernal_size", "strides", "padding", "bottleneck", "filters", "kernal_size", "pool_size", "strides", "padding",
"default_activation", "route", "level/name", "is_output" "default_activation", "route", "level/name", "is_output"
] ]
...@@ -101,12 +104,12 @@ CSPDARKNET53 = { ...@@ -101,12 +104,12 @@ CSPDARKNET53 = {
"splits": {"backbone_split": 106, "splits": {"backbone_split": 106,
"neck_split": 138}, "neck_split": 138},
"backbone": [ "backbone": [
["DarkConv", None, 1, False, 32, 3, 1, "same", "mish", -1, 0, False], # 1 ["DarkConv", None, 1, False, 32, None, 3, 1, "same", "mish", -1, 0, False], # 1
["DarkRes", "csp", 1, True, 64, None, None, None, "mish", -1, 1, False], # 3 ["DarkRes", "csp", 1, True, 64, None, None, None, None, "mish", -1, 1, False], # 3
["DarkRes", "csp", 2, False, 128, None, None, None, "mish", -1, 2, False], # 2 ["DarkRes", "csp", 2, False, 128, None, None, None, None, "mish", -1, 2, False], # 2
["DarkRes", "csp", 8, False, 256, None, None, None, "mish", -1, 3, True], ["DarkRes", "csp", 8, False, 256, None, None, None, None, "mish", -1, 3, True],
["DarkRes", "csp", 8, False, 512, None, None, None, "mish", -1, 4, True], # 3 ["DarkRes", "csp", 8, False, 512, None, None, None, None, "mish", -1, 4, True], # 3
["DarkRes", "csp", 4, False, 1024, None, None, None, "mish", -1, 5, True], # 6 #route ["DarkRes", "csp", 4, False, 1024, None, None, None, None, "mish", -1, 5, True], # 6 #route
] ]
} }
...@@ -114,12 +117,12 @@ DARKNET53 = { ...@@ -114,12 +117,12 @@ DARKNET53 = {
"list_names": LISTNAMES, "list_names": LISTNAMES,
"splits": {"backbone_split": 76}, "splits": {"backbone_split": 76},
"backbone": [ "backbone": [
["DarkConv", None, 1, False, 32, 3, 1, "same", "leaky", -1, 0, False], # 1 ["DarkConv", None, 1, False, 32, None, 3, 1, "same", "leaky", -1, 0, False], # 1
["DarkRes", "residual", 1, True, 64, None, None, None, "leaky", -1, 1, False], # 3 ["DarkRes", "residual", 1, True, 64, None, None, None, None, "leaky", -1, 1, False], # 3
["DarkRes", "residual", 2, False, 128, None, None, None, "leaky", -1, 2, False], # 2 ["DarkRes", "residual", 2, False, 128, None, None, None, None, "leaky", -1, 2, False], # 2
["DarkRes", "residual", 8, False, 256, None, None, None, "leaky", -1, 3, True], ["DarkRes", "residual", 8, False, 256, None, None, None, None, "leaky", -1, 3, True],
["DarkRes", "residual", 8, False, 512, None, None, None, "leaky", -1, 4, True], # 3 ["DarkRes", "residual", 8, False, 512, None, None, None, None, "leaky", -1, 4, True], # 3
["DarkRes", "residual", 4, False, 1024, None, None, None, "leaky", -1, 5, True], # 6 ["DarkRes", "residual", 4, False, 1024, None, None, None, None, "leaky", -1, 5, True], # 6
] ]
} }
...@@ -127,12 +130,12 @@ CSPDARKNETTINY = { ...@@ -127,12 +130,12 @@ CSPDARKNETTINY = {
"list_names": LISTNAMES, "list_names": LISTNAMES,
"splits": {"backbone_split": 28}, "splits": {"backbone_split": 28},
"backbone": [ "backbone": [
["DarkConv", None, 1, False, 32, 3, 2, "same", "leaky", -1, 0, False], # 1 ["DarkConv", None, 1, False, 32, None, 3, 2, "same", "leaky", -1, 0, False], # 1
["DarkConv", None, 1, False, 64, 3, 2, "same", "leaky", -1, 1, False], # 1 ["DarkConv", None, 1, False, 64, None, 3, 2, "same", "leaky", -1, 1, False], # 1
["CSPTiny", "csp_tiny", 1, False, 64, 3, 2, "same", "leaky", -1, 2, False], # 3 ["CSPTiny", "csp_tiny", 1, False, 64, None, 3, 2, "same", "leaky", -1, 2, False], # 3
["CSPTiny", "csp_tiny", 1, False, 128, 3, 2, "same", "leaky", -1, 3, False], # 3 ["CSPTiny", "csp_tiny", 1, False, 128, None, 3, 2, "same", "leaky", -1, 3, False], # 3
["CSPTiny", "csp_tiny", 1, False, 256, 3, 2, "same", "leaky", -1, 4, True], # 3 ["CSPTiny", "csp_tiny", 1, False, 256, None, 3, 2, "same", "leaky", -1, 4, True], # 3
["DarkConv", None, 1, False, 512, 3, 1, "same", "leaky", -1, 5, True], # 1 ["DarkConv", None, 1, False, 512, None, 3, 1, "same", "leaky", -1, 5, True], # 1
] ]
} }
...@@ -140,13 +143,13 @@ DARKNETTINY = { ...@@ -140,13 +143,13 @@ DARKNETTINY = {
"list_names": LISTNAMES, "list_names": LISTNAMES,
"splits": {"backbone_split": 14}, "splits": {"backbone_split": 14},
"backbone": [ "backbone": [
["DarkConv", None, 1, False, 16, 3, 1, "same", "leaky", -1, 0, False], # 1 ["DarkConv", None, 1, False, 16, None, 3, 1, "same", "leaky", -1, 0, False], # 1
["DarkTiny", None, 1, True, 32, 3, 2, "same", "leaky", -1, 1, False], # 3 ["DarkTiny", None, 1, True, 32, None, 3, 2, "same", "leaky", -1, 1, False], # 3
["DarkTiny", None, 1, True, 64, 3, 2, "same", "leaky", -1, 2, False], # 3 ["DarkTiny", None, 1, True, 64, None, 3, 2, "same", "leaky", -1, 2, False], # 3
["DarkTiny", None, 1, False, 128, 3, 2, "same", "leaky", -1, 3, False], # 2 ["DarkTiny", None, 1, False, 128, None, 3, 2, "same", "leaky", -1, 3, False], # 2
["DarkTiny", None, 1, False, 256, 3, 2, "same", "leaky", -1, 4, True], ["DarkTiny", None, 1, False, 256, None, 3, 2, "same", "leaky", -1, 4, True],
["DarkTiny", None, 1, False, 512, 3, 2, "same", "leaky", -1, 5, False], # 3 ["DarkTiny", None, 1, False, 512, None, 3, 2, "same", "leaky", -1, 5, False], # 3
["DarkTiny", None, 1, False, 1024, 3, 1, "same", "leaky", -1, 5, True], # 6 #route ["DarkTiny", None, 1, False, 1024, None, 3, 1, "same", "leaky", -1, 5, True], # 6 #route
] ]
} }
...@@ -164,7 +167,7 @@ class Darknet(ks.Model): ...@@ -164,7 +167,7 @@ class Darknet(ks.Model):
def __init__( def __init__(
self, self,
model_id="darknet53", model_id="darknet53",
input_shape=tf.keras.layers.InputSpec(shape=[None, None, None, 3]), input_specs=tf.keras.layers.InputSpec(shape=[None, None, None, 3]),
min_size=None, min_size=None,
max_size=5, max_size=5,
activation=None, activation=None,
...@@ -181,8 +184,8 @@ class Darknet(ks.Model): ...@@ -181,8 +184,8 @@ class Darknet(ks.Model):
self._model_name = model_id self._model_name = model_id
self._splits = splits self._splits = splits
self._input_shape = input_shape self._input_shape = input_specs
self._registry = layer_registry() self._registry = layer_factory()
# default layer look up # default layer look up
self._min_size = min_size self._min_size = min_size
...@@ -195,11 +198,11 @@ class Darknet(ks.Model): ...@@ -195,11 +198,11 @@ class Darknet(ks.Model):
self._norm_epislon = norm_epsilon self._norm_epislon = norm_epsilon
self._use_sync_bn = use_sync_bn self._use_sync_bn = use_sync_bn
self._activation = activation self._activation = activation
self._weight_decay = kernel_regularizer self._kernel_regularizer = kernel_regularizer
self._default_dict = { self._default_dict = {
"kernel_initializer": self._kernel_initializer, "kernel_initializer": self._kernel_initializer,
"weight_decay": self._weight_decay, "weight_decay": self._kernel_regularizer,
"bias_regularizer": self._bias_regularizer, "bias_regularizer": self._bias_regularizer,
"norm_momentum": self._norm_momentum, "norm_momentum": self._norm_momentum,
"norm_epsilon": self._norm_epislon, "norm_epsilon": self._norm_epislon,
...@@ -211,7 +214,6 @@ class Darknet(ks.Model): ...@@ -211,7 +214,6 @@ class Darknet(ks.Model):
inputs = ks.layers.Input(shape=self._input_shape.shape[1:]) inputs = ks.layers.Input(shape=self._input_shape.shape[1:])
output = self._build_struct(layer_specs, inputs) output = self._build_struct(layer_specs, inputs)
super().__init__(inputs=inputs, outputs=output, name=self._model_name) super().__init__(inputs=inputs, outputs=output, name=self._model_name)
return
@property @property
def input_specs(self): def input_specs(self):
...@@ -251,9 +253,9 @@ class Darknet(ks.Model): ...@@ -251,9 +253,9 @@ class Darknet(ks.Model):
stack_outputs.append(x_pass) stack_outputs.append(x_pass)
if (config.is_output and if (config.is_output and
self._min_size == None): # or isinstance(config.output_name, str): self._min_size == None): # or isinstance(config.output_name, str):
endpoints[config.output_name] = x endpoints[str(config.output_name)] = x
elif self._min_size != None and config.output_name >= self._min_size and config.output_name <= self._max_size: elif self._min_size != None and config.output_name >= self._min_size and config.output_name <= self._max_size:
endpoints[config.output_name] = x endpoints[str(config.output_name)] = x
self._output_specs = {l: endpoints[l].get_shape() for l in endpoints.keys()} self._output_specs = {l: endpoints[l].get_shape() for l in endpoints.keys()}
return endpoints return endpoints
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for resnet."""
# Import libraries
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.vision.beta.projects.yolo.modeling.backbones import Darknet
class ResNetTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(224, "darknet53", 2),
(224, "darknettiny", 2),
(224, "cspdarknettiny", 1),
(224, "cspdarknet53", 2),
)
def test_network_creation(self, input_size, model_id,
endpoint_filter_scale):
"""Test creation of ResNet family models."""
tf.keras.backend.set_image_data_format('channels_last')
network = Darknet.Darknet(model_id=model_id, min_size=3, max_size=5)
self.assertEqual(network.count_params(), resnet_params[model_id])
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
endpoints = network(inputs)
self.assertAllEqual(
[1, input_size / 2**3, input_size / 2**3, 128 * endpoint_filter_scale],
endpoints['3'].shape.as_list())
self.assertAllEqual(
[1, input_size / 2**4, input_size / 2**4, 256 * endpoint_filter_scale],
endpoints['4'].shape.as_list())
self.assertAllEqual(
[1, input_size / 2**5, input_size / 2**5, 512 * endpoint_filter_scale],
endpoints['5'].shape.as_list())
@combinations.generate(
combinations.combine(
strategy=[
strategy_combinations.tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
use_sync_bn=[False, True],
))
def test_sync_bn_multiple_devices(self, strategy, use_sync_bn):
"""Test for sync bn on TPU and GPU devices."""
inputs = np.random.rand(64, 224, 224, 3)
tf.keras.backend.set_image_data_format('channels_last')
with strategy.scope():
network = Darknet.Darknet(model_id="darknet53", min_size=3, max_size=5)
_ = network(inputs)
@parameterized.parameters(1, 3, 4)
def test_input_specs(self, input_dim):
"""Test different input feature dimensions."""
tf.keras.backend.set_image_data_format('channels_last')
input_specs = tf.keras.layers.InputSpec(shape=[None, None, None, input_dim])
network = Darknet.Darknet(model_id="darknet53", min_size=3, max_size=5, input_shape=input_specs)
inputs = tf.keras.Input(shape=(224, 224, input_dim), batch_size=1)
_ = network(inputs)
def test_serialize_deserialize(self):
# Create a network object that sets all of its config options.
kwargs = dict(
model_id="darknet53",
use_sync_bn=False,
activation='relu',
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
)
network = Darknet.Darknet(**kwargs)
expected_config = dict(kwargs)
self.assertEqual(network.get_config(), expected_config)
# Create another network object from the first object's config.
new_network = Darknet.Darknet.from_config(network.get_config())
# Validate that the config can be forced to JSON.
_ = new_network.to_json()
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(network.get_config(), new_network.get_config())
if __name__ == '__main__':
tf.test.main()
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Image classification task definition."""
import tensorflow as tf
from official.core import base_task
from official.core import input_reader
from official.core import task_factory
from official.modeling import tf_utils
from official.vision.beta.configs import image_classification as exp_cfg
from official.vision.beta.dataloaders import classification_input
from official.vision.beta.modeling import factory
@task_factory.register_task_cls(exp_cfg.ImageClassificationTask)
class ImageClassificationTask(base_task.Task):
"""A task for image classification."""
def build_model(self):
"""Builds classification model."""
input_specs = tf.keras.layers.InputSpec(
shape=[None] + self.task_config.model.input_size)
l2_weight_decay = self.task_config.losses.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf.keras.regularizers.l2(
l2_weight_decay / 2.0) if l2_weight_decay else None)
model = factory.build_classification_model(
input_specs=input_specs,
model_config=self.task_config.model,
l2_regularizer=l2_regularizer)
return model
def build_inputs(self, params, input_context=None):
"""Builds classification input."""
num_classes = self.task_config.model.num_classes
input_size = self.task_config.model.input_size
decoder = classification_input.Decoder()
parser = classification_input.Parser(
output_size=input_size[:2],
num_classes=num_classes,
dtype=params.dtype)
reader = input_reader.InputReader(
params,
dataset_fn=tf.data.TFRecordDataset,
decoder_fn=decoder.decode,
parser_fn=parser.parse_fn(params.is_training))
dataset = reader.read(input_context=input_context)
return dataset
def build_losses(self, labels, model_outputs, aux_losses=None):
"""Sparse categorical cross entropy loss.
Args:
labels: labels.
model_outputs: Output logits of the classifier.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
Returns:
The total loss tensor.
"""
losses_config = self.task_config.losses
if losses_config.one_hot:
total_loss = tf.keras.losses.categorical_crossentropy(
labels,
model_outputs,
from_logits=True,
label_smoothing=losses_config.label_smoothing)
else:
total_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, model_outputs, from_logits=True)
total_loss = tf_utils.safe_mean(total_loss)
if aux_losses:
total_loss += tf.add_n(aux_losses)
return total_loss
def build_metrics(self, training=True):
"""Gets streaming metrics for training/validation."""
if self.task_config.losses.one_hot:
metrics = [
tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='top_5_accuracy')]
else:
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
tf.keras.metrics.SparseTopKCategoricalAccuracy(
k=5, name='top_5_accuracy')]
return metrics
def train_step(self, inputs, model, optimizer, metrics=None):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
if self.task_config.losses.one_hot:
labels = tf.one_hot(labels, self.task_config.model.num_classes)
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape:
outputs = model(features, training=True)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure(
lambda x: tf.cast(x, tf.float32), outputs)
# Computes per-replica loss.
loss = self.build_losses(
model_outputs=outputs, labels=labels, aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if isinstance(
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
if isinstance(
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads)
# Apply gradient clipping.
if self.task_config.gradient_clip_norm > 0:
grads, _ = tf.clip_by_global_norm(
grads, self.task_config.gradient_clip_norm)
optimizer.apply_gradients(list(zip(grads, tvars)))
logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics})
elif model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
logs.update({m.name: m.result() for m in model.metrics})
return logs
def validation_step(self, inputs, model, metrics=None):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
if self.task_config.losses.one_hot:
labels = tf.one_hot(labels, self.task_config.model.num_classes)
outputs = self.inference_step(features, model)
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
loss = self.build_losses(model_outputs=outputs, labels=labels,
aux_losses=model.losses)
logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics})
elif model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
logs.update({m.name: m.result() for m in model.metrics})
return logs
def inference_step(self, inputs, model):
"""Performs the forward step."""
return model(inputs, training=False)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment