".github/vscode:/vscode.git/clone" did not exist on "f8e3ce890c140f45c089778e0674dd66e68b1b49"
Commit c0a47fd4 authored by Frederick's avatar Frederick
Browse files

Merge pull request #10696 from gunho1123:master

PiperOrigin-RevId: 459562542
parents 44b4088c 578b320d
......@@ -15,44 +15,91 @@
"""DETR configurations."""
import dataclasses
import os
from typing import List, Optional, Union
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
from official.projects.detr import optimization
from official.projects.detr.dataloaders import coco
from official.vision.configs import backbones
from official.vision.configs import common
@dataclasses.dataclass
class DetectionConfig(cfg.TaskConfig):
"""The translation task config."""
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
class DataConfig(cfg.DataConfig):
"""Input config for training."""
input_path: str = ''
tfds_name: str = ''
tfds_split: str = 'train'
global_batch_size: int = 0
is_training: bool = False
dtype: str = 'bfloat16'
decoder: common.DataDecoder = common.DataDecoder()
shuffle_buffer_size: int = 10000
file_type: str = 'tfrecord'
drop_remainder: bool = True
@dataclasses.dataclass
class Losses(hyperparams.Config):
class_offset: int = 0
lambda_cls: float = 1.0
lambda_box: float = 5.0
lambda_giou: float = 2.0
init_ckpt: str = ''
num_classes: int = 81 # 0: background
background_cls_weight: float = 0.1
l2_weight_decay: float = 1e-4
@dataclasses.dataclass
class Detr(hyperparams.Config):
num_queries: int = 100
hidden_size: int = 256
num_classes: int = 91 # 0: background
num_encoder_layers: int = 6
num_decoder_layers: int = 6
input_size: List[int] = dataclasses.field(default_factory=list)
backbone: backbones.Backbone = backbones.Backbone(
type='resnet', resnet=backbones.ResNet(model_id=50, bn_trainable=False))
norm_activation: common.NormActivation = common.NormActivation()
# Make DETRConfig.
num_queries: int = 100
num_hidden: int = 256
@dataclasses.dataclass
class DetrTask(cfg.TaskConfig):
model: Detr = Detr()
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
losses: Losses = Losses()
init_checkpoint: Optional[str] = None
init_checkpoint_modules: Union[str, List[str]] = 'all' # all, backbone
annotation_file: Optional[str] = None
per_category_metrics: bool = False
COCO_INPUT_PATH_BASE = 'coco'
COCO_TRAIN_EXAMPLES = 118287
COCO_VAL_EXAMPLES = 5000
@exp_factory.register_config_factory('detr_coco')
def detr_coco() -> cfg.ExperimentConfig:
"""Config to get results that matches the paper."""
train_batch_size = 64
eval_batch_size = 64
num_train_data = 118287
num_train_data = COCO_TRAIN_EXAMPLES
num_steps_per_epoch = num_train_data // train_batch_size
train_steps = 500 * num_steps_per_epoch # 500 epochs
decay_at = train_steps - 100 * num_steps_per_epoch # 400 epochs
config = cfg.ExperimentConfig(
task=DetectionConfig(
task=DetrTask(
init_checkpoint='',
init_checkpoint_modules='backbone',
model=Detr(
num_classes=81,
input_size=[1333, 1333, 3],
norm_activation=common.NormActivation()),
losses=Losses(),
train_data=coco.COCODataConfig(
tfds_name='coco/2017',
tfds_split='train',
......@@ -65,9 +112,7 @@ def detr_coco() -> cfg.ExperimentConfig:
tfds_split='validation',
is_training=False,
global_batch_size=eval_batch_size,
drop_remainder=False
)
),
drop_remainder=False)),
trainer=cfg.TrainerConfig(
train_steps=train_steps,
validation_steps=-1,
......@@ -95,8 +140,135 @@ def detr_coco() -> cfg.ExperimentConfig:
'values': [0.0001, 1.0e-05]
}
},
})
})),
restrictions=[
'task.train_data.is_training != None',
])
return config
@exp_factory.register_config_factory('detr_coco_tfrecord')
def detr_coco_tfrecord() -> cfg.ExperimentConfig:
"""Config to get results that matches the paper."""
train_batch_size = 64
eval_batch_size = 64
steps_per_epoch = COCO_TRAIN_EXAMPLES // train_batch_size
train_steps = 300 * steps_per_epoch # 300 epochs
decay_at = train_steps - 100 * steps_per_epoch # 200 epochs
config = cfg.ExperimentConfig(
task=DetrTask(
init_checkpoint='',
init_checkpoint_modules='backbone',
annotation_file=os.path.join(COCO_INPUT_PATH_BASE,
'instances_val2017.json'),
model=Detr(
input_size=[1333, 1333, 3],
norm_activation=common.NormActivation()),
losses=Losses(),
train_data=DataConfig(
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size,
shuffle_buffer_size=1000,
),
validation_data=DataConfig(
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'val*'),
is_training=False,
global_batch_size=eval_batch_size,
drop_remainder=False,
)),
trainer=cfg.TrainerConfig(
train_steps=train_steps,
validation_steps=COCO_VAL_EXAMPLES // eval_batch_size,
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
validation_interval=5 * steps_per_epoch,
max_to_keep=1,
best_checkpoint_export_subdir='best_ckpt',
best_checkpoint_eval_metric='AP',
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'detr_adamw',
'detr_adamw': {
'weight_decay_rate': 1e-4,
'global_clipnorm': 0.1,
# Avoid AdamW legacy behavior.
'gradient_clip_norm': 0.0
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {
'boundaries': [decay_at],
'values': [0.0001, 1.0e-05]
}
},
})),
restrictions=[
'task.train_data.is_training != None',
])
return config
@exp_factory.register_config_factory('detr_coco_tfds')
def detr_coco_tfds() -> cfg.ExperimentConfig:
"""Config to get results that matches the paper."""
train_batch_size = 64
eval_batch_size = 64
steps_per_epoch = COCO_TRAIN_EXAMPLES // train_batch_size
train_steps = 300 * steps_per_epoch # 300 epochs
decay_at = train_steps - 100 * steps_per_epoch # 200 epochs
config = cfg.ExperimentConfig(
task=DetrTask(
init_checkpoint='',
init_checkpoint_modules='backbone',
model=Detr(
num_classes=81,
input_size=[1333, 1333, 3],
norm_activation=common.NormActivation()),
losses=Losses(class_offset=1),
train_data=DataConfig(
tfds_name='coco/2017',
tfds_split='train',
is_training=True,
global_batch_size=train_batch_size,
shuffle_buffer_size=1000,
),
validation_data=DataConfig(
tfds_name='coco/2017',
tfds_split='validation',
is_training=False,
global_batch_size=eval_batch_size,
drop_remainder=False)),
trainer=cfg.TrainerConfig(
train_steps=train_steps,
validation_steps=COCO_VAL_EXAMPLES // eval_batch_size,
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
validation_interval=5 * steps_per_epoch,
max_to_keep=1,
best_checkpoint_export_subdir='best_ckpt',
best_checkpoint_eval_metric='AP',
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'detr_adamw',
'detr_adamw': {
'weight_decay_rate': 1e-4,
'global_clipnorm': 0.1,
# Avoid AdamW legacy behavior.
'gradient_clip_norm': 0.0
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {
'boundaries': [decay_at],
'values': [0.0001, 1.0e-05]
}
},
})),
restrictions=[
'task.train_data.is_training != None',
])
......
......@@ -27,15 +27,25 @@ from official.projects.detr.dataloaders import coco
class DetrTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(('detr_coco',))
def test_detr_configs(self, config_name):
def test_detr_configs_tfds(self, config_name):
config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.DetectionConfig)
self.assertIsInstance(config.task, exp_cfg.DetrTask)
self.assertIsInstance(config.task.train_data, coco.COCODataConfig)
config.task.train_data.is_training = None
with self.assertRaises(KeyError):
config.validate()
@parameterized.parameters(('detr_coco_tfrecord'), ('detr_coco_tfds'))
def test_detr_configs(self, config_name):
config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.DetrTask)
self.assertIsInstance(config.task.train_data, cfg.DataConfig)
config.task.train_data.is_training = None
with self.assertRaises(KeyError):
config.validate()
if __name__ == '__main__':
tf.test.main()
......@@ -145,7 +145,7 @@ class COCODataLoader():
self._params.global_batch_size
) if input_context else self._params.global_batch_size
dataset = dataset.batch(
per_replica_batch_size, drop_remainder=self._params.is_training)
per_replica_batch_size, drop_remainder=self._params.drop_remainder)
return dataset
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""COCO data loader for DETR."""
from typing import Tuple
import tensorflow as tf
from official.vision.dataloaders import parser
from official.vision.ops import box_ops
from official.vision.ops import preprocess_ops
RESIZE_SCALES = (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
class Parser(parser.Parser):
"""Parse an image and its annotations into a dictionary of tensors."""
def __init__(self,
class_offset: int = 0,
output_size: Tuple[int, int] = (1333, 1333),
max_num_boxes: int = 100,
resize_scales: Tuple[int, ...] = RESIZE_SCALES,
aug_rand_hflip=True):
self._class_offset = class_offset
self._output_size = output_size
self._max_num_boxes = max_num_boxes
self._resize_scales = resize_scales
self._aug_rand_hflip = aug_rand_hflip
def _parse_train_data(self, data):
"""Parses data for training and evaluation."""
classes = data['groundtruth_classes'] + self._class_offset
boxes = data['groundtruth_boxes']
is_crowd = data['groundtruth_is_crowd']
# Gets original image.
image = data['image']
# Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image)
image, boxes, _ = preprocess_ops.random_horizontal_flip(image, boxes)
do_crop = tf.greater(tf.random.uniform([]), 0.5)
if do_crop:
# Rescale
boxes = box_ops.denormalize_boxes(boxes, tf.shape(image)[:2])
index = tf.random.categorical(tf.zeros([1, 3]), 1)[0]
scales = tf.gather([400.0, 500.0, 600.0], index, axis=0)
short_side = scales[0]
image, image_info = preprocess_ops.resize_image(image, short_side)
boxes = preprocess_ops.resize_and_crop_boxes(boxes, image_info[2, :],
image_info[1, :],
image_info[3, :])
boxes = box_ops.normalize_boxes(boxes, image_info[1, :])
# Do croping
shape = tf.cast(image_info[1], dtype=tf.int32)
h = tf.random.uniform([],
384,
tf.math.minimum(shape[0], 600),
dtype=tf.int32)
w = tf.random.uniform([],
384,
tf.math.minimum(shape[1], 600),
dtype=tf.int32)
i = tf.random.uniform([], 0, shape[0] - h + 1, dtype=tf.int32)
j = tf.random.uniform([], 0, shape[1] - w + 1, dtype=tf.int32)
image = tf.image.crop_to_bounding_box(image, i, j, h, w)
boxes = tf.clip_by_value(
(boxes[..., :] * tf.cast(
tf.stack([shape[0], shape[1], shape[0], shape[1]]),
dtype=tf.float32) -
tf.cast(tf.stack([i, j, i, j]), dtype=tf.float32)) /
tf.cast(tf.stack([h, w, h, w]), dtype=tf.float32), 0.0, 1.0)
scales = tf.constant(self._resize_scales, dtype=tf.float32)
index = tf.random.categorical(tf.zeros([1, 11]), 1)[0]
scales = tf.gather(scales, index, axis=0)
image_shape = tf.shape(image)[:2]
boxes = box_ops.denormalize_boxes(boxes, image_shape)
short_side = scales[0]
image, image_info = preprocess_ops.resize_image(image, short_side,
max(self._output_size))
boxes = preprocess_ops.resize_and_crop_boxes(boxes, image_info[2, :],
image_info[1, :],
image_info[3, :])
boxes = box_ops.normalize_boxes(boxes, image_info[1, :])
# Filters out ground truth boxes that are all zeros.
indices = box_ops.get_non_empty_box_indices(boxes)
boxes = tf.gather(boxes, indices)
classes = tf.gather(classes, indices)
is_crowd = tf.gather(is_crowd, indices)
boxes = box_ops.yxyx_to_cycxhw(boxes)
image = tf.image.pad_to_bounding_box(image, 0, 0, self._output_size[0],
self._output_size[1])
labels = {
'classes':
preprocess_ops.clip_or_pad_to_fixed_size(classes,
self._max_num_boxes),
'boxes':
preprocess_ops.clip_or_pad_to_fixed_size(boxes, self._max_num_boxes)
}
return image, labels
def _parse_eval_data(self, data):
"""Parses data for training and evaluation."""
classes = data['groundtruth_classes']
boxes = data['groundtruth_boxes']
is_crowd = data['groundtruth_is_crowd']
# Gets original image and its size.
image = data['image']
# Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image)
scales = tf.constant([self._resize_scales[-1]], tf.float32)
image_shape = tf.shape(image)[:2]
boxes = box_ops.denormalize_boxes(boxes, image_shape)
gt_boxes = boxes
short_side = scales[0]
image, image_info = preprocess_ops.resize_image(image, short_side,
max(self._output_size))
boxes = preprocess_ops.resize_and_crop_boxes(boxes, image_info[2, :],
image_info[1, :],
image_info[3, :])
boxes = box_ops.normalize_boxes(boxes, image_info[1, :])
# Filters out ground truth boxes that are all zeros.
indices = box_ops.get_non_empty_box_indices(boxes)
boxes = tf.gather(boxes, indices)
classes = tf.gather(classes, indices)
is_crowd = tf.gather(is_crowd, indices)
boxes = box_ops.yxyx_to_cycxhw(boxes)
image = tf.image.pad_to_bounding_box(image, 0, 0, self._output_size[0],
self._output_size[1])
labels = {
'classes':
preprocess_ops.clip_or_pad_to_fixed_size(classes,
self._max_num_boxes),
'boxes':
preprocess_ops.clip_or_pad_to_fixed_size(boxes, self._max_num_boxes)
}
labels.update({
'id':
int(data['source_id']),
'image_info':
image_info,
'is_crowd':
preprocess_ops.clip_or_pad_to_fixed_size(is_crowd,
self._max_num_boxes),
'gt_boxes':
preprocess_ops.clip_or_pad_to_fixed_size(gt_boxes,
self._max_num_boxes),
})
return image, labels
......@@ -3,4 +3,4 @@ python3 official/projects/detr/train.py \
--experiment=detr_coco \
--mode=train_and_eval \
--model_dir=/tmp/logging_dir/ \
--params_override=task.init_ckpt='gs://tf_model_garden/vision/resnet50_imagenet/ckpt-62400',trainer.train_steps=554400
--params_override=task.init_checkpoint='gs://tf_model_garden/vision/resnet50_imagenet/ckpt-62400',trainer.train_steps=554400
......@@ -3,4 +3,4 @@ python3 official/projects/detr/train.py \
--experiment=detr_coco \
--mode=train_and_eval \
--model_dir=/tmp/logging_dir/ \
--params_override=task.init_ckpt='gs://tf_model_garden/vision/resnet50_imagenet/ckpt-62400'
--params_override=task.init_checkpoint='gs://tf_model_garden/vision/resnet50_imagenet/ckpt-62400'
......@@ -24,7 +24,6 @@ import tensorflow as tf
from official.modeling import tf_utils
from official.projects.detr.modeling import transformer
from official.vision.modeling.backbones import resnet
def position_embedding_sine(attention_mask,
......@@ -100,7 +99,11 @@ class DETR(tf.keras.Model):
class and box heads.
"""
def __init__(self, num_queries, hidden_size, num_classes,
def __init__(self,
backbone,
num_queries,
hidden_size,
num_classes,
num_encoder_layers=6,
num_decoder_layers=6,
dropout_rate=0.1,
......@@ -114,9 +117,7 @@ class DETR(tf.keras.Model):
self._dropout_rate = dropout_rate
if hidden_size % 2 != 0:
raise ValueError("hidden_size must be a multiple of 2.")
# TODO(frederickliu): Consider using the backbone factory.
# TODO(frederickliu): Add to factory once we get skeleton code in.
self._backbone = resnet.ResNet(50, bn_trainable=False)
self._backbone = backbone
def build(self, input_shape=None):
self._input_proj = tf.keras.layers.Conv2D(
......@@ -159,6 +160,7 @@ class DETR(tf.keras.Model):
def get_config(self):
return {
"backbone": self._backbone,
"num_queries": self._num_queries,
"hidden_size": self._hidden_size,
"num_classes": self._num_classes,
......
......@@ -15,6 +15,7 @@
"""Tests for tensorflow_models.official.projects.detr.detr."""
import tensorflow as tf
from official.projects.detr.modeling import detr
from official.vision.modeling.backbones import resnet
class DetrTest(tf.test.TestCase):
......@@ -25,7 +26,8 @@ class DetrTest(tf.test.TestCase):
num_classes = 10
image_size = 640
batch_size = 2
model = detr.DETR(num_queries, hidden_size, num_classes)
backbone = resnet.ResNet(50, bn_trainable=False)
model = detr.DETR(backbone, num_queries, hidden_size, num_classes)
outs = model(tf.ones((batch_size, image_size, image_size, 3)))
self.assertLen(outs, 6) # intermediate decoded outputs.
for out in outs:
......@@ -47,6 +49,7 @@ class DetrTest(tf.test.TestCase):
def test_get_from_config_detr(self):
config = {
'backbone': resnet.ResNet(50, bn_trainable=False),
'num_queries': 2,
'hidden_size': 4,
'num_classes': 10,
......
......@@ -13,20 +13,29 @@
# limitations under the License.
"""DETR detection task definition."""
from typing import Optional
from absl import logging
import tensorflow as tf
from official.common import dataset_fn
from official.core import base_task
from official.core import task_factory
from official.projects.detr.configs import detr as detr_cfg
from official.projects.detr.dataloaders import coco
from official.projects.detr.dataloaders import detr_input
from official.projects.detr.modeling import detr
from official.projects.detr.ops import matchers
from official.vision.dataloaders import input_reader_factory
from official.vision.dataloaders import tf_example_decoder
from official.vision.dataloaders import tfds_factory
from official.vision.dataloaders import tf_example_label_map_decoder
from official.vision.evaluation import coco_evaluator
from official.vision.modeling import backbones
from official.vision.ops import box_ops
@task_factory.register_task_cls(detr_cfg.DetectionConfig)
@task_factory.register_task_cls(detr_cfg.DetrTask)
class DectectionTask(base_task.Task):
"""A single-replica view of training procedure.
......@@ -37,46 +46,103 @@ class DectectionTask(base_task.Task):
def build_model(self):
"""Build DETR model."""
model = detr.DETR(
self._task_config.num_queries,
self._task_config.num_hidden,
self._task_config.num_classes,
self._task_config.num_encoder_layers,
self._task_config.num_decoder_layers)
input_specs = tf.keras.layers.InputSpec(shape=[None] +
self._task_config.model.input_size)
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
backbone_config=self._task_config.model.backbone,
norm_activation_config=self._task_config.model.norm_activation)
model = detr.DETR(backbone, self._task_config.model.num_queries,
self._task_config.model.hidden_size,
self._task_config.model.num_classes,
self._task_config.model.num_encoder_layers,
self._task_config.model.num_decoder_layers)
return model
def initialize(self, model: tf.keras.Model):
"""Loading pretrained checkpoint."""
if not self._task_config.init_checkpoint:
return
ckpt_dir_or_file = self._task_config.init_checkpoint
# Restoring checkpoint.
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if self._task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.assert_consumed()
elif self._task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.read(self._task_config.init_ckpt)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
def build_inputs(self, params, input_context=None):
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def build_inputs(self,
params,
input_context: Optional[tf.distribute.InputContext] = None):
"""Build input dataset."""
return coco.COCODataLoader(params).load(input_context)
if isinstance(params, coco.COCODataConfig):
dataset = coco.COCODataLoader(params).load(input_context)
else:
if params.tfds_name:
decoder = tfds_factory.get_detection_decoder(params.tfds_name)
else:
decoder_cfg = params.decoder.get()
if params.decoder.type == 'simple_decoder':
decoder = tf_example_decoder.TfExampleDecoder(
regenerate_source_id=decoder_cfg.regenerate_source_id)
elif params.decoder.type == 'label_map_decoder':
decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
label_map=decoder_cfg.label_map,
regenerate_source_id=decoder_cfg.regenerate_source_id)
else:
raise ValueError('Unknown decoder type: {}!'.format(
params.decoder.type))
parser = detr_input.Parser(
class_offset=self._task_config.losses.class_offset,
output_size=self._task_config.model.input_size[:2],
)
reader = input_reader_factory.input_reader_generator(
params,
dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
decoder_fn=decoder.decode,
parser_fn=parser.parse_fn(params.is_training))
dataset = reader.read(input_context=input_context)
return dataset
def _compute_cost(self, cls_outputs, box_outputs, cls_targets, box_targets):
# Approximate classification cost with 1 - prob[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
# background: 0
cls_cost = self._task_config.lambda_cls * tf.gather(
cls_cost = self._task_config.losses.lambda_cls * tf.gather(
-tf.nn.softmax(cls_outputs), cls_targets, batch_dims=1, axis=-1)
# Compute the L1 cost between boxes,
paired_differences = self._task_config.lambda_box * tf.abs(
paired_differences = self._task_config.losses.lambda_box * tf.abs(
tf.expand_dims(box_outputs, 2) - tf.expand_dims(box_targets, 1))
box_cost = tf.reduce_sum(paired_differences, axis=-1)
# Compute the giou cost betwen boxes
giou_cost = self._task_config.lambda_giou * -box_ops.bbox_generalized_overlap(
giou_cost = self._task_config.losses.lambda_giou * -box_ops.bbox_generalized_overlap(
box_ops.cycxhw_to_yxyx(box_outputs),
box_ops.cycxhw_to_yxyx(box_targets))
total_cost = cls_cost + box_cost + giou_cost
max_cost = (
self._task_config.lambda_cls * 0.0 + self._task_config.lambda_box * 4. +
self._task_config.lambda_giou * 0.0)
self._task_config.losses.lambda_cls * 0.0 +
self._task_config.losses.lambda_box * 4. +
self._task_config.losses.lambda_giou * 0.0)
# Set pads to large constant
valid = tf.expand_dims(
......@@ -115,35 +181,26 @@ class DectectionTask(base_task.Task):
# Down-weight background to account for class imbalance.
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=cls_targets, logits=cls_assigned)
cls_loss = self._task_config.lambda_cls * tf.where(
background,
self._task_config.background_cls_weight * xentropy,
xentropy
)
cls_loss = self._task_config.losses.lambda_cls * tf.where(
background, self._task_config.losses.background_cls_weight * xentropy,
xentropy)
cls_weights = tf.where(
background,
self._task_config.background_cls_weight * tf.ones_like(cls_loss),
tf.ones_like(cls_loss)
)
self._task_config.losses.background_cls_weight * tf.ones_like(cls_loss),
tf.ones_like(cls_loss))
# Box loss is only calculated on non-background class.
l_1 = tf.reduce_sum(tf.abs(box_assigned - box_targets), axis=-1)
box_loss = self._task_config.lambda_box * tf.where(
background,
tf.zeros_like(l_1),
l_1
)
box_loss = self._task_config.losses.lambda_box * tf.where(
background, tf.zeros_like(l_1), l_1)
# Giou loss is only calculated on non-background class.
giou = tf.linalg.diag_part(1.0 - box_ops.bbox_generalized_overlap(
box_ops.cycxhw_to_yxyx(box_assigned),
box_ops.cycxhw_to_yxyx(box_targets)
))
giou_loss = self._task_config.lambda_giou * tf.where(
background,
tf.zeros_like(giou),
giou
)
giou_loss = self._task_config.losses.lambda_giou * tf.where(
background, tf.zeros_like(giou), giou)
# Consider doing all reduce once in train_step to speed up.
num_boxes_per_replica = tf.reduce_sum(num_boxes)
......@@ -160,6 +217,7 @@ class DectectionTask(base_task.Task):
tf.reduce_sum(giou_loss), num_boxes_sum)
aux_losses = tf.add_n(aux_losses) if aux_losses else 0.0
total_loss = cls_loss + box_loss + giou_loss + aux_losses
return total_loss, cls_loss, box_loss, giou_loss
......@@ -172,7 +230,7 @@ class DectectionTask(base_task.Task):
if not training:
self.coco_metric = coco_evaluator.COCOEvaluator(
annotation_file='',
annotation_file=self._task_config.annotation_file,
include_mask=False,
need_rescale_bboxes=True,
per_category_metrics=self._task_config.per_category_metrics)
......
......@@ -22,6 +22,7 @@ from official.projects.detr import optimization
from official.projects.detr.configs import detr as detr_cfg
from official.projects.detr.dataloaders import coco
from official.projects.detr.tasks import detection
from official.vision.configs import backbones
_NUM_EXAMPLES = 10
......@@ -58,9 +59,16 @@ def _as_dataset(self, *args, **kwargs):
class DetectionTest(tf.test.TestCase):
def test_train_step(self):
config = detr_cfg.DetectionConfig(
config = detr_cfg.DetrTask(
model=detr_cfg.Detr(
input_size=[1333, 1333, 3],
num_encoder_layers=1,
num_decoder_layers=1,
num_classes=81,
backbone=backbones.Backbone(
type='resnet',
resnet=backbones.ResNet(model_id=10, bn_trainable=False))
),
train_data=coco.COCODataConfig(
tfds_name='coco/2017',
tfds_split='validation',
......@@ -92,9 +100,16 @@ class DetectionTest(tf.test.TestCase):
task.train_step(next(iterator), model, optimizer)
def test_validation_step(self):
config = detr_cfg.DetectionConfig(
config = detr_cfg.DetrTask(
model=detr_cfg.Detr(
input_size=[1333, 1333, 3],
num_encoder_layers=1,
num_decoder_layers=1,
num_classes=81,
backbone=backbones.Backbone(
type='resnet',
resnet=backbones.ResNet(model_id=10, bn_trainable=False))
),
validation_data=coco.COCODataConfig(
tfds_name='coco/2017',
tfds_split='validation',
......@@ -112,5 +127,77 @@ class DetectionTest(tf.test.TestCase):
state = task.aggregate_logs(step_outputs=logs)
task.reduce_aggregated_logs(state)
class DetectionTFDSTest(tf.test.TestCase):
def test_train_step(self):
config = detr_cfg.DetrTask(
model=detr_cfg.Detr(
input_size=[1333, 1333, 3],
num_encoder_layers=1,
num_decoder_layers=1,
backbone=backbones.Backbone(
type='resnet',
resnet=backbones.ResNet(model_id=10, bn_trainable=False))
),
losses=detr_cfg.Losses(class_offset=1),
train_data=detr_cfg.DataConfig(
tfds_name='coco/2017',
tfds_split='validation',
is_training=True,
global_batch_size=2,
))
with tfds.testing.mock_data(as_dataset_fn=_as_dataset):
task = detection.DectectionTask(config)
model = task.build_model()
dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)
opt_cfg = optimization.OptimizationConfig({
'optimizer': {
'type': 'detr_adamw',
'detr_adamw': {
'weight_decay_rate': 1e-4,
'global_clipnorm': 0.1,
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {
'boundaries': [120000],
'values': [0.0001, 1.0e-05]
}
},
})
optimizer = detection.DectectionTask.create_optimizer(opt_cfg)
task.train_step(next(iterator), model, optimizer)
def test_validation_step(self):
config = detr_cfg.DetrTask(
model=detr_cfg.Detr(
input_size=[1333, 1333, 3],
num_encoder_layers=1,
num_decoder_layers=1,
backbone=backbones.Backbone(
type='resnet',
resnet=backbones.ResNet(model_id=10, bn_trainable=False))
),
losses=detr_cfg.Losses(class_offset=1),
validation_data=detr_cfg.DataConfig(
tfds_name='coco/2017',
tfds_split='validation',
is_training=False,
global_batch_size=2,
))
with tfds.testing.mock_data(as_dataset_fn=_as_dataset):
task = detection.DectectionTask(config)
model = task.build_model()
metrics = task.build_metrics(training=False)
dataset = task.build_inputs(config.validation_data)
iterator = iter(dataset)
logs = task.validation_step(next(iterator), model, metrics)
state = task.aggregate_logs(step_outputs=logs)
task.reduce_aggregated_logs(state)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment