Commit dfa1a569 authored by Vishnu Banna's avatar Vishnu Banna
Browse files

model task

parent cac293d6
import tensorflow as tf
class ListMetrics:
"""Private class used to cleanly place the matric values for each level."""
def __init__(self, metric_names, name="ListMetrics", **kwargs):
self.name = name
self._metric_names = metric_names
self._metrics = self.build_metric()
return
def build_metric(self):
metric_names = self._metric_names
metrics = []
for name in metric_names:
metrics.append(tf.keras.metrics.Mean(name, dtype=tf.float32))
return metrics
def update_state(self, loss_metrics):
metrics = self._metrics
for m in metrics:
m.update_state(loss_metrics[m.name])
return
def result(self):
logs = dict()
metrics = self._metrics
for m in metrics:
logs.update({m.name: m.result()})
return logs
def reset_states(self):
metrics = self._metrics
for m in metrics:
m.reset_states()
return
\ No newline at end of file
...@@ -33,6 +33,7 @@ from official.vision.beta.projects.yolo.ops import preprocessing_ops ...@@ -33,6 +33,7 @@ from official.vision.beta.projects.yolo.ops import preprocessing_ops
from official.vision.beta.projects.yolo.dataloaders import yolo_input from official.vision.beta.projects.yolo.dataloaders import yolo_input
from official.vision.beta.projects.yolo.dataloaders import tf_example_decoder from official.vision.beta.projects.yolo.dataloaders import tf_example_decoder
from official.vision.beta.projects.yolo.configs import yolo as exp_cfg from official.vision.beta.projects.yolo.configs import yolo as exp_cfg
from official.vision.beta.projects.yolo.tasks import task_utils
import tensorflow as tf import tensorflow as tf
from typing import Optional from typing import Optional
...@@ -185,7 +186,7 @@ class YoloTask(base_task.Task): ...@@ -185,7 +186,7 @@ class YoloTask(base_task.Task):
metric_names['net'].append('conf') metric_names['net'].append('conf')
for i, key in enumerate(metric_names.keys()): for i, key in enumerate(metric_names.keys()):
metrics.append(_ListMetrics(metric_names[key], name=key)) metrics.append(task_utils.ListMetrics(metric_names[key], name=key))
self._metrics = metrics self._metrics = metrics
if not training: if not training:
...@@ -396,39 +397,3 @@ class YoloTask(base_task.Task): ...@@ -396,39 +397,3 @@ class YoloTask(base_task.Task):
loss_scale=runtime_config.loss_scale) loss_scale=runtime_config.loss_scale)
return optimizer return optimizer
\ No newline at end of file
class _ListMetrics:
"""Private class used to cleanly place the matric values for each level."""
def __init__(self, metric_names, name="ListMetrics", **kwargs):
self.name = name
self._metric_names = metric_names
self._metrics = self.build_metric()
return
def build_metric(self):
metric_names = self._metric_names
metrics = []
for name in metric_names:
metrics.append(tf.keras.metrics.Mean(name, dtype=tf.float32))
return metrics
def update_state(self, loss_metrics):
metrics = self._metrics
for m in metrics:
m.update_state(loss_metrics[m.name])
return
def result(self):
logs = dict()
metrics = self._metrics
for m in metrics:
logs.update({m.name: m.result()})
return logs
def reset_states(self):
metrics = self._metrics
for m in metrics:
m.reset_states()
return
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment