Commit 6fb46d26 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 405520991
parent d09e811f
...@@ -26,10 +26,10 @@ from official.core import exp_factory ...@@ -26,10 +26,10 @@ from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
from official.modeling import optimization from official.modeling import optimization
from official.modeling.hyperparams import config_definitions as cfg from official.modeling.hyperparams import config_definitions as cfg
from official.vision.beta.configs import backbones
from official.vision.beta.configs import common from official.vision.beta.configs import common
from official.vision.beta.configs import decoders from official.vision.beta.configs import decoders
from official.vision.beta.configs import semantic_segmentation as base_cfg from official.vision.beta.configs import semantic_segmentation as base_cfg
from official.vision.beta.configs.google import backbones
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -18,15 +18,16 @@ ...@@ -18,15 +18,16 @@
import dataclasses import dataclasses
import os import os
from typing import Any, List, Optional, Mapping from typing import Any, List, Mapping, Optional
# Import libraries # Import libraries
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
from official.modeling import optimization from official.modeling import optimization
from official.vision.beta.configs import backbones
from official.vision.beta.configs import semantic_segmentation as base_cfg from official.vision.beta.configs import semantic_segmentation as base_cfg
from official.vision.beta.configs.google import backbones
# ADE 20K Dataset # ADE 20K Dataset
ADE20K_TRAIN_EXAMPLES = 20210 ADE20K_TRAIN_EXAMPLES = 20210
......
...@@ -28,7 +28,6 @@ from official.projects.edgetpu.vision.modeling import mobilenet_edgetpu_v1_model ...@@ -28,7 +28,6 @@ from official.projects.edgetpu.vision.modeling import mobilenet_edgetpu_v1_model
from official.projects.edgetpu.vision.modeling import mobilenet_edgetpu_v2_model from official.projects.edgetpu.vision.modeling import mobilenet_edgetpu_v2_model
from official.vision.beta.configs import image_classification as base_cfg from official.vision.beta.configs import image_classification as base_cfg
from official.vision.beta.dataloaders import input_reader_factory from official.vision.beta.dataloaders import input_reader_factory
from official.vision.beta.dataloaders.google import tfds_classification_decoders
def get_models() -> Mapping[str, tf.keras.Model]: def get_models() -> Mapping[str, tf.keras.Model]:
...@@ -141,11 +140,7 @@ class EdgeTPUTask(base_task.Task): ...@@ -141,11 +140,7 @@ class EdgeTPUTask(base_task.Task):
is_multilabel = self.task_config.train_data.is_multilabel is_multilabel = self.task_config.train_data.is_multilabel
if params.tfds_name: if params.tfds_name:
if params.tfds_name in tfds_classification_decoders.TFDS_ID_TO_DECODER_MAP: raise ValueError('TFDS {} is not supported'.format(params.tfds_name))
decoder = tfds_classification_decoders.TFDS_ID_TO_DECODER_MAP[
params.tfds_name]()
else:
raise ValueError('TFDS {} is not supported'.format(params.tfds_name))
else: else:
decoder = classification_input.Decoder( decoder = classification_input.Decoder(
image_field_key=image_field_key, label_field_key=label_field_key, image_field_key=image_field_key, label_field_key=label_field_key,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment