"vscode:/vscode.git/clone" did not exist on "eb90d3be139cbb353e443460ee13f8fabe098cfb"
Commit 94220a58 authored by Gunho Park's avatar Gunho Park
Browse files

TPU compatible

parent a5bbb547
......@@ -18,20 +18,55 @@ import dataclasses
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.projects.detr import optimization
from official.projects.detr.dataloaders import coco
import os
from official.vision.configs import common
# pylint: disable=missing-class-docstring
# Keep for backward compatibility.
@dataclasses.dataclass
class TfExampleDecoder(common.TfExampleDecoder):
"""A simple TF Example decoder config."""
# Keep for backward compatibility.
@dataclasses.dataclass
class TfExampleDecoderLabelMap(common.TfExampleDecoderLabelMap):
"""TF Example decoder with label map config."""
# Keep for backward compatibility.
@dataclasses.dataclass
class DataDecoder(common.DataDecoder):
"""Data decoder config."""
@dataclasses.dataclass
class DataConfig(cfg.DataConfig):
"""Input config for training."""
input_path: str = ''
global_batch_size: int = 0
is_training: bool = False
dtype: str = 'bfloat16'
decoder: common.DataDecoder = common.DataDecoder()
#parser: Parser = Parser()
shuffle_buffer_size: int = 10000
file_type: str = 'tfrecord'
@dataclasses.dataclass
class DetectionConfig(cfg.TaskConfig):
"""The translation task config."""
annotation_file: str = ''
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
lambda_cls: float = 1.0
lambda_box: float = 5.0
lambda_giou: float = 2.0
init_ckpt: str = ''
num_classes: int = 81 # 0: background
#init_ckpt: str = ''
init_checkpoint: str = 'gs://ghpark-imagenet-tfrecord/ckpt/resnet50_imagenet'
init_checkpoint_modules: str = 'backbone'
#num_classes: int = 81 # 0: background
num_classes: int = 91 # 0: background
background_cls_weight: float = 0.1
num_encoder_layers: int = 6
num_decoder_layers: int = 6
......@@ -41,40 +76,44 @@ class DetectionConfig(cfg.TaskConfig):
num_hidden: int = 256
per_category_metrics: bool = False
COCO_INPUT_PATH_BASE = 'gs://ghpark-tfrecords/coco'
#COCO_TRAIN_EXAMPLES = 118287
COCO_TRAIN_EXAMPLES = 960
COCO_VAL_EXAMPLES = 5000
@exp_factory.register_config_factory('detr_coco')
def detr_coco() -> cfg.ExperimentConfig:
"""Config to get results that matches the paper."""
train_batch_size = 64
train_batch_size = 32
eval_batch_size = 64
num_train_data = 118287
num_steps_per_epoch = num_train_data // train_batch_size
train_steps = 500 * num_steps_per_epoch # 500 epochs
decay_at = train_steps - 100 * num_steps_per_epoch # 400 epochs
steps_per_epoch = COCO_TRAIN_EXAMPLES // train_batch_size
train_steps = 300 * steps_per_epoch # 500 epochs
decay_at = train_steps - 100 * steps_per_epoch # 400 epochs
config = cfg.ExperimentConfig(
task=DetectionConfig(
train_data=coco.COCODataConfig(
tfds_name='coco/2017',
tfds_split='train',
annotation_file=os.path.join(COCO_INPUT_PATH_BASE,
'instances_val2017.json'),
train_data=DataConfig(
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size,
shuffle_buffer_size=1000,
),
validation_data=coco.COCODataConfig(
tfds_name='coco/2017',
tfds_split='validation',
validation_data=DataConfig(
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'val*'),
is_training=False,
global_batch_size=eval_batch_size,
drop_remainder=False
drop_remainder=False,
)
),
trainer=cfg.TrainerConfig(
train_steps=train_steps,
validation_steps=-1,
steps_per_loop=10000,
summary_interval=10000,
checkpoint_interval=10000,
validation_interval=10000,
validation_steps=COCO_VAL_EXAMPLES // eval_batch_size,
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
validation_interval=5*steps_per_epoch,
max_to_keep=1,
best_checkpoint_export_subdir='best_ckpt',
best_checkpoint_eval_metric='AP',
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""COCO data loader for DETR."""
from typing import Optional, Tuple
import tensorflow as tf
from official.vision.dataloaders import parser
from official.vision.dataloaders import utils
from official.vision.ops import box_ops
from official.vision.ops import preprocess_ops
from official.core import input_reader
RESIZE_SCALES = (
480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
class Parser(parser.Parser):
"""Parse an image and its annotations into a dictionary of tensors."""
def __init__(self,
output_size: Tuple[int, int] = (1333, 1333),
max_num_boxes: int = 100,
resize_scales: Tuple[int, ...] = RESIZE_SCALES,
aug_rand_hflip=True):
self._output_size = output_size
self._max_num_boxes = max_num_boxes
self._resize_scales = resize_scales
self._aug_rand_hflip = aug_rand_hflip
def _parse_train_data(self, data):
"""Parses data for training and evaluation."""
#classes = data['groundtruth_classes'] + 1
classes = data['groundtruth_classes']
boxes = data['groundtruth_boxes']
# If not empty, `attributes` is a dict of (name, ground_truth) pairs.
# `ground_gruth` of attributes is assumed in shape [N, attribute_size].
# TODO(xianzhi): support parsing attributes weights.
attributes = data.get('groundtruth_attributes', {})
is_crowd = data['groundtruth_is_crowd']
# Gets original image.
image = data['image']
# Apply autoaug or randaug.
#if self._augmenter is not None:
# image, boxes = self._augmenter.distort_with_boxes(image, boxes)
# Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image)
image, boxes, _ = preprocess_ops.random_horizontal_flip(image, boxes)
do_crop = tf.greater(tf.random.uniform([]), 0.5)
if do_crop:
# Rescale
boxes = box_ops.denormalize_boxes(boxes, tf.shape(image)[:2])
index = tf.random.categorical(tf.zeros([1, 3]), 1)[0]
scales = tf.gather([400.0, 500.0, 600.0], index, axis=0)
short_side = scales[0]
image, image_info = preprocess_ops.resize_image(image, short_side)
boxes = preprocess_ops.resize_and_crop_boxes(boxes,
image_info[2, :],
image_info[1, :],
image_info[3, :])
boxes = box_ops.normalize_boxes(boxes, image_info[1, :])
# Do croping
shape = tf.cast(image_info[1], dtype=tf.int32)
h = tf.random.uniform(
[], 384, tf.math.minimum(shape[0], 600), dtype=tf.int32)
w = tf.random.uniform(
[], 384, tf.math.minimum(shape[1], 600), dtype=tf.int32)
i = tf.random.uniform([], 0, shape[0] - h + 1, dtype=tf.int32)
j = tf.random.uniform([], 0, shape[1] - w + 1, dtype=tf.int32)
image = tf.image.crop_to_bounding_box(image, i, j, h, w)
boxes = tf.clip_by_value(
(boxes[..., :] * tf.cast(
tf.stack([shape[0], shape[1], shape[0], shape[1]]),
dtype=tf.float32) -
tf.cast(tf.stack([i, j, i, j]), dtype=tf.float32)) /
tf.cast(tf.stack([h, w, h, w]), dtype=tf.float32), 0.0, 1.0)
scales = tf.constant(
self._resize_scales,
dtype=tf.float32)
index = tf.random.categorical(tf.zeros([1, 11]), 1)[0]
scales = tf.gather(scales, index, axis=0)
image_shape = tf.shape(image)[:2]
boxes = box_ops.denormalize_boxes(boxes, image_shape)
gt_boxes = boxes
short_side = scales[0]
image, image_info = preprocess_ops.resize_image(
image,
short_side,
max(self._output_size))
boxes = preprocess_ops.resize_and_crop_boxes(boxes,
image_info[2, :],
image_info[1, :],
image_info[3, :])
boxes = box_ops.normalize_boxes(boxes, image_info[1, :])
# Filters out ground truth boxes that are all zeros.
indices = box_ops.get_non_empty_box_indices(boxes)
boxes = tf.gather(boxes, indices)
classes = tf.gather(classes, indices)
is_crowd = tf.gather(is_crowd, indices)
boxes = box_ops.yxyx_to_cycxhw(boxes)
image = tf.image.pad_to_bounding_box(
image, 0, 0, self._output_size[0], self._output_size[1])
labels = {
'classes':
preprocess_ops.clip_or_pad_to_fixed_size(
classes, self._max_num_boxes),
'boxes':
preprocess_ops.clip_or_pad_to_fixed_size(
boxes, self._max_num_boxes)
}
return image, labels
def _parse_eval_data(self, data):
"""Parses data for training and evaluation."""
groundtruths = {}
classes = data['groundtruth_classes']
boxes = data['groundtruth_boxes']
# If not empty, `attributes` is a dict of (name, ground_truth) pairs.
# `ground_gruth` of attributes is assumed in shape [N, attribute_size].
# TODO(xianzhi): support parsing attributes weights.
attributes = data.get('groundtruth_attributes', {})
is_crowd = data['groundtruth_is_crowd']
# Gets original image and its size.
image = data['image']
# Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image)
scales = tf.constant([self._resize_scales[-1]], tf.float32)
image_shape = tf.shape(image)[:2]
boxes = box_ops.denormalize_boxes(boxes, image_shape)
gt_boxes = boxes
short_side = scales[0]
image, image_info = preprocess_ops.resize_image(
image,
short_side,
max(self._output_size))
boxes = preprocess_ops.resize_and_crop_boxes(boxes,
image_info[2, :],
image_info[1, :],
image_info[3, :])
boxes = box_ops.normalize_boxes(boxes, image_info[1, :])
# Filters out ground truth boxes that are all zeros.
indices = box_ops.get_non_empty_box_indices(boxes)
boxes = tf.gather(boxes, indices)
classes = tf.gather(classes, indices)
is_crowd = tf.gather(is_crowd, indices)
boxes = box_ops.yxyx_to_cycxhw(boxes)
image = tf.image.pad_to_bounding_box(
image, 0, 0, self._output_size[0], self._output_size[1])
labels = {
'classes':
preprocess_ops.clip_or_pad_to_fixed_size(
classes, self._max_num_boxes),
'boxes':
preprocess_ops.clip_or_pad_to_fixed_size(
boxes, self._max_num_boxes)
}
labels.update({
'id':
int(data['source_id']),
'image_info':
image_info,
'is_crowd':
preprocess_ops.clip_or_pad_to_fixed_size(
is_crowd, self._max_num_boxes),
'gt_boxes':
preprocess_ops.clip_or_pad_to_fixed_size(
gt_boxes, self._max_num_boxes),
})
return image, labels
\ No newline at end of file
#!/bin/bash
python3 train.py \
--experiment=detr_coco \
--mode=train_and_eval \
--model_dir=gs://ghpark-ckpts/detr/detr_coco/ckpt_03_test \
--tpu=postech-tpu \
--params_override=runtime.distribution_strategy='tpu'
\ No newline at end of file
......@@ -13,17 +13,14 @@
# limitations under the License.
"""Tensorflow implementation to solve the Linear Sum Assignment problem.
The Linear Sum Assignment problem involves determining the minimum weight
matching for bipartite graphs. For example, this problem can be defined by
a 2D matrix C, where each element i,j determines the cost of matching worker i
with job j. The solution to the problem is a complete assignment of jobs to
workers, such that no job is assigned to more than one work and no worker is
assigned more than one job, with minimum cost.
This implementation builds off of the Hungarian
Matching Algorithm (https://www.cse.ust.hk/~golin/COMP572/Notes/Matching.pdf).
Based on the original implementation by Jiquan Ngiam <jngiam@google.com>.
"""
import tensorflow as tf
......@@ -32,17 +29,14 @@ from official.modeling import tf_utils
def _prepare(weights):
"""Prepare the cost matrix.
To speed up computational efficiency of the algorithm, all weights are shifted
to be non-negative. Each element is reduced by the row / column minimum. Note
that neither operation will effect the resulting solution but will provide
a better starting point for the greedy assignment. Note this corresponds to
the pre-processing and step 1 of the Hungarian algorithm from Wikipedia.
Args:
weights: A float32 [batch_size, num_elems, num_elems] tensor, where each
inner matrix represents weights to be use for matching.
Returns:
A prepared weights tensor of the same shape and dtype.
"""
......@@ -55,18 +49,15 @@ def _prepare(weights):
def _greedy_assignment(adj_matrix):
"""Greedily assigns workers to jobs based on an adjaceny matrix.
Starting with an adjacency matrix representing the available connections
in the bi-partite graph, this function greedily chooses elements such
that each worker is matched to at most one job (or each job is assigned to
at most one worker). Note, if the adjacency matrix has no available values
for a particular row/column, the corresponding job/worker may go unassigned.
Args:
adj_matrix: A bool [batch_size, num_elems, num_elems] tensor, where each
element of the inner matrix represents whether the worker (row) can be
matched to the job (column).
Returns:
A bool [batch_size, num_elems, num_elems] tensor, where each element of the
inner matrix represents whether the worker has been matched to the job.
......@@ -119,15 +110,12 @@ def _greedy_assignment(adj_matrix):
def _find_augmenting_path(assignment, adj_matrix):
"""Finds an augmenting path given an assignment and an adjacency matrix.
The augmenting path search starts from the unassigned workers, then goes on
to find jobs (via an unassigned pairing), then back again to workers (via an
existing pairing), and so on. The path alternates between unassigned and
existing pairings. Returns the state after the search.
Note: In the state the worker and job, indices are 1-indexed so that we can
use 0 to represent unreachable nodes. State contains the following keys:
- jobs: A [batch_size, 1, num_elems] tensor containing the highest index
unassigned worker that can reach this job through a path.
- jobs_from_worker: A [batch_size, num_elems] tensor containing the worker
......@@ -138,9 +126,7 @@ def _find_augmenting_path(assignment, adj_matrix):
reached immediately before this worker.
- new_jobs: A bool [batch_size, num_elems] tensor containing True if the
unassigned job can be reached via a path.
State can be used to recover the path via backtracking.
Args:
assignment: A bool [batch_size, num_elems, num_elems] tensor, where each
element of the inner matrix represents whether the worker has been matched
......@@ -148,7 +134,6 @@ def _find_augmenting_path(assignment, adj_matrix):
adj_matrix: A bool [batch_size, num_elems, num_elems] tensor, where each
element of the inner matrix represents whether the worker (row) can be
matched to the job (column).
Returns:
A state dict, which represents the outcome of running an augmenting
path search on the graph given the assignment.
......@@ -235,14 +220,12 @@ def _find_augmenting_path(assignment, adj_matrix):
def _improve_assignment(assignment, state):
"""Improves an assignment by backtracking the augmented path using state.
Args:
assignment: A bool [batch_size, num_elems, num_elems] tensor, where each
element of the inner matrix represents whether the worker has been matched
to the job. This may be a partial assignment.
state: A dict, which represents the outcome of running an augmenting path
search on the graph given the assignment.
Returns:
A new assignment matrix of the same shape and type as assignment, where the
assignment has been updated using the augmented path found.
......@@ -317,7 +300,6 @@ def _improve_assignment(assignment, state):
def _maximum_bipartite_matching(adj_matrix, assignment=None):
"""Performs maximum bipartite matching using augmented paths.
Args:
adj_matrix: A bool [batch_size, num_elems, num_elems] tensor, where each
element of the inner matrix represents whether the worker (row) can be
......@@ -326,7 +308,6 @@ def _maximum_bipartite_matching(adj_matrix, assignment=None):
where each element of the inner matrix represents whether the worker has
been matched to the job. This may be a partial assignment. If specified,
this assignment will be used to seed the iterative algorithm.
Returns:
A state dict representing the final augmenting path state search, and
a maximum bipartite matching assignment tensor. Note that the state outcome
......@@ -357,11 +338,9 @@ def _maximum_bipartite_matching(adj_matrix, assignment=None):
def _compute_cover(state, assignment):
"""Computes a cover for the bipartite graph.
We compute a cover using the construction provided at
https://en.wikipedia.org/wiki/K%C5%91nig%27s_theorem_(graph_theory)#Proof
which uses the outcome from the alternating path search.
Args:
state: A state dict, which represents the outcome of running an augmenting
path search on the graph given the assignment.
......@@ -369,7 +348,6 @@ def _compute_cover(state, assignment):
where each element of the inner matrix represents whether the worker has
been matched to the job. This may be a partial assignment. If specified,
this assignment will be used to seed the iterative algorithm.
Returns:
A tuple of (workers_cover, jobs_cover) corresponding to row and column
covers for the bipartite graph. workers_cover is a boolean tensor of shape
......@@ -390,16 +368,13 @@ def _compute_cover(state, assignment):
def _update_weights_using_cover(workers_cover, jobs_cover, weights):
"""Updates weights for hungarian matching using a cover.
We first find the minimum uncovered weight. Then, we subtract this from all
the uncovered weights, and add it to all the doubly covered weights.
Args:
workers_cover: A boolean tensor of shape [batch_size, num_elems, 1].
jobs_cover: A boolean tensor of shape [batch_size, 1, num_elems].
weights: A float32 [batch_size, num_elems, num_elems] tensor, where each
inner matrix represents weights to be use for matching.
Returns:
A new weight matrix with elements adjusted by the cover.
"""
......@@ -423,12 +398,10 @@ def _update_weights_using_cover(workers_cover, jobs_cover, weights):
def assert_rank(tensor, expected_rank, name=None):
"""Raises an exception if the tensor rank is not of the expected rank.
Args:
tensor: A tf.Tensor to check the rank of.
expected_rank: Python integer or list of integers, expected rank.
name: Optional name of the tensor for the error message.
Raises:
ValueError: If the expected shape doesn't match the actual shape.
"""
......@@ -449,11 +422,9 @@ def assert_rank(tensor, expected_rank, name=None):
def hungarian_matching(weights):
"""Computes the minimum linear sum assignment using the Hungarian algorithm.
Args:
weights: A float32 [batch_size, num_elems, num_elems] tensor, where each
inner matrix represents weights to be use for matching.
Returns:
A bool [batch_size, num_elems, num_elems] tensor, where each element of the
inner matrix represents whether the worker has been matched to the job.
......@@ -485,5 +456,4 @@ def hungarian_matching(weights):
_update_weights_and_match,
(workers_cover, jobs_cover, weights, assignment),
back_prop=False)
return weights, assignment
return weights, assignment
\ No newline at end of file
......@@ -13,18 +13,24 @@
# limitations under the License.
"""DETR detection task definition."""
from typing import Any, List, Mapping, Optional, Tuple
from absl import logging
import tensorflow as tf
from official.common import dataset_fn
from official.core import base_task
from official.core import task_factory
from official.projects.detr.configs import detr as detr_cfg
from official.projects.detr.dataloaders import coco
from official.projects.detr.modeling import detr
from official.projects.detr.ops import matchers
from official.vision.evaluation import coco_evaluator
from official.vision.ops import box_ops
from official.vision.dataloaders import input_reader_factory
from official.vision.dataloaders import tf_example_decoder
from official.vision.dataloaders import tfds_factory
from official.vision.dataloaders import tf_example_label_map_decoder
from official.projects.detr.dataloaders import detr_input
@task_factory.register_task_cls(detr_cfg.DetectionConfig)
class DectectionTask(base_task.Task):
......@@ -47,13 +53,62 @@ class DectectionTask(base_task.Task):
def initialize(self, model: tf.keras.Model):
"""Loading pretrained checkpoint."""
ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.read(self._task_config.init_ckpt)
status.expect_partial().assert_existing_objects_matched()
def build_inputs(self, params, input_context=None):
if not self._task_config.init_checkpoint:
return
ckpt_dir_or_file = self._task_config.init_checkpoint
# Restoring checkpoint.
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if self._task_config.init_checkpoint_modules == 'all':
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.assert_consumed()
elif self._task_config.init_checkpoint_modules == 'backbone':
ckpt = tf.train.Checkpoint(backbone=model.backbone)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
"""def build_inputs(self,
params: detr_cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None):
return coco.COCODataLoader(params).load(input_context)"""
def build_inputs(self,
params,
input_context: Optional[tf.distribute.InputContext] = None):
"""Build input dataset."""
return coco.COCODataLoader(params).load(input_context)
if params.tfds_name:
decoder = tfds_factory.get_detection_decoder(params.tfds_name)
else:
decoder_cfg = params.decoder.get()
if params.decoder.type == 'simple_decoder':
decoder = tf_example_decoder.TfExampleDecoder(
regenerate_source_id=decoder_cfg.regenerate_source_id)
elif params.decoder.type == 'label_map_decoder':
decoder = tf_example_label_map_decoder.TfExampleDecoderLabelMap(
label_map=decoder_cfg.label_map,
regenerate_source_id=decoder_cfg.regenerate_source_id)
else:
raise ValueError('Unknown decoder type: {}!'.format(
params.decoder.type))
parser = detr_input.Parser()
reader = input_reader_factory.input_reader_generator(
params,
dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
decoder_fn=decoder.decode,
parser_fn=parser.parse_fn(params.is_training))
dataset = reader.read(input_context=input_context)
return dataset
def _compute_cost(self, cls_outputs, box_outputs, cls_targets, box_targets):
# Approximate classification cost with 1 - prob[target class].
......@@ -160,6 +215,7 @@ class DectectionTask(base_task.Task):
tf.reduce_sum(giou_loss), num_boxes_sum)
aux_losses = tf.add_n(aux_losses) if aux_losses else 0.0
total_loss = cls_loss + box_loss + giou_loss + aux_losses
return total_loss, cls_loss, box_loss, giou_loss
......@@ -172,7 +228,7 @@ class DectectionTask(base_task.Task):
if not training:
self.coco_metric = coco_evaluator.COCOEvaluator(
annotation_file='',
annotation_file=self._task_config.annotation_file,
include_mask=False,
need_rescale_bboxes=True,
per_category_metrics=self._task_config.per_category_metrics)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment