Unverified Commit 982f457a authored by Ruslan Baratov's avatar Ruslan Baratov Committed by GitHub
Browse files

DeepLab: Cityscapes splits refactoring (#8870)

* DeepLab: Cityscapes splits refactoring

* DeepLab: Process 'test_fine' split in Cityscapes
parent a8eae76f
...@@ -113,17 +113,23 @@ def _get_files(data, dataset_split): ...@@ -113,17 +113,23 @@ def _get_files(data, dataset_split):
Args: Args:
data: String, desired data ('image' or 'label'). data: String, desired data ('image' or 'label').
dataset_split: String, dataset split ('train', 'val', 'test') dataset_split: String, dataset split ('train_fine', 'val_fine', 'test_fine')
Returns: Returns:
A list of sorted file names or None when getting label for A list of sorted file names or None when getting label for
test set. test set.
""" """
if data == 'label' and dataset_split == 'test': if dataset_split == 'train_fine':
return None split_dir = 'train'
elif dataset_split == 'val_fine':
split_dir = 'val'
elif dataset_split == 'test_fine':
split_dir = 'test'
else:
raise RuntimeError("Split {} is not supported".format(dataset_split))
pattern = '*%s.%s' % (_POSTFIX_MAP[data], _DATA_FORMAT_MAP[data]) pattern = '*%s.%s' % (_POSTFIX_MAP[data], _DATA_FORMAT_MAP[data])
search_files = os.path.join( search_files = os.path.join(
FLAGS.cityscapes_root, _FOLDERS_MAP[data], dataset_split, '*', pattern) FLAGS.cityscapes_root, _FOLDERS_MAP[data], split_dir, '*', pattern)
filenames = glob.glob(search_files) filenames = glob.glob(search_files)
return sorted(filenames) return sorted(filenames)
...@@ -132,7 +138,7 @@ def _convert_dataset(dataset_split): ...@@ -132,7 +138,7 @@ def _convert_dataset(dataset_split):
"""Converts the specified dataset split to TFRecord format. """Converts the specified dataset split to TFRecord format.
Args: Args:
dataset_split: The dataset split (e.g., train, val). dataset_split: The dataset split (e.g., train_fine, val_fine).
Raises: Raises:
RuntimeError: If loaded image and label have different shape, or if the RuntimeError: If loaded image and label have different shape, or if the
...@@ -152,7 +158,7 @@ def _convert_dataset(dataset_split): ...@@ -152,7 +158,7 @@ def _convert_dataset(dataset_split):
label_reader = build_data.ImageReader('png', channels=1) label_reader = build_data.ImageReader('png', channels=1)
for shard_id in range(_NUM_SHARDS): for shard_id in range(_NUM_SHARDS):
shard_filename = '%s_fine-%05d-of-%05d.tfrecord' % ( shard_filename = '%s-%05d-of-%05d.tfrecord' % (
dataset_split, shard_id, _NUM_SHARDS) dataset_split, shard_id, _NUM_SHARDS)
output_filename = os.path.join(FLAGS.output_dir, shard_filename) output_filename = os.path.join(FLAGS.output_dir, shard_filename)
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer: with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
...@@ -183,8 +189,8 @@ def _convert_dataset(dataset_split): ...@@ -183,8 +189,8 @@ def _convert_dataset(dataset_split):
def main(unused_argv): def main(unused_argv):
# Only support converting 'train' and 'val' sets for now. # Only support converting 'train_fine', 'val_fine' and 'test_fine' sets for now.
for dataset_split in ['train', 'val']: for dataset_split in ['train_fine', 'val_fine', 'test_fine']:
_convert_dataset(dataset_split) _convert_dataset(dataset_split)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment