Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
94220a58
"vscode:/vscode.git/clone" did not exist on "eb90d3be139cbb353e443460ee13f8fabe098cfb"
Commit
94220a58
authored
Jun 25, 2022
by
Gunho Park
Browse files
TPU compatible
parent
a5bbb547
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
331 additions
and
59 deletions
+331
-59
official/projects/detr/configs/detr.py
official/projects/detr/configs/detr.py
+58
-19
official/projects/detr/dataloaders/detr_input.py
official/projects/detr/dataloaders/detr_input.py
+200
-0
official/projects/detr/do_train.sh
official/projects/detr/do_train.sh
+7
-0
official/projects/detr/ops/matchers.py
official/projects/detr/ops/matchers.py
+1
-31
official/projects/detr/tasks/detection.py
official/projects/detr/tasks/detection.py
+65
-9
No files found.
official/projects/detr/configs/detr.py
View file @
94220a58
...
...
@@ -18,20 +18,55 @@ import dataclasses
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.projects.detr
import
optimization
from
official.projects.detr.dataloaders
import
coco
import
os
from
official.vision.configs
import
common
# pylint: disable=missing-class-docstring
# Keep for backward compatibility.
@
dataclasses
.
dataclass
class
TfExampleDecoder
(
common
.
TfExampleDecoder
):
"""A simple TF Example decoder config."""
# Keep for backward compatibility.
@
dataclasses
.
dataclass
class
TfExampleDecoderLabelMap
(
common
.
TfExampleDecoderLabelMap
):
"""TF Example decoder with label map config."""
# Keep for backward compatibility.
@
dataclasses
.
dataclass
class
DataDecoder
(
common
.
DataDecoder
):
"""Data decoder config."""
@
dataclasses
.
dataclass
class
DataConfig
(
cfg
.
DataConfig
):
"""Input config for training."""
input_path
:
str
=
''
global_batch_size
:
int
=
0
is_training
:
bool
=
False
dtype
:
str
=
'bfloat16'
decoder
:
common
.
DataDecoder
=
common
.
DataDecoder
()
#parser: Parser = Parser()
shuffle_buffer_size
:
int
=
10000
file_type
:
str
=
'tfrecord'
@
dataclasses
.
dataclass
class
DetectionConfig
(
cfg
.
TaskConfig
):
"""The translation task config."""
annotation_file
:
str
=
''
train_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
validation_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
lambda_cls
:
float
=
1.0
lambda_box
:
float
=
5.0
lambda_giou
:
float
=
2.0
init_ckpt
:
str
=
''
num_classes
:
int
=
81
# 0: background
#init_ckpt: str = ''
init_checkpoint
:
str
=
'gs://ghpark-imagenet-tfrecord/ckpt/resnet50_imagenet'
init_checkpoint_modules
:
str
=
'backbone'
#num_classes: int = 81 # 0: background
num_classes
:
int
=
91
# 0: background
background_cls_weight
:
float
=
0.1
num_encoder_layers
:
int
=
6
num_decoder_layers
:
int
=
6
...
...
@@ -41,40 +76,44 @@ class DetectionConfig(cfg.TaskConfig):
num_hidden
:
int
=
256
per_category_metrics
:
bool
=
False
COCO_INPUT_PATH_BASE
=
'gs://ghpark-tfrecords/coco'
#COCO_TRAIN_EXAMPLES = 118287
COCO_TRAIN_EXAMPLES
=
960
COCO_VAL_EXAMPLES
=
5000
@
exp_factory
.
register_config_factory
(
'detr_coco'
)
def
detr_coco
()
->
cfg
.
ExperimentConfig
:
"""Config to get results that matches the paper."""
train_batch_size
=
64
train_batch_size
=
32
eval_batch_size
=
64
num_train_data
=
118287
num_
steps_per_epoch
=
num_train_data
//
train_batch_size
train_steps
=
5
00
*
num_
steps_per_epoch
# 500 epochs
decay_at
=
train_steps
-
100
*
num_
steps_per_epoch
# 400 epochs
steps_per_epoch
=
COCO_TRAIN_EXAMPLES
//
train_batch_size
train_steps
=
3
00
*
steps_per_epoch
# 500 epochs
decay_at
=
train_steps
-
100
*
steps_per_epoch
# 400 epochs
config
=
cfg
.
ExperimentConfig
(
task
=
DetectionConfig
(
train_data
=
coco
.
COCODataConfig
(
tfds_name
=
'coco/2017'
,
tfds_split
=
'train'
,
annotation_file
=
os
.
path
.
join
(
COCO_INPUT_PATH_BASE
,
'instances_val2017.json'
),
train_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
COCO_INPUT_PATH_BASE
,
'train*'
),
is_training
=
True
,
global_batch_size
=
train_batch_size
,
shuffle_buffer_size
=
1000
,
),
validation_data
=
coco
.
COCODataConfig
(
tfds_name
=
'coco/2017'
,
tfds_split
=
'validation'
,
validation_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
COCO_INPUT_PATH_BASE
,
'val*'
),
is_training
=
False
,
global_batch_size
=
eval_batch_size
,
drop_remainder
=
False
drop_remainder
=
False
,
)
),
trainer
=
cfg
.
TrainerConfig
(
train_steps
=
train_steps
,
validation_steps
=
-
1
,
steps_per_loop
=
10000
,
summary_interval
=
10000
,
checkpoint_interval
=
10000
,
validation_interval
=
10000
,
validation_steps
=
COCO_VAL_EXAMPLES
//
eval_batch_size
,
steps_per_loop
=
steps_per_epoch
,
summary_interval
=
steps_per_epoch
,
checkpoint_interval
=
steps_per_epoch
,
validation_interval
=
5
*
steps_per_epoch
,
max_to_keep
=
1
,
best_checkpoint_export_subdir
=
'best_ckpt'
,
best_checkpoint_eval_metric
=
'AP'
,
...
...
official/projects/detr/dataloaders/detr_input.py
0 → 100644
View file @
94220a58
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""COCO data loader for DETR."""
from
typing
import
Optional
,
Tuple
import
tensorflow
as
tf
from
official.vision.dataloaders
import
parser
from
official.vision.dataloaders
import
utils
from
official.vision.ops
import
box_ops
from
official.vision.ops
import
preprocess_ops
from
official.core
import
input_reader
RESIZE_SCALES
=
(
480
,
512
,
544
,
576
,
608
,
640
,
672
,
704
,
736
,
768
,
800
)
class
Parser
(
parser
.
Parser
):
"""Parse an image and its annotations into a dictionary of tensors."""
def
__init__
(
self
,
output_size
:
Tuple
[
int
,
int
]
=
(
1333
,
1333
),
max_num_boxes
:
int
=
100
,
resize_scales
:
Tuple
[
int
,
...]
=
RESIZE_SCALES
,
aug_rand_hflip
=
True
):
self
.
_output_size
=
output_size
self
.
_max_num_boxes
=
max_num_boxes
self
.
_resize_scales
=
resize_scales
self
.
_aug_rand_hflip
=
aug_rand_hflip
def
_parse_train_data
(
self
,
data
):
"""Parses data for training and evaluation."""
#classes = data['groundtruth_classes'] + 1
classes
=
data
[
'groundtruth_classes'
]
boxes
=
data
[
'groundtruth_boxes'
]
# If not empty, `attributes` is a dict of (name, ground_truth) pairs.
# `ground_gruth` of attributes is assumed in shape [N, attribute_size].
# TODO(xianzhi): support parsing attributes weights.
attributes
=
data
.
get
(
'groundtruth_attributes'
,
{})
is_crowd
=
data
[
'groundtruth_is_crowd'
]
# Gets original image.
image
=
data
[
'image'
]
# Apply autoaug or randaug.
#if self._augmenter is not None:
# image, boxes = self._augmenter.distort_with_boxes(image, boxes)
# Normalizes image with mean and std pixel values.
image
=
preprocess_ops
.
normalize_image
(
image
)
image
,
boxes
,
_
=
preprocess_ops
.
random_horizontal_flip
(
image
,
boxes
)
do_crop
=
tf
.
greater
(
tf
.
random
.
uniform
([]),
0.5
)
if
do_crop
:
# Rescale
boxes
=
box_ops
.
denormalize_boxes
(
boxes
,
tf
.
shape
(
image
)[:
2
])
index
=
tf
.
random
.
categorical
(
tf
.
zeros
([
1
,
3
]),
1
)[
0
]
scales
=
tf
.
gather
([
400.0
,
500.0
,
600.0
],
index
,
axis
=
0
)
short_side
=
scales
[
0
]
image
,
image_info
=
preprocess_ops
.
resize_image
(
image
,
short_side
)
boxes
=
preprocess_ops
.
resize_and_crop_boxes
(
boxes
,
image_info
[
2
,
:],
image_info
[
1
,
:],
image_info
[
3
,
:])
boxes
=
box_ops
.
normalize_boxes
(
boxes
,
image_info
[
1
,
:])
# Do croping
shape
=
tf
.
cast
(
image_info
[
1
],
dtype
=
tf
.
int32
)
h
=
tf
.
random
.
uniform
(
[],
384
,
tf
.
math
.
minimum
(
shape
[
0
],
600
),
dtype
=
tf
.
int32
)
w
=
tf
.
random
.
uniform
(
[],
384
,
tf
.
math
.
minimum
(
shape
[
1
],
600
),
dtype
=
tf
.
int32
)
i
=
tf
.
random
.
uniform
([],
0
,
shape
[
0
]
-
h
+
1
,
dtype
=
tf
.
int32
)
j
=
tf
.
random
.
uniform
([],
0
,
shape
[
1
]
-
w
+
1
,
dtype
=
tf
.
int32
)
image
=
tf
.
image
.
crop_to_bounding_box
(
image
,
i
,
j
,
h
,
w
)
boxes
=
tf
.
clip_by_value
(
(
boxes
[...,
:]
*
tf
.
cast
(
tf
.
stack
([
shape
[
0
],
shape
[
1
],
shape
[
0
],
shape
[
1
]]),
dtype
=
tf
.
float32
)
-
tf
.
cast
(
tf
.
stack
([
i
,
j
,
i
,
j
]),
dtype
=
tf
.
float32
))
/
tf
.
cast
(
tf
.
stack
([
h
,
w
,
h
,
w
]),
dtype
=
tf
.
float32
),
0.0
,
1.0
)
scales
=
tf
.
constant
(
self
.
_resize_scales
,
dtype
=
tf
.
float32
)
index
=
tf
.
random
.
categorical
(
tf
.
zeros
([
1
,
11
]),
1
)[
0
]
scales
=
tf
.
gather
(
scales
,
index
,
axis
=
0
)
image_shape
=
tf
.
shape
(
image
)[:
2
]
boxes
=
box_ops
.
denormalize_boxes
(
boxes
,
image_shape
)
gt_boxes
=
boxes
short_side
=
scales
[
0
]
image
,
image_info
=
preprocess_ops
.
resize_image
(
image
,
short_side
,
max
(
self
.
_output_size
))
boxes
=
preprocess_ops
.
resize_and_crop_boxes
(
boxes
,
image_info
[
2
,
:],
image_info
[
1
,
:],
image_info
[
3
,
:])
boxes
=
box_ops
.
normalize_boxes
(
boxes
,
image_info
[
1
,
:])
# Filters out ground truth boxes that are all zeros.
indices
=
box_ops
.
get_non_empty_box_indices
(
boxes
)
boxes
=
tf
.
gather
(
boxes
,
indices
)
classes
=
tf
.
gather
(
classes
,
indices
)
is_crowd
=
tf
.
gather
(
is_crowd
,
indices
)
boxes
=
box_ops
.
yxyx_to_cycxhw
(
boxes
)
image
=
tf
.
image
.
pad_to_bounding_box
(
image
,
0
,
0
,
self
.
_output_size
[
0
],
self
.
_output_size
[
1
])
labels
=
{
'classes'
:
preprocess_ops
.
clip_or_pad_to_fixed_size
(
classes
,
self
.
_max_num_boxes
),
'boxes'
:
preprocess_ops
.
clip_or_pad_to_fixed_size
(
boxes
,
self
.
_max_num_boxes
)
}
return
image
,
labels
def
_parse_eval_data
(
self
,
data
):
"""Parses data for training and evaluation."""
groundtruths
=
{}
classes
=
data
[
'groundtruth_classes'
]
boxes
=
data
[
'groundtruth_boxes'
]
# If not empty, `attributes` is a dict of (name, ground_truth) pairs.
# `ground_gruth` of attributes is assumed in shape [N, attribute_size].
# TODO(xianzhi): support parsing attributes weights.
attributes
=
data
.
get
(
'groundtruth_attributes'
,
{})
is_crowd
=
data
[
'groundtruth_is_crowd'
]
# Gets original image and its size.
image
=
data
[
'image'
]
# Normalizes image with mean and std pixel values.
image
=
preprocess_ops
.
normalize_image
(
image
)
scales
=
tf
.
constant
([
self
.
_resize_scales
[
-
1
]],
tf
.
float32
)
image_shape
=
tf
.
shape
(
image
)[:
2
]
boxes
=
box_ops
.
denormalize_boxes
(
boxes
,
image_shape
)
gt_boxes
=
boxes
short_side
=
scales
[
0
]
image
,
image_info
=
preprocess_ops
.
resize_image
(
image
,
short_side
,
max
(
self
.
_output_size
))
boxes
=
preprocess_ops
.
resize_and_crop_boxes
(
boxes
,
image_info
[
2
,
:],
image_info
[
1
,
:],
image_info
[
3
,
:])
boxes
=
box_ops
.
normalize_boxes
(
boxes
,
image_info
[
1
,
:])
# Filters out ground truth boxes that are all zeros.
indices
=
box_ops
.
get_non_empty_box_indices
(
boxes
)
boxes
=
tf
.
gather
(
boxes
,
indices
)
classes
=
tf
.
gather
(
classes
,
indices
)
is_crowd
=
tf
.
gather
(
is_crowd
,
indices
)
boxes
=
box_ops
.
yxyx_to_cycxhw
(
boxes
)
image
=
tf
.
image
.
pad_to_bounding_box
(
image
,
0
,
0
,
self
.
_output_size
[
0
],
self
.
_output_size
[
1
])
labels
=
{
'classes'
:
preprocess_ops
.
clip_or_pad_to_fixed_size
(
classes
,
self
.
_max_num_boxes
),
'boxes'
:
preprocess_ops
.
clip_or_pad_to_fixed_size
(
boxes
,
self
.
_max_num_boxes
)
}
labels
.
update
({
'id'
:
int
(
data
[
'source_id'
]),
'image_info'
:
image_info
,
'is_crowd'
:
preprocess_ops
.
clip_or_pad_to_fixed_size
(
is_crowd
,
self
.
_max_num_boxes
),
'gt_boxes'
:
preprocess_ops
.
clip_or_pad_to_fixed_size
(
gt_boxes
,
self
.
_max_num_boxes
),
})
return
image
,
labels
\ No newline at end of file
official/projects/detr/do_train.sh
0 → 100644
View file @
94220a58
#!/bin/bash
python3 train.py
\
--experiment
=
detr_coco
\
--mode
=
train_and_eval
\
--model_dir
=
gs://ghpark-ckpts/detr/detr_coco/ckpt_03_test
\
--tpu
=
postech-tpu
\
--params_override
=
runtime.distribution_strategy
=
'tpu'
\ No newline at end of file
official/projects/detr/ops/matchers.py
View file @
94220a58
...
...
@@ -13,17 +13,14 @@
# limitations under the License.
"""Tensorflow implementation to solve the Linear Sum Assignment problem.
The Linear Sum Assignment problem involves determining the minimum weight
matching for bipartite graphs. For example, this problem can be defined by
a 2D matrix C, where each element i,j determines the cost of matching worker i
with job j. The solution to the problem is a complete assignment of jobs to
workers, such that no job is assigned to more than one work and no worker is
assigned more than one job, with minimum cost.
This implementation builds off of the Hungarian
Matching Algorithm (https://www.cse.ust.hk/~golin/COMP572/Notes/Matching.pdf).
Based on the original implementation by Jiquan Ngiam <jngiam@google.com>.
"""
import
tensorflow
as
tf
...
...
@@ -32,17 +29,14 @@ from official.modeling import tf_utils
def
_prepare
(
weights
):
"""Prepare the cost matrix.
To speed up computational efficiency of the algorithm, all weights are shifted
to be non-negative. Each element is reduced by the row / column minimum. Note
that neither operation will effect the resulting solution but will provide
a better starting point for the greedy assignment. Note this corresponds to
the pre-processing and step 1 of the Hungarian algorithm from Wikipedia.
Args:
weights: A float32 [batch_size, num_elems, num_elems] tensor, where each
inner matrix represents weights to be use for matching.
Returns:
A prepared weights tensor of the same shape and dtype.
"""
...
...
@@ -55,18 +49,15 @@ def _prepare(weights):
def
_greedy_assignment
(
adj_matrix
):
"""Greedily assigns workers to jobs based on an adjaceny matrix.
Starting with an adjacency matrix representing the available connections
in the bi-partite graph, this function greedily chooses elements such
that each worker is matched to at most one job (or each job is assigned to
at most one worker). Note, if the adjacency matrix has no available values
for a particular row/column, the corresponding job/worker may go unassigned.
Args:
adj_matrix: A bool [batch_size, num_elems, num_elems] tensor, where each
element of the inner matrix represents whether the worker (row) can be
matched to the job (column).
Returns:
A bool [batch_size, num_elems, num_elems] tensor, where each element of the
inner matrix represents whether the worker has been matched to the job.
...
...
@@ -119,15 +110,12 @@ def _greedy_assignment(adj_matrix):
def
_find_augmenting_path
(
assignment
,
adj_matrix
):
"""Finds an augmenting path given an assignment and an adjacency matrix.
The augmenting path search starts from the unassigned workers, then goes on
to find jobs (via an unassigned pairing), then back again to workers (via an
existing pairing), and so on. The path alternates between unassigned and
existing pairings. Returns the state after the search.
Note: In the state the worker and job, indices are 1-indexed so that we can
use 0 to represent unreachable nodes. State contains the following keys:
- jobs: A [batch_size, 1, num_elems] tensor containing the highest index
unassigned worker that can reach this job through a path.
- jobs_from_worker: A [batch_size, num_elems] tensor containing the worker
...
...
@@ -138,9 +126,7 @@ def _find_augmenting_path(assignment, adj_matrix):
reached immediately before this worker.
- new_jobs: A bool [batch_size, num_elems] tensor containing True if the
unassigned job can be reached via a path.
State can be used to recover the path via backtracking.
Args:
assignment: A bool [batch_size, num_elems, num_elems] tensor, where each
element of the inner matrix represents whether the worker has been matched
...
...
@@ -148,7 +134,6 @@ def _find_augmenting_path(assignment, adj_matrix):
adj_matrix: A bool [batch_size, num_elems, num_elems] tensor, where each
element of the inner matrix represents whether the worker (row) can be
matched to the job (column).
Returns:
A state dict, which represents the outcome of running an augmenting
path search on the graph given the assignment.
...
...
@@ -235,14 +220,12 @@ def _find_augmenting_path(assignment, adj_matrix):
def
_improve_assignment
(
assignment
,
state
):
"""Improves an assignment by backtracking the augmented path using state.
Args:
assignment: A bool [batch_size, num_elems, num_elems] tensor, where each
element of the inner matrix represents whether the worker has been matched
to the job. This may be a partial assignment.
state: A dict, which represents the outcome of running an augmenting path
search on the graph given the assignment.
Returns:
A new assignment matrix of the same shape and type as assignment, where the
assignment has been updated using the augmented path found.
...
...
@@ -317,7 +300,6 @@ def _improve_assignment(assignment, state):
def
_maximum_bipartite_matching
(
adj_matrix
,
assignment
=
None
):
"""Performs maximum bipartite matching using augmented paths.
Args:
adj_matrix: A bool [batch_size, num_elems, num_elems] tensor, where each
element of the inner matrix represents whether the worker (row) can be
...
...
@@ -326,7 +308,6 @@ def _maximum_bipartite_matching(adj_matrix, assignment=None):
where each element of the inner matrix represents whether the worker has
been matched to the job. This may be a partial assignment. If specified,
this assignment will be used to seed the iterative algorithm.
Returns:
A state dict representing the final augmenting path state search, and
a maximum bipartite matching assignment tensor. Note that the state outcome
...
...
@@ -357,11 +338,9 @@ def _maximum_bipartite_matching(adj_matrix, assignment=None):
def
_compute_cover
(
state
,
assignment
):
"""Computes a cover for the bipartite graph.
We compute a cover using the construction provided at
https://en.wikipedia.org/wiki/K%C5%91nig%27s_theorem_(graph_theory)#Proof
which uses the outcome from the alternating path search.
Args:
state: A state dict, which represents the outcome of running an augmenting
path search on the graph given the assignment.
...
...
@@ -369,7 +348,6 @@ def _compute_cover(state, assignment):
where each element of the inner matrix represents whether the worker has
been matched to the job. This may be a partial assignment. If specified,
this assignment will be used to seed the iterative algorithm.
Returns:
A tuple of (workers_cover, jobs_cover) corresponding to row and column
covers for the bipartite graph. workers_cover is a boolean tensor of shape
...
...
@@ -390,16 +368,13 @@ def _compute_cover(state, assignment):
def
_update_weights_using_cover
(
workers_cover
,
jobs_cover
,
weights
):
"""Updates weights for hungarian matching using a cover.
We first find the minimum uncovered weight. Then, we subtract this from all
the uncovered weights, and add it to all the doubly covered weights.
Args:
workers_cover: A boolean tensor of shape [batch_size, num_elems, 1].
jobs_cover: A boolean tensor of shape [batch_size, 1, num_elems].
weights: A float32 [batch_size, num_elems, num_elems] tensor, where each
inner matrix represents weights to be use for matching.
Returns:
A new weight matrix with elements adjusted by the cover.
"""
...
...
@@ -423,12 +398,10 @@ def _update_weights_using_cover(workers_cover, jobs_cover, weights):
def
assert_rank
(
tensor
,
expected_rank
,
name
=
None
):
"""Raises an exception if the tensor rank is not of the expected rank.
Args:
tensor: A tf.Tensor to check the rank of.
expected_rank: Python integer or list of integers, expected rank.
name: Optional name of the tensor for the error message.
Raises:
ValueError: If the expected shape doesn't match the actual shape.
"""
...
...
@@ -449,11 +422,9 @@ def assert_rank(tensor, expected_rank, name=None):
def
hungarian_matching
(
weights
):
"""Computes the minimum linear sum assignment using the Hungarian algorithm.
Args:
weights: A float32 [batch_size, num_elems, num_elems] tensor, where each
inner matrix represents weights to be use for matching.
Returns:
A bool [batch_size, num_elems, num_elems] tensor, where each element of the
inner matrix represents whether the worker has been matched to the job.
...
...
@@ -485,5 +456,4 @@ def hungarian_matching(weights):
_update_weights_and_match
,
(
workers_cover
,
jobs_cover
,
weights
,
assignment
),
back_prop
=
False
)
return
weights
,
assignment
return
weights
,
assignment
\ No newline at end of file
official/projects/detr/tasks/detection.py
View file @
94220a58
...
...
@@ -13,18 +13,24 @@
# limitations under the License.
"""DETR detection task definition."""
from
typing
import
Any
,
List
,
Mapping
,
Optional
,
Tuple
from
absl
import
logging
import
tensorflow
as
tf
from
official.common
import
dataset_fn
from
official.core
import
base_task
from
official.core
import
task_factory
from
official.projects.detr.configs
import
detr
as
detr_cfg
from
official.projects.detr.dataloaders
import
coco
from
official.projects.detr.modeling
import
detr
from
official.projects.detr.ops
import
matchers
from
official.vision.evaluation
import
coco_evaluator
from
official.vision.ops
import
box_ops
from
official.vision.dataloaders
import
input_reader_factory
from
official.vision.dataloaders
import
tf_example_decoder
from
official.vision.dataloaders
import
tfds_factory
from
official.vision.dataloaders
import
tf_example_label_map_decoder
from
official.projects.detr.dataloaders
import
detr_input
@
task_factory
.
register_task_cls
(
detr_cfg
.
DetectionConfig
)
class
DectectionTask
(
base_task
.
Task
):
...
...
@@ -47,13 +53,62 @@ class DectectionTask(base_task.Task):
def
initialize
(
self
,
model
:
tf
.
keras
.
Model
):
"""Loading pretrained checkpoint."""
ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
model
.
backbone
)
status
=
ckpt
.
read
(
self
.
_task_config
.
init_ckpt
)
status
.
expect_partial
().
assert_existing_objects_matched
()
def
build_inputs
(
self
,
params
,
input_context
=
None
):
if
not
self
.
_task_config
.
init_checkpoint
:
return
ckpt_dir_or_file
=
self
.
_task_config
.
init_checkpoint
# Restoring checkpoint.
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
if
self
.
_task_config
.
init_checkpoint_modules
==
'all'
:
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
.
assert_consumed
()
elif
self
.
_task_config
.
init_checkpoint_modules
==
'backbone'
:
ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
model
.
backbone
)
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
"""def build_inputs(self,
params: detr_cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None):
return coco.COCODataLoader(params).load(input_context)"""
def
build_inputs
(
self
,
params
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Build input dataset."""
return
coco
.
COCODataLoader
(
params
).
load
(
input_context
)
if
params
.
tfds_name
:
decoder
=
tfds_factory
.
get_detection_decoder
(
params
.
tfds_name
)
else
:
decoder_cfg
=
params
.
decoder
.
get
()
if
params
.
decoder
.
type
==
'simple_decoder'
:
decoder
=
tf_example_decoder
.
TfExampleDecoder
(
regenerate_source_id
=
decoder_cfg
.
regenerate_source_id
)
elif
params
.
decoder
.
type
==
'label_map_decoder'
:
decoder
=
tf_example_label_map_decoder
.
TfExampleDecoderLabelMap
(
label_map
=
decoder_cfg
.
label_map
,
regenerate_source_id
=
decoder_cfg
.
regenerate_source_id
)
else
:
raise
ValueError
(
'Unknown decoder type: {}!'
.
format
(
params
.
decoder
.
type
))
parser
=
detr_input
.
Parser
()
reader
=
input_reader_factory
.
input_reader_generator
(
params
,
dataset_fn
=
dataset_fn
.
pick_dataset_fn
(
params
.
file_type
),
decoder_fn
=
decoder
.
decode
,
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
))
dataset
=
reader
.
read
(
input_context
=
input_context
)
return
dataset
def
_compute_cost
(
self
,
cls_outputs
,
box_outputs
,
cls_targets
,
box_targets
):
# Approximate classification cost with 1 - prob[target class].
...
...
@@ -160,6 +215,7 @@ class DectectionTask(base_task.Task):
tf
.
reduce_sum
(
giou_loss
),
num_boxes_sum
)
aux_losses
=
tf
.
add_n
(
aux_losses
)
if
aux_losses
else
0.0
total_loss
=
cls_loss
+
box_loss
+
giou_loss
+
aux_losses
return
total_loss
,
cls_loss
,
box_loss
,
giou_loss
...
...
@@ -172,7 +228,7 @@ class DectectionTask(base_task.Task):
if
not
training
:
self
.
coco_metric
=
coco_evaluator
.
COCOEvaluator
(
annotation_file
=
''
,
annotation_file
=
self
.
_task_config
.
annotation_file
,
include_mask
=
False
,
need_rescale_bboxes
=
True
,
per_category_metrics
=
self
.
_task_config
.
per_category_metrics
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment