Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
ca8c44d5
Commit
ca8c44d5
authored
Jan 26, 2022
by
Frederick Liu
Committed by
A. Unique TensorFlower
Jan 26, 2022
Browse files
Internal change
PiperOrigin-RevId: 424422082
parent
a7894f9e
Changes
17
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
3174 additions
and
0 deletions
+3174
-0
official/projects/detr/README.md
official/projects/detr/README.md
+46
-0
official/projects/detr/configs/detr.py
official/projects/detr/configs/detr.py
+103
-0
official/projects/detr/configs/detr_test.py
official/projects/detr/configs/detr_test.py
+41
-0
official/projects/detr/dataloaders/coco.py
official/projects/detr/dataloaders/coco.py
+157
-0
official/projects/detr/dataloaders/coco_test.py
official/projects/detr/dataloaders/coco_test.py
+111
-0
official/projects/detr/experiments/detr_r50_300epochs.sh
official/projects/detr/experiments/detr_r50_300epochs.sh
+6
-0
official/projects/detr/experiments/detr_r50_500epochs.sh
official/projects/detr/experiments/detr_r50_500epochs.sh
+6
-0
official/projects/detr/modeling/detr.py
official/projects/detr/modeling/detr.py
+273
-0
official/projects/detr/modeling/detr_test.py
official/projects/detr/modeling/detr_test.py
+64
-0
official/projects/detr/modeling/transformer.py
official/projects/detr/modeling/transformer.py
+846
-0
official/projects/detr/modeling/transformer_test.py
official/projects/detr/modeling/transformer_test.py
+263
-0
official/projects/detr/ops/matchers.py
official/projects/detr/ops/matchers.py
+489
-0
official/projects/detr/ops/matchers_test.py
official/projects/detr/ops/matchers_test.py
+95
-0
official/projects/detr/optimization.py
official/projects/detr/optimization.py
+147
-0
official/projects/detr/tasks/detection.py
official/projects/detr/tasks/detection.py
+341
-0
official/projects/detr/tasks/detection_test.py
official/projects/detr/tasks/detection_test.py
+116
-0
official/projects/detr/train.py
official/projects/detr/train.py
+70
-0
No files found.
official/projects/detr/README.md
0 → 100644
View file @
ca8c44d5
# End-to-End Object Detection with Transformers (DETR)
[

](https://arxiv.org/abs/2005.12872).
TensorFlow 2 implementation of End-to-End Object Detection with Transformers
⚠️ Disclaimer: All datasets hyperlinked from this page are not owned or
distributed by Google. The dataset is made available by third parties.
Please review the terms and conditions made available by the third parties
before using the data.
## Scripts:
You can find the scripts to reproduce the following experiments in
detr/experiments.
## DETR [COCO](https://cocodataset.org) ([ImageNet](https://www.image-net.org) pretrained)
| Model | Resolution | Batch size | Epochs | Decay@ | Params (M) | Box AP | Dashboard | Checkpoint | Experiment |
| --------- | :--------: | ----------:| ------:| -----: | ---------: | -----: | --------: | ---------: | ---------: |
| DETR-ResNet-50 | 1333x1333 |64|300| 200 |41 | 40.6 |
[
tensorboard
](
https://tensorboard.dev/experiment/o2IEZnniRYu6pqViBeopIg/#scalars
)
|
[
ckpt
](
https://storage.googleapis.com/tf_model_garden/vision/detr/detr_resnet_50_300.tar.gz
)
| detr_r50_300epochs.sh |
| DETR-ResNet-50 | 1333x1333 |64|500| 400 |41 | 42.0|
[
tensorboard
](
https://tensorboard.dev/experiment/YFMDKpESR4yjocPh5HgfRw/
)
|
[
ckpt
](
https://storage.googleapis.com/tf_model_garden/vision/detr/detr_resnet_50_500.tar.gz
)
| detr_r50_500epochs.sh |
| DETR-ResNet-50 | 1333x1333 |64|300| 200 |41 | 40.6 | paper | NA | NA |
| DETR-ResNet-50 | 1333x1333 |64|500| 400 |41 | 42.0 | paper | NA | NA |
| DETR-DC5-ResNet-50 | 1333x1333 |64|500| 400 |41 | 43.3 | paper | NA | NA |
## Need contribution:
*
Add DC5 support and update experiment table.
## Citing TensorFlow Model Garden
If you find this codebase helpful in your research, please cite this repository.
```
@misc{tensorflowmodelgarden2020,
author = {Hongkun Yu and Chen Chen and Xianzhi Du and Yeqing Li and
Abdullah Rashwan and Le Hou and Pengchong Jin and Fan Yang and
Frederick Liu and Jaeyoun Kim and Jing Li},
title = {{TensorFlow Model Garden}},
howpublished = {\url{https://github.com/tensorflow/models}},
year = {2020}
}
```
official/projects/detr/configs/detr.py
0 → 100644
View file @
ca8c44d5
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DETR configurations."""
import
dataclasses
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.projects.detr
import
optimization
from
official.projects.detr.dataloaders
import
coco
@
dataclasses
.
dataclass
class
DetectionConfig
(
cfg
.
TaskConfig
):
"""The translation task config."""
train_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
validation_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
lambda_cls
:
float
=
1.0
lambda_box
:
float
=
5.0
lambda_giou
:
float
=
2.0
init_ckpt
:
str
=
''
num_classes
:
int
=
81
# 0: background
background_cls_weight
:
float
=
0.1
num_encoder_layers
:
int
=
6
num_decoder_layers
:
int
=
6
# Make DETRConfig.
num_queries
:
int
=
100
num_hidden
:
int
=
256
per_category_metrics
:
bool
=
False
@
exp_factory
.
register_config_factory
(
'detr_coco'
)
def
detr_coco
()
->
cfg
.
ExperimentConfig
:
"""Config to get results that matches the paper."""
train_batch_size
=
64
eval_batch_size
=
64
num_train_data
=
118287
num_steps_per_epoch
=
num_train_data
//
train_batch_size
train_steps
=
500
*
num_steps_per_epoch
# 500 epochs
decay_at
=
train_steps
-
100
*
num_steps_per_epoch
# 400 epochs
config
=
cfg
.
ExperimentConfig
(
task
=
DetectionConfig
(
train_data
=
coco
.
COCODataConfig
(
tfds_name
=
'coco/2017'
,
tfds_split
=
'train'
,
is_training
=
True
,
global_batch_size
=
train_batch_size
,
shuffle_buffer_size
=
1000
,
),
validation_data
=
coco
.
COCODataConfig
(
tfds_name
=
'coco/2017'
,
tfds_split
=
'validation'
,
is_training
=
False
,
global_batch_size
=
eval_batch_size
,
drop_remainder
=
False
)
),
trainer
=
cfg
.
TrainerConfig
(
train_steps
=
train_steps
,
validation_steps
=-
1
,
steps_per_loop
=
10000
,
summary_interval
=
10000
,
checkpoint_interval
=
10000
,
validation_interval
=
10000
,
max_to_keep
=
1
,
best_checkpoint_export_subdir
=
'best_ckpt'
,
best_checkpoint_eval_metric
=
'AP'
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'detr_adamw'
,
'detr_adamw'
:
{
'weight_decay_rate'
:
1e-4
,
'global_clipnorm'
:
0.1
,
# Avoid AdamW legacy behavior.
'gradient_clip_norm'
:
0.0
}
},
'learning_rate'
:
{
'type'
:
'stepwise'
,
'stepwise'
:
{
'boundaries'
:
[
decay_at
],
'values'
:
[
0.0001
,
1.0e-05
]
}
},
})
),
restrictions
=
[
'task.train_data.is_training != None'
,
])
return
config
official/projects/detr/configs/detr_test.py
0 → 100644
View file @
ca8c44d5
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for detr."""
# pylint: disable=unused-import
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.projects.detr.configs
import
detr
as
exp_cfg
from
official.projects.detr.dataloaders
import
coco
class
DetrTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
((
'detr_coco'
,))
def
test_detr_configs
(
self
,
config_name
):
config
=
exp_factory
.
get_exp_config
(
config_name
)
self
.
assertIsInstance
(
config
,
cfg
.
ExperimentConfig
)
self
.
assertIsInstance
(
config
.
task
,
exp_cfg
.
DetectionConfig
)
self
.
assertIsInstance
(
config
.
task
.
train_data
,
coco
.
COCODataConfig
)
config
.
task
.
train_data
.
is_training
=
None
with
self
.
assertRaises
(
KeyError
):
config
.
validate
()
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/projects/detr/dataloaders/coco.py
0 → 100644
View file @
ca8c44d5
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""COCO data loader for DETR."""
import
dataclasses
from
typing
import
Optional
,
Tuple
import
tensorflow
as
tf
from
official.core
import
config_definitions
as
cfg
from
official.core
import
input_reader
from
official.vision.beta.ops
import
box_ops
from
official.vision.beta.ops
import
preprocess_ops
@
dataclasses
.
dataclass
class
COCODataConfig
(
cfg
.
DataConfig
):
"""Data config for COCO."""
output_size
:
Tuple
[
int
,
int
]
=
(
1333
,
1333
)
max_num_boxes
:
int
=
100
resize_scales
:
Tuple
[
int
,
...]
=
(
480
,
512
,
544
,
576
,
608
,
640
,
672
,
704
,
736
,
768
,
800
)
class
COCODataLoader
():
"""A class to load dataset for COCO detection task."""
def
__init__
(
self
,
params
:
COCODataConfig
):
self
.
_params
=
params
def
preprocess
(
self
,
inputs
):
"""Preprocess COCO for DETR."""
image
=
inputs
[
'image'
]
boxes
=
inputs
[
'objects'
][
'bbox'
]
classes
=
inputs
[
'objects'
][
'label'
]
+
1
is_crowd
=
inputs
[
'objects'
][
'is_crowd'
]
image
=
preprocess_ops
.
normalize_image
(
image
)
if
self
.
_params
.
is_training
:
image
,
boxes
,
_
=
preprocess_ops
.
random_horizontal_flip
(
image
,
boxes
)
do_crop
=
tf
.
greater
(
tf
.
random
.
uniform
([]),
0.5
)
if
do_crop
:
# Rescale
boxes
=
box_ops
.
denormalize_boxes
(
boxes
,
tf
.
shape
(
image
)[:
2
])
index
=
tf
.
random
.
categorical
(
tf
.
zeros
([
1
,
3
]),
1
)[
0
]
scales
=
tf
.
gather
([
400.0
,
500.0
,
600.0
],
index
,
axis
=
0
)
short_side
=
scales
[
0
]
image
,
image_info
=
preprocess_ops
.
resize_image
(
image
,
short_side
)
boxes
=
preprocess_ops
.
resize_and_crop_boxes
(
boxes
,
image_info
[
2
,
:],
image_info
[
1
,
:],
image_info
[
3
,
:])
boxes
=
box_ops
.
normalize_boxes
(
boxes
,
image_info
[
1
,
:])
# Do croping
shape
=
tf
.
cast
(
image_info
[
1
],
dtype
=
tf
.
int32
)
h
=
tf
.
random
.
uniform
(
[],
384
,
tf
.
math
.
minimum
(
shape
[
0
],
600
),
dtype
=
tf
.
int32
)
w
=
tf
.
random
.
uniform
(
[],
384
,
tf
.
math
.
minimum
(
shape
[
1
],
600
),
dtype
=
tf
.
int32
)
i
=
tf
.
random
.
uniform
([],
0
,
shape
[
0
]
-
h
+
1
,
dtype
=
tf
.
int32
)
j
=
tf
.
random
.
uniform
([],
0
,
shape
[
1
]
-
w
+
1
,
dtype
=
tf
.
int32
)
image
=
tf
.
image
.
crop_to_bounding_box
(
image
,
i
,
j
,
h
,
w
)
boxes
=
tf
.
clip_by_value
(
(
boxes
[...,
:]
*
tf
.
cast
(
tf
.
stack
([
shape
[
0
],
shape
[
1
],
shape
[
0
],
shape
[
1
]]),
dtype
=
tf
.
float32
)
-
tf
.
cast
(
tf
.
stack
([
i
,
j
,
i
,
j
]),
dtype
=
tf
.
float32
))
/
tf
.
cast
(
tf
.
stack
([
h
,
w
,
h
,
w
]),
dtype
=
tf
.
float32
),
0.0
,
1.0
)
scales
=
tf
.
constant
(
self
.
_params
.
resize_scales
,
dtype
=
tf
.
float32
)
index
=
tf
.
random
.
categorical
(
tf
.
zeros
([
1
,
11
]),
1
)[
0
]
scales
=
tf
.
gather
(
scales
,
index
,
axis
=
0
)
else
:
scales
=
tf
.
constant
([
self
.
_params
.
resize_scales
[
-
1
]],
tf
.
float32
)
image_shape
=
tf
.
shape
(
image
)[:
2
]
boxes
=
box_ops
.
denormalize_boxes
(
boxes
,
image_shape
)
gt_boxes
=
boxes
short_side
=
scales
[
0
]
image
,
image_info
=
preprocess_ops
.
resize_image
(
image
,
short_side
,
max
(
self
.
_params
.
output_size
))
boxes
=
preprocess_ops
.
resize_and_crop_boxes
(
boxes
,
image_info
[
2
,
:],
image_info
[
1
,
:],
image_info
[
3
,
:])
boxes
=
box_ops
.
normalize_boxes
(
boxes
,
image_info
[
1
,
:])
# Filters out ground truth boxes that are all zeros.
indices
=
box_ops
.
get_non_empty_box_indices
(
boxes
)
boxes
=
tf
.
gather
(
boxes
,
indices
)
classes
=
tf
.
gather
(
classes
,
indices
)
is_crowd
=
tf
.
gather
(
is_crowd
,
indices
)
boxes
=
box_ops
.
yxyx_to_cycxhw
(
boxes
)
image
=
tf
.
image
.
pad_to_bounding_box
(
image
,
0
,
0
,
self
.
_params
.
output_size
[
0
],
self
.
_params
.
output_size
[
1
])
labels
=
{
'classes'
:
preprocess_ops
.
clip_or_pad_to_fixed_size
(
classes
,
self
.
_params
.
max_num_boxes
),
'boxes'
:
preprocess_ops
.
clip_or_pad_to_fixed_size
(
boxes
,
self
.
_params
.
max_num_boxes
)
}
if
not
self
.
_params
.
is_training
:
labels
.
update
({
'id'
:
inputs
[
'image/id'
],
'image_info'
:
image_info
,
'is_crowd'
:
preprocess_ops
.
clip_or_pad_to_fixed_size
(
is_crowd
,
self
.
_params
.
max_num_boxes
),
'gt_boxes'
:
preprocess_ops
.
clip_or_pad_to_fixed_size
(
gt_boxes
,
self
.
_params
.
max_num_boxes
),
})
return
image
,
labels
def
_transform_and_batch_fn
(
self
,
dataset
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Preprocess and batch."""
dataset
=
dataset
.
map
(
self
.
preprocess
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
per_replica_batch_size
=
input_context
.
get_per_replica_batch_size
(
self
.
_params
.
global_batch_size
)
if
input_context
else
self
.
_params
.
global_batch_size
dataset
=
dataset
.
batch
(
per_replica_batch_size
,
drop_remainder
=
self
.
_params
.
is_training
)
return
dataset
def
load
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Returns a tf.dataset.Dataset."""
reader
=
input_reader
.
InputReader
(
params
=
self
.
_params
,
decoder_fn
=
None
,
transform_and_batch_fn
=
self
.
_transform_and_batch_fn
)
return
reader
.
read
(
input_context
)
official/projects/detr/dataloaders/coco_test.py
0 → 100644
View file @
ca8c44d5
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tensorflow_models.official.projects.detr.dataloaders.coco."""
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow_datasets
as
tfds
from
official.projects.detr.dataloaders
import
coco
def
_gen_fn
():
h
=
np
.
random
.
randint
(
0
,
300
)
w
=
np
.
random
.
randint
(
0
,
300
)
num_boxes
=
np
.
random
.
randint
(
0
,
50
)
return
{
'image'
:
np
.
ones
(
shape
=
(
h
,
w
,
3
),
dtype
=
np
.
uint8
),
'image/id'
:
np
.
random
.
randint
(
0
,
100
),
'image/filename'
:
'test'
,
'objects'
:
{
'is_crowd'
:
np
.
ones
(
shape
=
(
num_boxes
),
dtype
=
np
.
bool
),
'bbox'
:
np
.
ones
(
shape
=
(
num_boxes
,
4
),
dtype
=
np
.
float32
),
'label'
:
np
.
ones
(
shape
=
(
num_boxes
),
dtype
=
np
.
int64
),
'id'
:
np
.
ones
(
shape
=
(
num_boxes
),
dtype
=
np
.
int64
),
'area'
:
np
.
ones
(
shape
=
(
num_boxes
),
dtype
=
np
.
int64
),
}
}
class
CocoDataloaderTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
test_load_dataset
(
self
):
output_size
=
1280
max_num_boxes
=
100
batch_size
=
2
data_config
=
coco
.
COCODataConfig
(
tfds_name
=
'coco/2017'
,
tfds_split
=
'validation'
,
is_training
=
False
,
global_batch_size
=
batch_size
,
output_size
=
(
output_size
,
output_size
),
max_num_boxes
=
max_num_boxes
,
)
num_examples
=
10
def
as_dataset
(
self
,
*
args
,
**
kwargs
):
del
args
del
kwargs
return
tf
.
data
.
Dataset
.
from_generator
(
lambda
:
(
_gen_fn
()
for
i
in
range
(
num_examples
)),
output_types
=
self
.
info
.
features
.
dtype
,
output_shapes
=
self
.
info
.
features
.
shape
,
)
with
tfds
.
testing
.
mock_data
(
num_examples
=
num_examples
,
as_dataset_fn
=
as_dataset
):
dataset
=
coco
.
COCODataLoader
(
data_config
).
load
()
dataset_iter
=
iter
(
dataset
)
images
,
labels
=
next
(
dataset_iter
)
self
.
assertEqual
(
images
.
shape
,
(
batch_size
,
output_size
,
output_size
,
3
))
self
.
assertEqual
(
labels
[
'classes'
].
shape
,
(
batch_size
,
max_num_boxes
))
self
.
assertEqual
(
labels
[
'boxes'
].
shape
,
(
batch_size
,
max_num_boxes
,
4
))
self
.
assertEqual
(
labels
[
'id'
].
shape
,
(
batch_size
,))
self
.
assertEqual
(
labels
[
'image_info'
].
shape
,
(
batch_size
,
4
,
2
))
self
.
assertEqual
(
labels
[
'is_crowd'
].
shape
,
(
batch_size
,
max_num_boxes
))
@
parameterized
.
named_parameters
(
(
'training'
,
True
),
(
'validation'
,
False
))
def
test_preprocess
(
self
,
is_training
):
output_size
=
1280
max_num_boxes
=
100
batch_size
=
2
data_config
=
coco
.
COCODataConfig
(
tfds_name
=
'coco/2017'
,
tfds_split
=
'validation'
,
is_training
=
is_training
,
global_batch_size
=
batch_size
,
output_size
=
(
output_size
,
output_size
),
max_num_boxes
=
max_num_boxes
,
)
dl
=
coco
.
COCODataLoader
(
data_config
)
inputs
=
_gen_fn
()
image
,
label
=
dl
.
preprocess
(
inputs
)
self
.
assertEqual
(
image
.
shape
,
(
output_size
,
output_size
,
3
))
self
.
assertEqual
(
label
[
'classes'
].
shape
,
(
max_num_boxes
))
self
.
assertEqual
(
label
[
'boxes'
].
shape
,
(
max_num_boxes
,
4
))
if
not
is_training
:
self
.
assertDTypeEqual
(
label
[
'id'
],
int
)
self
.
assertEqual
(
label
[
'image_info'
].
shape
,
(
4
,
2
))
self
.
assertEqual
(
label
[
'is_crowd'
].
shape
,
(
max_num_boxes
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/projects/detr/experiments/detr_r50_300epochs.sh
0 → 100644
View file @
ca8c44d5
#!/bin/bash
python3 official/projects/detr/train.py
\
--experiment
=
detr_coco
\
--mode
=
train_and_eval
\
--model_dir
=
/tmp/logging_dir/
\
--params_override
=
task.init_ckpt
=
'gs://tf_model_garden/vision/resnet50_imagenet/ckpt-62400'
,trainer.train_steps
=
554400
official/projects/detr/experiments/detr_r50_500epochs.sh
0 → 100644
View file @
ca8c44d5
#!/bin/bash
python3 official/projects/detr/train.py
\
--experiment
=
detr_coco
\
--mode
=
train_and_eval
\
--model_dir
=
/tmp/logging_dir/
\
--params_override
=
task.init_ckpt
=
'gs://tf_model_garden/vision/resnet50_imagenet/ckpt-62400'
official/projects/detr/modeling/detr.py
0 → 100644
View file @
ca8c44d5
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements End-to-End Object Detection with Transformers.
Model paper: https://arxiv.org/abs/2005.12872
This module does not support Keras de/serialization. Please use
tf.train.Checkpoint for object based saving and loading and tf.saved_model.save
for graph serializaiton.
"""
import
math
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.projects.detr.modeling
import
transformer
from
official.vision.beta.modeling.backbones
import
resnet
def
position_embedding_sine
(
attention_mask
,
num_pos_features
=
256
,
temperature
=
10000.
,
normalize
=
True
,
scale
=
2
*
math
.
pi
):
"""Sine-based positional embeddings for 2D images.
Args:
attention_mask: a `bool` Tensor specifying the size of the input image to
the Transformer and which elements are padded, of size [batch_size,
height, width]
num_pos_features: a `int` specifying the number of positional features,
should be equal to the hidden size of the Transformer network
temperature: a `float` specifying the temperature of the positional
embedding. Any type that is converted to a `float` can also be accepted.
normalize: a `bool` determining whether the positional embeddings should be
normalized between [0, scale] before application of the sine and cos
functions.
scale: a `float` if normalize is True specifying the scale embeddings before
application of the embedding function.
Returns:
embeddings: a `float` tensor of the same shape as input_tensor specifying
the positional embeddings based on sine features.
"""
if
num_pos_features
%
2
!=
0
:
raise
ValueError
(
"Number of embedding features (num_pos_features) must be even when "
"column and row embeddings are concatenated."
)
num_pos_features
=
num_pos_features
//
2
# Produce row and column embeddings based on total size of the image
# <tf.float>[batch_size, height, width]
attention_mask
=
tf
.
cast
(
attention_mask
,
tf
.
float32
)
row_embedding
=
tf
.
cumsum
(
attention_mask
,
1
)
col_embedding
=
tf
.
cumsum
(
attention_mask
,
2
)
if
normalize
:
eps
=
1e-6
row_embedding
=
row_embedding
/
(
row_embedding
[:,
-
1
:,
:]
+
eps
)
*
scale
col_embedding
=
col_embedding
/
(
col_embedding
[:,
:,
-
1
:]
+
eps
)
*
scale
dim_t
=
tf
.
range
(
num_pos_features
,
dtype
=
row_embedding
.
dtype
)
dim_t
=
tf
.
pow
(
temperature
,
2
*
(
dim_t
//
2
)
/
num_pos_features
)
# Creates positional embeddings for each row and column position
# <tf.float>[batch_size, height, width, num_pos_features]
pos_row
=
tf
.
expand_dims
(
row_embedding
,
-
1
)
/
dim_t
pos_col
=
tf
.
expand_dims
(
col_embedding
,
-
1
)
/
dim_t
pos_row
=
tf
.
stack
(
[
tf
.
sin
(
pos_row
[:,
:,
:,
0
::
2
]),
tf
.
cos
(
pos_row
[:,
:,
:,
1
::
2
])],
axis
=
4
)
pos_col
=
tf
.
stack
(
[
tf
.
sin
(
pos_col
[:,
:,
:,
0
::
2
]),
tf
.
cos
(
pos_col
[:,
:,
:,
1
::
2
])],
axis
=
4
)
# final_shape = pos_row.shape.as_list()[:3] + [-1]
final_shape
=
tf_utils
.
get_shape_list
(
pos_row
)[:
3
]
+
[
-
1
]
pos_row
=
tf
.
reshape
(
pos_row
,
final_shape
)
pos_col
=
tf
.
reshape
(
pos_col
,
final_shape
)
output
=
tf
.
concat
([
pos_row
,
pos_col
],
-
1
)
embeddings
=
tf
.
cast
(
output
,
tf
.
float32
)
return
embeddings
class
DETR
(
tf
.
keras
.
Model
):
"""DETR model with Keras.
DETR consists of backbone, query embedding, DETRTransformer,
class and box heads.
"""
def
__init__
(
self
,
num_queries
,
hidden_size
,
num_classes
,
num_encoder_layers
=
6
,
num_decoder_layers
=
6
,
dropout_rate
=
0.1
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
_num_queries
=
num_queries
self
.
_hidden_size
=
hidden_size
self
.
_num_classes
=
num_classes
self
.
_num_encoder_layers
=
num_encoder_layers
self
.
_num_decoder_layers
=
num_decoder_layers
self
.
_dropout_rate
=
dropout_rate
if
hidden_size
%
2
!=
0
:
raise
ValueError
(
"hidden_size must be a multiple of 2."
)
# TODO(frederickliu): Consider using the backbone factory.
# TODO(frederickliu): Add to factory once we get skeleton code in.
self
.
_backbone
=
resnet
.
ResNet
(
50
,
bn_trainable
=
False
)
def
build
(
self
,
input_shape
=
None
):
self
.
_input_proj
=
tf
.
keras
.
layers
.
Conv2D
(
self
.
_hidden_size
,
1
,
name
=
"detr/conv2d"
)
self
.
_transformer
=
DETRTransformer
(
num_encoder_layers
=
self
.
_num_encoder_layers
,
num_decoder_layers
=
self
.
_num_decoder_layers
,
dropout_rate
=
self
.
_dropout_rate
)
self
.
_query_embeddings
=
self
.
add_weight
(
"detr/query_embeddings"
,
shape
=
[
self
.
_num_queries
,
self
.
_hidden_size
],
initializer
=
tf
.
keras
.
initializers
.
RandomNormal
(
mean
=
0.
,
stddev
=
1.
),
dtype
=
tf
.
float32
)
sqrt_k
=
math
.
sqrt
(
1.0
/
self
.
_hidden_size
)
self
.
_class_embed
=
tf
.
keras
.
layers
.
Dense
(
self
.
_num_classes
,
kernel_initializer
=
tf
.
keras
.
initializers
.
RandomUniform
(
-
sqrt_k
,
sqrt_k
),
name
=
"detr/cls_dense"
)
self
.
_bbox_embed
=
[
tf
.
keras
.
layers
.
Dense
(
self
.
_hidden_size
,
activation
=
"relu"
,
kernel_initializer
=
tf
.
keras
.
initializers
.
RandomUniform
(
-
sqrt_k
,
sqrt_k
),
name
=
"detr/box_dense_0"
),
tf
.
keras
.
layers
.
Dense
(
self
.
_hidden_size
,
activation
=
"relu"
,
kernel_initializer
=
tf
.
keras
.
initializers
.
RandomUniform
(
-
sqrt_k
,
sqrt_k
),
name
=
"detr/box_dense_1"
),
tf
.
keras
.
layers
.
Dense
(
4
,
kernel_initializer
=
tf
.
keras
.
initializers
.
RandomUniform
(
-
sqrt_k
,
sqrt_k
),
name
=
"detr/box_dense_2"
)]
self
.
_sigmoid
=
tf
.
keras
.
layers
.
Activation
(
"sigmoid"
)
super
().
build
(
input_shape
)
@
property
def
backbone
(
self
)
->
tf
.
keras
.
Model
:
return
self
.
_backbone
def
get_config
(
self
):
return
{
"num_queries"
:
self
.
_num_queries
,
"hidden_size"
:
self
.
_hidden_size
,
"num_classes"
:
self
.
_num_classes
,
"num_encoder_layers"
:
self
.
_num_encoder_layers
,
"num_decoder_layers"
:
self
.
_num_decoder_layers
,
"dropout_rate"
:
self
.
_dropout_rate
,
}
@
classmethod
def
from_config
(
cls
,
config
):
return
cls
(
**
config
)
def
call
(
self
,
inputs
):
batch_size
=
tf
.
shape
(
inputs
)[
0
]
mask
=
tf
.
expand_dims
(
tf
.
cast
(
tf
.
not_equal
(
tf
.
reduce_sum
(
inputs
,
axis
=-
1
),
0
),
inputs
.
dtype
),
axis
=-
1
)
features
=
self
.
_backbone
(
inputs
)[
"5"
]
shape
=
tf
.
shape
(
features
)
mask
=
tf
.
image
.
resize
(
mask
,
shape
[
1
:
3
],
method
=
tf
.
image
.
ResizeMethod
.
NEAREST_NEIGHBOR
)
pos_embed
=
position_embedding_sine
(
mask
[:,
:,
:,
0
],
num_pos_features
=
self
.
_hidden_size
)
pos_embed
=
tf
.
reshape
(
pos_embed
,
[
batch_size
,
-
1
,
self
.
_hidden_size
])
features
=
tf
.
reshape
(
self
.
_input_proj
(
features
),
[
batch_size
,
-
1
,
self
.
_hidden_size
])
mask
=
tf
.
reshape
(
mask
,
[
batch_size
,
-
1
])
decoded_list
=
self
.
_transformer
({
"inputs"
:
features
,
"targets"
:
tf
.
tile
(
tf
.
expand_dims
(
self
.
_query_embeddings
,
axis
=
0
),
(
batch_size
,
1
,
1
)),
"pos_embed"
:
pos_embed
,
"mask"
:
mask
,
})
out_list
=
[]
for
decoded
in
decoded_list
:
decoded
=
tf
.
stack
(
decoded
)
output_class
=
self
.
_class_embed
(
decoded
)
box_out
=
decoded
for
layer
in
self
.
_bbox_embed
:
box_out
=
layer
(
box_out
)
output_coord
=
self
.
_sigmoid
(
box_out
)
out
=
{
"cls_outputs"
:
output_class
,
"box_outputs"
:
output_coord
}
out_list
.
append
(
out
)
return
out_list
class
DETRTransformer
(
tf
.
keras
.
layers
.
Layer
):
"""Encoder and Decoder of DETR."""
def
__init__
(
self
,
num_encoder_layers
=
6
,
num_decoder_layers
=
6
,
dropout_rate
=
0.1
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
_dropout_rate
=
dropout_rate
self
.
_num_encoder_layers
=
num_encoder_layers
self
.
_num_decoder_layers
=
num_decoder_layers
def
build
(
self
,
input_shape
=
None
):
self
.
_encoder
=
transformer
.
TransformerEncoder
(
attention_dropout_rate
=
self
.
_dropout_rate
,
dropout_rate
=
self
.
_dropout_rate
,
intermediate_dropout
=
self
.
_dropout_rate
,
norm_first
=
False
,
num_layers
=
self
.
_num_encoder_layers
,
)
self
.
_decoder
=
transformer
.
TransformerDecoder
(
attention_dropout_rate
=
self
.
_dropout_rate
,
dropout_rate
=
self
.
_dropout_rate
,
intermediate_dropout
=
self
.
_dropout_rate
,
norm_first
=
False
,
num_layers
=
self
.
_num_decoder_layers
)
super
().
build
(
input_shape
)
def
get_config
(
self
):
return
{
"num_encoder_layers"
:
self
.
_num_encoder_layers
,
"num_decoder_layers"
:
self
.
_num_decoder_layers
,
"dropout_rate"
:
self
.
_dropout_rate
,
}
def
call
(
self
,
inputs
):
sources
=
inputs
[
"inputs"
]
targets
=
inputs
[
"targets"
]
pos_embed
=
inputs
[
"pos_embed"
]
mask
=
inputs
[
"mask"
]
input_shape
=
tf_utils
.
get_shape_list
(
sources
)
source_attention_mask
=
tf
.
tile
(
tf
.
expand_dims
(
mask
,
axis
=
1
),
[
1
,
input_shape
[
1
],
1
])
memory
=
self
.
_encoder
(
sources
,
attention_mask
=
source_attention_mask
,
pos_embed
=
pos_embed
)
target_shape
=
tf_utils
.
get_shape_list
(
targets
)
cross_attention_mask
=
tf
.
tile
(
tf
.
expand_dims
(
mask
,
axis
=
1
),
[
1
,
target_shape
[
1
],
1
])
target_shape
=
tf
.
shape
(
targets
)
decoded
=
self
.
_decoder
(
tf
.
zeros_like
(
targets
),
memory
,
# TODO(b/199545430): self_attention_mask could be set to None when this
# bug is resolved. Passing ones for now.
self_attention_mask
=
tf
.
ones
(
(
target_shape
[
0
],
target_shape
[
1
],
target_shape
[
1
])),
cross_attention_mask
=
cross_attention_mask
,
return_all_decoder_outputs
=
True
,
input_pos_embed
=
targets
,
memory_pos_embed
=
pos_embed
)
return
decoded
official/projects/detr/modeling/detr_test.py
0 → 100644
View file @
ca8c44d5
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tensorflow_models.official.projects.detr.detr."""
import
tensorflow
as
tf
from
official.projects.detr.modeling
import
detr
class
DetrTest
(
tf
.
test
.
TestCase
):
def
test_forward
(
self
):
num_queries
=
10
hidden_size
=
128
num_classes
=
10
image_size
=
640
batch_size
=
2
model
=
detr
.
DETR
(
num_queries
,
hidden_size
,
num_classes
)
outs
=
model
(
tf
.
ones
((
batch_size
,
image_size
,
image_size
,
3
)))
self
.
assertLen
(
outs
,
6
)
# intermediate decoded outputs.
for
out
in
outs
:
self
.
assertAllEqual
(
tf
.
shape
(
out
[
'cls_outputs'
]),
(
batch_size
,
num_queries
,
num_classes
))
self
.
assertAllEqual
(
tf
.
shape
(
out
[
'box_outputs'
]),
(
batch_size
,
num_queries
,
4
))
def
test_get_from_config_detr_transformer
(
self
):
config
=
{
'num_encoder_layers'
:
1
,
'num_decoder_layers'
:
2
,
'dropout_rate'
:
0.5
,
}
detr_model
=
detr
.
DETRTransformer
.
from_config
(
config
)
retrieved_config
=
detr_model
.
get_config
()
self
.
assertEqual
(
config
,
retrieved_config
)
def
test_get_from_config_detr
(
self
):
config
=
{
'num_queries'
:
2
,
'hidden_size'
:
4
,
'num_classes'
:
10
,
'num_encoder_layers'
:
4
,
'num_decoder_layers'
:
5
,
'dropout_rate'
:
0.5
,
}
detr_model
=
detr
.
DETR
.
from_config
(
config
)
retrieved_config
=
detr_model
.
get_config
()
self
.
assertEqual
(
config
,
retrieved_config
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/projects/detr/modeling/transformer.py
0 → 100644
View file @
ca8c44d5
This diff is collapsed.
Click to expand it.
official/projects/detr/modeling/transformer_test.py
0 → 100644
View file @
ca8c44d5
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for transformer."""
import
tensorflow
as
tf
from
official.projects.detr.modeling
import
transformer
class
TransformerTest
(
tf
.
test
.
TestCase
):
def
test_transformer_encoder_block
(
self
):
batch_size
=
2
sequence_length
=
100
feature_size
=
256
num_attention_heads
=
2
inner_dim
=
256
inner_activation
=
'relu'
model
=
transformer
.
TransformerEncoderBlock
(
num_attention_heads
,
inner_dim
,
inner_activation
)
input_tensor
=
tf
.
ones
((
batch_size
,
sequence_length
,
feature_size
))
attention_mask
=
tf
.
ones
((
batch_size
,
sequence_length
,
sequence_length
),
dtype
=
tf
.
int64
)
pos_embed
=
tf
.
ones
((
batch_size
,
sequence_length
,
feature_size
))
out
=
model
([
input_tensor
,
attention_mask
,
pos_embed
])
self
.
assertAllEqual
(
tf
.
shape
(
out
),
(
batch_size
,
sequence_length
,
feature_size
))
def
test_transformer_encoder_block_get_config
(
self
):
num_attention_heads
=
2
inner_dim
=
256
inner_activation
=
'relu'
model
=
transformer
.
TransformerEncoderBlock
(
num_attention_heads
,
inner_dim
,
inner_activation
)
config
=
model
.
get_config
()
expected_config
=
{
'name'
:
'transformer_encoder_block'
,
'trainable'
:
True
,
'dtype'
:
'float32'
,
'num_attention_heads'
:
2
,
'inner_dim'
:
256
,
'inner_activation'
:
'relu'
,
'output_dropout'
:
0.0
,
'attention_dropout'
:
0.0
,
'output_range'
:
None
,
'kernel_initializer'
:
{
'class_name'
:
'GlorotUniform'
,
'config'
:
{
'seed'
:
None
}
},
'bias_initializer'
:
{
'class_name'
:
'Zeros'
,
'config'
:
{}
},
'kernel_regularizer'
:
None
,
'bias_regularizer'
:
None
,
'activity_regularizer'
:
None
,
'kernel_constraint'
:
None
,
'bias_constraint'
:
None
,
'use_bias'
:
True
,
'norm_first'
:
False
,
'norm_epsilon'
:
1e-12
,
'inner_dropout'
:
0.0
,
'attention_initializer'
:
{
'class_name'
:
'GlorotUniform'
,
'config'
:
{
'seed'
:
None
}
},
'attention_axes'
:
None
}
self
.
assertAllEqual
(
expected_config
,
config
)
def
test_transformer_encoder
(
self
):
batch_size
=
2
sequence_length
=
100
feature_size
=
256
num_layers
=
2
num_attention_heads
=
2
intermediate_size
=
256
model
=
transformer
.
TransformerEncoder
(
num_layers
=
num_layers
,
num_attention_heads
=
num_attention_heads
,
intermediate_size
=
intermediate_size
)
input_tensor
=
tf
.
ones
((
batch_size
,
sequence_length
,
feature_size
))
attention_mask
=
tf
.
ones
((
batch_size
,
sequence_length
,
sequence_length
),
dtype
=
tf
.
int64
)
pos_embed
=
tf
.
ones
((
batch_size
,
sequence_length
,
feature_size
))
out
=
model
(
input_tensor
,
attention_mask
,
pos_embed
)
self
.
assertAllEqual
(
tf
.
shape
(
out
),
(
batch_size
,
sequence_length
,
feature_size
))
def
test_transformer_encoder_get_config
(
self
):
num_layers
=
2
num_attention_heads
=
2
intermediate_size
=
256
model
=
transformer
.
TransformerEncoder
(
num_layers
=
num_layers
,
num_attention_heads
=
num_attention_heads
,
intermediate_size
=
intermediate_size
)
config
=
model
.
get_config
()
expected_config
=
{
'name'
:
'transformer_encoder'
,
'trainable'
:
True
,
'dtype'
:
'float32'
,
'num_layers'
:
2
,
'num_attention_heads'
:
2
,
'intermediate_size'
:
256
,
'activation'
:
'relu'
,
'dropout_rate'
:
0.0
,
'attention_dropout_rate'
:
0.0
,
'use_bias'
:
False
,
'norm_first'
:
True
,
'norm_epsilon'
:
1e-06
,
'intermediate_dropout'
:
0.0
}
self
.
assertAllEqual
(
expected_config
,
config
)
def
test_transformer_decoder_block
(
self
):
batch_size
=
2
sequence_length
=
100
memory_length
=
200
feature_size
=
256
num_attention_heads
=
2
intermediate_size
=
256
intermediate_activation
=
'relu'
model
=
transformer
.
TransformerDecoderBlock
(
num_attention_heads
,
intermediate_size
,
intermediate_activation
)
input_tensor
=
tf
.
ones
((
batch_size
,
sequence_length
,
feature_size
))
memory
=
tf
.
ones
((
batch_size
,
memory_length
,
feature_size
))
attention_mask
=
tf
.
ones
((
batch_size
,
sequence_length
,
memory_length
),
dtype
=
tf
.
int64
)
self_attention_mask
=
tf
.
ones
(
(
batch_size
,
sequence_length
,
sequence_length
),
dtype
=
tf
.
int64
)
input_pos_embed
=
tf
.
ones
((
batch_size
,
sequence_length
,
feature_size
))
memory_pos_embed
=
tf
.
ones
((
batch_size
,
memory_length
,
feature_size
))
out
,
_
=
model
([
input_tensor
,
memory
,
attention_mask
,
self_attention_mask
,
input_pos_embed
,
memory_pos_embed
])
self
.
assertAllEqual
(
tf
.
shape
(
out
),
(
batch_size
,
sequence_length
,
feature_size
))
def
test_transformer_decoder_block_get_config
(
self
):
num_attention_heads
=
2
intermediate_size
=
256
intermediate_activation
=
'relu'
model
=
transformer
.
TransformerDecoderBlock
(
num_attention_heads
,
intermediate_size
,
intermediate_activation
)
config
=
model
.
get_config
()
expected_config
=
{
'name'
:
'transformer_decoder_block'
,
'trainable'
:
True
,
'dtype'
:
'float32'
,
'num_attention_heads'
:
2
,
'intermediate_size'
:
256
,
'intermediate_activation'
:
'relu'
,
'dropout_rate'
:
0.0
,
'attention_dropout_rate'
:
0.0
,
'kernel_initializer'
:
{
'class_name'
:
'GlorotUniform'
,
'config'
:
{
'seed'
:
None
}
},
'bias_initializer'
:
{
'class_name'
:
'Zeros'
,
'config'
:
{}
},
'kernel_regularizer'
:
None
,
'bias_regularizer'
:
None
,
'activity_regularizer'
:
None
,
'kernel_constraint'
:
None
,
'bias_constraint'
:
None
,
'use_bias'
:
True
,
'norm_first'
:
False
,
'norm_epsilon'
:
1e-12
,
'intermediate_dropout'
:
0.0
,
'attention_initializer'
:
{
'class_name'
:
'GlorotUniform'
,
'config'
:
{
'seed'
:
None
}
}
}
self
.
assertAllEqual
(
expected_config
,
config
)
def
test_transformer_decoder
(
self
):
batch_size
=
2
sequence_length
=
100
memory_length
=
200
feature_size
=
256
num_layers
=
2
num_attention_heads
=
2
intermediate_size
=
256
model
=
transformer
.
TransformerDecoder
(
num_layers
=
num_layers
,
num_attention_heads
=
num_attention_heads
,
intermediate_size
=
intermediate_size
)
input_tensor
=
tf
.
ones
((
batch_size
,
sequence_length
,
feature_size
))
memory
=
tf
.
ones
((
batch_size
,
memory_length
,
feature_size
))
attention_mask
=
tf
.
ones
((
batch_size
,
sequence_length
,
memory_length
),
dtype
=
tf
.
int64
)
self_attention_mask
=
tf
.
ones
(
(
batch_size
,
sequence_length
,
sequence_length
),
dtype
=
tf
.
int64
)
input_pos_embed
=
tf
.
ones
((
batch_size
,
sequence_length
,
feature_size
))
memory_pos_embed
=
tf
.
ones
((
batch_size
,
memory_length
,
feature_size
))
outs
=
model
(
input_tensor
,
memory
,
self_attention_mask
,
attention_mask
,
return_all_decoder_outputs
=
True
,
input_pos_embed
=
input_pos_embed
,
memory_pos_embed
=
memory_pos_embed
)
self
.
assertLen
(
outs
,
2
)
# intermeidate decoded outputs.
for
out
in
outs
:
self
.
assertAllEqual
(
tf
.
shape
(
out
),
(
batch_size
,
sequence_length
,
feature_size
))
def
test_transformer_decoder_get_config
(
self
):
num_layers
=
2
num_attention_heads
=
2
intermediate_size
=
256
model
=
transformer
.
TransformerDecoder
(
num_layers
=
num_layers
,
num_attention_heads
=
num_attention_heads
,
intermediate_size
=
intermediate_size
)
config
=
model
.
get_config
()
expected_config
=
{
'name'
:
'transformer_decoder'
,
'trainable'
:
True
,
'dtype'
:
'float32'
,
'num_layers'
:
2
,
'num_attention_heads'
:
2
,
'intermediate_size'
:
256
,
'activation'
:
'relu'
,
'dropout_rate'
:
0.0
,
'attention_dropout_rate'
:
0.0
,
'use_bias'
:
False
,
'norm_first'
:
True
,
'norm_epsilon'
:
1e-06
,
'intermediate_dropout'
:
0.0
}
self
.
assertAllEqual
(
expected_config
,
config
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/projects/detr/ops/matchers.py
0 → 100644
View file @
ca8c44d5
This diff is collapsed.
Click to expand it.
official/projects/detr/ops/matchers_test.py
0 → 100644
View file @
ca8c44d5
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tensorflow_models.official.projects.detr.ops.matchers."""
import
numpy
as
np
from
scipy
import
optimize
import
tensorflow
as
tf
from
official.projects.detr.ops
import
matchers
class
MatchersOpsTest
(
tf
.
test
.
TestCase
):
def
testLinearSumAssignment
(
self
):
"""Check a simple 2D test case of the Linear Sum Assignment problem.
Ensures that the implementation of the matching algorithm is correct
and functional on TPUs.
"""
cost_matrix
=
np
.
array
([[[
4
,
1
,
3
],
[
2
,
0
,
5
],
[
3
,
2
,
2
]]],
dtype
=
np
.
float32
)
_
,
adjacency_matrix
=
matchers
.
hungarian_matching
(
tf
.
constant
(
cost_matrix
))
adjacency_output
=
adjacency_matrix
.
numpy
()
correct_output
=
np
.
array
([
[
0
,
1
,
0
],
[
1
,
0
,
0
],
[
0
,
0
,
1
],
],
dtype
=
bool
)
self
.
assertAllEqual
(
adjacency_output
[
0
],
correct_output
)
def
testBatchedLinearSumAssignment
(
self
):
"""Check a batched case of the Linear Sum Assignment Problem.
Ensures that a correct solution is found for all inputted problems within
a batch.
"""
cost_matrix
=
np
.
array
([
[[
4
,
1
,
3
],
[
2
,
0
,
5
],
[
3
,
2
,
2
]],
[[
1
,
4
,
3
],
[
0
,
2
,
5
],
[
2
,
3
,
2
]],
[[
1
,
3
,
4
],
[
0
,
5
,
2
],
[
2
,
2
,
3
]],
],
dtype
=
np
.
float32
)
_
,
adjacency_matrix
=
matchers
.
hungarian_matching
(
tf
.
constant
(
cost_matrix
))
adjacency_output
=
adjacency_matrix
.
numpy
()
# Hand solved correct output for the linear sum assignment problem
correct_output
=
np
.
array
([
[[
0
,
1
,
0
],
[
1
,
0
,
0
],
[
0
,
0
,
1
]],
[[
1
,
0
,
0
],
[
0
,
1
,
0
],
[
0
,
0
,
1
]],
[[
1
,
0
,
0
],
[
0
,
0
,
1
],
[
0
,
1
,
0
]],
],
dtype
=
bool
)
self
.
assertAllClose
(
adjacency_output
,
correct_output
)
def
testMaximumBipartiteMatching
(
self
):
"""Check that the maximum bipartite match assigns the correct numbers."""
adj_matrix
=
tf
.
cast
([[
[
1
,
0
,
0
,
0
,
1
],
[
0
,
1
,
0
,
1
,
0
],
[
0
,
0
,
1
,
0
,
0
],
[
0
,
1
,
0
,
0
,
0
],
[
1
,
0
,
0
,
0
,
0
],
]],
tf
.
bool
)
_
,
assignment
=
matchers
.
_maximum_bipartite_matching
(
adj_matrix
)
self
.
assertEqual
(
np
.
sum
(
assignment
.
numpy
()),
5
)
def
testAssignmentMatchesScipy
(
self
):
"""Check that the Linear Sum Assignment matches the Scipy implementation."""
batch_size
,
num_elems
=
2
,
25
weights
=
tf
.
random
.
uniform
((
batch_size
,
num_elems
,
num_elems
),
minval
=
0.
,
maxval
=
1.
)
weights
,
assignment
=
matchers
.
hungarian_matching
(
weights
)
for
idx
in
range
(
batch_size
):
_
,
scipy_assignment
=
optimize
.
linear_sum_assignment
(
weights
.
numpy
()[
idx
])
hungarian_assignment
=
np
.
where
(
assignment
.
numpy
()[
idx
])[
1
]
self
.
assertAllEqual
(
hungarian_assignment
,
scipy_assignment
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/projects/detr/optimization.py
0 → 100644
View file @
ca8c44d5
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Customized optimizer to match paper results."""
import
dataclasses
import
tensorflow
as
tf
from
official.modeling
import
optimization
from
official.nlp
import
optimization
as
nlp_optimization
@
dataclasses
.
dataclass
class
DETRAdamWConfig
(
optimization
.
AdamWeightDecayConfig
):
pass
@
dataclasses
.
dataclass
class
OptimizerConfig
(
optimization
.
OptimizerConfig
):
detr_adamw
:
DETRAdamWConfig
=
DETRAdamWConfig
()
@
dataclasses
.
dataclass
class
OptimizationConfig
(
optimization
.
OptimizationConfig
):
"""Configuration for optimizer and learning rate schedule.
Attributes:
optimizer: optimizer oneof config.
ema: optional exponential moving average optimizer config, if specified, ema
optimizer will be used.
learning_rate: learning rate oneof config.
warmup: warmup oneof config.
"""
optimizer
:
OptimizerConfig
=
OptimizerConfig
()
# TODO(frederickliu): figure out how to make this configuable.
# TODO(frederickliu): Study if this is needed.
class
_DETRAdamW
(
nlp_optimization
.
AdamWeightDecay
):
"""Custom AdamW to support different lr scaling for backbone.
The code is copied from AdamWeightDecay and Adam with learning scaling.
"""
def
_resource_apply_dense
(
self
,
grad
,
var
,
apply_state
=
None
):
lr_t
,
kwargs
=
self
.
_get_lr
(
var
.
device
,
var
.
dtype
.
base_dtype
,
apply_state
)
apply_state
=
kwargs
[
'apply_state'
]
if
'detr'
not
in
var
.
name
:
lr_t
*=
0.1
decay
=
self
.
_decay_weights_op
(
var
,
lr_t
,
apply_state
)
with
tf
.
control_dependencies
([
decay
]):
var_device
,
var_dtype
=
var
.
device
,
var
.
dtype
.
base_dtype
coefficients
=
((
apply_state
or
{}).
get
((
var_device
,
var_dtype
))
or
self
.
_fallback_apply_state
(
var_device
,
var_dtype
))
m
=
self
.
get_slot
(
var
,
'm'
)
v
=
self
.
get_slot
(
var
,
'v'
)
lr
=
coefficients
[
'lr_t'
]
*
0.1
if
'detr'
not
in
var
.
name
else
coefficients
[
'lr_t'
]
if
not
self
.
amsgrad
:
return
tf
.
raw_ops
.
ResourceApplyAdam
(
var
=
var
.
handle
,
m
=
m
.
handle
,
v
=
v
.
handle
,
beta1_power
=
coefficients
[
'beta_1_power'
],
beta2_power
=
coefficients
[
'beta_2_power'
],
lr
=
lr
,
beta1
=
coefficients
[
'beta_1_t'
],
beta2
=
coefficients
[
'beta_2_t'
],
epsilon
=
coefficients
[
'epsilon'
],
grad
=
grad
,
use_locking
=
self
.
_use_locking
)
else
:
vhat
=
self
.
get_slot
(
var
,
'vhat'
)
return
tf
.
raw_ops
.
ResourceApplyAdamWithAmsgrad
(
var
=
var
.
handle
,
m
=
m
.
handle
,
v
=
v
.
handle
,
vhat
=
vhat
.
handle
,
beta1_power
=
coefficients
[
'beta_1_power'
],
beta2_power
=
coefficients
[
'beta_2_power'
],
lr
=
lr
,
beta1
=
coefficients
[
'beta_1_t'
],
beta2
=
coefficients
[
'beta_2_t'
],
epsilon
=
coefficients
[
'epsilon'
],
grad
=
grad
,
use_locking
=
self
.
_use_locking
)
def
_resource_apply_sparse
(
self
,
grad
,
var
,
indices
,
apply_state
=
None
):
lr_t
,
kwargs
=
self
.
_get_lr
(
var
.
device
,
var
.
dtype
.
base_dtype
,
apply_state
)
apply_state
=
kwargs
[
'apply_state'
]
if
'detr'
not
in
var
.
name
:
lr_t
*=
0.1
decay
=
self
.
_decay_weights_op
(
var
,
lr_t
,
apply_state
)
with
tf
.
control_dependencies
([
decay
]):
var_device
,
var_dtype
=
var
.
device
,
var
.
dtype
.
base_dtype
coefficients
=
((
apply_state
or
{}).
get
((
var_device
,
var_dtype
))
or
self
.
_fallback_apply_state
(
var_device
,
var_dtype
))
# m_t = beta1 * m + (1 - beta1) * g_t
m
=
self
.
get_slot
(
var
,
'm'
)
m_scaled_g_values
=
grad
*
coefficients
[
'one_minus_beta_1_t'
]
m_t
=
tf
.
compat
.
v1
.
assign
(
m
,
m
*
coefficients
[
'beta_1_t'
],
use_locking
=
self
.
_use_locking
)
with
tf
.
control_dependencies
([
m_t
]):
m_t
=
self
.
_resource_scatter_add
(
m
,
indices
,
m_scaled_g_values
)
# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
v
=
self
.
get_slot
(
var
,
'v'
)
v_scaled_g_values
=
(
grad
*
grad
)
*
coefficients
[
'one_minus_beta_2_t'
]
v_t
=
tf
.
compat
.
v1
.
assign
(
v
,
v
*
coefficients
[
'beta_2_t'
],
use_locking
=
self
.
_use_locking
)
with
tf
.
control_dependencies
([
v_t
]):
v_t
=
self
.
_resource_scatter_add
(
v
,
indices
,
v_scaled_g_values
)
lr
=
coefficients
[
'lr_t'
]
*
0.1
if
'detr'
not
in
var
.
name
else
coefficients
[
'lr_t'
]
if
not
self
.
amsgrad
:
v_sqrt
=
tf
.
sqrt
(
v_t
)
var_update
=
tf
.
compat
.
v1
.
assign_sub
(
var
,
lr
*
m_t
/
(
v_sqrt
+
coefficients
[
'epsilon'
]),
use_locking
=
self
.
_use_locking
)
return
tf
.
group
(
*
[
var_update
,
m_t
,
v_t
])
else
:
v_hat
=
self
.
get_slot
(
var
,
'vhat'
)
v_hat_t
=
tf
.
maximum
(
v_hat
,
v_t
)
with
tf
.
control_dependencies
([
v_hat_t
]):
v_hat_t
=
tf
.
compat
.
v1
.
assign
(
v_hat
,
v_hat_t
,
use_locking
=
self
.
_use_locking
)
v_hat_sqrt
=
tf
.
sqrt
(
v_hat_t
)
var_update
=
tf
.
compat
.
v1
.
assign_sub
(
var
,
lr
*
m_t
/
(
v_hat_sqrt
+
coefficients
[
'epsilon'
]),
use_locking
=
self
.
_use_locking
)
return
tf
.
group
(
*
[
var_update
,
m_t
,
v_t
,
v_hat_t
])
optimization
.
register_optimizer_cls
(
'detr_adamw'
,
_DETRAdamW
)
official/projects/detr/tasks/detection.py
0 → 100644
View file @
ca8c44d5
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DETR detection task definition."""
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
task_factory
from
official.projects.detr.configs
import
detr
as
detr_cfg
from
official.projects.detr.dataloaders
import
coco
from
official.projects.detr.modeling
import
detr
from
official.projects.detr.ops
import
matchers
from
official.vision.beta.evaluation
import
coco_evaluator
from
official.vision.beta.ops
import
box_ops
@
task_factory
.
register_task_cls
(
detr_cfg
.
DetectionConfig
)
class
DectectionTask
(
base_task
.
Task
):
"""A single-replica view of training procedure.
DETR task provides artifacts for training/evalution procedures, including
loading/iterating over Datasets, initializing the model, calculating the loss,
post-processing, and customized metrics with reduction.
"""
def
build_model
(
self
):
"""Build DETR model."""
model
=
detr
.
DETR
(
self
.
_task_config
.
num_queries
,
self
.
_task_config
.
num_hidden
,
self
.
_task_config
.
num_classes
,
self
.
_task_config
.
num_encoder_layers
,
self
.
_task_config
.
num_decoder_layers
)
return
model
def
initialize
(
self
,
model
:
tf
.
keras
.
Model
):
"""Loading pretrained checkpoint."""
ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
model
.
backbone
)
status
=
ckpt
.
read
(
self
.
_task_config
.
init_ckpt
)
status
.
expect_partial
().
assert_existing_objects_matched
()
def
build_inputs
(
self
,
params
,
input_context
=
None
):
"""Build input dataset."""
return
coco
.
COCODataLoader
(
params
).
load
(
input_context
)
def
_compute_cost
(
self
,
cls_outputs
,
box_outputs
,
cls_targets
,
box_targets
):
# Approximate classification cost with 1 - prob[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
# background: 0
cls_cost
=
self
.
_task_config
.
lambda_cls
*
tf
.
gather
(
-
tf
.
nn
.
softmax
(
cls_outputs
),
cls_targets
,
batch_dims
=
1
,
axis
=-
1
)
# Compute the L1 cost between boxes,
paired_differences
=
self
.
_task_config
.
lambda_box
*
tf
.
abs
(
tf
.
expand_dims
(
box_outputs
,
2
)
-
tf
.
expand_dims
(
box_targets
,
1
))
box_cost
=
tf
.
reduce_sum
(
paired_differences
,
axis
=-
1
)
# Compute the giou cost betwen boxes
giou_cost
=
self
.
_task_config
.
lambda_giou
*
-
box_ops
.
bbox_generalized_overlap
(
box_ops
.
cycxhw_to_yxyx
(
box_outputs
),
box_ops
.
cycxhw_to_yxyx
(
box_targets
))
total_cost
=
cls_cost
+
box_cost
+
giou_cost
max_cost
=
(
self
.
_task_config
.
lambda_cls
*
0.0
+
self
.
_task_config
.
lambda_box
*
4.
+
self
.
_task_config
.
lambda_giou
*
0.0
)
# Set pads to large constant
valid
=
tf
.
expand_dims
(
tf
.
cast
(
tf
.
not_equal
(
cls_targets
,
0
),
dtype
=
total_cost
.
dtype
),
axis
=
1
)
total_cost
=
(
1
-
valid
)
*
max_cost
+
valid
*
total_cost
# Set inf of nan to large constant
total_cost
=
tf
.
where
(
tf
.
logical_or
(
tf
.
math
.
is_nan
(
total_cost
),
tf
.
math
.
is_inf
(
total_cost
)),
max_cost
*
tf
.
ones_like
(
total_cost
,
dtype
=
total_cost
.
dtype
),
total_cost
)
return
total_cost
def
build_losses
(
self
,
outputs
,
labels
,
aux_losses
=
None
):
"""Build DETR losses."""
cls_outputs
=
outputs
[
'cls_outputs'
]
box_outputs
=
outputs
[
'box_outputs'
]
cls_targets
=
labels
[
'classes'
]
box_targets
=
labels
[
'boxes'
]
cost
=
self
.
_compute_cost
(
cls_outputs
,
box_outputs
,
cls_targets
,
box_targets
)
_
,
indices
=
matchers
.
hungarian_matching
(
cost
)
indices
=
tf
.
stop_gradient
(
indices
)
target_index
=
tf
.
math
.
argmax
(
indices
,
axis
=
1
)
cls_assigned
=
tf
.
gather
(
cls_outputs
,
target_index
,
batch_dims
=
1
,
axis
=
1
)
box_assigned
=
tf
.
gather
(
box_outputs
,
target_index
,
batch_dims
=
1
,
axis
=
1
)
background
=
tf
.
equal
(
cls_targets
,
0
)
num_boxes
=
tf
.
reduce_sum
(
tf
.
cast
(
tf
.
logical_not
(
background
),
tf
.
float32
),
axis
=-
1
)
# Down-weight background to account for class imbalance.
xentropy
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
cls_targets
,
logits
=
cls_assigned
)
cls_loss
=
self
.
_task_config
.
lambda_cls
*
tf
.
where
(
background
,
self
.
_task_config
.
background_cls_weight
*
xentropy
,
xentropy
)
cls_weights
=
tf
.
where
(
background
,
self
.
_task_config
.
background_cls_weight
*
tf
.
ones_like
(
cls_loss
),
tf
.
ones_like
(
cls_loss
)
)
# Box loss is only calculated on non-background class.
l_1
=
tf
.
reduce_sum
(
tf
.
abs
(
box_assigned
-
box_targets
),
axis
=-
1
)
box_loss
=
self
.
_task_config
.
lambda_box
*
tf
.
where
(
background
,
tf
.
zeros_like
(
l_1
),
l_1
)
# Giou loss is only calculated on non-background class.
giou
=
tf
.
linalg
.
diag_part
(
1.0
-
box_ops
.
bbox_generalized_overlap
(
box_ops
.
cycxhw_to_yxyx
(
box_assigned
),
box_ops
.
cycxhw_to_yxyx
(
box_targets
)
))
giou_loss
=
self
.
_task_config
.
lambda_giou
*
tf
.
where
(
background
,
tf
.
zeros_like
(
giou
),
giou
)
# Consider doing all reduce once in train_step to speed up.
num_boxes_per_replica
=
tf
.
reduce_sum
(
num_boxes
)
cls_weights_per_replica
=
tf
.
reduce_sum
(
cls_weights
)
replica_context
=
tf
.
distribute
.
get_replica_context
()
num_boxes_sum
,
cls_weights_sum
=
replica_context
.
all_reduce
(
tf
.
distribute
.
ReduceOp
.
SUM
,
[
num_boxes_per_replica
,
cls_weights_per_replica
])
cls_loss
=
tf
.
math
.
divide_no_nan
(
tf
.
reduce_sum
(
cls_loss
),
cls_weights_sum
)
box_loss
=
tf
.
math
.
divide_no_nan
(
tf
.
reduce_sum
(
box_loss
),
num_boxes_sum
)
giou_loss
=
tf
.
math
.
divide_no_nan
(
tf
.
reduce_sum
(
giou_loss
),
num_boxes_sum
)
aux_losses
=
tf
.
add_n
(
aux_losses
)
if
aux_losses
else
0.0
total_loss
=
cls_loss
+
box_loss
+
giou_loss
+
aux_losses
return
total_loss
,
cls_loss
,
box_loss
,
giou_loss
def
build_metrics
(
self
,
training
=
True
):
"""Build detection metrics."""
metrics
=
[]
metric_names
=
[
'cls_loss'
,
'box_loss'
,
'giou_loss'
]
for
name
in
metric_names
:
metrics
.
append
(
tf
.
keras
.
metrics
.
Mean
(
name
,
dtype
=
tf
.
float32
))
if
not
training
:
self
.
coco_metric
=
coco_evaluator
.
COCOEvaluator
(
annotation_file
=
''
,
include_mask
=
False
,
need_rescale_bboxes
=
True
,
per_category_metrics
=
self
.
_task_config
.
per_category_metrics
)
return
metrics
def
train_step
(
self
,
inputs
,
model
,
optimizer
,
metrics
=
None
):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features
,
labels
=
inputs
with
tf
.
GradientTape
()
as
tape
:
outputs
=
model
(
features
,
training
=
True
)
loss
=
0.0
cls_loss
=
0.0
box_loss
=
0.0
giou_loss
=
0.0
for
output
in
outputs
:
# Computes per-replica loss.
layer_loss
,
layer_cls_loss
,
layer_box_loss
,
layer_giou_loss
=
self
.
build_losses
(
outputs
=
output
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
loss
+=
layer_loss
cls_loss
+=
layer_cls_loss
box_loss
+=
layer_box_loss
giou_loss
+=
layer_giou_loss
# Consider moving scaling logic from build_losses to here.
scaled_loss
=
loss
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
scaled_loss
=
optimizer
.
get_scaled_loss
(
scaled_loss
)
tvars
=
model
.
trainable_variables
grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
# Scales back gradient when LossScaleOptimizer is used.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
grads
=
optimizer
.
get_unscaled_gradients
(
grads
)
optimizer
.
apply_gradients
(
list
(
zip
(
grads
,
tvars
)))
# Multiply for logging.
# Since we expect the gradient replica sum to happen in the optimizer,
# the loss is scaled with global num_boxes and weights.
# To have it more interpretable/comparable we scale it back when logging.
num_replicas_in_sync
=
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
loss
*=
num_replicas_in_sync
cls_loss
*=
num_replicas_in_sync
box_loss
*=
num_replicas_in_sync
giou_loss
*=
num_replicas_in_sync
# Trainer class handles loss metric for you.
logs
=
{
self
.
loss
:
loss
}
all_losses
=
{
'cls_loss'
:
cls_loss
,
'box_loss'
:
box_loss
,
'giou_loss'
:
giou_loss
,
}
# Metric results will be added to logs for you.
if
metrics
:
for
m
in
metrics
:
m
.
update_state
(
all_losses
[
m
.
name
])
return
logs
def
validation_step
(
self
,
inputs
,
model
,
metrics
=
None
):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features
,
labels
=
inputs
outputs
=
model
(
features
,
training
=
False
)[
-
1
]
loss
,
cls_loss
,
box_loss
,
giou_loss
=
self
.
build_losses
(
outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
# Multiply for logging.
# Since we expect the gradient replica sum to happen in the optimizer,
# the loss is scaled with global num_boxes and weights.
# To have it more interpretable/comparable we scale it back when logging.
num_replicas_in_sync
=
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
loss
*=
num_replicas_in_sync
cls_loss
*=
num_replicas_in_sync
box_loss
*=
num_replicas_in_sync
giou_loss
*=
num_replicas_in_sync
# Evaluator class handles loss metric for you.
logs
=
{
self
.
loss
:
loss
}
predictions
=
{
'detection_boxes'
:
box_ops
.
cycxhw_to_yxyx
(
outputs
[
'box_outputs'
])
*
tf
.
expand_dims
(
tf
.
concat
([
labels
[
'image_info'
][:,
1
:
2
,
0
],
labels
[
'image_info'
][:,
1
:
2
,
1
],
labels
[
'image_info'
][:,
1
:
2
,
0
],
labels
[
'image_info'
][:,
1
:
2
,
1
]
],
axis
=
1
),
axis
=
1
),
'detection_scores'
:
tf
.
math
.
reduce_max
(
tf
.
nn
.
softmax
(
outputs
[
'cls_outputs'
])[:,
:,
1
:],
axis
=-
1
),
'detection_classes'
:
tf
.
math
.
argmax
(
outputs
[
'cls_outputs'
][:,
:,
1
:],
axis
=-
1
)
+
1
,
# Fix this. It's not being used at the moment.
'num_detections'
:
tf
.
reduce_sum
(
tf
.
cast
(
tf
.
math
.
greater
(
tf
.
math
.
reduce_max
(
outputs
[
'cls_outputs'
],
axis
=-
1
),
0
),
tf
.
int32
),
axis
=-
1
),
'source_id'
:
labels
[
'id'
],
'image_info'
:
labels
[
'image_info'
]
}
ground_truths
=
{
'source_id'
:
labels
[
'id'
],
'height'
:
labels
[
'image_info'
][:,
0
:
1
,
0
],
'width'
:
labels
[
'image_info'
][:,
0
:
1
,
1
],
'num_detections'
:
tf
.
reduce_sum
(
tf
.
cast
(
tf
.
math
.
greater
(
labels
[
'classes'
],
0
),
tf
.
int32
),
axis
=-
1
),
'boxes'
:
labels
[
'gt_boxes'
],
'classes'
:
labels
[
'classes'
],
'is_crowds'
:
labels
[
'is_crowd'
]
}
logs
.
update
({
'predictions'
:
predictions
,
'ground_truths'
:
ground_truths
})
all_losses
=
{
'cls_loss'
:
cls_loss
,
'box_loss'
:
box_loss
,
'giou_loss'
:
giou_loss
,
}
# Metric results will be added to logs for you.
if
metrics
:
for
m
in
metrics
:
m
.
update_state
(
all_losses
[
m
.
name
])
return
logs
def
aggregate_logs
(
self
,
state
=
None
,
step_outputs
=
None
):
if
state
is
None
:
self
.
coco_metric
.
reset_states
()
state
=
self
.
coco_metric
state
.
update_state
(
step_outputs
[
'ground_truths'
],
step_outputs
[
'predictions'
])
return
state
def
reduce_aggregated_logs
(
self
,
aggregated_logs
,
global_step
=
None
):
return
aggregated_logs
.
result
()
official/projects/detr/tasks/detection_test.py
0 → 100644
View file @
ca8c44d5
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for detection."""
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow_datasets
as
tfds
from
official.projects.detr
import
optimization
from
official.projects.detr.configs
import
detr
as
detr_cfg
from
official.projects.detr.dataloaders
import
coco
from
official.projects.detr.tasks
import
detection
_NUM_EXAMPLES
=
10
def
_gen_fn
():
h
=
np
.
random
.
randint
(
0
,
300
)
w
=
np
.
random
.
randint
(
0
,
300
)
num_boxes
=
np
.
random
.
randint
(
0
,
50
)
return
{
'image'
:
np
.
ones
(
shape
=
(
h
,
w
,
3
),
dtype
=
np
.
uint8
),
'image/id'
:
np
.
random
.
randint
(
0
,
100
),
'image/filename'
:
'test'
,
'objects'
:
{
'is_crowd'
:
np
.
ones
(
shape
=
(
num_boxes
),
dtype
=
np
.
bool
),
'bbox'
:
np
.
ones
(
shape
=
(
num_boxes
,
4
),
dtype
=
np
.
float32
),
'label'
:
np
.
ones
(
shape
=
(
num_boxes
),
dtype
=
np
.
int64
),
'id'
:
np
.
ones
(
shape
=
(
num_boxes
),
dtype
=
np
.
int64
),
'area'
:
np
.
ones
(
shape
=
(
num_boxes
),
dtype
=
np
.
int64
),
}
}
def
_as_dataset
(
self
,
*
args
,
**
kwargs
):
del
args
del
kwargs
return
tf
.
data
.
Dataset
.
from_generator
(
lambda
:
(
_gen_fn
()
for
i
in
range
(
_NUM_EXAMPLES
)),
output_types
=
self
.
info
.
features
.
dtype
,
output_shapes
=
self
.
info
.
features
.
shape
,
)
class
DetectionTest
(
tf
.
test
.
TestCase
):
def
test_train_step
(
self
):
config
=
detr_cfg
.
DetectionConfig
(
num_encoder_layers
=
1
,
num_decoder_layers
=
1
,
train_data
=
coco
.
COCODataConfig
(
tfds_name
=
'coco/2017'
,
tfds_split
=
'validation'
,
is_training
=
True
,
global_batch_size
=
2
,
))
with
tfds
.
testing
.
mock_data
(
as_dataset_fn
=
_as_dataset
):
task
=
detection
.
DectectionTask
(
config
)
model
=
task
.
build_model
()
dataset
=
task
.
build_inputs
(
config
.
train_data
)
iterator
=
iter
(
dataset
)
opt_cfg
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'detr_adamw'
,
'detr_adamw'
:
{
'weight_decay_rate'
:
1e-4
,
'global_clipnorm'
:
0.1
,
}
},
'learning_rate'
:
{
'type'
:
'stepwise'
,
'stepwise'
:
{
'boundaries'
:
[
120000
],
'values'
:
[
0.0001
,
1.0e-05
]
}
},
})
optimizer
=
detection
.
DectectionTask
.
create_optimizer
(
opt_cfg
)
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
)
def
test_validation_step
(
self
):
config
=
detr_cfg
.
DetectionConfig
(
num_encoder_layers
=
1
,
num_decoder_layers
=
1
,
validation_data
=
coco
.
COCODataConfig
(
tfds_name
=
'coco/2017'
,
tfds_split
=
'validation'
,
is_training
=
False
,
global_batch_size
=
2
,
))
with
tfds
.
testing
.
mock_data
(
as_dataset_fn
=
_as_dataset
):
task
=
detection
.
DectectionTask
(
config
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
(
training
=
False
)
dataset
=
task
.
build_inputs
(
config
.
validation_data
)
iterator
=
iter
(
dataset
)
logs
=
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
)
state
=
task
.
aggregate_logs
(
step_outputs
=
logs
)
task
.
reduce_aggregated_logs
(
state
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/projects/detr/train.py
0 → 100644
View file @
ca8c44d5
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorFlow Model Garden Vision training driver."""
from
absl
import
app
from
absl
import
flags
import
gin
from
official.common
import
distribute_utils
from
official.common
import
flags
as
tfm_flags
from
official.core
import
task_factory
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.modeling
import
performance
# pylint: disable=unused-import
from
official.projects.detr.configs
import
detr
from
official.projects.detr.tasks
import
detection
# pylint: enable=unused-import
FLAGS
=
flags
.
FLAGS
def
main
(
_
):
gin
.
parse_config_files_and_bindings
(
FLAGS
.
gin_file
,
FLAGS
.
gin_params
)
params
=
train_utils
.
parse_configuration
(
FLAGS
)
model_dir
=
FLAGS
.
model_dir
if
'train'
in
FLAGS
.
mode
:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils
.
serialize_config
(
params
,
model_dir
)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if
params
.
runtime
.
mixed_precision_dtype
:
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
)
distribution_strategy
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
num_gpus
=
params
.
runtime
.
num_gpus
,
tpu_address
=
params
.
runtime
.
tpu
)
with
distribution_strategy
.
scope
():
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
FLAGS
.
mode
,
params
=
params
,
model_dir
=
model_dir
)
train_utils
.
save_gin_config
(
FLAGS
.
mode
,
model_dir
)
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
flags
.
mark_flags_as_required
([
'experiment'
,
'mode'
,
'model_dir'
])
app
.
run
(
main
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment