"git@developer.sourcefind.cn:OpenDAS/torchani.git" did not exist on "79470486acb053936495e0d61fd74a534417933d"
Commit 406c65ab authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 358864651
parent e5eca36a
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dataset reader for vision model garden."""
from typing import Any, Callable, Optional, Tuple
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import input_reader
def calculate_batch_sizes(total_batch_size: int,
pseudo_label_ratio: float) -> Tuple[int, int]:
"""Calculates labeled and pseudo-labeled dataset batch sizes.
Returns (labeled_batch_size, pseudo_labeled_batch_size) given a
total batch size and pseudo-label data ratio.
Args:
total_batch_size: The total batch size for all data.
pseudo_label_ratio: A non-negative float ratio of pseudo-labeled
to labeled data in a batch.
Returns:
(labeled_batch_size, pseudo_labeled_batch_size) as ints.
Raises:
ValueError: If total_batch_size is negative.
ValueError: If pseudo_label_ratio is negative.
"""
if total_batch_size < 0:
raise ValueError('Invalid total_batch_size: {}'.format(total_batch_size))
if pseudo_label_ratio < 0.0:
raise ValueError(
'Invalid pseudo_label_ratio: {}'.format(pseudo_label_ratio))
ratio_factor = pseudo_label_ratio / (1.0 + pseudo_label_ratio)
pseudo_labeled_batch_size = int(round(total_batch_size * ratio_factor))
labeled_batch_size = total_batch_size - pseudo_labeled_batch_size
return labeled_batch_size, pseudo_labeled_batch_size
class CombinationDatasetInputReader(input_reader.InputReader):
"""Combination dataset input reader."""
def __init__(self,
params: cfg.DataConfig,
dataset_fn=tf.data.TFRecordDataset,
pseudo_label_dataset_fn=tf.data.TFRecordDataset,
decoder_fn: Optional[Callable[..., Any]] = None,
sample_fn: Optional[Callable[..., Any]] = None,
parser_fn: Optional[Callable[..., Any]] = None,
transform_and_batch_fn: Optional[Callable[
[tf.data.Dataset, Optional[tf.distribute.InputContext]],
tf.data.Dataset]] = None,
postprocess_fn: Optional[Callable[..., Any]] = None):
"""Initializes an CombinationDatasetInputReader instance.
This class mixes a labeled and pseudo-labeled dataset. The params
must contain "pseudo_label_data.input_path" to specify the
pseudo-label dataset files and "pseudo_label_data.data_ratio"
to specify a per-batch mixing ratio of pseudo-label examples to
labeled dataset examples.
Args:
params: A config_definitions.DataConfig object.
dataset_fn: A `tf.data.Dataset` that consumes the input files. For
example, it can be `tf.data.TFRecordDataset`.
pseudo_label_dataset_fn: A `tf.data.Dataset` that consumes the input
files. For example, it can be `tf.data.TFRecordDataset`.
decoder_fn: An optional `callable` that takes the serialized data string
and decodes them into the raw tensor dictionary.
sample_fn: An optional `callable` that takes a `tf.data.Dataset` object as
input and outputs the transformed dataset. It performs sampling on the
decoded raw tensors dict before the parser_fn.
parser_fn: An optional `callable` that takes the decoded raw tensors dict
and parse them into a dictionary of tensors that can be consumed by the
model. It will be executed after decoder_fn.
transform_and_batch_fn: An optional `callable` that takes a
`tf.data.Dataset` object and an optional `tf.distribute.InputContext` as
input, and returns a `tf.data.Dataset` object. It will be executed after
`parser_fn` to transform and batch the dataset; if None, after
`parser_fn` is executed, the dataset will be batched into per-replica
batch size.
postprocess_fn: A optional `callable` that processes batched tensors. It
will be executed after batching.
Raises:
ValueError: If drop_remainder is False.
"""
super().__init__(params=params,
dataset_fn=dataset_fn,
decoder_fn=decoder_fn,
sample_fn=sample_fn,
parser_fn=parser_fn,
transform_and_batch_fn=transform_and_batch_fn,
postprocess_fn=postprocess_fn)
self._pseudo_label_file_pattern = params.pseudo_label_data.input_path
self._pseudo_label_dataset_fn = pseudo_label_dataset_fn
self._pseudo_label_data_ratio = params.pseudo_label_data.data_ratio
self._pseudo_label_matched_files = self._match_files(
self._pseudo_label_file_pattern)
if not self._drop_remainder:
raise ValueError(
'Must use drop_remainder=True with CombinationDatasetInputReader')
def read(
self,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Generates a tf.data.Dataset object."""
labeled_batch_size, pl_batch_size = calculate_batch_sizes(
self._global_batch_size, self._pseudo_label_data_ratio)
if not labeled_batch_size and pl_batch_size:
raise ValueError(
'Invalid batch_size: {} and pseudo_label_data_ratio: {}, '
'resulting in a 0 batch size for one of the datasets.'.format(
self._global_batch_size, self._pseudo_label_data_ratio))
labeled_dataset = self._read_decode_and_parse_dataset(
matched_files=self._matched_files,
dataset_fn=self._dataset_fn,
batch_size=labeled_batch_size,
input_context=input_context,
tfds_builder=self._tfds_builder)
pseudo_labeled_dataset = self._read_decode_and_parse_dataset(
matched_files=self._pseudo_label_matched_files,
dataset_fn=self._pseudo_label_dataset_fn,
batch_size=pl_batch_size,
input_context=input_context,
tfds_builder=False)
def concat_fn(d1, d2):
return tf.nest.map_structure(
lambda x1, x2: tf.concat([x1, x2], axis=0), d1, d2)
dataset_concat = tf.data.Dataset.zip(
(labeled_dataset, pseudo_labeled_dataset))
dataset_concat = dataset_concat.map(
concat_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
def maybe_map_fn(dataset, fn):
return dataset if fn is None else dataset.map(
fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset_concat = maybe_map_fn(dataset_concat, self._postprocess_fn)
dataset_concat = self._maybe_apply_data_service(dataset_concat,
input_context)
if self._deterministic is not None:
options = tf.data.Options()
options.experimental_deterministic = self._deterministic
dataset_concat = dataset_concat.with_options(options)
return dataset_concat.prefetch(tf.data.experimental.AUTOTUNE)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment