Commit 3c22c16d authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 410610733
parent 763bee61
......@@ -247,14 +247,12 @@ class Trainer(_AsyncTrainer):
self._validation_loss = tf.keras.metrics.Mean(
"validation_loss", dtype=tf.float32)
model_metrics = model.metrics if hasattr(model, "metrics") else []
self._train_metrics = self.task.build_metrics(
training=True) + model_metrics
self._validation_metrics = self.task.build_metrics(
training=False) + model_metrics
self.init_async()
if train:
self._train_metrics = self.task.build_metrics(
training=True) + model_metrics
train_dataset = train_dataset or self.distribute_dataset(
self.task.build_inputs, self.config.task.train_data)
orbit.StandardTrainer.__init__(
......@@ -266,6 +264,8 @@ class Trainer(_AsyncTrainer):
use_tpu_summary_optimization=config.trainer.allow_tpu_summary))
if evaluate:
self._validation_metrics = self.task.build_metrics(
training=False) + model_metrics
validation_dataset = validation_dataset or self.distribute_dataset(
self.task.build_inputs, self.config.task.validation_data)
orbit.StandardEvaluator.__init__(
......@@ -403,10 +403,7 @@ class Trainer(_AsyncTrainer):
"""See base class."""
def step_fn(inputs):
if self.config.runtime.enable_xla and (self.config.runtime.num_gpus > 0):
task_train_step = tf.function(self.task.train_step, jit_compile=True)
else:
task_train_step = self.task.train_step
task_train_step = self.task.train_step
logs = task_train_step(
inputs,
model=self.model,
......
......@@ -23,8 +23,9 @@ from official.vision.beta.dataloaders import decoder
def _generate_source_id(image_bytes):
# Hashing using 22 bits since float32 has only 23 mantissa bits.
return tf.strings.as_string(
tf.strings.to_hash_bucket_fast(image_bytes, 2 ** 63 - 1))
tf.strings.to_hash_bucket_fast(image_bytes, 2 ** 22 - 1))
class TfExampleDecoder(decoder.Decoder):
......
......@@ -31,7 +31,7 @@ def process_source_id(source_id: tf.Tensor) -> tf.Tensor:
A formatted source ID.
"""
if source_id.dtype == tf.string:
source_id = tf.cast(tf.strings.to_number(source_id), tf.int64)
source_id = tf.strings.to_number(source_id, tf.int64)
with tf.control_dependencies([source_id]):
source_id = tf.cond(
pred=tf.equal(tf.size(input=source_id), 0),
......
......@@ -212,6 +212,8 @@ def convert_groundtruths_to_coco_dataset(groundtruths, label_map=None):
gt_annotations = []
num_batches = len(groundtruths['source_id'])
for i in range(num_batches):
logging.info(
'convert_groundtruths_to_coco_dataset: Processing annotation %d', i)
max_num_instances = groundtruths['classes'][i].shape[1]
batch_size = groundtruths['source_id'][i].shape[0]
for j in range(batch_size):
......@@ -283,11 +285,13 @@ def convert_groundtruths_to_coco_dataset(groundtruths, label_map=None):
class COCOGroundtruthGenerator:
"""Generates the groundtruth annotations from a single example."""
def __init__(self, file_pattern, file_type, num_examples, include_mask):
def __init__(self, file_pattern, file_type, num_examples, include_mask,
regenerate_source_id=False):
self._file_pattern = file_pattern
self._num_examples = num_examples
self._include_mask = include_mask
self._dataset_fn = dataset_fn.pick_dataset_fn(file_type)
self._regenerate_source_id = regenerate_source_id
def _parse_single_example(self, example):
"""Parses a single serialized tf.Example proto.
......@@ -312,16 +316,21 @@ class COCOGroundtruthGenerator:
mask of each instance.
"""
decoder = tf_example_decoder.TfExampleDecoder(
include_mask=self._include_mask)
include_mask=self._include_mask,
regenerate_source_id=self._regenerate_source_id)
decoded_tensors = decoder.decode(example)
image = decoded_tensors['image']
image_size = tf.shape(image)[0:2]
boxes = box_ops.denormalize_boxes(
decoded_tensors['groundtruth_boxes'], image_size)
source_id = decoded_tensors['source_id']
if source_id.dtype is tf.string:
source_id = tf.strings.to_number(source_id, out_type=tf.int64)
groundtruths = {
'source_id': tf.strings.to_number(
decoded_tensors['source_id'], out_type=tf.int64),
'source_id': source_id,
'height': decoded_tensors['height'],
'width': decoded_tensors['width'],
'num_detections': tf.shape(decoded_tensors['groundtruth_classes'])[0],
......@@ -341,9 +350,10 @@ class COCOGroundtruthGenerator:
dataset = tf.data.Dataset.list_files(self._file_pattern, shuffle=False)
dataset = dataset.interleave(
map_func=lambda filename: self._dataset_fn(filename).prefetch(1),
cycle_length=12,
cycle_length=None,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.take(self._num_examples)
dataset = dataset.map(self._parse_single_example,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(1, drop_remainder=False)
......@@ -351,18 +361,18 @@ class COCOGroundtruthGenerator:
return dataset
def __call__(self):
for groundtruth_result in self._build_pipeline():
yield groundtruth_result
return self._build_pipeline()
def scan_and_generator_annotation_file(file_pattern: str,
file_type: str,
num_samples: int,
include_mask: bool,
annotation_file: str):
annotation_file: str,
regenerate_source_id: bool = False):
"""Scans and generate the COCO-style annotation JSON file given a dataset."""
groundtruth_generator = COCOGroundtruthGenerator(
file_pattern, file_type, num_samples, include_mask)
file_pattern, file_type, num_samples, include_mask, regenerate_source_id)
generate_annotation_file(groundtruth_generator, annotation_file)
......@@ -371,7 +381,8 @@ def generate_annotation_file(groundtruth_generator,
"""Generates COCO-style annotation JSON file given a groundtruth generator."""
groundtruths = {}
logging.info('Loading groundtruth annotations from dataset to memory...')
for groundtruth in groundtruth_generator():
for i, groundtruth in enumerate(groundtruth_generator()):
logging.info('generate_annotation_file: Processing annotation %d', i)
for k, v in six.iteritems(groundtruth):
if k not in groundtruths:
groundtruths[k] = [v]
......
......@@ -13,14 +13,13 @@
# limitations under the License.
"""Panoptic MaskRCNN task definition."""
from typing import Any, Dict, List, Mapping, Optional, Tuple
from typing import Any, List, Mapping, Optional, Tuple, Dict
from absl import logging
import tensorflow as tf
from official.common import dataset_fn
from official.core import task_factory
from official.vision.beta.dataloaders import input_reader_factory
from official.vision.beta.evaluation import coco_evaluator
from official.vision.beta.evaluation import panoptic_quality_evaluator
from official.vision.beta.evaluation import segmentation_metrics
from official.vision.beta.losses import segmentation_losses
......@@ -235,10 +234,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
dtype=tf.float32)
else:
self.coco_metric = coco_evaluator.COCOEvaluator(
annotation_file=self.task_config.annotation_file,
include_mask=self.task_config.model.include_mask,
per_category_metrics=self.task_config.per_category_metrics)
self._build_coco_metrics()
rescale_predictions = (not self.task_config.validation_data.parser
.segmentation_resize_eval_groundtruth)
......@@ -430,7 +426,7 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
result = {}
result[self.coco_metric.name] = super(
result = super(
PanopticMaskRCNNTask, self).reduce_aggregated_logs(
aggregated_logs=aggregated_logs,
global_step=global_step)
......
......@@ -275,7 +275,9 @@ class MaskRCNNTask(base_task.Task):
self._task_config.validation_data.input_path,
self._task_config.validation_data.file_type,
self._task_config.validation_data.num_examples,
self.task_config.model.include_mask, annotation_path)
self.task_config.model.include_mask, annotation_path,
regenerate_source_id=self._task_config.validation_data.decoder
.simple_decoder.regenerate_source_id)
self.coco_metric = coco_evaluator.COCOEvaluator(
annotation_file=annotation_path,
include_mask=self._task_config.model.include_mask,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment