Commit 4ae6c23b authored by vishnubanna's avatar vishnubanna
Browse files

tfds decoder test

parent c96825a4
......@@ -19,7 +19,8 @@ from official.core import base_task
from official.core import input_reader
from official.core import task_factory
from official.modeling import tf_utils
from official.vision.beta.projects.yolo.configs import darknet_classification as exp_cfg
from official.vision.beta.configs import image_classification as exp_cfg
from official.vision.beta.projects.yolo.dataloaders import classification_input as cli
from official.vision.beta.dataloaders import classification_input
from official.vision.beta.modeling import factory
......@@ -52,7 +53,13 @@ class ImageClassificationTask(base_task.Task):
num_classes = self.task_config.model.num_classes
input_size = self.task_config.model.input_size
if params.tfds_name != None:
tf.print("i am here for training using tfds")
decoder = cli.Decoder()
else:
tf.print("i am here for regular input")
decoder = classification_input.Decoder()
parser = classification_input.Parser(
output_size=input_size[:2],
num_classes=num_classes,
......@@ -201,3 +208,5 @@ class ImageClassificationTask(base_task.Task):
def inference_step(self, inputs, model):
"""Performs the forward step."""
return model(inputs, training=False)
......@@ -20,7 +20,6 @@ from official.core import input_reader
from official.core import task_factory
from official.modeling import tf_utils
from official.vision.beta.configs import image_classification as exp_cfg
from official.vision.beta.projects.yolo.dataloaders import classification_input as cli
from official.vision.beta.dataloaders import classification_input
from official.vision.beta.modeling import factory
......@@ -53,11 +52,6 @@ class ImageClassificationTask(base_task.Task):
num_classes = self.task_config.model.num_classes
input_size = self.task_config.model.input_size
if params.tfds_name != None:
tf.print("i am here for training using tfds")
decoder = cli.Decoder()
else:
tf.print("i am here for regular input")
decoder = classification_input.Decoder()
parser = classification_input.Parser(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment