"megatron/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "2c63b5cd2eeaf66c3a45e7c65da41d16fb8838ca"
Commit 6fb46d26 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 405520991
parent d09e811f
......@@ -26,10 +26,10 @@ from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling import optimization
from official.modeling.hyperparams import config_definitions as cfg
from official.vision.beta.configs import backbones
from official.vision.beta.configs import common
from official.vision.beta.configs import decoders
from official.vision.beta.configs import semantic_segmentation as base_cfg
from official.vision.beta.configs.google import backbones
@dataclasses.dataclass
......
......@@ -18,15 +18,16 @@
import dataclasses
import os
from typing import Any, List, Optional, Mapping
from typing import Any, List, Mapping, Optional
# Import libraries
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling import optimization
from official.vision.beta.configs import backbones
from official.vision.beta.configs import semantic_segmentation as base_cfg
from official.vision.beta.configs.google import backbones
# ADE 20K Dataset
ADE20K_TRAIN_EXAMPLES = 20210
......
......@@ -28,7 +28,6 @@ from official.projects.edgetpu.vision.modeling import mobilenet_edgetpu_v1_model
from official.projects.edgetpu.vision.modeling import mobilenet_edgetpu_v2_model
from official.vision.beta.configs import image_classification as base_cfg
from official.vision.beta.dataloaders import input_reader_factory
from official.vision.beta.dataloaders.google import tfds_classification_decoders
def get_models() -> Mapping[str, tf.keras.Model]:
......@@ -141,11 +140,7 @@ class EdgeTPUTask(base_task.Task):
is_multilabel = self.task_config.train_data.is_multilabel
if params.tfds_name:
if params.tfds_name in tfds_classification_decoders.TFDS_ID_TO_DECODER_MAP:
decoder = tfds_classification_decoders.TFDS_ID_TO_DECODER_MAP[
params.tfds_name]()
else:
raise ValueError('TFDS {} is not supported'.format(params.tfds_name))
raise ValueError('TFDS {} is not supported'.format(params.tfds_name))
else:
decoder = classification_input.Decoder(
image_field_key=image_field_key, label_field_key=label_field_key,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment