"csrc/git@developer.sourcefind.cn:OpenDAS/torch-scatter.git" did not exist on "feca30d12428ac88bd5cdaaa91eaad312ebc9e45"
Commit d4149bca authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 350255525
parent 23280eaa
...@@ -33,6 +33,8 @@ from official.vision.beta.configs import decoders ...@@ -33,6 +33,8 @@ from official.vision.beta.configs import decoders
class DataConfig(cfg.DataConfig): class DataConfig(cfg.DataConfig):
"""Input config for training.""" """Input config for training."""
output_size: List[int] = dataclasses.field(default_factory=list) output_size: List[int] = dataclasses.field(default_factory=list)
# If train_on_crops is set to True, a patch of size output_size is cropped
# from the input image.
train_on_crops: bool = False train_on_crops: bool = False
input_path: str = '' input_path: str = ''
global_batch_size: int = 0 global_batch_size: int = 0
...@@ -40,12 +42,16 @@ class DataConfig(cfg.DataConfig): ...@@ -40,12 +42,16 @@ class DataConfig(cfg.DataConfig):
dtype: str = 'float32' dtype: str = 'float32'
shuffle_buffer_size: int = 1000 shuffle_buffer_size: int = 1000
cycle_length: int = 10 cycle_length: int = 10
# If resize_eval_groundtruth is set to False, original image sizes are used
# for eval. In that case, groundtruth_padded_size has to be specified too to
# allow for batching the variable input sizes of images.
resize_eval_groundtruth: bool = True resize_eval_groundtruth: bool = True
groundtruth_padded_size: List[int] = dataclasses.field(default_factory=list) groundtruth_padded_size: List[int] = dataclasses.field(default_factory=list)
aug_scale_min: float = 1.0 aug_scale_min: float = 1.0
aug_scale_max: float = 1.0 aug_scale_max: float = 1.0
aug_rand_hflip: bool = True aug_rand_hflip: bool = True
drop_remainder: bool = True drop_remainder: bool = True
file_type: str = 'tfrecod' # tfrecord, or sstable
@dataclasses.dataclass @dataclasses.dataclass
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility library for picking an appropriate dataset function."""
from typing import Any, Callable, Union, Type
import tensorflow as tf
PossibleDatasetType = Union[Type[tf.data.Dataset], Callable[[tf.Tensor], Any]]
def pick_dataset_fn(file_type: str) -> PossibleDatasetType:
if file_type == 'tf_record':
return tf.data.TFRecordDataset
raise ValueError('Unrecognized file_type: {}'.format(file_type))
...@@ -83,7 +83,7 @@ class SegmentationLoss: ...@@ -83,7 +83,7 @@ class SegmentationLoss:
top_k_losses, _ = tf.math.top_k( top_k_losses, _ = tf.math.top_k(
cross_entropy_loss, k=top_k_pixels, sorted=True) cross_entropy_loss, k=top_k_pixels, sorted=True)
normalizer = tf.reduce_sum( normalizer = tf.reduce_sum(
tf.cast(tf.not_equal(top_k_losses, 0.0), tf.float32) + EPSILON) tf.cast(tf.not_equal(top_k_losses, 0.0), tf.float32)) + EPSILON
loss = tf.reduce_sum(top_k_losses) / normalizer loss = tf.reduce_sum(top_k_losses) / normalizer
return loss return loss
...@@ -23,6 +23,7 @@ from official.core import input_reader ...@@ -23,6 +23,7 @@ from official.core import input_reader
from official.core import task_factory from official.core import task_factory
from official.vision.beta.configs import semantic_segmentation as exp_cfg from official.vision.beta.configs import semantic_segmentation as exp_cfg
from official.vision.beta.dataloaders import segmentation_input from official.vision.beta.dataloaders import segmentation_input
from official.vision.beta.dataloaders import dataset_fn
from official.vision.beta.evaluation import segmentation_metrics from official.vision.beta.evaluation import segmentation_metrics
from official.vision.beta.losses import segmentation_losses from official.vision.beta.losses import segmentation_losses
from official.vision.beta.modeling import factory from official.vision.beta.modeling import factory
...@@ -97,7 +98,7 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -97,7 +98,7 @@ class SemanticSegmentationTask(base_task.Task):
reader = input_reader.InputReader( reader = input_reader.InputReader(
params, params,
dataset_fn=tf.data.TFRecordDataset, dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
decoder_fn=decoder.decode, decoder_fn=decoder.decode,
parser_fn=parser.parse_fn(params.is_training)) parser_fn=parser.parse_fn(params.is_training))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment