"model/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "9d2a20a7638a7c4e10cb119c3c3b6bf6e470ca3e"
Commit f2ee8b52 authored by Vishnu Banna's avatar Vishnu Banna
Browse files

comments addrssing

parent dabcbc97
...@@ -38,7 +38,7 @@ task: ...@@ -38,7 +38,7 @@ task:
'all': 0.05 'all': 0.05
cls_normalizer: cls_normalizer:
'all': 0.3 'all': 0.3
obj_normalizer: object_normalizer:
'5': 0.28 '5': 0.28
'4': 0.70 '4': 0.70
'3': 2.80 '3': 2.80
......
...@@ -41,7 +41,7 @@ task: ...@@ -41,7 +41,7 @@ task:
'all': 0.07 'all': 0.07
cls_normalizer: cls_normalizer:
'all': 1.0 'all': 1.0
obj_normalizer: object_normalizer:
'all': 1.0 'all': 1.0
objectness_smooth: objectness_smooth:
'all': 0.0 'all': 0.0
......
...@@ -143,7 +143,7 @@ class YoloLoss(hyperparams.Config): ...@@ -143,7 +143,7 @@ class YoloLoss(hyperparams.Config):
default_factory=_build_dict(MIN_LEVEL, MAX_LEVEL, 1.0)) default_factory=_build_dict(MIN_LEVEL, MAX_LEVEL, 1.0))
cls_normalizer: FPNConfig = dataclasses.field( cls_normalizer: FPNConfig = dataclasses.field(
default_factory=_build_dict(MIN_LEVEL, MAX_LEVEL, 1.0)) default_factory=_build_dict(MIN_LEVEL, MAX_LEVEL, 1.0))
obj_normalizer: FPNConfig = dataclasses.field( object_normalizer: FPNConfig = dataclasses.field(
default_factory=_build_dict(MIN_LEVEL, MAX_LEVEL, 1.0)) default_factory=_build_dict(MIN_LEVEL, MAX_LEVEL, 1.0))
max_delta: FPNConfig = dataclasses.field( max_delta: FPNConfig = dataclasses.field(
default_factory=_build_dict(MIN_LEVEL, MAX_LEVEL, np.inf)) default_factory=_build_dict(MIN_LEVEL, MAX_LEVEL, np.inf))
......
...@@ -40,7 +40,7 @@ class YoloLossBase(object, metaclass=abc.ABCMeta): ...@@ -40,7 +40,7 @@ class YoloLossBase(object, metaclass=abc.ABCMeta):
loss_type='ciou', loss_type='ciou',
iou_normalizer=1.0, iou_normalizer=1.0,
cls_normalizer=1.0, cls_normalizer=1.0,
obj_normalizer=1.0, object_normalizer=1.0,
label_smoothing=0.0, label_smoothing=0.0,
objectness_smooth=True, objectness_smooth=True,
update_on_repeat=False, update_on_repeat=False,
...@@ -65,7 +65,7 @@ class YoloLossBase(object, metaclass=abc.ABCMeta): ...@@ -65,7 +65,7 @@ class YoloLossBase(object, metaclass=abc.ABCMeta):
iou_normalizer: `float` for how much to scale the loss on the IOU or the iou_normalizer: `float` for how much to scale the loss on the IOU or the
boxes. boxes.
cls_normalizer: `float` for how much to scale the loss on the classes. cls_normalizer: `float` for how much to scale the loss on the classes.
obj_normalizer: `float` for how much to scale loss on the detection map. object_normalizer: `float` for how much to scale loss on the detection map.
label_smoothing: `float` for how much to smooth the loss on the classes. label_smoothing: `float` for how much to smooth the loss on the classes.
objectness_smooth: `float` for how much to smooth the loss on the objectness_smooth: `float` for how much to smooth the loss on the
detection map. detection map.
...@@ -90,7 +90,7 @@ class YoloLossBase(object, metaclass=abc.ABCMeta): ...@@ -90,7 +90,7 @@ class YoloLossBase(object, metaclass=abc.ABCMeta):
self._iou_normalizer = iou_normalizer self._iou_normalizer = iou_normalizer
self._cls_normalizer = cls_normalizer self._cls_normalizer = cls_normalizer
self._obj_normalizer = obj_normalizer self._object_normalizer = object_normalizer
self._scale_x_y = scale_x_y self._scale_x_y = scale_x_y
self._max_delta = max_delta self._max_delta = max_delta
...@@ -240,9 +240,9 @@ class YoloLossBase(object, metaclass=abc.ABCMeta): ...@@ -240,9 +240,9 @@ class YoloLossBase(object, metaclass=abc.ABCMeta):
Returns: Returns:
loss: `tf.float` scalar for the scaled loss. loss: `tf.float` scalar for the scaled loss.
scale: `tf.float` how much the loss was scaled by.
""" """
del box_loss, conf_loss, class_loss, ground_truths, predictions return loss, tf.ones_like(loss)
return loss
@abc.abstractmethod @abc.abstractmethod
def cross_replica_aggregation(self, loss, num_replicas_in_sync): def cross_replica_aggregation(self, loss, num_replicas_in_sync):
...@@ -373,6 +373,11 @@ class DarknetLoss(YoloLossBase): ...@@ -373,6 +373,11 @@ class DarknetLoss(YoloLossBase):
box_loss = tf.cast(tf.reduce_sum(box_loss, axis=1), dtype=y_pred.dtype) box_loss = tf.cast(tf.reduce_sum(box_loss, axis=1), dtype=y_pred.dtype)
if self._update_on_repeat: if self._update_on_repeat:
# Converts list of gound truths into a grid where repeated values
# are replaced by the most recent value. So some class identities may
# get lost but the loss computation will be more stable. Results are
# more consistent.
# Compute the sigmoid binary cross entropy for the class maps. # Compute the sigmoid binary cross entropy for the class maps.
class_loss = tf.reduce_mean( class_loss = tf.reduce_mean(
loss_utils.sigmoid_bce( loss_utils.sigmoid_bce(
...@@ -395,6 +400,9 @@ class DarknetLoss(YoloLossBase): ...@@ -395,6 +400,9 @@ class DarknetLoss(YoloLossBase):
class_loss = tf.cast( class_loss = tf.cast(
tf.reduce_sum(class_loss, axis=(1, 2, 3)), dtype=y_pred.dtype) tf.reduce_sum(class_loss, axis=(1, 2, 3)), dtype=y_pred.dtype)
else: else:
# Computes the loss while keeping the structure as a list in
# order to ensure all objects are considered. In some cases can
# make training more unstable but may also return higher APs.
pred_class = loss_utils.apply_mask( pred_class = loss_utils.apply_mask(
ind_mask, tf.gather_nd(pred_class, inds, batch_dims=1)) ind_mask, tf.gather_nd(pred_class, inds, batch_dims=1))
class_loss = tf.keras.losses.binary_crossentropy( class_loss = tf.keras.losses.binary_crossentropy(
...@@ -405,8 +413,9 @@ class DarknetLoss(YoloLossBase): ...@@ -405,8 +413,9 @@ class DarknetLoss(YoloLossBase):
class_loss = loss_utils.apply_mask(ind_mask, class_loss) class_loss = loss_utils.apply_mask(ind_mask, class_loss)
class_loss = math_ops.divide_no_nan( class_loss = math_ops.divide_no_nan(
class_loss, tf.expand_dims(reps, axis = -1)) class_loss, tf.expand_dims(reps, axis = -1))
class_loss = tf.cast(tf.reduce_sum( class_loss = tf.cast(
class_loss, axis=(1, 2)), dtype=y_pred.dtype) tf.reduce_sum(class_loss, axis=(1, 2)), dtype=y_pred.dtype)
class_loss *= self._cls_normalizer
# Compute the sigmoid binary cross entropy for the confidence maps. # Compute the sigmoid binary cross entropy for the confidence maps.
bce = tf.reduce_mean( bce = tf.reduce_mean(
...@@ -421,7 +430,7 @@ class DarknetLoss(YoloLossBase): ...@@ -421,7 +430,7 @@ class DarknetLoss(YoloLossBase):
# Apply the weights to each loss. # Apply the weights to each loss.
box_loss *= self._iou_normalizer box_loss *= self._iou_normalizer
conf_loss *= self._obj_normalizer conf_loss *= self._object_normalizer
# Add all the losses together then take the mean over the batches. # Add all the losses together then take the mean over the batches.
loss = box_loss + class_loss + conf_loss loss = box_loss + class_loss + conf_loss
...@@ -505,7 +514,7 @@ class ScaledLoss(YoloLossBase): ...@@ -505,7 +514,7 @@ class ScaledLoss(YoloLossBase):
# Scale and shift and select the ground truth boxes # Scale and shift and select the ground truth boxes
# and predictions to the prediciton domain. # and predictions to the prediciton domain.
if self._box_type == 'anchor_free': if self._box_type == "anchor_free":
true_box = loss_utils.apply_mask(ind_mask, true_box = loss_utils.apply_mask(ind_mask,
(scale * self._path_stride * true_box)) (scale * self._path_stride * true_box))
else: else:
...@@ -562,7 +571,7 @@ class ScaledLoss(YoloLossBase): ...@@ -562,7 +571,7 @@ class ScaledLoss(YoloLossBase):
# Apply the weights to each loss. # Apply the weights to each loss.
box_loss *= self._iou_normalizer box_loss *= self._iou_normalizer
class_loss *= self._cls_normalizer class_loss *= self._cls_normalizer
conf_loss *= self._obj_normalizer conf_loss *= self._object_normalizer
# Add all the losses together then take the sum over the batches. # Add all the losses together then take the sum over the batches.
mean_loss = box_loss + class_loss + conf_loss mean_loss = box_loss + class_loss + conf_loss
...@@ -590,15 +599,15 @@ class ScaledLoss(YoloLossBase): ...@@ -590,15 +599,15 @@ class ScaledLoss(YoloLossBase):
predictions: `Dict` holding all the predicted values. predictions: `Dict` holding all the predicted values.
Returns: Returns:
loss: `tf.float` scalar for the scaled loss. loss: `tf.float` scalar for the scaled loss.
scale: `tf.float` how much the loss was scaled by.
""" """
scale = tf.stop_gradient(3 / len(list(predictions.keys()))) scale = tf.stop_gradient(3 / len(list(predictions.keys())))
return loss * scale return loss * scale, 1/scale
def cross_replica_aggregation(self, loss, num_replicas_in_sync): def cross_replica_aggregation(self, loss, num_replicas_in_sync):
"""this method is not specific to each loss path, but each loss type.""" """This method is not specific to each loss path, but each loss type."""
return loss return loss
class YoloLoss: class YoloLoss:
"""This class implements the aggregated loss across YOLO model FPN levels.""" """This class implements the aggregated loss across YOLO model FPN levels."""
...@@ -612,7 +621,7 @@ class YoloLoss: ...@@ -612,7 +621,7 @@ class YoloLoss:
loss_types=None, loss_types=None,
iou_normalizers=None, iou_normalizers=None,
cls_normalizers=None, cls_normalizers=None,
obj_normalizers=None, object_normalizers=None,
objectness_smooths=None, objectness_smooths=None,
box_types=None, box_types=None,
scale_xys=None, scale_xys=None,
...@@ -642,7 +651,7 @@ class YoloLoss: ...@@ -642,7 +651,7 @@ class YoloLoss:
or the boxes for each FPN path. or the boxes for each FPN path.
cls_normalizers: `Dict[float]` for how much to scale the loss on the cls_normalizers: `Dict[float]` for how much to scale the loss on the
classes for each FPN path. classes for each FPN path.
obj_normalizers: `Dict[float]` for how much to scale loss on the detection object_normalizers: `Dict[float]` for how much to scale loss on the detection
map for each FPN path. map for each FPN path.
objectness_smooths: `Dict[float]` for how much to smooth the loss on the objectness_smooths: `Dict[float]` for how much to smooth the loss on the
detection map for each FPN path. detection map for each FPN path.
...@@ -670,7 +679,7 @@ class YoloLoss: ...@@ -670,7 +679,7 @@ class YoloLoss:
loss_type = 'scaled' loss_type = 'scaled'
else: else:
loss_type = 'darknet' loss_type = 'darknet'
self._loss_dict = {} self._loss_dict = {}
for key in keys: for key in keys:
self._loss_dict[key] = losses[loss_type]( self._loss_dict[key] = losses[loss_type](
...@@ -681,7 +690,7 @@ class YoloLoss: ...@@ -681,7 +690,7 @@ class YoloLoss:
loss_type=loss_types[key], loss_type=loss_types[key],
iou_normalizer=iou_normalizers[key], iou_normalizer=iou_normalizers[key],
cls_normalizer=cls_normalizers[key], cls_normalizer=cls_normalizers[key],
obj_normalizer=obj_normalizers[key], object_normalizer=object_normalizers[key],
box_type=box_types[key], box_type=box_types[key],
objectness_smooth=objectness_smooths[key], objectness_smooth=objectness_smooths[key],
max_delta=max_deltas[key], max_delta=max_deltas[key],
...@@ -710,10 +719,11 @@ class YoloLoss: ...@@ -710,10 +719,11 @@ class YoloLoss:
# after computing the loss, scale loss as needed for aggregation # after computing the loss, scale loss as needed for aggregation
# across FPN levels # across FPN levels
loss = self._loss_dict[key].post_path_aggregation(loss, loss_box, loss, scale = self._loss_dict[key].post_path_aggregation(loss, loss_box,
loss_conf, loss_class, loss_conf,
ground_truth, loss_class,
predictions) ground_truth,
predictions)
# after completing the scaling of the loss on each replica, handle # after completing the scaling of the loss on each replica, handle
# scaling the loss for mergeing the loss across replicas # scaling the loss for mergeing the loss across replicas
...@@ -723,12 +733,13 @@ class YoloLoss: ...@@ -723,12 +733,13 @@ class YoloLoss:
# detach all the below gradients: none of them should make a # detach all the below gradients: none of them should make a
# contribution to the gradient form this point forwards # contribution to the gradient form this point forwards
metric_loss += tf.stop_gradient(mean_loss) metric_loss += tf.stop_gradient(mean_loss/scale)
metric_dict[key]['loss'] = tf.stop_gradient(mean_loss) metric_dict[key]['loss'] = tf.stop_gradient(mean_loss/scale)
metric_dict[key]['avg_iou'] = tf.stop_gradient(avg_iou) metric_dict[key]['avg_iou'] = tf.stop_gradient(avg_iou)
metric_dict[key]['avg_obj'] = tf.stop_gradient(avg_obj) metric_dict[key]['avg_obj'] = tf.stop_gradient(avg_obj)
metric_dict['net']['box'] += tf.stop_gradient(loss_box) metric_dict['net']['box'] += tf.stop_gradient(loss_box/scale)
metric_dict['net']['class'] += tf.stop_gradient(loss_class) metric_dict['net']['class'] += tf.stop_gradient(loss_class/scale)
metric_dict['net']['conf'] += tf.stop_gradient(loss_conf) metric_dict['net']['conf'] += tf.stop_gradient(loss_conf/scale)
return loss_val, metric_loss, metric_dict return loss_val, metric_loss, metric_dict
...@@ -60,7 +60,7 @@ class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -60,7 +60,7 @@ class YoloDecoderTest(parameterized.TestCase, tf.test.TestCase):
loss_types={key: 'ciou' for key in keys}, loss_types={key: 'ciou' for key in keys},
iou_normalizers={key: 0.05 for key in keys}, iou_normalizers={key: 0.05 for key in keys},
cls_normalizers={key: 0.5 for key in keys}, cls_normalizers={key: 0.5 for key in keys},
obj_normalizers={key: 1.0 for key in keys}, object_normalizers={key: 1.0 for key in keys},
objectness_smooths={key: 1.0 for key in keys}, objectness_smooths={key: 1.0 for key in keys},
box_types={key: 'scaled' for key in keys}, box_types={key: 'scaled' for key in keys},
scale_xys={key: 2.0 for key in keys}, scale_xys={key: 2.0 for key in keys},
......
...@@ -43,7 +43,7 @@ def build_yolo_detection_generator(model_config: yolo.Yolo, anchor_boxes): ...@@ -43,7 +43,7 @@ def build_yolo_detection_generator(model_config: yolo.Yolo, anchor_boxes):
max_delta=model_config.loss.max_delta.get(), max_delta=model_config.loss.max_delta.get(),
iou_normalizer=model_config.loss.iou_normalizer.get(), iou_normalizer=model_config.loss.iou_normalizer.get(),
cls_normalizer=model_config.loss.cls_normalizer.get(), cls_normalizer=model_config.loss.cls_normalizer.get(),
obj_normalizer=model_config.loss.obj_normalizer.get(), object_normalizer=model_config.loss.object_normalizer.get(),
ignore_thresh=model_config.loss.ignore_thresh.get(), ignore_thresh=model_config.loss.ignore_thresh.get(),
objectness_smooth=model_config.loss.objectness_smooth.get()) objectness_smooth=model_config.loss.objectness_smooth.get())
return model return model
......
...@@ -36,7 +36,7 @@ class YoloLayer(tf.keras.Model): ...@@ -36,7 +36,7 @@ class YoloLayer(tf.keras.Model):
loss_type='ciou', loss_type='ciou',
iou_normalizer=1.0, iou_normalizer=1.0,
cls_normalizer=1.0, cls_normalizer=1.0,
obj_normalizer=1.0, object_normalizer=1.0,
use_scaled_loss=False, use_scaled_loss=False,
update_on_repeat=False, update_on_repeat=False,
pre_nms_points=5000, pre_nms_points=5000,
...@@ -67,7 +67,7 @@ class YoloLayer(tf.keras.Model): ...@@ -67,7 +67,7 @@ class YoloLayer(tf.keras.Model):
iou_normalizer: `float` for how much to scale the loss on the IOU or the iou_normalizer: `float` for how much to scale the loss on the IOU or the
boxes. boxes.
cls_normalizer: `float` for how much to scale the loss on the classes. cls_normalizer: `float` for how much to scale the loss on the classes.
obj_normalizer: `float` for how much to scale loss on the detection map. object_normalizer: `float` for how much to scale loss on the detection map.
use_scaled_loss: `bool` for whether to use the scaled loss use_scaled_loss: `bool` for whether to use the scaled loss
or the traditional loss. or the traditional loss.
update_on_repeat: `bool` indicating how you would like to handle repeated update_on_repeat: `bool` indicating how you would like to handle repeated
...@@ -110,7 +110,7 @@ class YoloLayer(tf.keras.Model): ...@@ -110,7 +110,7 @@ class YoloLayer(tf.keras.Model):
self._truth_thresh = truth_thresh self._truth_thresh = truth_thresh
self._iou_normalizer = iou_normalizer self._iou_normalizer = iou_normalizer
self._cls_normalizer = cls_normalizer self._cls_normalizer = cls_normalizer
self._obj_normalizer = obj_normalizer self._object_normalizer = object_normalizer
self._objectness_smooth = objectness_smooth self._objectness_smooth = objectness_smooth
self._nms_thresh = nms_thresh self._nms_thresh = nms_thresh
self._max_boxes = max_boxes self._max_boxes = max_boxes
...@@ -289,7 +289,7 @@ class YoloLayer(tf.keras.Model): ...@@ -289,7 +289,7 @@ class YoloLayer(tf.keras.Model):
loss_types=self._loss_type, loss_types=self._loss_type,
iou_normalizers=self._iou_normalizer, iou_normalizers=self._iou_normalizer,
cls_normalizers=self._cls_normalizer, cls_normalizers=self._cls_normalizer,
obj_normalizers=self._obj_normalizer, object_normalizers=self._object_normalizer,
objectness_smooths=self._objectness_smooth, objectness_smooths=self._objectness_smooth,
box_types=self._box_type, box_types=self._box_type,
max_deltas=self._max_delta, max_deltas=self._max_delta,
......
...@@ -325,6 +325,7 @@ class Mosaic: ...@@ -325,6 +325,7 @@ class Mosaic:
return self._add_param(noop) return self._add_param(noop)
def _beta(self, alpha, beta): def _beta(self, alpha, beta):
"""Generates a random number using the beta distribution."""
a = tf.random.gamma([], alpha) a = tf.random.gamma([], alpha)
b = tf.random.gamma([], beta) b = tf.random.gamma([], beta)
return b / (a + b) return b / (a + b)
......
...@@ -24,7 +24,7 @@ from official.core import config_definitions ...@@ -24,7 +24,7 @@ from official.core import config_definitions
from official.modeling import performance from official.modeling import performance
from official.vision.beta.ops import box_ops from official.vision.beta.ops import box_ops
from official.vision.beta.evaluation import coco_evaluator from official.vision.beta.evaluation import coco_evaluator
from official.vision.beta.dataloaders import tfds_detection_decoders from official.vision.beta.dataloaders import tfds_factory
from official.vision.beta.dataloaders import tf_example_label_map_decoder from official.vision.beta.dataloaders import tf_example_label_map_decoder
from official.vision.beta.projects.yolo import optimization from official.vision.beta.projects.yolo import optimization
...@@ -54,6 +54,7 @@ class YoloTask(base_task.Task): ...@@ -54,6 +54,7 @@ class YoloTask(base_task.Task):
self.coco_metric = None self.coco_metric = None
self._loss_fn = None self._loss_fn = None
self._model = None self._model = None
self._coco_91_to_80 = False
self._metrics = [] self._metrics = []
# globally set the random seed # globally set the random seed
...@@ -79,17 +80,14 @@ class YoloTask(base_task.Task): ...@@ -79,17 +80,14 @@ class YoloTask(base_task.Task):
self._model = model self._model = model
return model return model
def get_decoder(self, params): def _get_data_decoder(self, params):
"""Get a decoder object to decode the dataset.""" """Get a decoder object to decode the dataset."""
if params.tfds_name: if params.tfds_name:
if params.tfds_name in tfds_detection_decoders.TFDS_ID_TO_DECODER_MAP: decoder = tfds_factory.get_detection_decoder(params.tfds_name)
decoder = tfds_detection_decoders.TFDS_ID_TO_DECODER_MAP[
params.tfds_name]()
else:
raise ValueError('TFDS {} is not supported'.format(params.tfds_name))
else: else:
decoder_cfg = params.decoder.get() decoder_cfg = params.decoder.get()
if params.decoder.type == 'simple_decoder': if params.decoder.type == 'simple_decoder':
self._coco_91_to_80 = decoder_cfg.coco91_to_80
decoder = tf_example_decoder.TfExampleDecoder( decoder = tf_example_decoder.TfExampleDecoder(
coco91_to_80=decoder_cfg.coco91_to_80, coco91_to_80=decoder_cfg.coco91_to_80,
regenerate_source_id=decoder_cfg.regenerate_source_id) regenerate_source_id=decoder_cfg.regenerate_source_id)
...@@ -123,7 +121,7 @@ class YoloTask(base_task.Task): ...@@ -123,7 +121,7 @@ class YoloTask(base_task.Task):
) )
# get the decoder # get the decoder
decoder = self.get_decoder(params) decoder = self._get_data_decoder(params)
# init Mosaic # init Mosaic
sample_fn = mosaic.Mosaic( sample_fn = mosaic.Mosaic(
...@@ -186,12 +184,15 @@ class YoloTask(base_task.Task): ...@@ -186,12 +184,15 @@ class YoloTask(base_task.Task):
metric_names['net'].append('conf') metric_names['net'].append('conf')
for i, key in enumerate(metric_names.keys()): for i, key in enumerate(metric_names.keys()):
metrics.append(ListMetrics(metric_names[key], name=key)) metrics.append(_ListMetrics(metric_names[key], name=key))
self._metrics = metrics self._metrics = metrics
if not training: if not training:
annotation_file = self.task_config.annotation_file
if self._coco_91_to_80:
annotation_file = None
self.coco_metric = coco_evaluator.COCOEvaluator( self.coco_metric = coco_evaluator.COCOEvaluator(
annotation_file=self.task_config.annotation_file, annotation_file=annotation_file,
include_mask=False, include_mask=False,
need_rescale_bboxes=False, need_rescale_bboxes=False,
per_category_metrics=self._task_config.per_category_metrics) per_category_metrics=self._task_config.per_category_metrics)
...@@ -239,11 +240,6 @@ class YoloTask(base_task.Task): ...@@ -239,11 +240,6 @@ class YoloTask(base_task.Task):
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer): if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
gradients = optimizer.get_unscaled_gradients(gradients) gradients = optimizer.get_unscaled_gradients(gradients)
# Clip the gradients
if self.task_config.gradient_clip_norm > 0.0:
gradients, _ = tf.clip_by_global_norm(gradients,
self.task_config.gradient_clip_norm)
# Apply gradients to the model # Apply gradients to the model
optimizer.apply_gradients(zip(gradients, train_vars)) optimizer.apply_gradients(zip(gradients, train_vars))
logs = {self.loss: metric_loss} logs = {self.loss: metric_loss}
...@@ -407,7 +403,8 @@ class YoloTask(base_task.Task): ...@@ -407,7 +403,8 @@ class YoloTask(base_task.Task):
return optimizer return optimizer
class ListMetrics: class _ListMetrics:
"""Private class used to cleanly place the matric values for each level."""
def __init__(self, metric_names, name="ListMetrics", **kwargs): def __init__(self, metric_names, name="ListMetrics", **kwargs):
self.name = name self.name = name
......
...@@ -27,6 +27,8 @@ class YoloTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -27,6 +27,8 @@ class YoloTaskTest(tf.test.TestCase, parameterized.TestCase):
config.trainer.optimizer_config.ema = None config.trainer.optimizer_config.ema = None
config.task.train_data.global_batch_size = 1 config.task.train_data.global_batch_size = 1
config.task.validation_data.global_batch_size = 1 config.task.validation_data.global_batch_size = 1
config.task.train_data.shuffle_buffer_size = 1
config.task.validation_data.shuffle_buffer_size = 1
task = yolo.YoloTask(config.task) task = yolo.YoloTask(config.task)
model = task.build_model() model = task.build_model()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment