Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
eedfa888
Unverified
Commit
eedfa888
authored
Jul 05, 2022
by
Frederick Liu
Committed by
GitHub
Jul 05, 2022
Browse files
Revert "DETR implementation update (#10689)" (#10691)
This reverts commit
5633969b
.
parent
5633969b
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
72 additions
and
471 deletions
+72
-471
official/projects/detr/configs/detr.py
official/projects/detr/configs/detr.py
+12
-186
official/projects/detr/dataloaders/detr_input.py
official/projects/detr/dataloaders/detr_input.py
+0
-186
official/projects/detr/experiments/detr_r50_300epochs.sh
official/projects/detr/experiments/detr_r50_300epochs.sh
+1
-1
official/projects/detr/experiments/detr_r50_500epochs.sh
official/projects/detr/experiments/detr_r50_500epochs.sh
+1
-1
official/projects/detr/modeling/detr.py
official/projects/detr/modeling/detr.py
+2
-4
official/projects/detr/ops/matchers.py
official/projects/detr/ops/matchers.py
+31
-1
official/projects/detr/tasks/detection.py
official/projects/detr/tasks/detection.py
+25
-92
No files found.
official/projects/detr/configs/detr.py
View file @
eedfa888
...
@@ -15,63 +15,33 @@
...
@@ -15,63 +15,33 @@
"""DETR configurations."""
"""DETR configurations."""
import
dataclasses
import
dataclasses
import
os
from
typing
import
List
,
Optional
,
Union
from
official.core
import
config_definitions
as
cfg
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.core
import
exp_factory
from
official.modeling
import
hyperparams
from
official.vision.configs
import
common
from
official.vision.configs
import
backbones
from
official.projects.detr
import
optimization
from
official.projects.detr
import
optimization
from
official.projects.detr.dataloaders
import
coco
from
official.projects.detr.dataloaders
import
coco
@
dataclasses
.
dataclass
class
DataConfig
(
cfg
.
DataConfig
):
"""Input config for training."""
input_path
:
str
=
''
global_batch_size
:
int
=
0
is_training
:
bool
=
False
dtype
:
str
=
'bfloat16'
decoder
:
common
.
DataDecoder
=
common
.
DataDecoder
()
shuffle_buffer_size
:
int
=
10000
file_type
:
str
=
'tfrecord'
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
Losses
(
hyperparams
.
Config
):
class
DetectionConfig
(
cfg
.
TaskConfig
):
class_offset
:
int
=
0
"""The translation task config."""
train_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
validation_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
lambda_cls
:
float
=
1.0
lambda_cls
:
float
=
1.0
lambda_box
:
float
=
5.0
lambda_box
:
float
=
5.0
lambda_giou
:
float
=
2.0
lambda_giou
:
float
=
2.0
background_cls_weight
:
float
=
0.1
l2_weight_decay
:
float
=
1e-4
@
dataclasses
.
dataclass
init_ckpt
:
str
=
''
class
Detr
(
hyperparams
.
Config
):
num_classes
:
int
=
81
# 0: background
num_queries
:
int
=
100
background_cls_weight
:
float
=
0.1
hidden_size
:
int
=
256
num_classes
:
int
=
91
# 0: background
num_encoder_layers
:
int
=
6
num_encoder_layers
:
int
=
6
num_decoder_layers
:
int
=
6
num_decoder_layers
:
int
=
6
input_size
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
backbone
:
backbones
.
Backbone
=
backbones
.
Backbone
(
type
=
'resnet'
,
resnet
=
backbones
.
ResNet
(
model_id
=
50
,
bn_trainable
=
False
))
norm_activation
:
common
.
NormActivation
=
common
.
NormActivation
()
@
dataclasses
.
dataclass
# Make DETRConfig.
class
DetrTask
(
cfg
.
TaskConfig
):
num_queries
:
int
=
100
model
:
Detr
=
Detr
()
num_hidden
:
int
=
256
train_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
validation_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
losses
:
Losses
=
Losses
()
init_checkpoint
:
Optional
[
str
]
=
None
init_checkpoint_modules
:
Union
[
str
,
List
[
str
]]
=
'all'
# all, backbone
annotation_file
:
Optional
[
str
]
=
None
per_category_metrics
:
bool
=
False
per_category_metrics
:
bool
=
False
@
exp_factory
.
register_config_factory
(
'detr_coco'
)
@
exp_factory
.
register_config_factory
(
'detr_coco'
)
def
detr_coco
()
->
cfg
.
ExperimentConfig
:
def
detr_coco
()
->
cfg
.
ExperimentConfig
:
"""Config to get results that matches the paper."""
"""Config to get results that matches the paper."""
...
@@ -82,14 +52,7 @@ def detr_coco() -> cfg.ExperimentConfig:
...
@@ -82,14 +52,7 @@ def detr_coco() -> cfg.ExperimentConfig:
train_steps
=
500
*
num_steps_per_epoch
# 500 epochs
train_steps
=
500
*
num_steps_per_epoch
# 500 epochs
decay_at
=
train_steps
-
100
*
num_steps_per_epoch
# 400 epochs
decay_at
=
train_steps
-
100
*
num_steps_per_epoch
# 400 epochs
config
=
cfg
.
ExperimentConfig
(
config
=
cfg
.
ExperimentConfig
(
task
=
DetrTask
(
task
=
DetectionConfig
(
init_checkpoint
=
'gs://tf_model_garden/vision/resnet50_imagenet/ckpt-62400'
,
init_checkpoint_modules
=
'backbone'
,
model
=
Detr
(
num_classes
=
81
,
input_size
=
[
1333
,
1333
,
3
],
norm_activation
=
common
.
NormActivation
()),
losses
=
Losses
(),
train_data
=
coco
.
COCODataConfig
(
train_data
=
coco
.
COCODataConfig
(
tfds_name
=
'coco/2017'
,
tfds_name
=
'coco/2017'
,
tfds_split
=
'train'
,
tfds_split
=
'train'
,
...
@@ -138,140 +101,3 @@ def detr_coco() -> cfg.ExperimentConfig:
...
@@ -138,140 +101,3 @@ def detr_coco() -> cfg.ExperimentConfig:
'task.train_data.is_training != None'
,
'task.train_data.is_training != None'
,
])
])
return
config
return
config
COCO_INPUT_PATH_BASE
=
''
COCO_TRAIN_EXAMPLES
=
118287
COCO_VAL_EXAMPLES
=
5000
@
exp_factory
.
register_config_factory
(
'detr_coco_tfrecord'
)
def
detr_coco
()
->
cfg
.
ExperimentConfig
:
"""Config to get results that matches the paper."""
train_batch_size
=
64
eval_batch_size
=
64
steps_per_epoch
=
COCO_TRAIN_EXAMPLES
//
train_batch_size
train_steps
=
300
*
steps_per_epoch
# 300 epochs
decay_at
=
train_steps
-
100
*
steps_per_epoch
# 200 epochs
config
=
cfg
.
ExperimentConfig
(
task
=
DetrTask
(
init_checkpoint
=
'gs://tf_model_garden/vision/resnet50_imagenet/ckpt-62400'
,
init_checkpoint_modules
=
'backbone'
,
annotation_file
=
os
.
path
.
join
(
COCO_INPUT_PATH_BASE
,
'instances_val2017.json'
),
model
=
Detr
(
input_size
=
[
1333
,
1333
,
3
],
norm_activation
=
common
.
NormActivation
()),
losses
=
Losses
(),
train_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
COCO_INPUT_PATH_BASE
,
'train*'
),
is_training
=
True
,
global_batch_size
=
train_batch_size
,
shuffle_buffer_size
=
1000
,
),
validation_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
COCO_INPUT_PATH_BASE
,
'val*'
),
is_training
=
False
,
global_batch_size
=
eval_batch_size
,
drop_remainder
=
False
,
)
),
trainer
=
cfg
.
TrainerConfig
(
train_steps
=
train_steps
,
validation_steps
=
COCO_VAL_EXAMPLES
//
eval_batch_size
,
steps_per_loop
=
steps_per_epoch
,
summary_interval
=
steps_per_epoch
,
checkpoint_interval
=
steps_per_epoch
,
validation_interval
=
5
*
steps_per_epoch
,
max_to_keep
=
1
,
best_checkpoint_export_subdir
=
'best_ckpt'
,
best_checkpoint_eval_metric
=
'AP'
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'detr_adamw'
,
'detr_adamw'
:
{
'weight_decay_rate'
:
1e-4
,
'global_clipnorm'
:
0.1
,
# Avoid AdamW legacy behavior.
'gradient_clip_norm'
:
0.0
}
},
'learning_rate'
:
{
'type'
:
'stepwise'
,
'stepwise'
:
{
'boundaries'
:
[
decay_at
],
'values'
:
[
0.0001
,
1.0e-05
]
}
},
})
),
restrictions
=
[
'task.train_data.is_training != None'
,
])
return
config
@
exp_factory
.
register_config_factory
(
'detr_coco_tfds'
)
def
detr_coco
()
->
cfg
.
ExperimentConfig
:
"""Config to get results that matches the paper."""
train_batch_size
=
64
eval_batch_size
=
64
steps_per_epoch
=
COCO_TRAIN_EXAMPLES
//
train_batch_size
train_steps
=
300
*
steps_per_epoch
# 300 epochs
decay_at
=
train_steps
-
100
*
steps_per_epoch
# 200 epochs
config
=
cfg
.
ExperimentConfig
(
task
=
DetrTask
(
init_checkpoint
=
'gs://tf_model_garden/vision/resnet50_imagenet/ckpt-62400'
,
init_checkpoint_modules
=
'backbone'
,
model
=
Detr
(
num_classes
=
81
,
input_size
=
[
1333
,
1333
,
3
],
norm_activation
=
common
.
NormActivation
()),
losses
=
Losses
(
class_offset
=
1
),
train_data
=
DataConfig
(
tfds_name
=
'coco/2017'
,
tfds_split
=
'train'
,
is_training
=
True
,
global_batch_size
=
train_batch_size
,
shuffle_buffer_size
=
1000
,
),
validation_data
=
DataConfig
(
tfds_name
=
'coco/2017'
,
tfds_split
=
'validation'
,
is_training
=
False
,
global_batch_size
=
eval_batch_size
,
drop_remainder
=
False
)
),
trainer
=
cfg
.
TrainerConfig
(
train_steps
=
train_steps
,
validation_steps
=
COCO_VAL_EXAMPLES
//
eval_batch_size
,
steps_per_loop
=
steps_per_epoch
,
summary_interval
=
steps_per_epoch
,
checkpoint_interval
=
steps_per_epoch
,
validation_interval
=
5
*
steps_per_epoch
,
max_to_keep
=
1
,
best_checkpoint_export_subdir
=
'best_ckpt'
,
best_checkpoint_eval_metric
=
'AP'
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'detr_adamw'
,
'detr_adamw'
:
{
'weight_decay_rate'
:
1e-4
,
'global_clipnorm'
:
0.1
,
# Avoid AdamW legacy behavior.
'gradient_clip_norm'
:
0.0
}
},
'learning_rate'
:
{
'type'
:
'stepwise'
,
'stepwise'
:
{
'boundaries'
:
[
decay_at
],
'values'
:
[
0.0001
,
1.0e-05
]
}
},
})
),
restrictions
=
[
'task.train_data.is_training != None'
,
])
return
config
\ No newline at end of file
official/projects/detr/dataloaders/detr_input.py
deleted
100644 → 0
View file @
5633969b
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""COCO data loader for DETR."""
from
typing
import
Optional
,
Tuple
import
tensorflow
as
tf
from
official.vision.dataloaders
import
parser
from
official.vision.dataloaders
import
utils
from
official.vision.ops
import
box_ops
from
official.vision.ops
import
preprocess_ops
from
official.core
import
input_reader
RESIZE_SCALES
=
(
480
,
512
,
544
,
576
,
608
,
640
,
672
,
704
,
736
,
768
,
800
)
class
Parser
(
parser
.
Parser
):
"""Parse an image and its annotations into a dictionary of tensors."""
def
__init__
(
self
,
class_offset
:
int
=
0
,
output_size
:
Tuple
[
int
,
int
]
=
(
1333
,
1333
),
max_num_boxes
:
int
=
100
,
resize_scales
:
Tuple
[
int
,
...]
=
RESIZE_SCALES
,
aug_rand_hflip
=
True
):
self
.
_class_offset
=
class_offset
self
.
_output_size
=
output_size
self
.
_max_num_boxes
=
max_num_boxes
self
.
_resize_scales
=
resize_scales
self
.
_aug_rand_hflip
=
aug_rand_hflip
def
_parse_train_data
(
self
,
data
):
"""Parses data for training and evaluation."""
classes
=
data
[
'groundtruth_classes'
]
+
self
.
_class_offset
boxes
=
data
[
'groundtruth_boxes'
]
is_crowd
=
data
[
'groundtruth_is_crowd'
]
# Gets original image.
image
=
data
[
'image'
]
# Normalizes image with mean and std pixel values.
image
=
preprocess_ops
.
normalize_image
(
image
)
image
,
boxes
,
_
=
preprocess_ops
.
random_horizontal_flip
(
image
,
boxes
)
do_crop
=
tf
.
greater
(
tf
.
random
.
uniform
([]),
0.5
)
if
do_crop
:
# Rescale
boxes
=
box_ops
.
denormalize_boxes
(
boxes
,
tf
.
shape
(
image
)[:
2
])
index
=
tf
.
random
.
categorical
(
tf
.
zeros
([
1
,
3
]),
1
)[
0
]
scales
=
tf
.
gather
([
400.0
,
500.0
,
600.0
],
index
,
axis
=
0
)
short_side
=
scales
[
0
]
image
,
image_info
=
preprocess_ops
.
resize_image
(
image
,
short_side
)
boxes
=
preprocess_ops
.
resize_and_crop_boxes
(
boxes
,
image_info
[
2
,
:],
image_info
[
1
,
:],
image_info
[
3
,
:])
boxes
=
box_ops
.
normalize_boxes
(
boxes
,
image_info
[
1
,
:])
# Do croping
shape
=
tf
.
cast
(
image_info
[
1
],
dtype
=
tf
.
int32
)
h
=
tf
.
random
.
uniform
(
[],
384
,
tf
.
math
.
minimum
(
shape
[
0
],
600
),
dtype
=
tf
.
int32
)
w
=
tf
.
random
.
uniform
(
[],
384
,
tf
.
math
.
minimum
(
shape
[
1
],
600
),
dtype
=
tf
.
int32
)
i
=
tf
.
random
.
uniform
([],
0
,
shape
[
0
]
-
h
+
1
,
dtype
=
tf
.
int32
)
j
=
tf
.
random
.
uniform
([],
0
,
shape
[
1
]
-
w
+
1
,
dtype
=
tf
.
int32
)
image
=
tf
.
image
.
crop_to_bounding_box
(
image
,
i
,
j
,
h
,
w
)
boxes
=
tf
.
clip_by_value
(
(
boxes
[...,
:]
*
tf
.
cast
(
tf
.
stack
([
shape
[
0
],
shape
[
1
],
shape
[
0
],
shape
[
1
]]),
dtype
=
tf
.
float32
)
-
tf
.
cast
(
tf
.
stack
([
i
,
j
,
i
,
j
]),
dtype
=
tf
.
float32
))
/
tf
.
cast
(
tf
.
stack
([
h
,
w
,
h
,
w
]),
dtype
=
tf
.
float32
),
0.0
,
1.0
)
scales
=
tf
.
constant
(
self
.
_resize_scales
,
dtype
=
tf
.
float32
)
index
=
tf
.
random
.
categorical
(
tf
.
zeros
([
1
,
11
]),
1
)[
0
]
scales
=
tf
.
gather
(
scales
,
index
,
axis
=
0
)
image_shape
=
tf
.
shape
(
image
)[:
2
]
boxes
=
box_ops
.
denormalize_boxes
(
boxes
,
image_shape
)
gt_boxes
=
boxes
short_side
=
scales
[
0
]
image
,
image_info
=
preprocess_ops
.
resize_image
(
image
,
short_side
,
max
(
self
.
_output_size
))
boxes
=
preprocess_ops
.
resize_and_crop_boxes
(
boxes
,
image_info
[
2
,
:],
image_info
[
1
,
:],
image_info
[
3
,
:])
boxes
=
box_ops
.
normalize_boxes
(
boxes
,
image_info
[
1
,
:])
# Filters out ground truth boxes that are all zeros.
indices
=
box_ops
.
get_non_empty_box_indices
(
boxes
)
boxes
=
tf
.
gather
(
boxes
,
indices
)
classes
=
tf
.
gather
(
classes
,
indices
)
is_crowd
=
tf
.
gather
(
is_crowd
,
indices
)
boxes
=
box_ops
.
yxyx_to_cycxhw
(
boxes
)
image
=
tf
.
image
.
pad_to_bounding_box
(
image
,
0
,
0
,
self
.
_output_size
[
0
],
self
.
_output_size
[
1
])
labels
=
{
'classes'
:
preprocess_ops
.
clip_or_pad_to_fixed_size
(
classes
,
self
.
_max_num_boxes
),
'boxes'
:
preprocess_ops
.
clip_or_pad_to_fixed_size
(
boxes
,
self
.
_max_num_boxes
)
}
return
image
,
labels
def
_parse_eval_data
(
self
,
data
):
"""Parses data for training and evaluation."""
groundtruths
=
{}
classes
=
data
[
'groundtruth_classes'
]
boxes
=
data
[
'groundtruth_boxes'
]
is_crowd
=
data
[
'groundtruth_is_crowd'
]
# Gets original image and its size.
image
=
data
[
'image'
]
# Normalizes image with mean and std pixel values.
image
=
preprocess_ops
.
normalize_image
(
image
)
scales
=
tf
.
constant
([
self
.
_resize_scales
[
-
1
]],
tf
.
float32
)
image_shape
=
tf
.
shape
(
image
)[:
2
]
boxes
=
box_ops
.
denormalize_boxes
(
boxes
,
image_shape
)
gt_boxes
=
boxes
short_side
=
scales
[
0
]
image
,
image_info
=
preprocess_ops
.
resize_image
(
image
,
short_side
,
max
(
self
.
_output_size
))
boxes
=
preprocess_ops
.
resize_and_crop_boxes
(
boxes
,
image_info
[
2
,
:],
image_info
[
1
,
:],
image_info
[
3
,
:])
boxes
=
box_ops
.
normalize_boxes
(
boxes
,
image_info
[
1
,
:])
# Filters out ground truth boxes that are all zeros.
indices
=
box_ops
.
get_non_empty_box_indices
(
boxes
)
boxes
=
tf
.
gather
(
boxes
,
indices
)
classes
=
tf
.
gather
(
classes
,
indices
)
is_crowd
=
tf
.
gather
(
is_crowd
,
indices
)
boxes
=
box_ops
.
yxyx_to_cycxhw
(
boxes
)
image
=
tf
.
image
.
pad_to_bounding_box
(
image
,
0
,
0
,
self
.
_output_size
[
0
],
self
.
_output_size
[
1
])
labels
=
{
'classes'
:
preprocess_ops
.
clip_or_pad_to_fixed_size
(
classes
,
self
.
_max_num_boxes
),
'boxes'
:
preprocess_ops
.
clip_or_pad_to_fixed_size
(
boxes
,
self
.
_max_num_boxes
)
}
labels
.
update
({
'id'
:
int
(
data
[
'source_id'
]),
'image_info'
:
image_info
,
'is_crowd'
:
preprocess_ops
.
clip_or_pad_to_fixed_size
(
is_crowd
,
self
.
_max_num_boxes
),
'gt_boxes'
:
preprocess_ops
.
clip_or_pad_to_fixed_size
(
gt_boxes
,
self
.
_max_num_boxes
),
})
return
image
,
labels
\ No newline at end of file
official/projects/detr/experiments/detr_r50_300epochs.sh
View file @
eedfa888
...
@@ -3,4 +3,4 @@ python3 official/projects/detr/train.py \
...
@@ -3,4 +3,4 @@ python3 official/projects/detr/train.py \
--experiment
=
detr_coco
\
--experiment
=
detr_coco
\
--mode
=
train_and_eval
\
--mode
=
train_and_eval
\
--model_dir
=
/tmp/logging_dir/
\
--model_dir
=
/tmp/logging_dir/
\
--params_override
=
task.init_c
heckpoin
t
=
'gs://tf_model_garden/vision/resnet50_imagenet/ckpt-62400'
,trainer.train_steps
=
554400
--params_override
=
task.init_c
kp
t
=
'gs://tf_model_garden/vision/resnet50_imagenet/ckpt-62400'
,trainer.train_steps
=
554400
official/projects/detr/experiments/detr_r50_500epochs.sh
View file @
eedfa888
...
@@ -3,4 +3,4 @@ python3 official/projects/detr/train.py \
...
@@ -3,4 +3,4 @@ python3 official/projects/detr/train.py \
--experiment
=
detr_coco
\
--experiment
=
detr_coco
\
--mode
=
train_and_eval
\
--mode
=
train_and_eval
\
--model_dir
=
/tmp/logging_dir/
\
--model_dir
=
/tmp/logging_dir/
\
--params_override
=
task.init_c
heckpoin
t
=
'gs://tf_model_garden/vision/resnet50_imagenet/ckpt-62400'
--params_override
=
task.init_c
kp
t
=
'gs://tf_model_garden/vision/resnet50_imagenet/ckpt-62400'
official/projects/detr/modeling/detr.py
View file @
eedfa888
...
@@ -100,7 +100,7 @@ class DETR(tf.keras.Model):
...
@@ -100,7 +100,7 @@ class DETR(tf.keras.Model):
class and box heads.
class and box heads.
"""
"""
def
__init__
(
self
,
backbone
,
num_queries
,
hidden_size
,
num_classes
,
def
__init__
(
self
,
num_queries
,
hidden_size
,
num_classes
,
num_encoder_layers
=
6
,
num_encoder_layers
=
6
,
num_decoder_layers
=
6
,
num_decoder_layers
=
6
,
dropout_rate
=
0.1
,
dropout_rate
=
0.1
,
...
@@ -116,9 +116,7 @@ class DETR(tf.keras.Model):
...
@@ -116,9 +116,7 @@ class DETR(tf.keras.Model):
raise
ValueError
(
"hidden_size must be a multiple of 2."
)
raise
ValueError
(
"hidden_size must be a multiple of 2."
)
# TODO(frederickliu): Consider using the backbone factory.
# TODO(frederickliu): Consider using the backbone factory.
# TODO(frederickliu): Add to factory once we get skeleton code in.
# TODO(frederickliu): Add to factory once we get skeleton code in.
#self._backbone = resnet.ResNet(101, bn_trainable=False)
self
.
_backbone
=
resnet
.
ResNet
(
50
,
bn_trainable
=
False
)
# (gunho) use backbone factory
self
.
_backbone
=
backbone
def
build
(
self
,
input_shape
=
None
):
def
build
(
self
,
input_shape
=
None
):
self
.
_input_proj
=
tf
.
keras
.
layers
.
Conv2D
(
self
.
_input_proj
=
tf
.
keras
.
layers
.
Conv2D
(
...
...
official/projects/detr/ops/matchers.py
View file @
eedfa888
...
@@ -13,14 +13,17 @@
...
@@ -13,14 +13,17 @@
# limitations under the License.
# limitations under the License.
"""Tensorflow implementation to solve the Linear Sum Assignment problem.
"""Tensorflow implementation to solve the Linear Sum Assignment problem.
The Linear Sum Assignment problem involves determining the minimum weight
The Linear Sum Assignment problem involves determining the minimum weight
matching for bipartite graphs. For example, this problem can be defined by
matching for bipartite graphs. For example, this problem can be defined by
a 2D matrix C, where each element i,j determines the cost of matching worker i
a 2D matrix C, where each element i,j determines the cost of matching worker i
with job j. The solution to the problem is a complete assignment of jobs to
with job j. The solution to the problem is a complete assignment of jobs to
workers, such that no job is assigned to more than one work and no worker is
workers, such that no job is assigned to more than one work and no worker is
assigned more than one job, with minimum cost.
assigned more than one job, with minimum cost.
This implementation builds off of the Hungarian
This implementation builds off of the Hungarian
Matching Algorithm (https://www.cse.ust.hk/~golin/COMP572/Notes/Matching.pdf).
Matching Algorithm (https://www.cse.ust.hk/~golin/COMP572/Notes/Matching.pdf).
Based on the original implementation by Jiquan Ngiam <jngiam@google.com>.
Based on the original implementation by Jiquan Ngiam <jngiam@google.com>.
"""
"""
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -29,14 +32,17 @@ from official.modeling import tf_utils
...
@@ -29,14 +32,17 @@ from official.modeling import tf_utils
def
_prepare
(
weights
):
def
_prepare
(
weights
):
"""Prepare the cost matrix.
"""Prepare the cost matrix.
To speed up computational efficiency of the algorithm, all weights are shifted
To speed up computational efficiency of the algorithm, all weights are shifted
to be non-negative. Each element is reduced by the row / column minimum. Note
to be non-negative. Each element is reduced by the row / column minimum. Note
that neither operation will effect the resulting solution but will provide
that neither operation will effect the resulting solution but will provide
a better starting point for the greedy assignment. Note this corresponds to
a better starting point for the greedy assignment. Note this corresponds to
the pre-processing and step 1 of the Hungarian algorithm from Wikipedia.
the pre-processing and step 1 of the Hungarian algorithm from Wikipedia.
Args:
Args:
weights: A float32 [batch_size, num_elems, num_elems] tensor, where each
weights: A float32 [batch_size, num_elems, num_elems] tensor, where each
inner matrix represents weights to be use for matching.
inner matrix represents weights to be use for matching.
Returns:
Returns:
A prepared weights tensor of the same shape and dtype.
A prepared weights tensor of the same shape and dtype.
"""
"""
...
@@ -49,15 +55,18 @@ def _prepare(weights):
...
@@ -49,15 +55,18 @@ def _prepare(weights):
def
_greedy_assignment
(
adj_matrix
):
def
_greedy_assignment
(
adj_matrix
):
"""Greedily assigns workers to jobs based on an adjaceny matrix.
"""Greedily assigns workers to jobs based on an adjaceny matrix.
Starting with an adjacency matrix representing the available connections
Starting with an adjacency matrix representing the available connections
in the bi-partite graph, this function greedily chooses elements such
in the bi-partite graph, this function greedily chooses elements such
that each worker is matched to at most one job (or each job is assigned to
that each worker is matched to at most one job (or each job is assigned to
at most one worker). Note, if the adjacency matrix has no available values
at most one worker). Note, if the adjacency matrix has no available values
for a particular row/column, the corresponding job/worker may go unassigned.
for a particular row/column, the corresponding job/worker may go unassigned.
Args:
Args:
adj_matrix: A bool [batch_size, num_elems, num_elems] tensor, where each
adj_matrix: A bool [batch_size, num_elems, num_elems] tensor, where each
element of the inner matrix represents whether the worker (row) can be
element of the inner matrix represents whether the worker (row) can be
matched to the job (column).
matched to the job (column).
Returns:
Returns:
A bool [batch_size, num_elems, num_elems] tensor, where each element of the
A bool [batch_size, num_elems, num_elems] tensor, where each element of the
inner matrix represents whether the worker has been matched to the job.
inner matrix represents whether the worker has been matched to the job.
...
@@ -110,12 +119,15 @@ def _greedy_assignment(adj_matrix):
...
@@ -110,12 +119,15 @@ def _greedy_assignment(adj_matrix):
def
_find_augmenting_path
(
assignment
,
adj_matrix
):
def
_find_augmenting_path
(
assignment
,
adj_matrix
):
"""Finds an augmenting path given an assignment and an adjacency matrix.
"""Finds an augmenting path given an assignment and an adjacency matrix.
The augmenting path search starts from the unassigned workers, then goes on
The augmenting path search starts from the unassigned workers, then goes on
to find jobs (via an unassigned pairing), then back again to workers (via an
to find jobs (via an unassigned pairing), then back again to workers (via an
existing pairing), and so on. The path alternates between unassigned and
existing pairing), and so on. The path alternates between unassigned and
existing pairings. Returns the state after the search.
existing pairings. Returns the state after the search.
Note: In the state the worker and job, indices are 1-indexed so that we can
Note: In the state the worker and job, indices are 1-indexed so that we can
use 0 to represent unreachable nodes. State contains the following keys:
use 0 to represent unreachable nodes. State contains the following keys:
- jobs: A [batch_size, 1, num_elems] tensor containing the highest index
- jobs: A [batch_size, 1, num_elems] tensor containing the highest index
unassigned worker that can reach this job through a path.
unassigned worker that can reach this job through a path.
- jobs_from_worker: A [batch_size, num_elems] tensor containing the worker
- jobs_from_worker: A [batch_size, num_elems] tensor containing the worker
...
@@ -126,7 +138,9 @@ def _find_augmenting_path(assignment, adj_matrix):
...
@@ -126,7 +138,9 @@ def _find_augmenting_path(assignment, adj_matrix):
reached immediately before this worker.
reached immediately before this worker.
- new_jobs: A bool [batch_size, num_elems] tensor containing True if the
- new_jobs: A bool [batch_size, num_elems] tensor containing True if the
unassigned job can be reached via a path.
unassigned job can be reached via a path.
State can be used to recover the path via backtracking.
State can be used to recover the path via backtracking.
Args:
Args:
assignment: A bool [batch_size, num_elems, num_elems] tensor, where each
assignment: A bool [batch_size, num_elems, num_elems] tensor, where each
element of the inner matrix represents whether the worker has been matched
element of the inner matrix represents whether the worker has been matched
...
@@ -134,6 +148,7 @@ def _find_augmenting_path(assignment, adj_matrix):
...
@@ -134,6 +148,7 @@ def _find_augmenting_path(assignment, adj_matrix):
adj_matrix: A bool [batch_size, num_elems, num_elems] tensor, where each
adj_matrix: A bool [batch_size, num_elems, num_elems] tensor, where each
element of the inner matrix represents whether the worker (row) can be
element of the inner matrix represents whether the worker (row) can be
matched to the job (column).
matched to the job (column).
Returns:
Returns:
A state dict, which represents the outcome of running an augmenting
A state dict, which represents the outcome of running an augmenting
path search on the graph given the assignment.
path search on the graph given the assignment.
...
@@ -220,12 +235,14 @@ def _find_augmenting_path(assignment, adj_matrix):
...
@@ -220,12 +235,14 @@ def _find_augmenting_path(assignment, adj_matrix):
def
_improve_assignment
(
assignment
,
state
):
def
_improve_assignment
(
assignment
,
state
):
"""Improves an assignment by backtracking the augmented path using state.
"""Improves an assignment by backtracking the augmented path using state.
Args:
Args:
assignment: A bool [batch_size, num_elems, num_elems] tensor, where each
assignment: A bool [batch_size, num_elems, num_elems] tensor, where each
element of the inner matrix represents whether the worker has been matched
element of the inner matrix represents whether the worker has been matched
to the job. This may be a partial assignment.
to the job. This may be a partial assignment.
state: A dict, which represents the outcome of running an augmenting path
state: A dict, which represents the outcome of running an augmenting path
search on the graph given the assignment.
search on the graph given the assignment.
Returns:
Returns:
A new assignment matrix of the same shape and type as assignment, where the
A new assignment matrix of the same shape and type as assignment, where the
assignment has been updated using the augmented path found.
assignment has been updated using the augmented path found.
...
@@ -300,6 +317,7 @@ def _improve_assignment(assignment, state):
...
@@ -300,6 +317,7 @@ def _improve_assignment(assignment, state):
def
_maximum_bipartite_matching
(
adj_matrix
,
assignment
=
None
):
def
_maximum_bipartite_matching
(
adj_matrix
,
assignment
=
None
):
"""Performs maximum bipartite matching using augmented paths.
"""Performs maximum bipartite matching using augmented paths.
Args:
Args:
adj_matrix: A bool [batch_size, num_elems, num_elems] tensor, where each
adj_matrix: A bool [batch_size, num_elems, num_elems] tensor, where each
element of the inner matrix represents whether the worker (row) can be
element of the inner matrix represents whether the worker (row) can be
...
@@ -308,6 +326,7 @@ def _maximum_bipartite_matching(adj_matrix, assignment=None):
...
@@ -308,6 +326,7 @@ def _maximum_bipartite_matching(adj_matrix, assignment=None):
where each element of the inner matrix represents whether the worker has
where each element of the inner matrix represents whether the worker has
been matched to the job. This may be a partial assignment. If specified,
been matched to the job. This may be a partial assignment. If specified,
this assignment will be used to seed the iterative algorithm.
this assignment will be used to seed the iterative algorithm.
Returns:
Returns:
A state dict representing the final augmenting path state search, and
A state dict representing the final augmenting path state search, and
a maximum bipartite matching assignment tensor. Note that the state outcome
a maximum bipartite matching assignment tensor. Note that the state outcome
...
@@ -338,9 +357,11 @@ def _maximum_bipartite_matching(adj_matrix, assignment=None):
...
@@ -338,9 +357,11 @@ def _maximum_bipartite_matching(adj_matrix, assignment=None):
def
_compute_cover
(
state
,
assignment
):
def
_compute_cover
(
state
,
assignment
):
"""Computes a cover for the bipartite graph.
"""Computes a cover for the bipartite graph.
We compute a cover using the construction provided at
We compute a cover using the construction provided at
https://en.wikipedia.org/wiki/K%C5%91nig%27s_theorem_(graph_theory)#Proof
https://en.wikipedia.org/wiki/K%C5%91nig%27s_theorem_(graph_theory)#Proof
which uses the outcome from the alternating path search.
which uses the outcome from the alternating path search.
Args:
Args:
state: A state dict, which represents the outcome of running an augmenting
state: A state dict, which represents the outcome of running an augmenting
path search on the graph given the assignment.
path search on the graph given the assignment.
...
@@ -348,6 +369,7 @@ def _compute_cover(state, assignment):
...
@@ -348,6 +369,7 @@ def _compute_cover(state, assignment):
where each element of the inner matrix represents whether the worker has
where each element of the inner matrix represents whether the worker has
been matched to the job. This may be a partial assignment. If specified,
been matched to the job. This may be a partial assignment. If specified,
this assignment will be used to seed the iterative algorithm.
this assignment will be used to seed the iterative algorithm.
Returns:
Returns:
A tuple of (workers_cover, jobs_cover) corresponding to row and column
A tuple of (workers_cover, jobs_cover) corresponding to row and column
covers for the bipartite graph. workers_cover is a boolean tensor of shape
covers for the bipartite graph. workers_cover is a boolean tensor of shape
...
@@ -368,13 +390,16 @@ def _compute_cover(state, assignment):
...
@@ -368,13 +390,16 @@ def _compute_cover(state, assignment):
def
_update_weights_using_cover
(
workers_cover
,
jobs_cover
,
weights
):
def
_update_weights_using_cover
(
workers_cover
,
jobs_cover
,
weights
):
"""Updates weights for hungarian matching using a cover.
"""Updates weights for hungarian matching using a cover.
We first find the minimum uncovered weight. Then, we subtract this from all
We first find the minimum uncovered weight. Then, we subtract this from all
the uncovered weights, and add it to all the doubly covered weights.
the uncovered weights, and add it to all the doubly covered weights.
Args:
Args:
workers_cover: A boolean tensor of shape [batch_size, num_elems, 1].
workers_cover: A boolean tensor of shape [batch_size, num_elems, 1].
jobs_cover: A boolean tensor of shape [batch_size, 1, num_elems].
jobs_cover: A boolean tensor of shape [batch_size, 1, num_elems].
weights: A float32 [batch_size, num_elems, num_elems] tensor, where each
weights: A float32 [batch_size, num_elems, num_elems] tensor, where each
inner matrix represents weights to be use for matching.
inner matrix represents weights to be use for matching.
Returns:
Returns:
A new weight matrix with elements adjusted by the cover.
A new weight matrix with elements adjusted by the cover.
"""
"""
...
@@ -398,10 +423,12 @@ def _update_weights_using_cover(workers_cover, jobs_cover, weights):
...
@@ -398,10 +423,12 @@ def _update_weights_using_cover(workers_cover, jobs_cover, weights):
def
assert_rank
(
tensor
,
expected_rank
,
name
=
None
):
def
assert_rank
(
tensor
,
expected_rank
,
name
=
None
):
"""Raises an exception if the tensor rank is not of the expected rank.
"""Raises an exception if the tensor rank is not of the expected rank.
Args:
Args:
tensor: A tf.Tensor to check the rank of.
tensor: A tf.Tensor to check the rank of.
expected_rank: Python integer or list of integers, expected rank.
expected_rank: Python integer or list of integers, expected rank.
name: Optional name of the tensor for the error message.
name: Optional name of the tensor for the error message.
Raises:
Raises:
ValueError: If the expected shape doesn't match the actual shape.
ValueError: If the expected shape doesn't match the actual shape.
"""
"""
...
@@ -422,9 +449,11 @@ def assert_rank(tensor, expected_rank, name=None):
...
@@ -422,9 +449,11 @@ def assert_rank(tensor, expected_rank, name=None):
def
hungarian_matching
(
weights
):
def
hungarian_matching
(
weights
):
"""Computes the minimum linear sum assignment using the Hungarian algorithm.
"""Computes the minimum linear sum assignment using the Hungarian algorithm.
Args:
Args:
weights: A float32 [batch_size, num_elems, num_elems] tensor, where each
weights: A float32 [batch_size, num_elems, num_elems] tensor, where each
inner matrix represents weights to be use for matching.
inner matrix represents weights to be use for matching.
Returns:
Returns:
A bool [batch_size, num_elems, num_elems] tensor, where each element of the
A bool [batch_size, num_elems, num_elems] tensor, where each element of the
inner matrix represents whether the worker has been matched to the job.
inner matrix represents whether the worker has been matched to the job.
...
@@ -456,4 +485,5 @@ def hungarian_matching(weights):
...
@@ -456,4 +485,5 @@ def hungarian_matching(weights):
_update_weights_and_match
,
_update_weights_and_match
,
(
workers_cover
,
jobs_cover
,
weights
,
assignment
),
(
workers_cover
,
jobs_cover
,
weights
,
assignment
),
back_prop
=
False
)
back_prop
=
False
)
return
weights
,
assignment
return
weights
,
assignment
\ No newline at end of file
official/projects/detr/tasks/detection.py
View file @
eedfa888
...
@@ -13,28 +13,20 @@
...
@@ -13,28 +13,20 @@
# limitations under the License.
# limitations under the License.
"""DETR detection task definition."""
"""DETR detection task definition."""
from
typing
import
Any
,
List
,
Mapping
,
Optional
,
Tuple
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.common
import
dataset_fn
from
official.core
import
base_task
from
official.core
import
base_task
from
official.core
import
task_factory
from
official.core
import
task_factory
from
official.projects.detr.configs
import
detr
as
detr_cfg
from
official.projects.detr.configs
import
detr
as
detr_cfg
from
official.projects.detr.dataloaders
import
coco
from
official.projects.detr.modeling
import
detr
from
official.projects.detr.modeling
import
detr
from
official.projects.detr.ops
import
matchers
from
official.projects.detr.ops
import
matchers
from
official.vision.evaluation
import
coco_evaluator
from
official.vision.evaluation
import
coco_evaluator
from
official.vision.ops
import
box_ops
from
official.vision.ops
import
box_ops
from
official.vision.dataloaders
import
input_reader_factory
from
official.vision.dataloaders
import
tf_example_decoder
from
official.vision.dataloaders
import
tfds_factory
from
official.vision.dataloaders
import
tf_example_label_map_decoder
from
official.projects.detr.dataloaders
import
detr_input
from
official.projects.detr.dataloaders
import
coco
from
official.vision.modeling
import
backbones
@
task_factory
.
register_task_cls
(
detr_cfg
.
DetrTask
)
@
task_factory
.
register_task_cls
(
detr_cfg
.
DetectionConfig
)
class
DectectionTask
(
base_task
.
Task
):
class
DectectionTask
(
base_task
.
Task
):
"""A single-replica view of training procedure.
"""A single-replica view of training procedure.
...
@@ -45,104 +37,46 @@ class DectectionTask(base_task.Task):
...
@@ -45,104 +37,46 @@ class DectectionTask(base_task.Task):
def
build_model
(
self
):
def
build_model
(
self
):
"""Build DETR model."""
"""Build DETR model."""
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
]
+
self
.
_task_config
.
model
.
input_size
)
backbone
=
backbones
.
factory
.
build_backbone
(
input_specs
=
input_specs
,
backbone_config
=
self
.
_task_config
.
model
.
backbone
,
norm_activation_config
=
self
.
_task_config
.
model
.
norm_activation
)
model
=
detr
.
DETR
(
model
=
detr
.
DETR
(
backbone
,
self
.
_task_config
.
num_queries
,
self
.
_task_config
.
model
.
num_queries
,
self
.
_task_config
.
num_hidden
,
self
.
_task_config
.
model
.
hidden_size
,
self
.
_task_config
.
num_classes
,
self
.
_task_config
.
model
.
num_classes
,
self
.
_task_config
.
num_encoder_layers
,
self
.
_task_config
.
model
.
num_encoder_layers
,
self
.
_task_config
.
num_decoder_layers
)
self
.
_task_config
.
model
.
num_decoder_layers
)
return
model
return
model
def
initialize
(
self
,
model
:
tf
.
keras
.
Model
):
def
initialize
(
self
,
model
:
tf
.
keras
.
Model
):
"""Loading pretrained checkpoint."""
"""Loading pretrained checkpoint."""
if
not
self
.
_task_config
.
init_checkpoint
:
ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
model
.
backbone
)
return
status
=
ckpt
.
read
(
self
.
_task_config
.
init_ckpt
)
status
.
expect_partial
().
assert_existing_objects_matched
()
ckpt_dir_or_file
=
self
.
_task_config
.
init_checkpoint
def
build_inputs
(
self
,
params
,
input_context
=
None
):
# Restoring checkpoint.
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
if
self
.
_task_config
.
init_checkpoint_modules
==
'all'
:
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
.
assert_consumed
()
elif
self
.
_task_config
.
init_checkpoint_modules
==
'backbone'
:
ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
model
.
backbone
)
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
def
build_inputs
(
self
,
params
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Build input dataset."""
"""Build input dataset."""
if
type
(
params
)
is
coco
.
COCODataConfig
:
return
coco
.
COCODataLoader
(
params
).
load
(
input_context
)
dataset
=
coco
.
COCODataLoader
(
params
).
load
(
input_context
)
else
:
if
params
.
tfds_name
:
decoder
=
tfds_factory
.
get_detection_decoder
(
params
.
tfds_name
)
else
:
decoder_cfg
=
params
.
decoder
.
get
()
if
params
.
decoder
.
type
==
'simple_decoder'
:
decoder
=
tf_example_decoder
.
TfExampleDecoder
(
regenerate_source_id
=
decoder_cfg
.
regenerate_source_id
)
elif
params
.
decoder
.
type
==
'label_map_decoder'
:
decoder
=
tf_example_label_map_decoder
.
TfExampleDecoderLabelMap
(
label_map
=
decoder_cfg
.
label_map
,
regenerate_source_id
=
decoder_cfg
.
regenerate_source_id
)
else
:
raise
ValueError
(
'Unknown decoder type: {}!'
.
format
(
params
.
decoder
.
type
))
parser
=
detr_input
.
Parser
(
class_offset
=
self
.
_task_config
.
losses
.
class_offset
,
output_size
=
self
.
_task_config
.
model
.
input_size
[:
2
],
)
reader
=
input_reader_factory
.
input_reader_generator
(
params
,
dataset_fn
=
dataset_fn
.
pick_dataset_fn
(
params
.
file_type
),
decoder_fn
=
decoder
.
decode
,
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
))
dataset
=
reader
.
read
(
input_context
=
input_context
)
return
dataset
def
_compute_cost
(
self
,
cls_outputs
,
box_outputs
,
cls_targets
,
box_targets
):
def
_compute_cost
(
self
,
cls_outputs
,
box_outputs
,
cls_targets
,
box_targets
):
# Approximate classification cost with 1 - prob[target class].
# Approximate classification cost with 1 - prob[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
# The 1 is a constant that doesn't change the matching, it can be ommitted.
# background: 0
# background: 0
cls_cost
=
self
.
_task_config
.
losses
.
lambda_cls
*
tf
.
gather
(
cls_cost
=
self
.
_task_config
.
lambda_cls
*
tf
.
gather
(
-
tf
.
nn
.
softmax
(
cls_outputs
),
cls_targets
,
batch_dims
=
1
,
axis
=-
1
)
-
tf
.
nn
.
softmax
(
cls_outputs
),
cls_targets
,
batch_dims
=
1
,
axis
=-
1
)
# Compute the L1 cost between boxes,
# Compute the L1 cost between boxes,
paired_differences
=
self
.
_task_config
.
losses
.
lambda_box
*
tf
.
abs
(
paired_differences
=
self
.
_task_config
.
lambda_box
*
tf
.
abs
(
tf
.
expand_dims
(
box_outputs
,
2
)
-
tf
.
expand_dims
(
box_targets
,
1
))
tf
.
expand_dims
(
box_outputs
,
2
)
-
tf
.
expand_dims
(
box_targets
,
1
))
box_cost
=
tf
.
reduce_sum
(
paired_differences
,
axis
=-
1
)
box_cost
=
tf
.
reduce_sum
(
paired_differences
,
axis
=-
1
)
# Compute the giou cost betwen boxes
# Compute the giou cost betwen boxes
giou_cost
=
self
.
_task_config
.
losses
.
lambda_giou
*
-
box_ops
.
bbox_generalized_overlap
(
giou_cost
=
self
.
_task_config
.
lambda_giou
*
-
box_ops
.
bbox_generalized_overlap
(
box_ops
.
cycxhw_to_yxyx
(
box_outputs
),
box_ops
.
cycxhw_to_yxyx
(
box_outputs
),
box_ops
.
cycxhw_to_yxyx
(
box_targets
))
box_ops
.
cycxhw_to_yxyx
(
box_targets
))
total_cost
=
cls_cost
+
box_cost
+
giou_cost
total_cost
=
cls_cost
+
box_cost
+
giou_cost
max_cost
=
(
max_cost
=
(
self
.
_task_config
.
losses
.
lambda_cls
*
0.0
+
self
.
_task_config
.
losses
.
lambda_box
*
4.
+
self
.
_task_config
.
lambda_cls
*
0.0
+
self
.
_task_config
.
lambda_box
*
4.
+
self
.
_task_config
.
losses
.
lambda_giou
*
0.0
)
self
.
_task_config
.
lambda_giou
*
0.0
)
# Set pads to large constant
# Set pads to large constant
valid
=
tf
.
expand_dims
(
valid
=
tf
.
expand_dims
(
...
@@ -181,20 +115,20 @@ class DectectionTask(base_task.Task):
...
@@ -181,20 +115,20 @@ class DectectionTask(base_task.Task):
# Down-weight background to account for class imbalance.
# Down-weight background to account for class imbalance.
xentropy
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
xentropy
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
cls_targets
,
logits
=
cls_assigned
)
labels
=
cls_targets
,
logits
=
cls_assigned
)
cls_loss
=
self
.
_task_config
.
losses
.
lambda_cls
*
tf
.
where
(
cls_loss
=
self
.
_task_config
.
lambda_cls
*
tf
.
where
(
background
,
background
,
self
.
_task_config
.
losses
.
background_cls_weight
*
xentropy
,
self
.
_task_config
.
background_cls_weight
*
xentropy
,
xentropy
xentropy
)
)
cls_weights
=
tf
.
where
(
cls_weights
=
tf
.
where
(
background
,
background
,
self
.
_task_config
.
losses
.
background_cls_weight
*
tf
.
ones_like
(
cls_loss
),
self
.
_task_config
.
background_cls_weight
*
tf
.
ones_like
(
cls_loss
),
tf
.
ones_like
(
cls_loss
)
tf
.
ones_like
(
cls_loss
)
)
)
# Box loss is only calculated on non-background class.
# Box loss is only calculated on non-background class.
l_1
=
tf
.
reduce_sum
(
tf
.
abs
(
box_assigned
-
box_targets
),
axis
=-
1
)
l_1
=
tf
.
reduce_sum
(
tf
.
abs
(
box_assigned
-
box_targets
),
axis
=-
1
)
box_loss
=
self
.
_task_config
.
losses
.
lambda_box
*
tf
.
where
(
box_loss
=
self
.
_task_config
.
lambda_box
*
tf
.
where
(
background
,
background
,
tf
.
zeros_like
(
l_1
),
tf
.
zeros_like
(
l_1
),
l_1
l_1
...
@@ -205,7 +139,7 @@ class DectectionTask(base_task.Task):
...
@@ -205,7 +139,7 @@ class DectectionTask(base_task.Task):
box_ops
.
cycxhw_to_yxyx
(
box_assigned
),
box_ops
.
cycxhw_to_yxyx
(
box_assigned
),
box_ops
.
cycxhw_to_yxyx
(
box_targets
)
box_ops
.
cycxhw_to_yxyx
(
box_targets
)
))
))
giou_loss
=
self
.
_task_config
.
losses
.
lambda_giou
*
tf
.
where
(
giou_loss
=
self
.
_task_config
.
lambda_giou
*
tf
.
where
(
background
,
background
,
tf
.
zeros_like
(
giou
),
tf
.
zeros_like
(
giou
),
giou
giou
...
@@ -226,7 +160,6 @@ class DectectionTask(base_task.Task):
...
@@ -226,7 +160,6 @@ class DectectionTask(base_task.Task):
tf
.
reduce_sum
(
giou_loss
),
num_boxes_sum
)
tf
.
reduce_sum
(
giou_loss
),
num_boxes_sum
)
aux_losses
=
tf
.
add_n
(
aux_losses
)
if
aux_losses
else
0.0
aux_losses
=
tf
.
add_n
(
aux_losses
)
if
aux_losses
else
0.0
total_loss
=
cls_loss
+
box_loss
+
giou_loss
+
aux_losses
total_loss
=
cls_loss
+
box_loss
+
giou_loss
+
aux_losses
return
total_loss
,
cls_loss
,
box_loss
,
giou_loss
return
total_loss
,
cls_loss
,
box_loss
,
giou_loss
...
@@ -239,7 +172,7 @@ class DectectionTask(base_task.Task):
...
@@ -239,7 +172,7 @@ class DectectionTask(base_task.Task):
if
not
training
:
if
not
training
:
self
.
coco_metric
=
coco_evaluator
.
COCOEvaluator
(
self
.
coco_metric
=
coco_evaluator
.
COCOEvaluator
(
annotation_file
=
self
.
_task_config
.
annotation_file
,
annotation_file
=
''
,
include_mask
=
False
,
include_mask
=
False
,
need_rescale_bboxes
=
True
,
need_rescale_bboxes
=
True
,
per_category_metrics
=
self
.
_task_config
.
per_category_metrics
)
per_category_metrics
=
self
.
_task_config
.
per_category_metrics
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment