"git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "e307effedc5abcf17aee3950f8988fcd735e9f8b"
Unverified Commit eedfa888 authored by Frederick Liu's avatar Frederick Liu Committed by GitHub
Browse files

Revert "DETR implementation update (#10689)" (#10691)

This reverts commit 5633969b.
parent 5633969b
......@@ -15,63 +15,33 @@
"""DETR configurations."""
import dataclasses
import os
from typing import List, Optional, Union
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
from official.vision.configs import common
from official.vision.configs import backbones
from official.projects.detr import optimization
from official.projects.detr.dataloaders import coco
@dataclasses.dataclass
class DataConfig(cfg.DataConfig):
"""Input config for training."""
input_path: str = ''
global_batch_size: int = 0
is_training: bool = False
dtype: str = 'bfloat16'
decoder: common.DataDecoder = common.DataDecoder()
shuffle_buffer_size: int = 10000
file_type: str = 'tfrecord'
@dataclasses.dataclass
class Losses(hyperparams.Config):
class_offset: int = 0
class DetectionConfig(cfg.TaskConfig):
"""The translation task config."""
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
lambda_cls: float = 1.0
lambda_box: float = 5.0
lambda_giou: float = 2.0
background_cls_weight: float = 0.1
l2_weight_decay: float = 1e-4
@dataclasses.dataclass
class Detr(hyperparams.Config):
num_queries: int = 100
hidden_size: int = 256
num_classes: int = 91 # 0: background
init_ckpt: str = ''
num_classes: int = 81 # 0: background
background_cls_weight: float = 0.1
num_encoder_layers: int = 6
num_decoder_layers: int = 6
input_size: List[int] = dataclasses.field(default_factory=list)
backbone: backbones.Backbone = backbones.Backbone(
type='resnet', resnet=backbones.ResNet(
model_id=50,
bn_trainable=False))
norm_activation: common.NormActivation = common.NormActivation()
@dataclasses.dataclass
class DetrTask(cfg.TaskConfig):
model: Detr = Detr()
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
losses: Losses = Losses()
init_checkpoint: Optional[str] = None
init_checkpoint_modules: Union[
str, List[str]] = 'all' # all, backbone
annotation_file: Optional[str] = None
# Make DETRConfig.
num_queries: int = 100
num_hidden: int = 256
per_category_metrics: bool = False
@exp_factory.register_config_factory('detr_coco')
def detr_coco() -> cfg.ExperimentConfig:
"""Config to get results that matches the paper."""
......@@ -82,14 +52,7 @@ def detr_coco() -> cfg.ExperimentConfig:
train_steps = 500 * num_steps_per_epoch # 500 epochs
decay_at = train_steps - 100 * num_steps_per_epoch # 400 epochs
config = cfg.ExperimentConfig(
task=DetrTask(
init_checkpoint='gs://tf_model_garden/vision/resnet50_imagenet/ckpt-62400',
init_checkpoint_modules='backbone',
model=Detr(
num_classes=81,
input_size=[1333, 1333, 3],
norm_activation=common.NormActivation()),
losses=Losses(),
task=DetectionConfig(
train_data=coco.COCODataConfig(
tfds_name='coco/2017',
tfds_split='train',
......@@ -138,140 +101,3 @@ def detr_coco() -> cfg.ExperimentConfig:
'task.train_data.is_training != None',
])
return config
COCO_INPUT_PATH_BASE = ''
COCO_TRAIN_EXAMPLES = 118287
COCO_VAL_EXAMPLES = 5000
@exp_factory.register_config_factory('detr_coco_tfrecord')
def detr_coco() -> cfg.ExperimentConfig:
"""Config to get results that matches the paper."""
train_batch_size = 64
eval_batch_size = 64
steps_per_epoch = COCO_TRAIN_EXAMPLES // train_batch_size
train_steps = 300 * steps_per_epoch # 300 epochs
decay_at = train_steps - 100 * steps_per_epoch # 200 epochs
config = cfg.ExperimentConfig(
task=DetrTask(
init_checkpoint='gs://tf_model_garden/vision/resnet50_imagenet/ckpt-62400',
init_checkpoint_modules='backbone',
annotation_file=os.path.join(COCO_INPUT_PATH_BASE,
'instances_val2017.json'),
model=Detr(
input_size=[1333, 1333, 3],
norm_activation=common.NormActivation()),
losses=Losses(),
train_data=DataConfig(
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size,
shuffle_buffer_size=1000,
),
validation_data=DataConfig(
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'val*'),
is_training=False,
global_batch_size=eval_batch_size,
drop_remainder=False,
)
),
trainer=cfg.TrainerConfig(
train_steps=train_steps,
validation_steps=COCO_VAL_EXAMPLES // eval_batch_size,
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
validation_interval=5*steps_per_epoch,
max_to_keep=1,
best_checkpoint_export_subdir='best_ckpt',
best_checkpoint_eval_metric='AP',
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'detr_adamw',
'detr_adamw': {
'weight_decay_rate': 1e-4,
'global_clipnorm': 0.1,
# Avoid AdamW legacy behavior.
'gradient_clip_norm': 0.0
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {
'boundaries': [decay_at],
'values': [0.0001, 1.0e-05]
}
},
})
),
restrictions=[
'task.train_data.is_training != None',
])
return config
@exp_factory.register_config_factory('detr_coco_tfds')
def detr_coco() -> cfg.ExperimentConfig:
"""Config to get results that matches the paper."""
train_batch_size = 64
eval_batch_size = 64
steps_per_epoch = COCO_TRAIN_EXAMPLES // train_batch_size
train_steps = 300 * steps_per_epoch # 300 epochs
decay_at = train_steps - 100 * steps_per_epoch # 200 epochs
config = cfg.ExperimentConfig(
task=DetrTask(
init_checkpoint='gs://tf_model_garden/vision/resnet50_imagenet/ckpt-62400',
init_checkpoint_modules='backbone',
model=Detr(
num_classes=81,
input_size=[1333, 1333, 3],
norm_activation=common.NormActivation()),
losses=Losses(
class_offset=1
),
train_data=DataConfig(
tfds_name='coco/2017',
tfds_split='train',
is_training=True,
global_batch_size=train_batch_size,
shuffle_buffer_size=1000,
),
validation_data=DataConfig(
tfds_name='coco/2017',
tfds_split='validation',
is_training=False,
global_batch_size=eval_batch_size,
drop_remainder=False
)
),
trainer=cfg.TrainerConfig(
train_steps=train_steps,
validation_steps=COCO_VAL_EXAMPLES // eval_batch_size,
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
validation_interval=5*steps_per_epoch,
max_to_keep=1,
best_checkpoint_export_subdir='best_ckpt',
best_checkpoint_eval_metric='AP',
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'detr_adamw',
'detr_adamw': {
'weight_decay_rate': 1e-4,
'global_clipnorm': 0.1,
# Avoid AdamW legacy behavior.
'gradient_clip_norm': 0.0
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {
'boundaries': [decay_at],
'values': [0.0001, 1.0e-05]
}
},
})
),
restrictions=[
'task.train_data.is_training != None',
])
return config
\ No newline at end of file
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""COCO data loader for DETR."""
from typing import Optional, Tuple
import tensorflow as tf
from official.vision.dataloaders import parser
from official.vision.dataloaders import utils
from official.vision.ops import box_ops
from official.vision.ops import preprocess_ops
from official.core import input_reader
RESIZE_SCALES = (
480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
class Parser(parser.Parser):
"""Parse an image and its annotations into a dictionary of tensors."""
def __init__(self,
class_offset: int = 0,
output_size: Tuple[int, int] = (1333, 1333),
max_num_boxes: int = 100,
resize_scales: Tuple[int, ...] = RESIZE_SCALES,
aug_rand_hflip=True):
self._class_offset = class_offset
self._output_size = output_size
self._max_num_boxes = max_num_boxes
self._resize_scales = resize_scales
self._aug_rand_hflip = aug_rand_hflip
def _parse_train_data(self, data):
"""Parses data for training and evaluation."""
classes = data['groundtruth_classes'] + self._class_offset
boxes = data['groundtruth_boxes']
is_crowd = data['groundtruth_is_crowd']
# Gets original image.
image = data['image']
# Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image)
image, boxes, _ = preprocess_ops.random_horizontal_flip(image, boxes)
do_crop = tf.greater(tf.random.uniform([]), 0.5)
if do_crop:
# Rescale
boxes = box_ops.denormalize_boxes(boxes, tf.shape(image)[:2])
index = tf.random.categorical(tf.zeros([1, 3]), 1)[0]
scales = tf.gather([400.0, 500.0, 600.0], index, axis=0)
short_side = scales[0]
image, image_info = preprocess_ops.resize_image(image, short_side)
boxes = preprocess_ops.resize_and_crop_boxes(boxes,
image_info[2, :],
image_info[1, :],
image_info[3, :])
boxes = box_ops.normalize_boxes(boxes, image_info[1, :])
# Do croping
shape = tf.cast(image_info[1], dtype=tf.int32)
h = tf.random.uniform(
[], 384, tf.math.minimum(shape[0], 600), dtype=tf.int32)
w = tf.random.uniform(
[], 384, tf.math.minimum(shape[1], 600), dtype=tf.int32)
i = tf.random.uniform([], 0, shape[0] - h + 1, dtype=tf.int32)
j = tf.random.uniform([], 0, shape[1] - w + 1, dtype=tf.int32)
image = tf.image.crop_to_bounding_box(image, i, j, h, w)
boxes = tf.clip_by_value(
(boxes[..., :] * tf.cast(
tf.stack([shape[0], shape[1], shape[0], shape[1]]),
dtype=tf.float32) -
tf.cast(tf.stack([i, j, i, j]), dtype=tf.float32)) /
tf.cast(tf.stack([h, w, h, w]), dtype=tf.float32), 0.0, 1.0)
scales = tf.constant(
self._resize_scales,
dtype=tf.float32)
index = tf.random.categorical(tf.zeros([1, 11]), 1)[0]
scales = tf.gather(scales, index, axis=0)
image_shape = tf.shape(image)[:2]
boxes = box_ops.denormalize_boxes(boxes, image_shape)
gt_boxes = boxes
short_side = scales[0]
image, image_info = preprocess_ops.resize_image(
image,
short_side,
max(self._output_size))
boxes = preprocess_ops.resize_and_crop_boxes(boxes,
image_info[2, :],
image_info[1, :],
image_info[3, :])
boxes = box_ops.normalize_boxes(boxes, image_info[1, :])
# Filters out ground truth boxes that are all zeros.
indices = box_ops.get_non_empty_box_indices(boxes)
boxes = tf.gather(boxes, indices)
classes = tf.gather(classes, indices)
is_crowd = tf.gather(is_crowd, indices)
boxes = box_ops.yxyx_to_cycxhw(boxes)
image = tf.image.pad_to_bounding_box(
image, 0, 0, self._output_size[0], self._output_size[1])
labels = {
'classes':
preprocess_ops.clip_or_pad_to_fixed_size(
classes, self._max_num_boxes),
'boxes':
preprocess_ops.clip_or_pad_to_fixed_size(
boxes, self._max_num_boxes)
}
return image, labels
def _parse_eval_data(self, data):
"""Parses data for training and evaluation."""
groundtruths = {}
classes = data['groundtruth_classes']
boxes = data['groundtruth_boxes']
is_crowd = data['groundtruth_is_crowd']
# Gets original image and its size.
image = data['image']
# Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image)
scales = tf.constant([self._resize_scales[-1]], tf.float32)
image_shape = tf.shape(image)[:2]
boxes = box_ops.denormalize_boxes(boxes, image_shape)
gt_boxes = boxes
short_side = scales[0]
image, image_info = preprocess_ops.resize_image(
image,
short_side,
max(self._output_size))
boxes = preprocess_ops.resize_and_crop_boxes(boxes,
image_info[2, :],
image_info[1, :],
image_info[3, :])
boxes = box_ops.normalize_boxes(boxes, image_info[1, :])
# Filters out ground truth boxes that are all zeros.
indices = box_ops.get_non_empty_box_indices(boxes)
boxes = tf.gather(boxes, indices)
classes = tf.gather(classes, indices)
is_crowd = tf.gather(is_crowd, indices)
boxes = box_ops.yxyx_to_cycxhw(boxes)
image = tf.image.pad_to_bounding_box(
image, 0, 0, self._output_size[0], self._output_size[1])
labels = {
'classes':
preprocess_ops.clip_or_pad_to_fixed_size(
classes, self._max_num_boxes),
'boxes':
preprocess_ops.clip_or_pad_to_fixed_size(
boxes, self._max_num_boxes)
}
labels.update({
'id':
int(data['source_id']),
'image_info':
image_info,
'is_crowd':
preprocess_ops.clip_or_pad_to_fixed_size(
is_crowd, self._max_num_boxes),
'gt_boxes':
preprocess_ops.clip_or_pad_to_fixed_size(
gt_boxes, self._max_num_boxes),
})
return image, labels
\ No newline at end of file
......@@ -3,4 +3,4 @@ python3 official/projects/detr/train.py \
--experiment=detr_coco \
--mode=train_and_eval \
--model_dir=/tmp/logging_dir/ \
--params_override=task.init_checkpoint='gs://tf_model_garden/vision/resnet50_imagenet/ckpt-62400',trainer.train_steps=554400
--params_override=task.init_ckpt='gs://tf_model_garden/vision/resnet50_imagenet/ckpt-62400',trainer.train_steps=554400
......@@ -3,4 +3,4 @@ python3 official/projects/detr/train.py \
--experiment=detr_coco \
--mode=train_and_eval \
--model_dir=/tmp/logging_dir/ \
--params_override=task.init_checkpoint='gs://tf_model_garden/vision/resnet50_imagenet/ckpt-62400'
--params_override=task.init_ckpt='gs://tf_model_garden/vision/resnet50_imagenet/ckpt-62400'
......@@ -100,7 +100,7 @@ class DETR(tf.keras.Model):
class and box heads.
"""
def __init__(self, backbone, num_queries, hidden_size, num_classes,
def __init__(self, num_queries, hidden_size, num_classes,
num_encoder_layers=6,
num_decoder_layers=6,
dropout_rate=0.1,
......@@ -116,9 +116,7 @@ class DETR(tf.keras.Model):
raise ValueError("hidden_size must be a multiple of 2.")
# TODO(frederickliu): Consider using the backbone factory.
# TODO(frederickliu): Add to factory once we get skeleton code in.
#self._backbone = resnet.ResNet(101, bn_trainable=False)
# (gunho) use backbone factory
self._backbone = backbone
self._backbone = resnet.ResNet(50, bn_trainable=False)
def build(self, input_shape=None):
self._input_proj = tf.keras.layers.Conv2D(
......
......@@ -13,14 +13,17 @@
# limitations under the License.
"""Tensorflow implementation to solve the Linear Sum Assignment problem.
The Linear Sum Assignment problem involves determining the minimum weight
matching for bipartite graphs. For example, this problem can be defined by
a 2D matrix C, where each element i,j determines the cost of matching worker i
with job j. The solution to the problem is a complete assignment of jobs to
workers, such that no job is assigned to more than one work and no worker is
assigned more than one job, with minimum cost.
This implementation builds off of the Hungarian
Matching Algorithm (https://www.cse.ust.hk/~golin/COMP572/Notes/Matching.pdf).
Based on the original implementation by Jiquan Ngiam <jngiam@google.com>.
"""
import tensorflow as tf
......@@ -29,14 +32,17 @@ from official.modeling import tf_utils
def _prepare(weights):
"""Prepare the cost matrix.
To speed up computational efficiency of the algorithm, all weights are shifted
to be non-negative. Each element is reduced by the row / column minimum. Note
that neither operation will effect the resulting solution but will provide
a better starting point for the greedy assignment. Note this corresponds to
the pre-processing and step 1 of the Hungarian algorithm from Wikipedia.
Args:
weights: A float32 [batch_size, num_elems, num_elems] tensor, where each
inner matrix represents weights to be use for matching.
Returns:
A prepared weights tensor of the same shape and dtype.
"""
......@@ -49,15 +55,18 @@ def _prepare(weights):
def _greedy_assignment(adj_matrix):
"""Greedily assigns workers to jobs based on an adjaceny matrix.
Starting with an adjacency matrix representing the available connections
in the bi-partite graph, this function greedily chooses elements such
that each worker is matched to at most one job (or each job is assigned to
at most one worker). Note, if the adjacency matrix has no available values
for a particular row/column, the corresponding job/worker may go unassigned.
Args:
adj_matrix: A bool [batch_size, num_elems, num_elems] tensor, where each
element of the inner matrix represents whether the worker (row) can be
matched to the job (column).
Returns:
A bool [batch_size, num_elems, num_elems] tensor, where each element of the
inner matrix represents whether the worker has been matched to the job.
......@@ -110,12 +119,15 @@ def _greedy_assignment(adj_matrix):
def _find_augmenting_path(assignment, adj_matrix):
"""Finds an augmenting path given an assignment and an adjacency matrix.
The augmenting path search starts from the unassigned workers, then goes on
to find jobs (via an unassigned pairing), then back again to workers (via an
existing pairing), and so on. The path alternates between unassigned and
existing pairings. Returns the state after the search.
Note: In the state the worker and job, indices are 1-indexed so that we can
use 0 to represent unreachable nodes. State contains the following keys:
- jobs: A [batch_size, 1, num_elems] tensor containing the highest index
unassigned worker that can reach this job through a path.
- jobs_from_worker: A [batch_size, num_elems] tensor containing the worker
......@@ -126,7 +138,9 @@ def _find_augmenting_path(assignment, adj_matrix):
reached immediately before this worker.
- new_jobs: A bool [batch_size, num_elems] tensor containing True if the
unassigned job can be reached via a path.
State can be used to recover the path via backtracking.
Args:
assignment: A bool [batch_size, num_elems, num_elems] tensor, where each
element of the inner matrix represents whether the worker has been matched
......@@ -134,6 +148,7 @@ def _find_augmenting_path(assignment, adj_matrix):
adj_matrix: A bool [batch_size, num_elems, num_elems] tensor, where each
element of the inner matrix represents whether the worker (row) can be
matched to the job (column).
Returns:
A state dict, which represents the outcome of running an augmenting
path search on the graph given the assignment.
......@@ -220,12 +235,14 @@ def _find_augmenting_path(assignment, adj_matrix):
def _improve_assignment(assignment, state):
"""Improves an assignment by backtracking the augmented path using state.
Args:
assignment: A bool [batch_size, num_elems, num_elems] tensor, where each
element of the inner matrix represents whether the worker has been matched
to the job. This may be a partial assignment.
state: A dict, which represents the outcome of running an augmenting path
search on the graph given the assignment.
Returns:
A new assignment matrix of the same shape and type as assignment, where the
assignment has been updated using the augmented path found.
......@@ -300,6 +317,7 @@ def _improve_assignment(assignment, state):
def _maximum_bipartite_matching(adj_matrix, assignment=None):
"""Performs maximum bipartite matching using augmented paths.
Args:
adj_matrix: A bool [batch_size, num_elems, num_elems] tensor, where each
element of the inner matrix represents whether the worker (row) can be
......@@ -308,6 +326,7 @@ def _maximum_bipartite_matching(adj_matrix, assignment=None):
where each element of the inner matrix represents whether the worker has
been matched to the job. This may be a partial assignment. If specified,
this assignment will be used to seed the iterative algorithm.
Returns:
A state dict representing the final augmenting path state search, and
a maximum bipartite matching assignment tensor. Note that the state outcome
......@@ -338,9 +357,11 @@ def _maximum_bipartite_matching(adj_matrix, assignment=None):
def _compute_cover(state, assignment):
"""Computes a cover for the bipartite graph.
We compute a cover using the construction provided at
https://en.wikipedia.org/wiki/K%C5%91nig%27s_theorem_(graph_theory)#Proof
which uses the outcome from the alternating path search.
Args:
state: A state dict, which represents the outcome of running an augmenting
path search on the graph given the assignment.
......@@ -348,6 +369,7 @@ def _compute_cover(state, assignment):
where each element of the inner matrix represents whether the worker has
been matched to the job. This may be a partial assignment. If specified,
this assignment will be used to seed the iterative algorithm.
Returns:
A tuple of (workers_cover, jobs_cover) corresponding to row and column
covers for the bipartite graph. workers_cover is a boolean tensor of shape
......@@ -368,13 +390,16 @@ def _compute_cover(state, assignment):
def _update_weights_using_cover(workers_cover, jobs_cover, weights):
"""Updates weights for hungarian matching using a cover.
We first find the minimum uncovered weight. Then, we subtract this from all
the uncovered weights, and add it to all the doubly covered weights.
Args:
workers_cover: A boolean tensor of shape [batch_size, num_elems, 1].
jobs_cover: A boolean tensor of shape [batch_size, 1, num_elems].
weights: A float32 [batch_size, num_elems, num_elems] tensor, where each
inner matrix represents weights to be use for matching.
Returns:
A new weight matrix with elements adjusted by the cover.
"""
......@@ -398,10 +423,12 @@ def _update_weights_using_cover(workers_cover, jobs_cover, weights):
def assert_rank(tensor, expected_rank, name=None):
"""Raises an exception if the tensor rank is not of the expected rank.
Args:
tensor: A tf.Tensor to check the rank of.
expected_rank: Python integer or list of integers, expected rank.
name: Optional name of the tensor for the error message.
Raises:
ValueError: If the expected shape doesn't match the actual shape.
"""
......@@ -422,9 +449,11 @@ def assert_rank(tensor, expected_rank, name=None):
def hungarian_matching(weights):
"""Computes the minimum linear sum assignment using the Hungarian algorithm.
Args:
weights: A float32 [batch_size, num_elems, num_elems] tensor, where each
inner matrix represents weights to be use for matching.
Returns:
A bool [batch_size, num_elems, num_elems] tensor, where each element of the
inner matrix represents whether the worker has been matched to the job.
......@@ -456,4 +485,5 @@ def hungarian_matching(weights):
_update_weights_and_match,
(workers_cover, jobs_cover, weights, assignment),
back_prop=False)
return weights, assignment
\ No newline at end of file
return weights, assignment
......@@ -13,28 +13,20 @@
# limitations under the License.
"""DETR detection task definition."""
from typing import Any, List, Mapping, Optional, Tuple
from absl import logging
import tensorflow as tf
from official.common import dataset_fn
from official.core import base_task
from official.core import task_factory
from official.projects.detr.configs import detr as detr_cfg
from official.projects.detr.dataloaders import coco
from official.projects.detr.modeling import detr
from official.projects.detr.ops import matchers
from official.vision.evaluation import coco_evaluator
from official.vision.ops import box_ops
from official.vision.dataloaders import input_reader_factory
from official.vision.dataloaders import tf_example_decoder
from official.vision.dataloaders import tfds_factory
from official.vision.dataloaders import tf_example_label_map_decoder
from official.projects.detr.dataloaders import detr_input
from official.projects.detr.dataloaders import coco
from official.vision.modeling import backbones
@task_factory.register_task_cls(detr_cfg.DetrTask)
@task_factory.register_task_cls(detr_cfg.DetectionConfig)
class DectectionTask(base_task.Task):
"""A single-replica view of training procedure.
......@@ -45,104 +37,46 @@ class DectectionTask(base_task.Task):
def build_model(self):
"""Build DETR model."""
input_specs = tf.keras.layers.InputSpec(
shape=[None] + self._task_config.model.input_size)
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
backbone_config=self._task_config.model.backbone,
norm_activation_config=self._task_config.model.norm_activation)
model = detr.DETR(
backbone,
self._task_config.model.num_queries,
self._task_config.model.hidden_size,
self._task_config.model.num_classes,
self._task_config.model.num_encoder_layers,
self._task_config.model.num_decoder_layers)
self._task_config.num_queries,
self._task_config.num_hidden,
self._task_config.num_classes,
self._task_config.num_encoder_layers,
self._task_config.num_decoder_layers)
return model
def initialize(self, model: tf.keras.Model):
"""Loading pretrained checkpoint."""
if not self._task_config.init_checkpoint:
return
ckpt_dir_or_file = self._task_config.init_checkpoint
# Restoring checkpoint.
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if self._task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.assert_consumed()
elif self._task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def build_inputs(self,
params,
input_context: Optional[tf.distribute.InputContext] = None):
ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.read(self._task_config.init_ckpt)
status.expect_partial().assert_existing_objects_matched()
def build_inputs(self, params, input_context=None):
"""Build input dataset."""
if type(params) is coco.COCODataConfig:
dataset = coco.COCODataLoader(params).load(input_context)
else:
if params.tfds_name:
decoder = tfds_factory.get_detection_decoder(params.tfds_name)
else:
decoder_cfg = params.decoder.get()
if params.decoder.type == 'simple_decoder':
decoder = tf_example_decoder.TfExampleDecoder(
regenerate_source_id=decoder_cfg.regenerate_source_id)
elif params.decoder.type == 'label_map_decoder':
decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
label_map=decoder_cfg.label_map,
regenerate_source_id=decoder_cfg.regenerate_source_id)
else:
raise ValueError('Unknown decoder type: {}!'.format(
params.decoder.type))
parser = detr_input.Parser(
class_offset=self._task_config.losses.class_offset,
output_size=self._task_config.model.input_size[:2],
)
reader = input_reader_factory.input_reader_generator(
params,
dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
decoder_fn=decoder.decode,
parser_fn=parser.parse_fn(params.is_training))
dataset = reader.read(input_context=input_context)
return dataset
return coco.COCODataLoader(params).load(input_context)
def _compute_cost(self, cls_outputs, box_outputs, cls_targets, box_targets):
# Approximate classification cost with 1 - prob[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
# background: 0
cls_cost = self._task_config.losses.lambda_cls * tf.gather(
cls_cost = self._task_config.lambda_cls * tf.gather(
-tf.nn.softmax(cls_outputs), cls_targets, batch_dims=1, axis=-1)
# Compute the L1 cost between boxes,
paired_differences = self._task_config.losses.lambda_box * tf.abs(
paired_differences = self._task_config.lambda_box * tf.abs(
tf.expand_dims(box_outputs, 2) - tf.expand_dims(box_targets, 1))
box_cost = tf.reduce_sum(paired_differences, axis=-1)
# Compute the giou cost betwen boxes
giou_cost = self._task_config.losses.lambda_giou * -box_ops.bbox_generalized_overlap(
giou_cost = self._task_config.lambda_giou * -box_ops.bbox_generalized_overlap(
box_ops.cycxhw_to_yxyx(box_outputs),
box_ops.cycxhw_to_yxyx(box_targets))
total_cost = cls_cost + box_cost + giou_cost
max_cost = (
self._task_config.losses.lambda_cls * 0.0 + self._task_config.losses.lambda_box * 4. +
self._task_config.losses.lambda_giou * 0.0)
self._task_config.lambda_cls * 0.0 + self._task_config.lambda_box * 4. +
self._task_config.lambda_giou * 0.0)
# Set pads to large constant
valid = tf.expand_dims(
......@@ -181,20 +115,20 @@ class DectectionTask(base_task.Task):
# Down-weight background to account for class imbalance.
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=cls_targets, logits=cls_assigned)
cls_loss = self._task_config.losses.lambda_cls * tf.where(
cls_loss = self._task_config.lambda_cls * tf.where(
background,
self._task_config.losses.background_cls_weight * xentropy,
self._task_config.background_cls_weight * xentropy,
xentropy
)
cls_weights = tf.where(
background,
self._task_config.losses.background_cls_weight * tf.ones_like(cls_loss),
self._task_config.background_cls_weight * tf.ones_like(cls_loss),
tf.ones_like(cls_loss)
)
# Box loss is only calculated on non-background class.
l_1 = tf.reduce_sum(tf.abs(box_assigned - box_targets), axis=-1)
box_loss = self._task_config.losses.lambda_box * tf.where(
box_loss = self._task_config.lambda_box * tf.where(
background,
tf.zeros_like(l_1),
l_1
......@@ -205,7 +139,7 @@ class DectectionTask(base_task.Task):
box_ops.cycxhw_to_yxyx(box_assigned),
box_ops.cycxhw_to_yxyx(box_targets)
))
giou_loss = self._task_config.losses.lambda_giou * tf.where(
giou_loss = self._task_config.lambda_giou * tf.where(
background,
tf.zeros_like(giou),
giou
......@@ -226,7 +160,6 @@ class DectectionTask(base_task.Task):
tf.reduce_sum(giou_loss), num_boxes_sum)
aux_losses = tf.add_n(aux_losses) if aux_losses else 0.0
total_loss = cls_loss + box_loss + giou_loss + aux_losses
return total_loss, cls_loss, box_loss, giou_loss
......@@ -239,7 +172,7 @@ class DectectionTask(base_task.Task):
if not training:
self.coco_metric = coco_evaluator.COCOEvaluator(
annotation_file=self._task_config.annotation_file,
annotation_file='',
include_mask=False,
need_rescale_bboxes=True,
per_category_metrics=self._task_config.per_category_metrics)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment