Commit 1887e546 authored by Frederick's avatar Frederick
Browse files

Merge pull request #10696 from gunho1123:master

PiperOrigin-RevId: 459562542
parents 283a0015 578b320d
...@@ -15,44 +15,91 @@ ...@@ -15,44 +15,91 @@
"""DETR configurations.""" """DETR configurations."""
import dataclasses import dataclasses
import os
from typing import List, Optional, Union
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams
from official.projects.detr import optimization from official.projects.detr import optimization
from official.projects.detr.dataloaders import coco from official.projects.detr.dataloaders import coco
from official.vision.configs import backbones
from official.vision.configs import common
@dataclasses.dataclass @dataclasses.dataclass
class DetectionConfig(cfg.TaskConfig): class DataConfig(cfg.DataConfig):
"""The translation task config.""" """Input config for training."""
train_data: cfg.DataConfig = cfg.DataConfig() input_path: str = ''
validation_data: cfg.DataConfig = cfg.DataConfig() tfds_name: str = ''
tfds_split: str = 'train'
global_batch_size: int = 0
is_training: bool = False
dtype: str = 'bfloat16'
decoder: common.DataDecoder = common.DataDecoder()
shuffle_buffer_size: int = 10000
file_type: str = 'tfrecord'
drop_remainder: bool = True
@dataclasses.dataclass
class Losses(hyperparams.Config):
class_offset: int = 0
lambda_cls: float = 1.0 lambda_cls: float = 1.0
lambda_box: float = 5.0 lambda_box: float = 5.0
lambda_giou: float = 2.0 lambda_giou: float = 2.0
init_ckpt: str = ''
num_classes: int = 81 # 0: background
background_cls_weight: float = 0.1 background_cls_weight: float = 0.1
l2_weight_decay: float = 1e-4
@dataclasses.dataclass
class Detr(hyperparams.Config):
num_queries: int = 100
hidden_size: int = 256
num_classes: int = 91 # 0: background
num_encoder_layers: int = 6 num_encoder_layers: int = 6
num_decoder_layers: int = 6 num_decoder_layers: int = 6
input_size: List[int] = dataclasses.field(default_factory=list)
backbone: backbones.Backbone = backbones.Backbone(
type='resnet', resnet=backbones.ResNet(model_id=50, bn_trainable=False))
norm_activation: common.NormActivation = common.NormActivation()
# Make DETRConfig.
num_queries: int = 100 @dataclasses.dataclass
num_hidden: int = 256 class DetrTask(cfg.TaskConfig):
model: Detr = Detr()
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
losses: Losses = Losses()
init_checkpoint: Optional[str] = None
init_checkpoint_modules: Union[str, List[str]] = 'all' # all, backbone
annotation_file: Optional[str] = None
per_category_metrics: bool = False per_category_metrics: bool = False
COCO_INPUT_PATH_BASE = 'coco'
COCO_TRAIN_EXAMPLES = 118287
COCO_VAL_EXAMPLES = 5000
@exp_factory.register_config_factory('detr_coco') @exp_factory.register_config_factory('detr_coco')
def detr_coco() -> cfg.ExperimentConfig: def detr_coco() -> cfg.ExperimentConfig:
"""Config to get results that matches the paper.""" """Config to get results that matches the paper."""
train_batch_size = 64 train_batch_size = 64
eval_batch_size = 64 eval_batch_size = 64
num_train_data = 118287 num_train_data = COCO_TRAIN_EXAMPLES
num_steps_per_epoch = num_train_data // train_batch_size num_steps_per_epoch = num_train_data // train_batch_size
train_steps = 500 * num_steps_per_epoch # 500 epochs train_steps = 500 * num_steps_per_epoch # 500 epochs
decay_at = train_steps - 100 * num_steps_per_epoch # 400 epochs decay_at = train_steps - 100 * num_steps_per_epoch # 400 epochs
config = cfg.ExperimentConfig( config = cfg.ExperimentConfig(
task=DetectionConfig( task=DetrTask(
init_checkpoint='',
init_checkpoint_modules='backbone',
model=Detr(
num_classes=81,
input_size=[1333, 1333, 3],
norm_activation=common.NormActivation()),
losses=Losses(),
train_data=coco.COCODataConfig( train_data=coco.COCODataConfig(
tfds_name='coco/2017', tfds_name='coco/2017',
tfds_split='train', tfds_split='train',
...@@ -65,9 +112,7 @@ def detr_coco() -> cfg.ExperimentConfig: ...@@ -65,9 +112,7 @@ def detr_coco() -> cfg.ExperimentConfig:
tfds_split='validation', tfds_split='validation',
is_training=False, is_training=False,
global_batch_size=eval_batch_size, global_batch_size=eval_batch_size,
drop_remainder=False drop_remainder=False)),
)
),
trainer=cfg.TrainerConfig( trainer=cfg.TrainerConfig(
train_steps=train_steps, train_steps=train_steps,
validation_steps=-1, validation_steps=-1,
...@@ -95,8 +140,135 @@ def detr_coco() -> cfg.ExperimentConfig: ...@@ -95,8 +140,135 @@ def detr_coco() -> cfg.ExperimentConfig:
'values': [0.0001, 1.0e-05] 'values': [0.0001, 1.0e-05]
} }
}, },
}) })),
restrictions=[
'task.train_data.is_training != None',
])
return config
@exp_factory.register_config_factory('detr_coco_tfrecord')
def detr_coco_tfrecord() -> cfg.ExperimentConfig:
"""Config to get results that matches the paper."""
train_batch_size = 64
eval_batch_size = 64
steps_per_epoch = COCO_TRAIN_EXAMPLES // train_batch_size
train_steps = 300 * steps_per_epoch # 300 epochs
decay_at = train_steps - 100 * steps_per_epoch # 200 epochs
config = cfg.ExperimentConfig(
task=DetrTask(
init_checkpoint='',
init_checkpoint_modules='backbone',
annotation_file=os.path.join(COCO_INPUT_PATH_BASE,
'instances_val2017.json'),
model=Detr(
input_size=[1333, 1333, 3],
norm_activation=common.NormActivation()),
losses=Losses(),
train_data=DataConfig(
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size,
shuffle_buffer_size=1000,
), ),
validation_data=DataConfig(
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'val*'),
is_training=False,
global_batch_size=eval_batch_size,
drop_remainder=False,
)),
trainer=cfg.TrainerConfig(
train_steps=train_steps,
validation_steps=COCO_VAL_EXAMPLES // eval_batch_size,
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
validation_interval=5 * steps_per_epoch,
max_to_keep=1,
best_checkpoint_export_subdir='best_ckpt',
best_checkpoint_eval_metric='AP',
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'detr_adamw',
'detr_adamw': {
'weight_decay_rate': 1e-4,
'global_clipnorm': 0.1,
# Avoid AdamW legacy behavior.
'gradient_clip_norm': 0.0
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {
'boundaries': [decay_at],
'values': [0.0001, 1.0e-05]
}
},
})),
restrictions=[
'task.train_data.is_training != None',
])
return config
@exp_factory.register_config_factory('detr_coco_tfds')
def detr_coco_tfds() -> cfg.ExperimentConfig:
"""Config to get results that matches the paper."""
train_batch_size = 64
eval_batch_size = 64
steps_per_epoch = COCO_TRAIN_EXAMPLES // train_batch_size
train_steps = 300 * steps_per_epoch # 300 epochs
decay_at = train_steps - 100 * steps_per_epoch # 200 epochs
config = cfg.ExperimentConfig(
task=DetrTask(
init_checkpoint='',
init_checkpoint_modules='backbone',
model=Detr(
num_classes=81,
input_size=[1333, 1333, 3],
norm_activation=common.NormActivation()),
losses=Losses(class_offset=1),
train_data=DataConfig(
tfds_name='coco/2017',
tfds_split='train',
is_training=True,
global_batch_size=train_batch_size,
shuffle_buffer_size=1000,
),
validation_data=DataConfig(
tfds_name='coco/2017',
tfds_split='validation',
is_training=False,
global_batch_size=eval_batch_size,
drop_remainder=False)),
trainer=cfg.TrainerConfig(
train_steps=train_steps,
validation_steps=COCO_VAL_EXAMPLES // eval_batch_size,
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
validation_interval=5 * steps_per_epoch,
max_to_keep=1,
best_checkpoint_export_subdir='best_ckpt',
best_checkpoint_eval_metric='AP',
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'detr_adamw',
'detr_adamw': {
'weight_decay_rate': 1e-4,
'global_clipnorm': 0.1,
# Avoid AdamW legacy behavior.
'gradient_clip_norm': 0.0
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {
'boundaries': [decay_at],
'values': [0.0001, 1.0e-05]
}
},
})),
restrictions=[ restrictions=[
'task.train_data.is_training != None', 'task.train_data.is_training != None',
]) ])
......
...@@ -27,15 +27,25 @@ from official.projects.detr.dataloaders import coco ...@@ -27,15 +27,25 @@ from official.projects.detr.dataloaders import coco
class DetrTest(tf.test.TestCase, parameterized.TestCase): class DetrTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(('detr_coco',)) @parameterized.parameters(('detr_coco',))
def test_detr_configs(self, config_name): def test_detr_configs_tfds(self, config_name):
config = exp_factory.get_exp_config(config_name) config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig) self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.DetectionConfig) self.assertIsInstance(config.task, exp_cfg.DetrTask)
self.assertIsInstance(config.task.train_data, coco.COCODataConfig) self.assertIsInstance(config.task.train_data, coco.COCODataConfig)
config.task.train_data.is_training = None config.task.train_data.is_training = None
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
config.validate() config.validate()
@parameterized.parameters(('detr_coco_tfrecord'), ('detr_coco_tfds'))
def test_detr_configs(self, config_name):
config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.DetrTask)
self.assertIsInstance(config.task.train_data, cfg.DataConfig)
config.task.train_data.is_training = None
with self.assertRaises(KeyError):
config.validate()
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -145,7 +145,7 @@ class COCODataLoader(): ...@@ -145,7 +145,7 @@ class COCODataLoader():
self._params.global_batch_size self._params.global_batch_size
) if input_context else self._params.global_batch_size ) if input_context else self._params.global_batch_size
dataset = dataset.batch( dataset = dataset.batch(
per_replica_batch_size, drop_remainder=self._params.is_training) per_replica_batch_size, drop_remainder=self._params.drop_remainder)
return dataset return dataset
def load(self, input_context: Optional[tf.distribute.InputContext] = None): def load(self, input_context: Optional[tf.distribute.InputContext] = None):
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""COCO data loader for DETR."""
from typing import Tuple
import tensorflow as tf
from official.vision.dataloaders import parser
from official.vision.ops import box_ops
from official.vision.ops import preprocess_ops
RESIZE_SCALES = (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
class Parser(parser.Parser):
"""Parse an image and its annotations into a dictionary of tensors."""
def __init__(self,
class_offset: int = 0,
output_size: Tuple[int, int] = (1333, 1333),
max_num_boxes: int = 100,
resize_scales: Tuple[int, ...] = RESIZE_SCALES,
aug_rand_hflip=True):
self._class_offset = class_offset
self._output_size = output_size
self._max_num_boxes = max_num_boxes
self._resize_scales = resize_scales
self._aug_rand_hflip = aug_rand_hflip
def _parse_train_data(self, data):
"""Parses data for training and evaluation."""
classes = data['groundtruth_classes'] + self._class_offset
boxes = data['groundtruth_boxes']
is_crowd = data['groundtruth_is_crowd']
# Gets original image.
image = data['image']
# Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image)
image, boxes, _ = preprocess_ops.random_horizontal_flip(image, boxes)
do_crop = tf.greater(tf.random.uniform([]), 0.5)
if do_crop:
# Rescale
boxes = box_ops.denormalize_boxes(boxes, tf.shape(image)[:2])
index = tf.random.categorical(tf.zeros([1, 3]), 1)[0]
scales = tf.gather([400.0, 500.0, 600.0], index, axis=0)
short_side = scales[0]
image, image_info = preprocess_ops.resize_image(image, short_side)
boxes = preprocess_ops.resize_and_crop_boxes(boxes, image_info[2, :],
image_info[1, :],
image_info[3, :])
boxes = box_ops.normalize_boxes(boxes, image_info[1, :])
# Do croping
shape = tf.cast(image_info[1], dtype=tf.int32)
h = tf.random.uniform([],
384,
tf.math.minimum(shape[0], 600),
dtype=tf.int32)
w = tf.random.uniform([],
384,
tf.math.minimum(shape[1], 600),
dtype=tf.int32)
i = tf.random.uniform([], 0, shape[0] - h + 1, dtype=tf.int32)
j = tf.random.uniform([], 0, shape[1] - w + 1, dtype=tf.int32)
image = tf.image.crop_to_bounding_box(image, i, j, h, w)
boxes = tf.clip_by_value(
(boxes[..., :] * tf.cast(
tf.stack([shape[0], shape[1], shape[0], shape[1]]),
dtype=tf.float32) -
tf.cast(tf.stack([i, j, i, j]), dtype=tf.float32)) /
tf.cast(tf.stack([h, w, h, w]), dtype=tf.float32), 0.0, 1.0)
scales = tf.constant(self._resize_scales, dtype=tf.float32)
index = tf.random.categorical(tf.zeros([1, 11]), 1)[0]
scales = tf.gather(scales, index, axis=0)
image_shape = tf.shape(image)[:2]
boxes = box_ops.denormalize_boxes(boxes, image_shape)
short_side = scales[0]
image, image_info = preprocess_ops.resize_image(image, short_side,
max(self._output_size))
boxes = preprocess_ops.resize_and_crop_boxes(boxes, image_info[2, :],
image_info[1, :],
image_info[3, :])
boxes = box_ops.normalize_boxes(boxes, image_info[1, :])
# Filters out ground truth boxes that are all zeros.
indices = box_ops.get_non_empty_box_indices(boxes)
boxes = tf.gather(boxes, indices)
classes = tf.gather(classes, indices)
is_crowd = tf.gather(is_crowd, indices)
boxes = box_ops.yxyx_to_cycxhw(boxes)
image = tf.image.pad_to_bounding_box(image, 0, 0, self._output_size[0],
self._output_size[1])
labels = {
'classes':
preprocess_ops.clip_or_pad_to_fixed_size(classes,
self._max_num_boxes),
'boxes':
preprocess_ops.clip_or_pad_to_fixed_size(boxes, self._max_num_boxes)
}
return image, labels
def _parse_eval_data(self, data):
"""Parses data for training and evaluation."""
classes = data['groundtruth_classes']
boxes = data['groundtruth_boxes']
is_crowd = data['groundtruth_is_crowd']
# Gets original image and its size.
image = data['image']
# Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image)
scales = tf.constant([self._resize_scales[-1]], tf.float32)
image_shape = tf.shape(image)[:2]
boxes = box_ops.denormalize_boxes(boxes, image_shape)
gt_boxes = boxes
short_side = scales[0]
image, image_info = preprocess_ops.resize_image(image, short_side,
max(self._output_size))
boxes = preprocess_ops.resize_and_crop_boxes(boxes, image_info[2, :],
image_info[1, :],
image_info[3, :])
boxes = box_ops.normalize_boxes(boxes, image_info[1, :])
# Filters out ground truth boxes that are all zeros.
indices = box_ops.get_non_empty_box_indices(boxes)
boxes = tf.gather(boxes, indices)
classes = tf.gather(classes, indices)
is_crowd = tf.gather(is_crowd, indices)
boxes = box_ops.yxyx_to_cycxhw(boxes)
image = tf.image.pad_to_bounding_box(image, 0, 0, self._output_size[0],
self._output_size[1])
labels = {
'classes':
preprocess_ops.clip_or_pad_to_fixed_size(classes,
self._max_num_boxes),
'boxes':
preprocess_ops.clip_or_pad_to_fixed_size(boxes, self._max_num_boxes)
}
labels.update({
'id':
int(data['source_id']),
'image_info':
image_info,
'is_crowd':
preprocess_ops.clip_or_pad_to_fixed_size(is_crowd,
self._max_num_boxes),
'gt_boxes':
preprocess_ops.clip_or_pad_to_fixed_size(gt_boxes,
self._max_num_boxes),
})
return image, labels
...@@ -3,4 +3,4 @@ python3 official/projects/detr/train.py \ ...@@ -3,4 +3,4 @@ python3 official/projects/detr/train.py \
--experiment=detr_coco \ --experiment=detr_coco \
--mode=train_and_eval \ --mode=train_and_eval \
--model_dir=/tmp/logging_dir/ \ --model_dir=/tmp/logging_dir/ \
--params_override=task.init_ckpt='gs://tf_model_garden/vision/resnet50_imagenet/ckpt-62400',trainer.train_steps=554400 --params_override=task.init_checkpoint='gs://tf_model_garden/vision/resnet50_imagenet/ckpt-62400',trainer.train_steps=554400
...@@ -3,4 +3,4 @@ python3 official/projects/detr/train.py \ ...@@ -3,4 +3,4 @@ python3 official/projects/detr/train.py \
--experiment=detr_coco \ --experiment=detr_coco \
--mode=train_and_eval \ --mode=train_and_eval \
--model_dir=/tmp/logging_dir/ \ --model_dir=/tmp/logging_dir/ \
--params_override=task.init_ckpt='gs://tf_model_garden/vision/resnet50_imagenet/ckpt-62400' --params_override=task.init_checkpoint='gs://tf_model_garden/vision/resnet50_imagenet/ckpt-62400'
...@@ -24,7 +24,6 @@ import tensorflow as tf ...@@ -24,7 +24,6 @@ import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.projects.detr.modeling import transformer from official.projects.detr.modeling import transformer
from official.vision.modeling.backbones import resnet
def position_embedding_sine(attention_mask, def position_embedding_sine(attention_mask,
...@@ -100,7 +99,11 @@ class DETR(tf.keras.Model): ...@@ -100,7 +99,11 @@ class DETR(tf.keras.Model):
class and box heads. class and box heads.
""" """
def __init__(self, num_queries, hidden_size, num_classes, def __init__(self,
backbone,
num_queries,
hidden_size,
num_classes,
num_encoder_layers=6, num_encoder_layers=6,
num_decoder_layers=6, num_decoder_layers=6,
dropout_rate=0.1, dropout_rate=0.1,
...@@ -114,9 +117,7 @@ class DETR(tf.keras.Model): ...@@ -114,9 +117,7 @@ class DETR(tf.keras.Model):
self._dropout_rate = dropout_rate self._dropout_rate = dropout_rate
if hidden_size % 2 != 0: if hidden_size % 2 != 0:
raise ValueError("hidden_size must be a multiple of 2.") raise ValueError("hidden_size must be a multiple of 2.")
# TODO(frederickliu): Consider using the backbone factory. self._backbone = backbone
# TODO(frederickliu): Add to factory once we get skeleton code in.
self._backbone = resnet.ResNet(50, bn_trainable=False)
def build(self, input_shape=None): def build(self, input_shape=None):
self._input_proj = tf.keras.layers.Conv2D( self._input_proj = tf.keras.layers.Conv2D(
...@@ -159,6 +160,7 @@ class DETR(tf.keras.Model): ...@@ -159,6 +160,7 @@ class DETR(tf.keras.Model):
def get_config(self): def get_config(self):
return { return {
"backbone": self._backbone,
"num_queries": self._num_queries, "num_queries": self._num_queries,
"hidden_size": self._hidden_size, "hidden_size": self._hidden_size,
"num_classes": self._num_classes, "num_classes": self._num_classes,
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""Tests for tensorflow_models.official.projects.detr.detr.""" """Tests for tensorflow_models.official.projects.detr.detr."""
import tensorflow as tf import tensorflow as tf
from official.projects.detr.modeling import detr from official.projects.detr.modeling import detr
from official.vision.modeling.backbones import resnet
class DetrTest(tf.test.TestCase): class DetrTest(tf.test.TestCase):
...@@ -25,7 +26,8 @@ class DetrTest(tf.test.TestCase): ...@@ -25,7 +26,8 @@ class DetrTest(tf.test.TestCase):
num_classes = 10 num_classes = 10
image_size = 640 image_size = 640
batch_size = 2 batch_size = 2
model = detr.DETR(num_queries, hidden_size, num_classes) backbone = resnet.ResNet(50, bn_trainable=False)
model = detr.DETR(backbone, num_queries, hidden_size, num_classes)
outs = model(tf.ones((batch_size, image_size, image_size, 3))) outs = model(tf.ones((batch_size, image_size, image_size, 3)))
self.assertLen(outs, 6) # intermediate decoded outputs. self.assertLen(outs, 6) # intermediate decoded outputs.
for out in outs: for out in outs:
...@@ -47,6 +49,7 @@ class DetrTest(tf.test.TestCase): ...@@ -47,6 +49,7 @@ class DetrTest(tf.test.TestCase):
def test_get_from_config_detr(self): def test_get_from_config_detr(self):
config = { config = {
'backbone': resnet.ResNet(50, bn_trainable=False),
'num_queries': 2, 'num_queries': 2,
'hidden_size': 4, 'hidden_size': 4,
'num_classes': 10, 'num_classes': 10,
......
...@@ -13,20 +13,29 @@ ...@@ -13,20 +13,29 @@
# limitations under the License. # limitations under the License.
"""DETR detection task definition.""" """DETR detection task definition."""
from typing import Optional
from absl import logging
import tensorflow as tf import tensorflow as tf
from official.common import dataset_fn
from official.core import base_task from official.core import base_task
from official.core import task_factory from official.core import task_factory
from official.projects.detr.configs import detr as detr_cfg from official.projects.detr.configs import detr as detr_cfg
from official.projects.detr.dataloaders import coco from official.projects.detr.dataloaders import coco
from official.projects.detr.dataloaders import detr_input
from official.projects.detr.modeling import detr from official.projects.detr.modeling import detr
from official.projects.detr.ops import matchers from official.projects.detr.ops import matchers
from official.vision.dataloaders import input_reader_factory
from official.vision.dataloaders import tf_example_decoder
from official.vision.dataloaders import tfds_factory
from official.vision.dataloaders import tf_example_label_map_decoder
from official.vision.evaluation import coco_evaluator from official.vision.evaluation import coco_evaluator
from official.vision.modeling import backbones
from official.vision.ops import box_ops from official.vision.ops import box_ops
@task_factory.register_task_cls(detr_cfg.DetectionConfig) @task_factory.register_task_cls(detr_cfg.DetrTask)
class DectectionTask(base_task.Task): class DectectionTask(base_task.Task):
"""A single-replica view of training procedure. """A single-replica view of training procedure.
...@@ -37,46 +46,103 @@ class DectectionTask(base_task.Task): ...@@ -37,46 +46,103 @@ class DectectionTask(base_task.Task):
def build_model(self): def build_model(self):
"""Build DETR model.""" """Build DETR model."""
model = detr.DETR(
self._task_config.num_queries, input_specs = tf.keras.layers.InputSpec(shape=[None] +
self._task_config.num_hidden, self._task_config.model.input_size)
self._task_config.num_classes,
self._task_config.num_encoder_layers, backbone = backbones.factory.build_backbone(
self._task_config.num_decoder_layers) input_specs=input_specs,
backbone_config=self._task_config.model.backbone,
norm_activation_config=self._task_config.model.norm_activation)
model = detr.DETR(backbone, self._task_config.model.num_queries,
self._task_config.model.hidden_size,
self._task_config.model.num_classes,
self._task_config.model.num_encoder_layers,
self._task_config.model.num_decoder_layers)
return model return model
def initialize(self, model: tf.keras.Model): def initialize(self, model: tf.keras.Model):
"""Loading pretrained checkpoint.""" """Loading pretrained checkpoint."""
ckpt = tf.train.Checkpoint(backbone=model.backbone) if not self._task_config.init_checkpoint:
status = ckpt.read(self._task_config.init_ckpt) return
status.expect_partial().assert_existing_objects_matched()
ckpt_dir_or_file = self._task_config.init_checkpoint
def build_inputs(self, params, input_context=None):
# Restoring checkpoint.
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if self._task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.assert_consumed()
elif self._task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def build_inputs(self,
params,
input_context: Optional[tf.distribute.InputContext] = None):
"""Build input dataset.""" """Build input dataset."""
return coco.COCODataLoader(params).load(input_context) if isinstance(params, coco.COCODataConfig):
dataset = coco.COCODataLoader(params).load(input_context)
else:
if params.tfds_name:
decoder = tfds_factory.get_detection_decoder(params.tfds_name)
else:
decoder_cfg = params.decoder.get()
if params.decoder.type == 'simple_decoder':
decoder = tf_example_decoder.TfExampleDecoder(
regenerate_source_id=decoder_cfg.regenerate_source_id)
elif params.decoder.type == 'label_map_decoder':
decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
label_map=decoder_cfg.label_map,
regenerate_source_id=decoder_cfg.regenerate_source_id)
else:
raise ValueError('Unknown decoder type: {}!'.format(
params.decoder.type))
parser = detr_input.Parser(
class_offset=self._task_config.losses.class_offset,
output_size=self._task_config.model.input_size[:2],
)
reader = input_reader_factory.input_reader_generator(
params,
dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
decoder_fn=decoder.decode,
parser_fn=parser.parse_fn(params.is_training))
dataset = reader.read(input_context=input_context)
return dataset
def _compute_cost(self, cls_outputs, box_outputs, cls_targets, box_targets): def _compute_cost(self, cls_outputs, box_outputs, cls_targets, box_targets):
# Approximate classification cost with 1 - prob[target class]. # Approximate classification cost with 1 - prob[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted. # The 1 is a constant that doesn't change the matching, it can be ommitted.
# background: 0 # background: 0
cls_cost = self._task_config.lambda_cls * tf.gather( cls_cost = self._task_config.losses.lambda_cls * tf.gather(
-tf.nn.softmax(cls_outputs), cls_targets, batch_dims=1, axis=-1) -tf.nn.softmax(cls_outputs), cls_targets, batch_dims=1, axis=-1)
# Compute the L1 cost between boxes, # Compute the L1 cost between boxes,
paired_differences = self._task_config.lambda_box * tf.abs( paired_differences = self._task_config.losses.lambda_box * tf.abs(
tf.expand_dims(box_outputs, 2) - tf.expand_dims(box_targets, 1)) tf.expand_dims(box_outputs, 2) - tf.expand_dims(box_targets, 1))
box_cost = tf.reduce_sum(paired_differences, axis=-1) box_cost = tf.reduce_sum(paired_differences, axis=-1)
# Compute the giou cost betwen boxes # Compute the giou cost betwen boxes
giou_cost = self._task_config.lambda_giou * -box_ops.bbox_generalized_overlap( giou_cost = self._task_config.losses.lambda_giou * -box_ops.bbox_generalized_overlap(
box_ops.cycxhw_to_yxyx(box_outputs), box_ops.cycxhw_to_yxyx(box_outputs),
box_ops.cycxhw_to_yxyx(box_targets)) box_ops.cycxhw_to_yxyx(box_targets))
total_cost = cls_cost + box_cost + giou_cost total_cost = cls_cost + box_cost + giou_cost
max_cost = ( max_cost = (
self._task_config.lambda_cls * 0.0 + self._task_config.lambda_box * 4. + self._task_config.losses.lambda_cls * 0.0 +
self._task_config.lambda_giou * 0.0) self._task_config.losses.lambda_box * 4. +
self._task_config.losses.lambda_giou * 0.0)
# Set pads to large constant # Set pads to large constant
valid = tf.expand_dims( valid = tf.expand_dims(
...@@ -115,35 +181,26 @@ class DectectionTask(base_task.Task): ...@@ -115,35 +181,26 @@ class DectectionTask(base_task.Task):
# Down-weight background to account for class imbalance. # Down-weight background to account for class imbalance.
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits( xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=cls_targets, logits=cls_assigned) labels=cls_targets, logits=cls_assigned)
cls_loss = self._task_config.lambda_cls * tf.where( cls_loss = self._task_config.losses.lambda_cls * tf.where(
background, background, self._task_config.losses.background_cls_weight * xentropy,
self._task_config.background_cls_weight * xentropy, xentropy)
xentropy
)
cls_weights = tf.where( cls_weights = tf.where(
background, background,
self._task_config.background_cls_weight * tf.ones_like(cls_loss), self._task_config.losses.background_cls_weight * tf.ones_like(cls_loss),
tf.ones_like(cls_loss) tf.ones_like(cls_loss))
)
# Box loss is only calculated on non-background class. # Box loss is only calculated on non-background class.
l_1 = tf.reduce_sum(tf.abs(box_assigned - box_targets), axis=-1) l_1 = tf.reduce_sum(tf.abs(box_assigned - box_targets), axis=-1)
box_loss = self._task_config.lambda_box * tf.where( box_loss = self._task_config.losses.lambda_box * tf.where(
background, background, tf.zeros_like(l_1), l_1)
tf.zeros_like(l_1),
l_1
)
# Giou loss is only calculated on non-background class. # Giou loss is only calculated on non-background class.
giou = tf.linalg.diag_part(1.0 - box_ops.bbox_generalized_overlap( giou = tf.linalg.diag_part(1.0 - box_ops.bbox_generalized_overlap(
box_ops.cycxhw_to_yxyx(box_assigned), box_ops.cycxhw_to_yxyx(box_assigned),
box_ops.cycxhw_to_yxyx(box_targets) box_ops.cycxhw_to_yxyx(box_targets)
)) ))
giou_loss = self._task_config.lambda_giou * tf.where( giou_loss = self._task_config.losses.lambda_giou * tf.where(
background, background, tf.zeros_like(giou), giou)
tf.zeros_like(giou),
giou
)
# Consider doing all reduce once in train_step to speed up. # Consider doing all reduce once in train_step to speed up.
num_boxes_per_replica = tf.reduce_sum(num_boxes) num_boxes_per_replica = tf.reduce_sum(num_boxes)
...@@ -160,6 +217,7 @@ class DectectionTask(base_task.Task): ...@@ -160,6 +217,7 @@ class DectectionTask(base_task.Task):
tf.reduce_sum(giou_loss), num_boxes_sum) tf.reduce_sum(giou_loss), num_boxes_sum)
aux_losses = tf.add_n(aux_losses) if aux_losses else 0.0 aux_losses = tf.add_n(aux_losses) if aux_losses else 0.0
total_loss = cls_loss + box_loss + giou_loss + aux_losses total_loss = cls_loss + box_loss + giou_loss + aux_losses
return total_loss, cls_loss, box_loss, giou_loss return total_loss, cls_loss, box_loss, giou_loss
...@@ -172,7 +230,7 @@ class DectectionTask(base_task.Task): ...@@ -172,7 +230,7 @@ class DectectionTask(base_task.Task):
if not training: if not training:
self.coco_metric = coco_evaluator.COCOEvaluator( self.coco_metric = coco_evaluator.COCOEvaluator(
annotation_file='', annotation_file=self._task_config.annotation_file,
include_mask=False, include_mask=False,
need_rescale_bboxes=True, need_rescale_bboxes=True,
per_category_metrics=self._task_config.per_category_metrics) per_category_metrics=self._task_config.per_category_metrics)
......
...@@ -22,6 +22,7 @@ from official.projects.detr import optimization ...@@ -22,6 +22,7 @@ from official.projects.detr import optimization
from official.projects.detr.configs import detr as detr_cfg from official.projects.detr.configs import detr as detr_cfg
from official.projects.detr.dataloaders import coco from official.projects.detr.dataloaders import coco
from official.projects.detr.tasks import detection from official.projects.detr.tasks import detection
from official.vision.configs import backbones
_NUM_EXAMPLES = 10 _NUM_EXAMPLES = 10
...@@ -58,9 +59,16 @@ def _as_dataset(self, *args, **kwargs): ...@@ -58,9 +59,16 @@ def _as_dataset(self, *args, **kwargs):
class DetectionTest(tf.test.TestCase): class DetectionTest(tf.test.TestCase):
def test_train_step(self): def test_train_step(self):
config = detr_cfg.DetectionConfig( config = detr_cfg.DetrTask(
num_encoder_layers=1, model=detr_cfg.Detr(
num_decoder_layers=1, input_size=[1333, 1333, 3],
num_encoder_layers=1,
num_decoder_layers=1,
num_classes=81,
backbone=backbones.Backbone(
type='resnet',
resnet=backbones.ResNet(model_id=10, bn_trainable=False))
),
train_data=coco.COCODataConfig( train_data=coco.COCODataConfig(
tfds_name='coco/2017', tfds_name='coco/2017',
tfds_split='validation', tfds_split='validation',
...@@ -92,9 +100,16 @@ class DetectionTest(tf.test.TestCase): ...@@ -92,9 +100,16 @@ class DetectionTest(tf.test.TestCase):
task.train_step(next(iterator), model, optimizer) task.train_step(next(iterator), model, optimizer)
def test_validation_step(self): def test_validation_step(self):
config = detr_cfg.DetectionConfig( config = detr_cfg.DetrTask(
num_encoder_layers=1, model=detr_cfg.Detr(
num_decoder_layers=1, input_size=[1333, 1333, 3],
num_encoder_layers=1,
num_decoder_layers=1,
num_classes=81,
backbone=backbones.Backbone(
type='resnet',
resnet=backbones.ResNet(model_id=10, bn_trainable=False))
),
validation_data=coco.COCODataConfig( validation_data=coco.COCODataConfig(
tfds_name='coco/2017', tfds_name='coco/2017',
tfds_split='validation', tfds_split='validation',
...@@ -112,5 +127,77 @@ class DetectionTest(tf.test.TestCase): ...@@ -112,5 +127,77 @@ class DetectionTest(tf.test.TestCase):
state = task.aggregate_logs(step_outputs=logs) state = task.aggregate_logs(step_outputs=logs)
task.reduce_aggregated_logs(state) task.reduce_aggregated_logs(state)
class DetectionTFDSTest(tf.test.TestCase):
def test_train_step(self):
config = detr_cfg.DetrTask(
model=detr_cfg.Detr(
input_size=[1333, 1333, 3],
num_encoder_layers=1,
num_decoder_layers=1,
backbone=backbones.Backbone(
type='resnet',
resnet=backbones.ResNet(model_id=10, bn_trainable=False))
),
losses=detr_cfg.Losses(class_offset=1),
train_data=detr_cfg.DataConfig(
tfds_name='coco/2017',
tfds_split='validation',
is_training=True,
global_batch_size=2,
))
with tfds.testing.mock_data(as_dataset_fn=_as_dataset):
task = detection.DectectionTask(config)
model = task.build_model()
dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)
opt_cfg = optimization.OptimizationConfig({
'optimizer': {
'type': 'detr_adamw',
'detr_adamw': {
'weight_decay_rate': 1e-4,
'global_clipnorm': 0.1,
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {
'boundaries': [120000],
'values': [0.0001, 1.0e-05]
}
},
})
optimizer = detection.DectectionTask.create_optimizer(opt_cfg)
task.train_step(next(iterator), model, optimizer)
def test_validation_step(self):
config = detr_cfg.DetrTask(
model=detr_cfg.Detr(
input_size=[1333, 1333, 3],
num_encoder_layers=1,
num_decoder_layers=1,
backbone=backbones.Backbone(
type='resnet',
resnet=backbones.ResNet(model_id=10, bn_trainable=False))
),
losses=detr_cfg.Losses(class_offset=1),
validation_data=detr_cfg.DataConfig(
tfds_name='coco/2017',
tfds_split='validation',
is_training=False,
global_batch_size=2,
))
with tfds.testing.mock_data(as_dataset_fn=_as_dataset):
task = detection.DectectionTask(config)
model = task.build_model()
metrics = task.build_metrics(training=False)
dataset = task.build_inputs(config.validation_data)
iterator = iter(dataset)
logs = task.validation_step(next(iterator), model, metrics)
state = task.aggregate_logs(step_outputs=logs)
task.reduce_aggregated_logs(state)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment