Commit c5ae4110 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by saberkun
Browse files

Internal change

PiperOrigin-RevId: 398593113
parent 6ca5ac92
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
"""A common dataset reader.""" """A common dataset reader."""
import random import random
from typing import Any, Callable, List, Optional, Union, Dict, Sequence from typing import Any, Callable, Dict, List, Optional, Sequence, Text, Union
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
...@@ -34,6 +34,154 @@ def _maybe_map_fn(dataset: tf.data.Dataset, ...@@ -34,6 +34,154 @@ def _maybe_map_fn(dataset: tf.data.Dataset,
fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
def match_files(input_path: Union[Sequence[str], str]) -> List[str]:
"""Matches files from an input_path."""
matched_files = []
# Read dataset from files.
usage = ('`input_path` should be either (1) a str indicating a file '
'path/pattern, or (2) a str indicating multiple file '
'paths/patterns separated by comma (e.g "a, b, c" or no spaces '
'"a,b,c", or (3) a list of str, each of which is a file '
'path/pattern or multiple file paths/patterns separated by '
'comma, but got: %s')
if isinstance(input_path, str):
input_path_list = [input_path]
elif isinstance(input_path, (list, tuple)):
if any(not isinstance(x, str) for x in input_path):
raise ValueError(usage % input_path)
input_path_list = input_path
else:
raise ValueError(usage % input_path)
for input_path in input_path_list:
input_patterns = input_path.strip().split(',')
for input_pattern in input_patterns:
input_pattern = input_pattern.strip()
if not input_pattern:
continue
if '*' in input_pattern or '?' in input_pattern:
tmp_matched_files = tf.io.gfile.glob(input_pattern)
if not tmp_matched_files:
raise ValueError('%s does not match any files.' % input_pattern)
matched_files.extend(tmp_matched_files)
else:
matched_files.append(input_pattern)
if not matched_files:
raise ValueError('%s does not match any files.' % input_path)
return matched_files
def _read_files_then_shard(matched_files: List[str],
dataset_fn,
input_context: Optional[
tf.distribute.InputContext] = None,
sharding: bool = False,
repeat: bool = False) -> tf.data.Dataset:
"""Sends all data files to every worker and then shard by data."""
dataset = dataset_fn(matched_files)
# When `input_file` is a path to a single file or the number of files is
# less than the number of input pipelines, disable auto sharding
# so that same input file is sent to all workers.
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = (
tf.data.experimental.AutoShardPolicy.OFF)
dataset = dataset.with_options(options)
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if sharding and input_context and (input_context.num_input_pipelines > 1):
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
if repeat:
dataset = dataset.repeat()
return dataset
def _shard_files_then_read(matched_files: List[str],
dataset_fn,
input_context: Optional[
tf.distribute.InputContext] = None,
seed: Optional[Union[int, tf.Tensor]] = None,
is_training: bool = False,
sharding: bool = False,
cache: bool = False,
cycle_length: Optional[int] = None,
block_length: Optional[int] = None,
deterministic: bool = False) -> tf.data.Dataset:
"""Shards the data files and then sent a split to every worker to read."""
dataset = tf.data.Dataset.from_tensor_slices(matched_files)
# Shuffle and repeat at file level.
# If cache is enabled, `reshuffle_each_iteration` is set to False,
# because we will read the same cached data in every iteration anyway.
if is_training:
# We need a seed to shuffle the files so that when each TPU workers gets
# its own shard the files do not overlap.
if sharding and seed is None:
seed = _get_random_integer()
dataset = dataset.shuffle(
len(matched_files),
seed=seed,
reshuffle_each_iteration=True if not cache else False)
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if sharding and input_context and (input_context.num_input_pipelines > 1):
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
# If cache is enabled, we will call `repeat()` later after `cache()`.
if is_training and not cache:
dataset = dataset.repeat()
dataset = dataset.interleave(
map_func=dataset_fn,
cycle_length=cycle_length,
block_length=block_length,
num_parallel_calls=(cycle_length
if cycle_length else tf.data.experimental.AUTOTUNE),
deterministic=deterministic)
return dataset
def _read_tfds(tfds_builder: tfds.core.DatasetBuilder,
tfds_split: Text,
tfds_skip_decoding_feature: Text,
tfds_as_supervised: bool,
input_context: Optional[tf.distribute.InputContext] = None,
seed: Optional[Union[int, tf.Tensor]] = None,
is_training: bool = False,
cache: bool = False,
cycle_length: Optional[int] = None,
block_length: Optional[int] = None) -> tf.data.Dataset:
"""Reads a dataset from tfds."""
# No op if exist.
tfds_builder.download_and_prepare()
read_config = tfds.ReadConfig(
interleave_cycle_length=cycle_length,
interleave_block_length=block_length,
input_context=input_context,
shuffle_seed=seed)
decoders = {}
if tfds_skip_decoding_feature:
for skip_feature in tfds_skip_decoding_feature.split(','):
decoders[skip_feature.strip()] = tfds.decode.SkipDecoding()
dataset = tfds_builder.as_dataset(
split=tfds_split,
shuffle_files=is_training,
as_supervised=tfds_as_supervised,
decoders=decoders,
read_config=read_config)
if is_training and not cache:
dataset = dataset.repeat()
return dataset
class InputReader: class InputReader:
"""Input reader that returns a tf.data.Dataset instance.""" """Input reader that returns a tf.data.Dataset instance."""
...@@ -90,16 +238,7 @@ class InputReader: ...@@ -90,16 +238,7 @@ class InputReader:
self._tfds_builder = None self._tfds_builder = None
self._matched_files = None self._matched_files = None
if params.input_path: if not params.input_path:
# we want to combine / mix datasets
if isinstance(params.input_path, cfg.base_config.Config):
self._matched_files = {}
for k, v in params.input_path.as_dict().items():
self._matched_files[k] = self._match_files(v)
# single dataset
else:
self._matched_files = self._match_files(params.input_path)
else:
# Read dataset from TFDS. # Read dataset from TFDS.
if not params.tfds_split: if not params.tfds_split:
raise ValueError( raise ValueError(
...@@ -107,6 +246,8 @@ class InputReader: ...@@ -107,6 +246,8 @@ class InputReader:
params.tfds_name) params.tfds_name)
self._tfds_builder = tfds.builder( self._tfds_builder = tfds.builder(
params.tfds_name, data_dir=params.tfds_data_dir) params.tfds_name, data_dir=params.tfds_data_dir)
else:
self._matched_files = self.get_files(params.input_path)
self._global_batch_size = params.global_batch_size self._global_batch_size = params.global_batch_size
self._is_training = params.is_training self._is_training = params.is_training
...@@ -149,145 +290,6 @@ class InputReader: ...@@ -149,145 +290,6 @@ class InputReader:
self._enable_round_robin_tf_data_service = params.get( self._enable_round_robin_tf_data_service = params.get(
'enable_round_robin_tf_data_service', False) 'enable_round_robin_tf_data_service', False)
def _match_files(self, input_path: Union[Sequence[str], str]) -> List[str]:
"""Matches files from an input_path."""
matched_files = []
# Read dataset from files.
usage = ('`input_path` should be either (1) a str indicating a file '
'path/pattern, or (2) a str indicating multiple file '
'paths/patterns separated by comma (e.g "a, b, c" or no spaces '
'"a,b,c", or (3) a list of str, each of which is a file '
'path/pattern or multiple file paths/patterns separated by '
'comma, but got: %s')
if isinstance(input_path, str):
input_path_list = [input_path]
elif isinstance(input_path, (list, tuple)):
if any(not isinstance(x, str) for x in input_path):
raise ValueError(usage % input_path)
input_path_list = input_path
else:
raise ValueError(usage % input_path)
for input_path in input_path_list:
input_patterns = input_path.strip().split(',')
for input_pattern in input_patterns:
input_pattern = input_pattern.strip()
if not input_pattern:
continue
if '*' in input_pattern or '?' in input_pattern:
tmp_matched_files = tf.io.gfile.glob(input_pattern)
if not tmp_matched_files:
raise ValueError('%s does not match any files.' % input_pattern)
matched_files.extend(tmp_matched_files)
else:
matched_files.append(input_pattern)
if not matched_files:
raise ValueError('%s does not match any files.' % input_path)
return matched_files
def _shard_files_then_read(
self,
matched_files: List[str],
dataset_fn,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Shards the data files and then sent a split to every worker to read."""
dataset = tf.data.Dataset.from_tensor_slices(matched_files)
# Shuffle and repeat at file level.
# If cache is enabled, `reshuffle_each_iteration` is set to False,
# because we will read the same cached data in every iteration anyway.
if self._is_training:
# We need a seed to shuffle the files so that when each TPU workers gets
# its own shard the files do not overlap.
if self._sharding and self._seed is None:
seed = _get_random_integer()
else:
seed = self._seed
dataset = dataset.shuffle(
len(matched_files),
seed=seed,
reshuffle_each_iteration=True if not self._cache else False)
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if self._sharding and input_context and (input_context.num_input_pipelines >
1):
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
# If cache is enabled, we will call `repeat()` later after `cache()`.
if self._is_training and not self._cache:
dataset = dataset.repeat()
dataset = dataset.interleave(
map_func=dataset_fn,
cycle_length=self._cycle_length,
block_length=self._block_length,
num_parallel_calls=(self._cycle_length if self._cycle_length else
tf.data.experimental.AUTOTUNE),
deterministic=self._deterministic)
return dataset
def _read_files_then_shard(
self,
matched_files: List[str],
dataset_fn,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Sends all data files to every worker and then shard by data."""
dataset = dataset_fn(matched_files)
# When `input_file` is a path to a single file or the number of files is
# less than the number of input pipelines, disable auto sharding
# so that same input file is sent to all workers.
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = (
tf.data.experimental.AutoShardPolicy.OFF)
dataset = dataset.with_options(options)
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if self._sharding and input_context and (input_context.num_input_pipelines >
1):
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
# If cache is enabled, we will call `repeat()` later after `cache()`.
if self._is_training and not self._cache:
dataset = dataset.repeat()
return dataset
def _read_tfds(
self,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Reads a dataset from tfds."""
# No op if exist.
self._tfds_builder.download_and_prepare()
read_config = tfds.ReadConfig(
interleave_cycle_length=self._cycle_length,
interleave_block_length=self._block_length,
input_context=input_context,
shuffle_seed=self._seed)
decoders = {}
if self._tfds_skip_decoding_feature:
for skip_feature in self._tfds_skip_decoding_feature.split(','):
decoders[skip_feature.strip()] = tfds.decode.SkipDecoding()
dataset = self._tfds_builder.as_dataset(
split=self._tfds_split,
shuffle_files=self._is_training,
as_supervised=self._tfds_as_supervised,
decoders=decoders,
read_config=read_config)
# If cache is enabled, we will call `repeat()` later after `cache()`.
if self._is_training and not self._cache:
dataset = dataset.repeat()
return dataset
@property @property
def tfds_info(self) -> tfds.core.DatasetInfo: def tfds_info(self) -> tfds.core.DatasetInfo:
"""Returns TFDS dataset info, if available.""" """Returns TFDS dataset info, if available."""
...@@ -297,14 +299,27 @@ class InputReader: ...@@ -297,14 +299,27 @@ class InputReader:
raise ValueError('tfds_info is not available, because the dataset ' raise ValueError('tfds_info is not available, because the dataset '
'is not loaded from tfds.') 'is not loaded from tfds.')
def _read_decode_and_parse_dataset( def get_files(self, input_path):
"""Gets matched files. Can be overridden by subclasses."""
if not input_path:
return None
# we want to combine / mix datasets
if isinstance(input_path, cfg.base_config.Config):
matched_files = {}
for k, v in input_path.as_dict().items():
matched_files[k] = match_files(v)
# single dataset
else:
matched_files = match_files(input_path)
return matched_files
def _read_data_source(
self, self,
matched_files: Union[Dict[str, List[str]], List[str]], matched_files: Union[Dict[str, List[str]], List[str]],
dataset_fn, dataset_fn,
batch_size: int,
input_context: Optional[tf.distribute.InputContext] = None, input_context: Optional[tf.distribute.InputContext] = None,
tfds_builder: bool = False) -> tf.data.Dataset: tfds_builder: Optional[tfds.core.DatasetBuilder] = None):
"""Returns a tf.data.Dataset object after reading, decoding, and parsing.""" """Reads the data source (files/tfds) to a dataset."""
def _files_to_dataset(files: List[str]) -> tf.data.Dataset: def _files_to_dataset(files: List[str]) -> tf.data.Dataset:
if len(files) > 1: if len(files) > 1:
...@@ -314,15 +329,66 @@ class InputReader: ...@@ -314,15 +329,66 @@ class InputReader:
'%d. We will send all input files to every worker. ' '%d. We will send all input files to every worker. '
'Please consider sharding your data into more files.', len(files), 'Please consider sharding your data into more files.', len(files),
input_context.num_input_pipelines) input_context.num_input_pipelines)
return self._read_files_then_shard(files, dataset_fn, input_context) return _read_files_then_shard(
files,
dataset_fn,
input_context,
sharding=self._sharding,
repeat=self._is_training and not self._cache)
else: else:
return self._shard_files_then_read(files, dataset_fn, input_context) return _shard_files_then_read(
files,
dataset_fn,
input_context,
seed=self._seed,
is_training=self._is_training,
sharding=self._sharding,
cache=self._cache,
cycle_length=self._cycle_length,
block_length=self._block_length,
deterministic=self._deterministic)
elif len(files) == 1: elif len(files) == 1:
return self._read_files_then_shard(files, dataset_fn, input_context) return _read_files_then_shard(
files,
dataset_fn,
input_context,
sharding=self._sharding,
repeat=self._is_training and not self._cache)
else: else:
raise ValueError('It is unexpected that `tfds_builder` is None and ' raise ValueError('It is unexpected that `tfds_builder` is None and '
'there is also no `files`.') 'there is also no `files`.')
if tfds_builder:
dataset = _read_tfds(
tfds_builder=self._tfds_builder,
tfds_split=self._tfds_split,
tfds_skip_decoding_feature=self._tfds_skip_decoding_feature,
tfds_as_supervised=self._tfds_as_supervised,
input_context=input_context,
seed=self._seed,
is_training=self._is_training,
cache=self._cache,
cycle_length=self._cycle_length,
block_length=self._block_length)
elif isinstance(matched_files, (list, tuple)):
dataset = _files_to_dataset(matched_files)
elif isinstance(matched_files, dict):
dataset = {}
for k, fs in matched_files.items():
dataset[k] = _files_to_dataset(fs)
else:
raise ValueError('`matched_files` should be a list or dict.')
return dataset
def _decode_and_parse_dataset(
self,
dataset: Union[tf.data.Dataset, Dict[Text, tf.data.Dataset]],
batch_size: int,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Returns a tf.data.Dataset object after shuffling, decoding, and parsing."""
def _shuffle_and_decode(ds): def _shuffle_and_decode(ds):
# If cache is enabled, we will call `shuffle()` later after `cache()`. # If cache is enabled, we will call `shuffle()` later after `cache()`.
if self._is_training and not self._cache: if self._is_training and not self._cache:
...@@ -331,20 +397,9 @@ class InputReader: ...@@ -331,20 +397,9 @@ class InputReader:
ds = _maybe_map_fn(ds, self._decoder_fn) ds = _maybe_map_fn(ds, self._decoder_fn)
return ds return ds
if tfds_builder: dataset = tf.nest.map_structure(_shuffle_and_decode, dataset)
dataset = self._read_tfds(input_context) if tf.nest.is_nested(dataset):
dataset = _shuffle_and_decode(dataset) dataset = self._combine_fn(dataset)
elif isinstance(matched_files, (list, tuple)):
dataset = _files_to_dataset(matched_files)
dataset = _shuffle_and_decode(dataset)
elif isinstance(matched_files, dict):
datasets = {}
for k, fs in matched_files.items():
datasets[k] = _files_to_dataset(fs)
datasets[k] = _shuffle_and_decode(datasets[k])
dataset = self._combine_fn(datasets)
else:
raise ValueError('`matched_files` should be a list or dict.')
if self._sample_fn is not None: if self._sample_fn is not None:
dataset = dataset.apply(self._sample_fn) dataset = dataset.apply(self._sample_fn)
...@@ -403,16 +458,16 @@ class InputReader: ...@@ -403,16 +458,16 @@ class InputReader:
job_name=self._tf_data_service_job_name)) job_name=self._tf_data_service_job_name))
return dataset return dataset
def read( def read(self,
self, input_context: Optional[tf.distribute.InputContext] = None,
input_context: Optional[tf.distribute.InputContext] = None dataset: Optional[tf.data.Dataset] = None) -> tf.data.Dataset:
) -> tf.data.Dataset:
"""Generates a tf.data.Dataset object.""" """Generates a tf.data.Dataset object."""
dataset = self._read_decode_and_parse_dataset(self._matched_files, if dataset is None:
self._dataset_fn, dataset = self._read_data_source(
self._global_batch_size, self._matched_files, self._dataset_fn, input_context,
input_context,
self._tfds_builder) self._tfds_builder)
dataset = self._decode_and_parse_dataset(dataset, self._global_batch_size,
input_context)
dataset = _maybe_map_fn(dataset, self._postprocess_fn) dataset = _maybe_map_fn(dataset, self._postprocess_fn)
dataset = self._maybe_apply_data_service(dataset, input_context) dataset = self._maybe_apply_data_service(dataset, input_context)
......
...@@ -204,6 +204,7 @@ class EncoderConfig(hyperparams.OneOfConfig): ...@@ -204,6 +204,7 @@ class EncoderConfig(hyperparams.OneOfConfig):
bigbird: BigBirdEncoderConfig = BigBirdEncoderConfig() bigbird: BigBirdEncoderConfig = BigBirdEncoderConfig()
kernel: KernelEncoderConfig = KernelEncoderConfig() kernel: KernelEncoderConfig = KernelEncoderConfig()
mobilebert: MobileBertEncoderConfig = MobileBertEncoderConfig() mobilebert: MobileBertEncoderConfig = MobileBertEncoderConfig()
teams: BertEncoderConfig = BertEncoderConfig()
xlnet: XLNetEncoderConfig = XLNetEncoderConfig() xlnet: XLNetEncoderConfig = XLNetEncoderConfig()
...@@ -436,6 +437,40 @@ def build_encoder(config: EncoderConfig, ...@@ -436,6 +437,40 @@ def build_encoder(config: EncoderConfig,
initializer=tf.keras.initializers.RandomNormal( initializer=tf.keras.initializers.RandomNormal(
stddev=encoder_cfg.initializer_range)) stddev=encoder_cfg.initializer_range))
if encoder_type == "teams":
embedding_cfg = dict(
vocab_size=encoder_cfg.vocab_size,
type_vocab_size=encoder_cfg.type_vocab_size,
hidden_size=encoder_cfg.hidden_size,
embedding_width=encoder_cfg.embedding_size,
max_seq_length=encoder_cfg.max_position_embeddings,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
dropout_rate=encoder_cfg.dropout_rate,
)
embedding_network = networks.PackedSequenceEmbedding(**embedding_cfg)
hidden_cfg = dict(
num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=encoder_cfg.intermediate_size,
intermediate_activation=tf_utils.get_activation(
encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
)
kwargs = dict(
embedding_cfg=embedding_cfg,
embedding_cls=embedding_network,
hidden_cfg=hidden_cfg,
num_hidden_instances=encoder_cfg.num_layers,
pooled_output_dim=encoder_cfg.hidden_size,
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
return_all_layer_outputs=encoder_cfg.return_all_encoder_outputs,
dict_outputs=True)
return networks.EncoderScaffold(**kwargs)
# Uses the default BERTEncoder configuration schema to create the encoder. # Uses the default BERTEncoder configuration schema to create the encoder.
# If it does not match, please add a switch branch by the encoder type. # If it does not match, please add a switch branch by the encoder type.
return networks.BertEncoder( return networks.BertEncoder(
......
...@@ -61,7 +61,6 @@ def bert_sentence_prediction() -> cfg.ExperimentConfig: ...@@ -61,7 +61,6 @@ def bert_sentence_prediction() -> cfg.ExperimentConfig:
'task.train_data.is_training != None', 'task.train_data.is_training != None',
'task.validation_data.is_training != None' 'task.validation_data.is_training != None'
]) ])
config.task.model.encoder.type = 'bert'
return config return config
...@@ -98,7 +97,6 @@ def bert_squad() -> cfg.ExperimentConfig: ...@@ -98,7 +97,6 @@ def bert_squad() -> cfg.ExperimentConfig:
'task.train_data.is_training != None', 'task.train_data.is_training != None',
'task.validation_data.is_training != None' 'task.validation_data.is_training != None'
]) ])
config.task.model.encoder.type = 'bert'
return config return config
......
...@@ -14,15 +14,16 @@ ...@@ -14,15 +14,16 @@
"""Funnel Transformer network.""" """Funnel Transformer network."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
from typing import Union, Collection from typing import Union, Sequence
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.nlp import keras_nlp from official.nlp import keras_nlp
def _pool_and_concat(data, unpool_length: int, stride: int, def _pool_and_concat(data, unpool_length: int, strides: Union[Sequence[int],
axes: Union[Collection[int], int]): int],
axes: Union[Sequence[int], int]):
"""Pools the data along a given axis with stride. """Pools the data along a given axis with stride.
It also skips first unpool_length elements. It also skips first unpool_length elements.
...@@ -30,7 +31,7 @@ def _pool_and_concat(data, unpool_length: int, stride: int, ...@@ -30,7 +31,7 @@ def _pool_and_concat(data, unpool_length: int, stride: int,
Args: Args:
data: Tensor to be pooled. data: Tensor to be pooled.
unpool_length: Leading elements to be skipped. unpool_length: Leading elements to be skipped.
stride: Stride for the given axis. strides: Strides for the given axes.
axes: Axes to pool the Tensor. axes: Axes to pool the Tensor.
Returns: Returns:
...@@ -39,8 +40,13 @@ def _pool_and_concat(data, unpool_length: int, stride: int, ...@@ -39,8 +40,13 @@ def _pool_and_concat(data, unpool_length: int, stride: int,
# Wraps the axes as a list. # Wraps the axes as a list.
if isinstance(axes, int): if isinstance(axes, int):
axes = [axes] axes = [axes]
if isinstance(strides, int):
strides = [strides] * len(axes)
else:
if len(strides) != len(axes):
raise ValueError('The lengths of strides and axes need to match.')
for axis in axes: for axis, stride in zip(axes, strides):
# Skips first `unpool_length` tokens. # Skips first `unpool_length` tokens.
unpool_tensor_shape = [slice(None)] * axis + [slice(None, unpool_length)] unpool_tensor_shape = [slice(None)] * axis + [slice(None, unpool_length)]
unpool_tensor = data[unpool_tensor_shape] unpool_tensor = data[unpool_tensor_shape]
...@@ -80,7 +86,9 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -80,7 +86,9 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
dropout. dropout.
attention_dropout: The dropout rate to use for the attention layers within attention_dropout: The dropout rate to use for the attention layers within
the transformer layers. the transformer layers.
pool_stride: Pooling stride to compress the sequence length. pool_stride: An int or a list of ints. Pooling stride(s) to compress the
sequence length. If set to int, each layer will have the same stride size.
If set to list, the number of elements needs to match num_layers.
unpool_length: Leading n tokens to be skipped from pooling. unpool_length: Leading n tokens to be skipped from pooling.
initializer: The initialzer to use for all weights in this encoder. initializer: The initialzer to use for all weights in this encoder.
output_range: The sequence output range, [0, output_range), by slicing the output_range: The sequence output range, [0, output_range), by slicing the
...@@ -185,12 +193,23 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -185,12 +193,23 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
activation='tanh', activation='tanh',
kernel_initializer=initializer, kernel_initializer=initializer,
name='pooler_transform') name='pooler_transform')
self._att_input_pool_layer = tf.keras.layers.MaxPooling1D( if isinstance(pool_stride, int):
pool_size=pool_stride, # TODO(b/197133196): Pooling layer can be shared.
strides=pool_stride, pool_strides = [pool_stride] * num_layers
else:
if len(pool_stride) != num_layers:
raise ValueError('Lengths of pool_stride and num_layers are not equal.')
pool_strides = pool_stride
self._att_input_pool_layers = []
for layer_pool_stride in pool_strides:
att_input_pool_layer = tf.keras.layers.MaxPooling1D(
pool_size=layer_pool_stride,
strides=layer_pool_stride,
padding='same', padding='same',
name='att_input_pool_layer') name='att_input_pool_layer')
self._pool_stride = pool_stride self._att_input_pool_layers.append(att_input_pool_layer)
self._pool_strides = pool_strides # This is a list here.
self._unpool_length = unpool_length self._unpool_length = unpool_length
self._config = { self._config = {
...@@ -250,11 +269,12 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -250,11 +269,12 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
attention_mask = _pool_and_concat( attention_mask = _pool_and_concat(
attention_mask, attention_mask,
unpool_length=self._unpool_length, unpool_length=self._unpool_length,
stride=self._pool_stride, strides=self._pool_strides[0],
axes=[1]) axes=[1])
for layer in self._transformer_layers: for i, layer in enumerate(self._transformer_layers):
# Pools layer for compressing the query length. # Pools layer for compressing the query length.
pooled_inputs = self._att_input_pool_layer(x[:, self._unpool_length:, :]) pooled_inputs = self._att_input_pool_layers[i](
x[:, self._unpool_length:, :])
query_inputs = tf.concat( query_inputs = tf.concat(
values=(tf.cast( values=(tf.cast(
x[:, :self._unpool_length, :], x[:, :self._unpool_length, :],
...@@ -262,10 +282,11 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -262,10 +282,11 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
axis=1) axis=1)
x = layer([query_inputs, x, attention_mask]) x = layer([query_inputs, x, attention_mask])
# Pools the corresponding attention_mask. # Pools the corresponding attention_mask.
if i < len(self._transformer_layers) - 1:
attention_mask = _pool_and_concat( attention_mask = _pool_and_concat(
attention_mask, attention_mask,
unpool_length=self._unpool_length, unpool_length=self._unpool_length,
stride=self._pool_stride, strides=[self._pool_strides[i+1], self._pool_strides[i]],
axes=[1, 2]) axes=[1, 2])
encoder_outputs.append(x) encoder_outputs.append(x)
......
...@@ -80,8 +80,24 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -80,8 +80,24 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual(tf.float32, data.dtype) self.assertAllEqual(tf.float32, data.dtype)
self.assertAllEqual(pooled_dtype, pooled.dtype) self.assertAllEqual(pooled_dtype, pooled.dtype)
def test_invalid_stride_and_num_layers(self):
hidden_size = 32
num_layers = 3
pool_stride = [2, 2]
unpool_length = 1
with self.assertRaisesRegex(ValueError,
"pool_stride and num_layers are not equal"):
_ = funnel_transformer.FunnelTransformerEncoder(
vocab_size=100,
hidden_size=hidden_size,
num_attention_heads=2,
num_layers=num_layers,
pool_stride=pool_stride,
unpool_length=unpool_length)
@parameterized.named_parameters( @parameterized.named_parameters(
("no_stride_no_unpool", 1, 0), ("no_stride_no_unpool", 1, 0),
("stride_list_with_unpool", [2, 3, 4], 1),
("large_stride_with_unpool", 3, 1), ("large_stride_with_unpool", 3, 1),
("large_stride_with_large_unpool", 5, 10), ("large_stride_with_large_unpool", 5, 10),
("no_stride_with_unpool", 1, 1), ("no_stride_with_unpool", 1, 1),
...@@ -110,11 +126,12 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -110,11 +126,12 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
expected_data_shape = [None, sequence_length, hidden_size] expected_data_shape = [None, sequence_length, hidden_size]
expected_pooled_shape = [None, hidden_size] expected_pooled_shape = [None, hidden_size]
self.assertLen(all_encoder_outputs, num_layers) self.assertLen(all_encoder_outputs, num_layers)
for data in all_encoder_outputs: if isinstance(pool_stride, int):
expected_data_shape[1] = unpool_length + (expected_data_shape[1] + pool_stride = [pool_stride] * num_layers
pool_stride - 1 - for layer_pool_stride, data in zip(pool_stride, all_encoder_outputs):
unpool_length) // pool_stride expected_data_shape[1] = unpool_length + (
print("shapes:", expected_data_shape, data.shape.as_list()) expected_data_shape[1] + layer_pool_stride - 1 -
unpool_length) // layer_pool_stride
self.assertAllEqual(expected_data_shape, data.shape.as_list()) self.assertAllEqual(expected_data_shape, data.shape.as_list())
self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list()) self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
......
...@@ -62,6 +62,8 @@ class PackedSequenceEmbedding(tf.keras.Model): ...@@ -62,6 +62,8 @@ class PackedSequenceEmbedding(tf.keras.Model):
pack_multiple_sequences=False, pack_multiple_sequences=False,
**kwargs): **kwargs):
initializer = tf.keras.initializers.get(initializer) initializer = tf.keras.initializers.get(initializer)
if embedding_width is None:
embedding_width = hidden_size
config_dict = { config_dict = {
'vocab_size': vocab_size, 'vocab_size': vocab_size,
'type_vocab_size': type_vocab_size, 'type_vocab_size': type_vocab_size,
......
task: task:
model: model:
encoder: encoder:
bert: teams:
attention_dropout_rate: 0.1 attention_dropout_rate: 0.1
dropout_rate: 0.1 dropout_rate: 0.1
embedding_size: 768 embedding_size: 768
...@@ -14,3 +14,4 @@ task: ...@@ -14,3 +14,4 @@ task:
num_layers: 12 num_layers: 12
type_vocab_size: 2 type_vocab_size: 2
vocab_size: 30522 vocab_size: 30522
type: teams
task: task:
model: model:
encoder: encoder:
bert: teams:
attention_dropout_rate: 0.1 attention_dropout_rate: 0.1
dropout_rate: 0.1 dropout_rate: 0.1
embedding_size: 128 embedding_size: 128
...@@ -14,3 +14,4 @@ task: ...@@ -14,3 +14,4 @@ task:
num_layers: 12 num_layers: 12
type_vocab_size: 2 type_vocab_size: 2
vocab_size: 30522 vocab_size: 30522
type: teams
...@@ -64,9 +64,6 @@ def get_encoder(bert_config, ...@@ -64,9 +64,6 @@ def get_encoder(bert_config,
Returns: Returns:
A encoder object. A encoder object.
""" """
# embedding_size is required for PackedSequenceEmbedding.
if bert_config.embedding_size is None:
bert_config.embedding_size = bert_config.hidden_size
embedding_cfg = dict( embedding_cfg = dict(
vocab_size=bert_config.vocab_size, vocab_size=bert_config.vocab_size,
type_vocab_size=bert_config.type_vocab_size, type_vocab_size=bert_config.type_vocab_size,
......
...@@ -21,6 +21,7 @@ pyyaml>=5.1 ...@@ -21,6 +21,7 @@ pyyaml>=5.1
opencv-python-headless opencv-python-headless
Pillow Pillow
pycocotools pycocotools
waymo-open-dataset-tf-2-6-0
# NLP related dependencies # NLP related dependencies
seqeval seqeval
sentencepiece sentencepiece
......
...@@ -208,6 +208,10 @@ class MaskRCNNTask(cfg.TaskConfig): ...@@ -208,6 +208,10 @@ class MaskRCNNTask(cfg.TaskConfig):
per_category_metrics: bool = False per_category_metrics: bool = False
# If set, we only use masks for the specified class IDs. # If set, we only use masks for the specified class IDs.
allowed_mask_class_ids: Optional[List[int]] = None allowed_mask_class_ids: Optional[List[int]] = None
# If set, the COCO metrics will be computed.
use_coco_metrics: bool = True
# If set, the Waymo Open Dataset evaluator would be used.
use_wod_metrics: bool = False
COCO_INPUT_PATH_BASE = 'coco' COCO_INPUT_PATH_BASE = 'coco'
......
...@@ -113,7 +113,7 @@ class CombinationDatasetInputReader(input_reader.InputReader): ...@@ -113,7 +113,7 @@ class CombinationDatasetInputReader(input_reader.InputReader):
self._pseudo_label_file_pattern = params.pseudo_label_data.input_path self._pseudo_label_file_pattern = params.pseudo_label_data.input_path
self._pseudo_label_dataset_fn = pseudo_label_dataset_fn self._pseudo_label_dataset_fn = pseudo_label_dataset_fn
self._pseudo_label_data_ratio = params.pseudo_label_data.data_ratio self._pseudo_label_data_ratio = params.pseudo_label_data.data_ratio
self._pseudo_label_matched_files = self._match_files( self._pseudo_label_matched_files = input_reader.match_files(
self._pseudo_label_file_pattern) self._pseudo_label_file_pattern)
if not self._drop_remainder: if not self._drop_remainder:
raise ValueError( raise ValueError(
...@@ -134,14 +134,20 @@ class CombinationDatasetInputReader(input_reader.InputReader): ...@@ -134,14 +134,20 @@ class CombinationDatasetInputReader(input_reader.InputReader):
'resulting in a 0 batch size for one of the datasets.'.format( 'resulting in a 0 batch size for one of the datasets.'.format(
self._global_batch_size, self._pseudo_label_data_ratio)) self._global_batch_size, self._pseudo_label_data_ratio))
labeled_dataset = self._read_decode_and_parse_dataset( def _read_decode_and_parse_dataset(matched_files, dataset_fn, batch_size,
input_context, tfds_builder):
dataset = self._read_data_source(matched_files, dataset_fn, input_context,
tfds_builder)
return self._decode_and_parse_dataset(dataset, batch_size, input_context)
labeled_dataset = _read_decode_and_parse_dataset(
matched_files=self._matched_files, matched_files=self._matched_files,
dataset_fn=self._dataset_fn, dataset_fn=self._dataset_fn,
batch_size=labeled_batch_size, batch_size=labeled_batch_size,
input_context=input_context, input_context=input_context,
tfds_builder=self._tfds_builder) tfds_builder=self._tfds_builder)
pseudo_labeled_dataset = self._read_decode_and_parse_dataset( pseudo_labeled_dataset = _read_decode_and_parse_dataset(
matched_files=self._pseudo_label_matched_files, matched_files=self._pseudo_label_matched_files,
dataset_fn=self._pseudo_label_dataset_fn, dataset_fn=self._pseudo_label_dataset_fn,
batch_size=pl_batch_size, batch_size=pl_batch_size,
......
...@@ -331,7 +331,7 @@ class Parser(parser.Parser): ...@@ -331,7 +331,7 @@ class Parser(parser.Parser):
'source_id': data['source_id'], 'source_id': data['source_id'],
'height': data['height'], 'height': data['height'],
'width': data['width'], 'width': data['width'],
'num_detections': tf.shape(data['groundtruth_classes']), 'num_detections': tf.shape(data['groundtruth_classes'])[0],
'boxes': boxes, 'boxes': boxes,
'classes': data['groundtruth_classes'], 'classes': data['groundtruth_classes'],
'areas': data['groundtruth_area'], 'areas': data['groundtruth_area'],
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""2D detection evaluator for the Waymo Open Dataset."""
import pprint
from absl import logging
import tensorflow as tf
from official.vision.beta.ops import box_ops
from waymo_open_dataset import label_pb2
from waymo_open_dataset.metrics.python import wod_detection_evaluator
from waymo_open_dataset.protos import breakdown_pb2
from waymo_open_dataset.protos import metrics_pb2
def get_2d_detection_default_config():
"""Returns the config proto for WOD 2D detection Evaluation."""
config = metrics_pb2.Config()
config.breakdown_generator_ids.append(breakdown_pb2.Breakdown.OBJECT_TYPE)
difficulty = config.difficulties.add()
difficulty.levels.append(label_pb2.Label.LEVEL_1)
difficulty.levels.append(label_pb2.Label.LEVEL_2)
config.breakdown_generator_ids.append(breakdown_pb2.Breakdown.ALL_BUT_SIGN)
difficulty = config.difficulties.add()
difficulty.levels.append(label_pb2.Label.LEVEL_1)
difficulty.levels.append(label_pb2.Label.LEVEL_2)
config.matcher_type = metrics_pb2.MatcherProto.TYPE_HUNGARIAN
config.iou_thresholds.append(0.0)
config.iou_thresholds.append(0.7)
config.iou_thresholds.append(0.5)
config.iou_thresholds.append(0.5)
config.iou_thresholds.append(0.5)
config.box_type = label_pb2.Label.Box.TYPE_2D
for i in range(100):
config.score_cutoffs.append(i * 0.01)
config.score_cutoffs.append(1.0)
return config
class WOD2dDetectionEvaluator(wod_detection_evaluator.WODDetectionEvaluator):
"""WOD 2D detection evaluation metric class."""
def __init__(self, config=None):
if config is None:
config = get_2d_detection_default_config()
super().__init__(config=config)
def _remove_padding(self, tensor_dict, num_valid):
"""Remove the paddings of the prediction/groundtruth data."""
result_tensor_dict = {}
gather_indices = tf.range(num_valid)
for k, v in tensor_dict.items():
if 'frame_id' in k:
result_tensor_dict[k] = tf.tile([v], [num_valid])
else:
result_tensor_dict[k] = tf.gather(v, gather_indices)
return result_tensor_dict
def update_state(self, groundtruths, predictions):
"""Update the metrics state with prediction and groundtruth data.
Args:
groundtruths: a dictionary of Tensors including the fields below.
Required fields:
- source_id: a numpy array of int or string of shape [batch_size].
- num_detections: a numpy array of int of shape [batch_size].
- boxes: a numpy array of float of shape [batch_size, K, 4].
- classes: a numpy array of int of shape [batch_size, K].
- difficulties: a numpy array of int of shape [batch_size, K].
predictions: a dictionary of tensors including the fields below.
Required fields:
- source_id: a numpy array of int or string of shape [batch_size].
- image_info: a numpy array of float of shape [batch_size, 4, 2].
- num_detections: a numpy array of int of shape [batch_size].
- detection_boxes: a numpy array of float of shape [batch_size, K, 4].
- detection_classes: a numpy array of int of shape [batch_size, K].
- detection_scores: a numpy array of float of shape [batch_size, K].
"""
# Preprocess potentially aggregated tensors.
for k, v in groundtruths.items():
if isinstance(v, tuple):
groundtruths[k] = tf.concat(v, axis=0)
for k, v in predictions.items():
if isinstance(v, tuple):
predictions[k] = tf.concat(v, axis=0)
# Change cyclists' type id from 3 to 4, where 3 is reserved for sign.
groundtruth_type = tf.cast(groundtruths['classes'], tf.uint8)
groundtruth_type = tf.where(
tf.equal(groundtruth_type, 3),
tf.ones_like(groundtruth_type) * 4, groundtruth_type)
prediction_type = tf.cast(predictions['detection_classes'], tf.uint8)
prediction_type = tf.where(
tf.equal(prediction_type, 3),
tf.ones_like(prediction_type) * 4, prediction_type)
# Rescale the detection boxes back to original scale.
image_scale = tf.tile(predictions['image_info'][:, 2:3, :], (1, 1, 2))
prediction_bbox = predictions['detection_boxes'] / image_scale
batch_size = tf.shape(groundtruths['source_id'])[0]
for i in tf.range(batch_size):
frame_groundtruths = {
'ground_truth_frame_id':
groundtruths['source_id'][i],
'ground_truth_bbox':
box_ops.yxyx_to_cycxhw(
tf.cast(groundtruths['boxes'][i], tf.float32)),
'ground_truth_type':
groundtruth_type[i],
'ground_truth_difficulty':
tf.cast(groundtruths['difficulties'][i], tf.uint8),
}
frame_groundtruths = self._remove_padding(
frame_groundtruths, groundtruths['num_detections'][i])
frame_predictions = {
'prediction_frame_id':
groundtruths['source_id'][i],
'prediction_bbox':
box_ops.yxyx_to_cycxhw(
tf.cast(prediction_bbox[i], tf.float32)),
'prediction_type':
prediction_type[i],
'prediction_score':
tf.cast(predictions['detection_scores'][i], tf.float32),
'prediction_overlap_nlz':
tf.zeros_like(predictions['detection_scores'][i], dtype=tf.bool)
}
frame_predictions = self._remove_padding(frame_predictions,
predictions['num_detections'][i])
super().update_state(frame_groundtruths, frame_predictions)
def evaluate(self):
"""Compute the final metrics."""
ap, _, _, _, _ = super().evaluate()
metric_dict = {}
for i, name in enumerate(self._breakdown_names):
# Skip sign metrics in 2d detection task.
if 'SIGN' in name:
continue
metric_dict['WOD metrics/{}/AP'.format(name)] = ap[i]
pp = pprint.PrettyPrinter()
logging.info('WOD Detection Metrics: \n %s', pp.pformat(metric_dict))
return metric_dict
...@@ -228,13 +228,17 @@ class NASFPN(tf.keras.Model): ...@@ -228,13 +228,17 @@ class NASFPN(tf.keras.Model):
if input_level < target_level: if input_level < target_level:
stride = int(2 ** (target_level - input_level)) stride = int(2 ** (target_level - input_level))
x = tf.keras.layers.MaxPool2D( return tf.keras.layers.MaxPool2D(
pool_size=stride, strides=stride, padding='same')(x) pool_size=stride, strides=stride, padding='same')(x)
elif input_level > target_level: if input_level > target_level:
scale = int(2 ** (input_level - target_level)) scale = int(2 ** (input_level - target_level))
x = spatial_transform_ops.nearest_upsampling(x, scale=scale) return spatial_transform_ops.nearest_upsampling(x, scale=scale)
return x # Force output x to be the same dtype as mixed precision policy. This avoids
# dtype mismatch when one input (by default float32 dtype) does not meet all
# the above conditions and is output unchanged, while other inputs are
# processed to have different dtype, e.g., using bfloat16 on TPU.
return tf.cast(x, dtype=tf.keras.layers.Layer().dtype_policy.compute_dtype)
def _global_attention(self, feat0, feat1): def _global_attention(self, feat0, feat1):
m = tf.math.reduce_max(feat0, axis=[1, 2], keepdims=True) m = tf.math.reduce_max(feat0, axis=[1, 2], keepdims=True)
......
...@@ -155,6 +155,8 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -155,6 +155,8 @@ class SegmentationHead(tf.keras.layers.Layer):
depthwise_initializer=random_initializer, depthwise_initializer=random_initializer,
depthwise_regularizer=self._config_dict['kernel_regularizer'], depthwise_regularizer=self._config_dict['kernel_regularizer'],
depth_multiplier=1)) depth_multiplier=1))
norm_name = 'segmentation_head_depthwise_norm_{}'.format(i)
self._norms.append(bn_op(name=norm_name, **bn_kwargs))
conv_name = 'segmentation_head_conv_{}'.format(i) conv_name = 'segmentation_head_conv_{}'.format(i)
self._convs.append( self._convs.append(
conv_op( conv_op(
......
...@@ -613,56 +613,14 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer): ...@@ -613,56 +613,14 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
} }
super(MultilevelDetectionGenerator, self).__init__(**kwargs) super(MultilevelDetectionGenerator, self).__init__(**kwargs)
def __call__(self, def _decode_multilevel_outputs(
self,
raw_boxes: Mapping[str, tf.Tensor], raw_boxes: Mapping[str, tf.Tensor],
raw_scores: Mapping[str, tf.Tensor], raw_scores: Mapping[str, tf.Tensor],
anchor_boxes: tf.Tensor, anchor_boxes: tf.Tensor,
image_shape: tf.Tensor, image_shape: tf.Tensor,
raw_attributes: Optional[Mapping[str, tf.Tensor]] = None): raw_attributes: Optional[Mapping[str, tf.Tensor]] = None):
"""Generates final detections. """Collects dict of multilevel boxes, scores, attributes into lists."""
Args:
raw_boxes: A `dict` with keys representing FPN levels and values
representing box tenors of shape `[batch, feature_h, feature_w,
num_anchors * 4]`.
raw_scores: A `dict` with keys representing FPN levels and values
representing logit tensors of shape `[batch, feature_h, feature_w,
num_anchors]`.
anchor_boxes: A `tf.Tensor` of shape of [batch_size, K, 4] representing
the corresponding anchor boxes w.r.t `box_outputs`.
image_shape: A `tf.Tensor` of shape of [batch_size, 2] storing the image
height and width w.r.t. the scaled image, i.e. the same image space as
`box_outputs` and `anchor_boxes`.
raw_attributes: If not None, a `dict` of (attribute_name,
attribute_prediction) pairs. `attribute_prediction` is a dict that
contains keys representing FPN levels and values representing tenors of
shape `[batch, feature_h, feature_w, num_anchors * attribute_size]`.
Returns:
If `apply_nms` = True, the return is a dictionary with keys:
`detection_boxes`: A `float` tf.Tensor of shape
[batch, max_num_detections, 4] representing top detected boxes in
[y1, x1, y2, x2].
`detection_scores`: A `float` tf.Tensor of shape
[batch, max_num_detections] representing sorted confidence scores for
detected boxes. The values are between [0, 1].
`detection_classes`: An `int` tf.Tensor of shape
[batch, max_num_detections] representing classes for detected boxes.
`num_detections`: An `int` tf.Tensor of shape [batch] only the first
`num_detections` boxes are valid detections
`detection_attributes`: A dict. Values of the dict is a `float`
tf.Tensor of shape [batch, max_num_detections, attribute_size]
representing attribute predictions for detected boxes.
If `apply_nms` = False, the return is a dictionary with keys:
`decoded_boxes`: A `float` tf.Tensor of shape [batch, num_raw_boxes, 4]
representing all the decoded boxes.
`decoded_box_scores`: A `float` tf.Tensor of shape
[batch, num_raw_boxes] representing socres of all the decoded boxes.
`decoded_box_attributes`: A dict. Values in the dict is a
`float` tf.Tensor of shape [batch, num_raw_boxes, attribute_size]
representing attribute predictions of all the decoded boxes.
"""
# Collects outputs from all levels into a list.
boxes = [] boxes = []
scores = [] scores = []
if raw_attributes: if raw_attributes:
...@@ -728,6 +686,60 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer): ...@@ -728,6 +686,60 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
attributes[att_name] = tf.concat(attributes[att_name], axis=1) attributes[att_name] = tf.concat(attributes[att_name], axis=1)
attributes[att_name] = tf.expand_dims(attributes[att_name], axis=2) attributes[att_name] = tf.expand_dims(attributes[att_name], axis=2)
return boxes, scores, attributes
def __call__(self,
raw_boxes: Mapping[str, tf.Tensor],
raw_scores: Mapping[str, tf.Tensor],
anchor_boxes: tf.Tensor,
image_shape: tf.Tensor,
raw_attributes: Optional[Mapping[str, tf.Tensor]] = None):
"""Generates final detections.
Args:
raw_boxes: A `dict` with keys representing FPN levels and values
representing box tenors of shape `[batch, feature_h, feature_w,
num_anchors * 4]`.
raw_scores: A `dict` with keys representing FPN levels and values
representing logit tensors of shape `[batch, feature_h, feature_w,
num_anchors]`.
anchor_boxes: A `tf.Tensor` of shape of [batch_size, K, 4] representing
the corresponding anchor boxes w.r.t `box_outputs`.
image_shape: A `tf.Tensor` of shape of [batch_size, 2] storing the image
height and width w.r.t. the scaled image, i.e. the same image space as
`box_outputs` and `anchor_boxes`.
raw_attributes: If not None, a `dict` of (attribute_name,
attribute_prediction) pairs. `attribute_prediction` is a dict that
contains keys representing FPN levels and values representing tenors of
shape `[batch, feature_h, feature_w, num_anchors * attribute_size]`.
Returns:
If `apply_nms` = True, the return is a dictionary with keys:
`detection_boxes`: A `float` tf.Tensor of shape
[batch, max_num_detections, 4] representing top detected boxes in
[y1, x1, y2, x2].
`detection_scores`: A `float` tf.Tensor of shape
[batch, max_num_detections] representing sorted confidence scores for
detected boxes. The values are between [0, 1].
`detection_classes`: An `int` tf.Tensor of shape
[batch, max_num_detections] representing classes for detected boxes.
`num_detections`: An `int` tf.Tensor of shape [batch] only the first
`num_detections` boxes are valid detections
`detection_attributes`: A dict. Values of the dict is a `float`
tf.Tensor of shape [batch, max_num_detections, attribute_size]
representing attribute predictions for detected boxes.
If `apply_nms` = False, the return is a dictionary with keys:
`decoded_boxes`: A `float` tf.Tensor of shape [batch, num_raw_boxes, 4]
representing all the decoded boxes.
`decoded_box_scores`: A `float` tf.Tensor of shape
[batch, num_raw_boxes] representing socres of all the decoded boxes.
`decoded_box_attributes`: A dict. Values in the dict is a
`float` tf.Tensor of shape [batch, num_raw_boxes, attribute_size]
representing attribute predictions of all the decoded boxes.
"""
boxes, scores, attributes = self._decode_multilevel_outputs(
raw_boxes, raw_scores, anchor_boxes, image_shape, raw_attributes)
if not self._config_dict['apply_nms']: if not self._config_dict['apply_nms']:
return { return {
'decoded_boxes': boxes, 'decoded_boxes': boxes,
......
...@@ -133,3 +133,4 @@ class DiceScore: ...@@ -133,3 +133,4 @@ class DiceScore:
if self._per_class_metric: if self._per_class_metric:
for class_id in range(self._num_classes): for class_id in range(self._num_classes):
self._dice_scores_per_class[class_id] = tf.Variable(0.0) self._dice_scores_per_class[class_id] = tf.Variable(0.0)
self._count_per_class[class_id] = tf.Variable(0.0)
...@@ -342,4 +342,6 @@ class SemanticSegmentation3DTask(base_task.Task): ...@@ -342,4 +342,6 @@ class SemanticSegmentation3DTask(base_task.Task):
metric_name = self.metrics[0].name + '/class_{0}'.format( metric_name = self.metrics[0].name + '/class_{0}'.format(
i - 1) if i > 0 else self.metrics[0].name i - 1) if i > 0 else self.metrics[0].name
result.update({metric_name: metric_val}) result.update({metric_name: metric_val})
else:
result.update({self.metrics[0].name: metric})
return result return result
...@@ -21,7 +21,7 @@ import tensorflow as tf ...@@ -21,7 +21,7 @@ import tensorflow as tf
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.vision.beta import configs from official.vision.beta import configs
from official.vision.beta.tasks import image_classification as img_cls_task from official.vision.beta import tasks
def create_representative_dataset( def create_representative_dataset(
...@@ -39,7 +39,13 @@ def create_representative_dataset( ...@@ -39,7 +39,13 @@ def create_representative_dataset(
""" """
if isinstance(params.task, if isinstance(params.task,
configs.image_classification.ImageClassificationTask): configs.image_classification.ImageClassificationTask):
task = img_cls_task.ImageClassificationTask(params.task)
task = tasks.image_classification.ImageClassificationTask(params.task)
elif isinstance(params.task, configs.retinanet.RetinaNetTask):
task = tasks.retinanet.RetinaNetTask(params.task)
elif isinstance(params.task,
configs.semantic_segmentation.SemanticSegmentationTask):
task = tasks.semantic_segmentation.SemanticSegmentationTask(params.task)
else: else:
raise ValueError('Task {} not supported.'.format(type(params.task))) raise ValueError('Task {} not supported.'.format(type(params.task)))
# Ensure batch size is 1 for TFLite model. # Ensure batch size is 1 for TFLite model.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment