Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
73a2818d
"git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "11fb22aae80a60fd867ce87ff8795256b0f733cd"
Commit
73a2818d
authored
Oct 21, 2020
by
vishnubanna
Browse files
PR1 darknet
parent
7beddae1
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1432 additions
and
0 deletions
+1432
-0
official/vision/beta/projects/yolo/common/__init__.py
official/vision/beta/projects/yolo/common/__init__.py
+1
-0
official/vision/beta/projects/yolo/common/registry_imports.py
...cial/vision/beta/projects/yolo/common/registry_imports.py
+22
-0
official/vision/beta/projects/yolo/configs/backbones.py
official/vision/beta/projects/yolo/configs/backbones.py
+16
-0
official/vision/beta/projects/yolo/configs/experiments/darknet53.yaml
...ion/beta/projects/yolo/configs/experiments/darknet53.yaml
+52
-0
official/vision/beta/projects/yolo/modeling/__init__.py
official/vision/beta/projects/yolo/modeling/__init__.py
+1
-0
official/vision/beta/projects/yolo/modeling/backbones/Darknet.py
...l/vision/beta/projects/yolo/modeling/backbones/Darknet.py
+346
-0
official/vision/beta/projects/yolo/modeling/backbones/__init__.py
.../vision/beta/projects/yolo/modeling/backbones/__init__.py
+0
-0
official/vision/beta/projects/yolo/modeling/building_blocks/_CSPConnect.py
...eta/projects/yolo/modeling/building_blocks/_CSPConnect.py
+73
-0
official/vision/beta/projects/yolo/modeling/building_blocks/_CSPDownSample.py
.../projects/yolo/modeling/building_blocks/_CSPDownSample.py
+84
-0
official/vision/beta/projects/yolo/modeling/building_blocks/_CSPTiny.py
...n/beta/projects/yolo/modeling/building_blocks/_CSPTiny.py
+154
-0
official/vision/beta/projects/yolo/modeling/building_blocks/_DarkConv.py
.../beta/projects/yolo/modeling/building_blocks/_DarkConv.py
+175
-0
official/vision/beta/projects/yolo/modeling/building_blocks/_DarkResidual.py
...a/projects/yolo/modeling/building_blocks/_DarkResidual.py
+158
-0
official/vision/beta/projects/yolo/modeling/building_blocks/_DarkTiny.py
.../beta/projects/yolo/modeling/building_blocks/_DarkTiny.py
+103
-0
official/vision/beta/projects/yolo/modeling/building_blocks/__init__.py
...n/beta/projects/yolo/modeling/building_blocks/__init__.py
+7
-0
official/vision/beta/projects/yolo/modeling/tests/README.md
official/vision/beta/projects/yolo/modeling/tests/README.md
+1
-0
official/vision/beta/projects/yolo/modeling/tests/__init__.py
...cial/vision/beta/projects/yolo/modeling/tests/__init__.py
+0
-0
official/vision/beta/projects/yolo/modeling/tests/test_CSPConnect.py
...sion/beta/projects/yolo/modeling/tests/test_CSPConnect.py
+55
-0
official/vision/beta/projects/yolo/modeling/tests/test_CSPDownSample.py
...n/beta/projects/yolo/modeling/tests/test_CSPDownSample.py
+53
-0
official/vision/beta/projects/yolo/modeling/tests/test_DarkConv.py
...vision/beta/projects/yolo/modeling/tests/test_DarkConv.py
+71
-0
official/vision/beta/projects/yolo/modeling/tests/test_DarkResidual.py
...on/beta/projects/yolo/modeling/tests/test_DarkResidual.py
+60
-0
No files found.
official/vision/beta/projects/yolo/common/__init__.py
0 → 100644
View file @
73a2818d
official/vision/beta/projects/yolo/common/registry_imports.py
0 → 100644
View file @
73a2818d
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""All necessary imports for registration."""
# pylint: disable=unused-import
from
official.nlp
import
tasks
as
nlp_task
from
official.utils.testing
import
mock_task
from
official.vision
import
beta
from
official.vision.beta.projects
import
yolo
official/vision/beta/projects/yolo/configs/backbones.py
0 → 100644
View file @
73a2818d
"""Backbones configurations."""
# Import libraries
import
dataclasses
from
typing
import
Optional
from
official.modeling
import
hyperparams
from
official.vision.beta.configs
import
backbones
@
dataclasses
.
dataclass
class
DarkNet
(
hyperparams
.
Config
):
"""DarkNet config."""
model_id
:
str
=
"darknet53"
@
dataclasses
.
dataclass
class
Backbone
(
backbones
.
Backbone
):
darknet
:
DarkNet
=
DarkNet
()
official/vision/beta/projects/yolo/configs/experiments/darknet53.yaml
0 → 100644
View file @
73a2818d
runtime
:
distribution_strategy
:
'
mirrored'
mixed_precision_dtype
:
'
float32'
loss_scale
:
'
dynamic'
task
:
model
:
num_classes
:
1001
input_size
:
[
224
,
224
,
3
]
backbone
:
type
:
'
darknet'
losses
:
l2_weight_decay
:
0.0005
one_hot
:
True
label_smoothing
:
0.1
train_data
:
tfds_name
:
'
imagenet_a'
tfds_split
:
'
test'
tfds_download
:
True
is_training
:
True
global_batch_size
:
128
dtype
:
'
float16'
validation_data
:
tfds_name
:
'
imagenet_a'
tfds_split
:
'
test'
tfds_download
:
True
is_training
:
False
global_batch_size
:
128
dtype
:
'
float16'
drop_remainder
:
False
trainer
:
train_steps
:
800000
# in the paper
validation_steps
:
400
# size of validation data
validation_interval
:
10000
steps_per_loop
:
10000
summary_interval
:
10000
checkpoint_interval
:
10000
optimizer_config
:
optimizer
:
type
:
'
sgd'
sgd
:
momentum
:
0.9
learning_rate
:
type
:
'
polynomial'
polynomial
:
initial_learning_rate
:
0.1
end_learning_rate
:
0.0001
power
:
4.0
decay_steps
:
799000
warmup
:
type
:
'
linear'
linear
:
warmup_steps
:
1000
#lr rise from 0 to 0.1 over 1000 steps
official/vision/beta/projects/yolo/modeling/__init__.py
0 → 100644
View file @
73a2818d
#from .yolo_v3 import Yolov3
official/vision/beta/projects/yolo/modeling/backbones/Darknet.py
0 → 100644
View file @
73a2818d
import
importlib
import
collections
from
typing
import
*
import
tensorflow
as
tf
import
tensorflow.keras
as
ks
import
official.vision.beta.projects.yolo.modeling.building_blocks
as
nn_blocks
from
official.vision.beta.modeling.backbones
import
factory
# builder required classes
class
CSPBlockConfig
(
object
):
def
__init__
(
self
,
layer
,
stack
,
reps
,
bottleneck
,
filters
,
kernel_size
,
strides
,
padding
,
activation
,
route
,
output_name
,
is_output
):
'''
get layer config to make code more readable
Args:
layer: string layer name
reps: integer for the number of times to repeat block
filters: integer for the filter for this layer, or the output depth
kernel_size: integer or none, if none, it implies that the the building block handles this automatically. not a layer input
downsample: boolean, to down sample the input width and height
output: boolean, true if the layer is required as an output
'''
self
.
layer
=
layer
self
.
stack
=
stack
self
.
repetitions
=
reps
self
.
bottleneck
=
bottleneck
self
.
filters
=
filters
self
.
kernel_size
=
kernel_size
self
.
strides
=
strides
self
.
padding
=
padding
self
.
activation
=
activation
self
.
route
=
route
self
.
output_name
=
output_name
self
.
is_output
=
is_output
return
def
csp_build_block_specs
(
config
):
specs
=
[]
for
layer
in
config
:
specs
.
append
(
CSPBlockConfig
(
*
layer
))
return
specs
class
layer_registry
(
object
):
def
__init__
(
self
):
self
.
_layer_dict
=
{
"DarkTiny"
:
(
nn_blocks
.
DarkTiny
,
darktiny_config_todict
),
"DarkConv"
:
(
nn_blocks
.
DarkConv
,
darkconv_config_todict
),
"MaxPool"
:
(
tf
.
keras
.
layers
.
MaxPool2D
,
maxpool_config_todict
)}
return
def
_get_layer
(
self
,
key
):
return
self
.
_layer_dict
[
key
]
def
__call__
(
self
,
config
,
kwargs
):
layer
,
get_param_dict
=
self
.
_get_layer
(
config
.
layer
)
param_dict
=
get_param_dict
(
config
,
kwargs
)
return
layer
(
**
param_dict
)
def
darkconv_config_todict
(
config
,
kwargs
):
dictvals
=
{
"filters"
:
config
.
filters
,
"kernel_size"
:
config
.
kernel_size
,
"strides"
:
config
.
strides
,
"padding"
:
config
.
padding
}
dictvals
.
update
(
kwargs
)
return
dictvals
def
darktiny_config_todict
(
config
,
kwargs
):
dictvals
=
{
"filters"
:
config
.
filters
,
"strides"
:
config
.
strides
}
dictvals
.
update
(
kwargs
)
return
dictvals
def
maxpool_config_todict
(
config
,
kwargs
):
return
{
"pool_size"
:
config
.
kernel_size
,
"strides"
:
config
.
strides
,
"padding"
:
config
.
padding
,
"name"
:
kwargs
[
"name"
]}
# model configs
LISTNAMES
=
[
"default_layer_name"
,
"level_type"
,
"number_of_layers_in_level"
,
"bottleneck"
,
"filters"
,
"kernal_size"
,
"strides"
,
"padding"
,
"default_activation"
,
"route"
,
"level/name"
,
"is_output"
]
CSPDARKNET53
=
{
"list_names"
:
LISTNAMES
,
"splits"
:
{
"backbone_split"
:
106
,
"neck_split"
:
138
},
"backbone"
:
[
[
"DarkConv"
,
None
,
1
,
False
,
32
,
3
,
1
,
"same"
,
"mish"
,
-
1
,
0
,
False
],
# 1
[
"DarkRes"
,
"csp"
,
1
,
True
,
64
,
None
,
None
,
None
,
"mish"
,
-
1
,
1
,
False
],
# 3
[
"DarkRes"
,
"csp"
,
2
,
False
,
128
,
None
,
None
,
None
,
"mish"
,
-
1
,
2
,
False
],
# 2
[
"DarkRes"
,
"csp"
,
8
,
False
,
256
,
None
,
None
,
None
,
"mish"
,
-
1
,
3
,
True
],
[
"DarkRes"
,
"csp"
,
8
,
False
,
512
,
None
,
None
,
None
,
"mish"
,
-
1
,
4
,
True
],
# 3
[
"DarkRes"
,
"csp"
,
4
,
False
,
1024
,
None
,
None
,
None
,
"mish"
,
-
1
,
5
,
True
],
# 6 #route
]
}
DARKNET53
=
{
"list_names"
:
LISTNAMES
,
"splits"
:
{
"backbone_split"
:
76
},
"backbone"
:
[
[
"DarkConv"
,
None
,
1
,
False
,
32
,
3
,
1
,
"same"
,
"leaky"
,
-
1
,
0
,
False
],
# 1
[
"DarkRes"
,
"residual"
,
1
,
True
,
64
,
None
,
None
,
None
,
"leaky"
,
-
1
,
1
,
False
],
# 3
[
"DarkRes"
,
"residual"
,
2
,
False
,
128
,
None
,
None
,
None
,
"leaky"
,
-
1
,
2
,
False
],
# 2
[
"DarkRes"
,
"residual"
,
8
,
False
,
256
,
None
,
None
,
None
,
"leaky"
,
-
1
,
3
,
True
],
[
"DarkRes"
,
"residual"
,
8
,
False
,
512
,
None
,
None
,
None
,
"leaky"
,
-
1
,
4
,
True
],
# 3
[
"DarkRes"
,
"residual"
,
4
,
False
,
1024
,
None
,
None
,
None
,
"leaky"
,
-
1
,
5
,
True
],
# 6
]
}
CSPDARKNETTINY
=
{
"list_names"
:
LISTNAMES
,
"splits"
:
{
"backbone_split"
:
28
},
"backbone"
:
[
[
"DarkConv"
,
None
,
1
,
False
,
32
,
3
,
2
,
"same"
,
"leaky"
,
-
1
,
0
,
False
],
# 1
[
"DarkConv"
,
None
,
1
,
False
,
64
,
3
,
2
,
"same"
,
"leaky"
,
-
1
,
1
,
False
],
# 1
[
"CSPTiny"
,
"csp_tiny"
,
1
,
False
,
64
,
3
,
2
,
"same"
,
"leaky"
,
-
1
,
2
,
False
],
# 3
[
"CSPTiny"
,
"csp_tiny"
,
1
,
False
,
128
,
3
,
2
,
"same"
,
"leaky"
,
-
1
,
3
,
False
],
# 3
[
"CSPTiny"
,
"csp_tiny"
,
1
,
False
,
256
,
3
,
2
,
"same"
,
"leaky"
,
-
1
,
4
,
True
],
# 3
[
"DarkConv"
,
None
,
1
,
False
,
512
,
3
,
1
,
"same"
,
"leaky"
,
-
1
,
5
,
True
],
# 1
]
}
DARKNETTINY
=
{
"list_names"
:
LISTNAMES
,
"splits"
:
{
"backbone_split"
:
14
},
"backbone"
:
[
[
"DarkConv"
,
None
,
1
,
False
,
16
,
3
,
1
,
"same"
,
"leaky"
,
-
1
,
0
,
False
],
# 1
[
"DarkTiny"
,
None
,
1
,
True
,
32
,
3
,
2
,
"same"
,
"leaky"
,
-
1
,
1
,
False
],
# 3
[
"DarkTiny"
,
None
,
1
,
True
,
64
,
3
,
2
,
"same"
,
"leaky"
,
-
1
,
2
,
False
],
# 3
[
"DarkTiny"
,
None
,
1
,
False
,
128
,
3
,
2
,
"same"
,
"leaky"
,
-
1
,
3
,
False
],
# 2
[
"DarkTiny"
,
None
,
1
,
False
,
256
,
3
,
2
,
"same"
,
"leaky"
,
-
1
,
4
,
True
],
[
"DarkTiny"
,
None
,
1
,
False
,
512
,
3
,
2
,
"same"
,
"leaky"
,
-
1
,
5
,
False
],
# 3
[
"DarkTiny"
,
None
,
1
,
False
,
1024
,
3
,
1
,
"same"
,
"leaky"
,
-
1
,
5
,
True
],
# 6 #route
]
}
BACKBONES
=
{
"darknettiny"
:
DARKNETTINY
,
"darknet53"
:
DARKNET53
,
"cspdarknet53"
:
CSPDARKNET53
,
"cspdarknettiny"
:
CSPDARKNETTINY
}
@
ks
.
utils
.
register_keras_serializable
(
package
=
'yolo'
)
class
Darknet
(
ks
.
Model
):
def
__init__
(
self
,
model_id
=
"darknet53"
,
input_shape
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
]),
min_size
=
None
,
max_size
=
5
,
activation
=
None
,
use_sync_bn
=
False
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
kernel_initializer
=
'glorot_uniform'
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
config
=
None
,
**
kwargs
):
layer_specs
,
splits
=
Darknet
.
get_config
(
model_id
)
self
.
_model_name
=
model_id
self
.
_splits
=
splits
self
.
_input_shape
=
input_shape
self
.
_registry
=
layer_registry
()
# default layer look up
self
.
_min_size
=
min_size
self
.
_max_size
=
max_size
self
.
_output_specs
=
None
self
.
_kernel_initializer
=
kernel_initializer
self
.
_bias_regularizer
=
bias_regularizer
self
.
_norm_momentum
=
norm_momentum
self
.
_norm_epislon
=
norm_epsilon
self
.
_use_sync_bn
=
use_sync_bn
self
.
_activation
=
activation
self
.
_weight_decay
=
kernel_regularizer
self
.
_default_dict
=
{
"kernel_initializer"
:
self
.
_kernel_initializer
,
"weight_decay"
:
self
.
_weight_decay
,
"bias_regularizer"
:
self
.
_bias_regularizer
,
"norm_momentum"
:
self
.
_norm_momentum
,
"norm_epsilon"
:
self
.
_norm_epislon
,
"use_sync_bn"
:
self
.
_use_sync_bn
,
"activation"
:
self
.
_activation
,
"name"
:
None
}
inputs
=
ks
.
layers
.
Input
(
shape
=
self
.
_input_shape
.
shape
[
1
:])
output
=
self
.
_build_struct
(
layer_specs
,
inputs
)
super
().
__init__
(
inputs
=
inputs
,
outputs
=
output
,
name
=
self
.
_model_name
)
return
@
property
def
input_specs
(
self
):
return
self
.
_input_shape
@
property
def
output_specs
(
self
):
return
self
.
_output_specs
@
property
def
splits
(
self
):
return
self
.
_splits
def
_build_struct
(
self
,
net
,
inputs
):
endpoints
=
collections
.
OrderedDict
()
stack_outputs
=
[
inputs
]
for
i
,
config
in
enumerate
(
net
):
if
config
.
stack
==
None
:
x
=
self
.
_build_block
(
stack_outputs
[
config
.
route
],
config
,
name
=
f
"
{
config
.
layer
}
_
{
i
}
"
)
stack_outputs
.
append
(
x
)
elif
config
.
stack
==
"residual"
:
x
=
self
.
_residual_stack
(
stack_outputs
[
config
.
route
],
config
,
name
=
f
"
{
config
.
layer
}
_
{
i
}
"
)
stack_outputs
.
append
(
x
)
elif
config
.
stack
==
"csp"
:
x
=
self
.
_csp_stack
(
stack_outputs
[
config
.
route
],
config
,
name
=
f
"
{
config
.
layer
}
_
{
i
}
"
)
stack_outputs
.
append
(
x
)
elif
config
.
stack
==
"csp_tiny"
:
x_pass
,
x
=
self
.
_tiny_stack
(
stack_outputs
[
config
.
route
],
config
,
name
=
f
"
{
config
.
layer
}
_
{
i
}
"
)
stack_outputs
.
append
(
x_pass
)
if
(
config
.
is_output
and
self
.
_min_size
==
None
):
# or isinstance(config.output_name, str):
endpoints
[
config
.
output_name
]
=
x
elif
self
.
_min_size
!=
None
and
config
.
output_name
>=
self
.
_min_size
and
config
.
output_name
<=
self
.
_max_size
:
endpoints
[
config
.
output_name
]
=
x
self
.
_output_specs
=
{
l
:
endpoints
[
l
].
get_shape
()
for
l
in
endpoints
.
keys
()}
return
endpoints
def
_get_activation
(
self
,
activation
):
if
self
.
_activation
==
None
:
return
activation
else
:
return
self
.
_activation
def
_csp_stack
(
self
,
inputs
,
config
,
name
):
if
config
.
bottleneck
:
csp_filter_reduce
=
1
residual_filter_reduce
=
2
scale_filters
=
1
else
:
csp_filter_reduce
=
2
residual_filter_reduce
=
1
scale_filters
=
2
self
.
_default_dict
[
"activation"
]
=
self
.
_get_activation
(
config
.
activation
)
self
.
_default_dict
[
"name"
]
=
f
"
{
name
}
_csp_down"
x
,
x_route
=
nn_blocks
.
CSPDownSample
(
filters
=
config
.
filters
,
filter_reduce
=
csp_filter_reduce
,
**
self
.
_default_dict
)(
inputs
)
for
i
in
range
(
config
.
repetitions
):
self
.
_default_dict
[
"name"
]
=
f
"
{
name
}
_
{
i
}
"
x
=
nn_blocks
.
DarkResidual
(
filters
=
config
.
filters
//
scale_filters
,
filter_scale
=
residual_filter_reduce
,
**
self
.
_default_dict
)(
x
)
self
.
_default_dict
[
"name"
]
=
f
"
{
name
}
_csp_connect"
output
=
nn_blocks
.
CSPConnect
(
filters
=
config
.
filters
,
filter_reduce
=
csp_filter_reduce
,
**
self
.
_default_dict
)([
x
,
x_route
])
self
.
_default_dict
[
"activation"
]
=
self
.
_activation
self
.
_default_dict
[
"name"
]
=
None
return
output
def
_tiny_stack
(
self
,
inputs
,
config
,
name
):
self
.
_default_dict
[
"activation"
]
=
self
.
_get_activation
(
config
.
activation
)
self
.
_default_dict
[
"name"
]
=
f
"
{
name
}
_tiny"
x
,
x_route
=
nn_blocks
.
CSPTiny
(
filters
=
config
.
filters
,
**
self
.
_default_dict
)(
inputs
)
self
.
_default_dict
[
"activation"
]
=
self
.
_activation
self
.
_default_dict
[
"name"
]
=
None
return
x
,
x_route
def
_residual_stack
(
self
,
inputs
,
config
,
name
):
self
.
_default_dict
[
"activation"
]
=
self
.
_get_activation
(
config
.
activation
)
self
.
_default_dict
[
"name"
]
=
f
"
{
name
}
_residual_down"
x
=
nn_blocks
.
DarkResidual
(
filters
=
config
.
filters
,
downsample
=
True
,
**
self
.
_default_dict
)(
inputs
)
for
i
in
range
(
config
.
repetitions
-
1
):
self
.
_default_dict
[
"name"
]
=
f
"
{
name
}
_
{
i
}
"
x
=
nn_blocks
.
DarkResidual
(
filters
=
config
.
filters
,
**
self
.
_default_dict
)(
x
)
self
.
_default_dict
[
"activation"
]
=
self
.
_activation
self
.
_default_dict
[
"name"
]
=
None
return
x
def
_build_block
(
self
,
inputs
,
config
,
name
):
x
=
inputs
i
=
0
self
.
_default_dict
[
"activation"
]
=
self
.
_get_activation
(
config
.
activation
)
while
i
<
config
.
repetitions
:
self
.
_default_dict
[
"name"
]
=
f
"
{
name
}
_
{
i
}
"
layer
=
self
.
_registry
(
config
,
self
.
_default_dict
)
x
=
layer
(
x
)
i
+=
1
self
.
_default_dict
[
"activation"
]
=
self
.
_activation
self
.
_default_dict
[
"name"
]
=
None
return
x
@
staticmethod
def
get_model_config
(
name
):
name
=
name
.
lower
()
backbone
=
BACKBONES
[
name
][
"backbone"
]
splits
=
BACKBONES
[
name
][
"splits"
]
return
csp_build_block_specs
(
backbone
),
splits
@
factory
.
register_backbone_builder
(
'darknet'
)
def
build_darknet
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
model_config
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
backbone_type
=
model_config
.
backbone
.
type
backbone_cfg
=
model_config
.
backbone
.
get
()
norm_activation_config
=
model_config
.
norm_activation
return
Darknet
(
model_id
=
backbone_cfg
.
model_id
,
input_shape
=
input_specs
,
activation
=
norm_activation_config
.
activation
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
kernel_regularizer
=
l2_regularizer
)
# if __name__ == "__main__":
# from yolo.configs import backbones
# from official.core import registry
# model = backbones.Backbone(type="darknet", darknet=backbones.DarkNet(model_id="darknet53"))
# cfg = temp(model)
# model = factory.build_backbone(tf.keras.layers.InputSpec(shape = [None, 416, 416, 3]), cfg, None)
# print(model.output_specs)
official/vision/beta/projects/yolo/modeling/backbones/__init__.py
0 → 100644
View file @
73a2818d
official/vision/beta/projects/yolo/modeling/building_blocks/_CSPConnect.py
0 → 100644
View file @
73a2818d
import
tensorflow
as
tf
import
tensorflow.keras
as
ks
from
._DarkConv
import
DarkConv
@
ks
.
utils
.
register_keras_serializable
(
package
=
'yolo'
)
class
CSPConnect
(
ks
.
layers
.
Layer
):
def
__init__
(
self
,
filters
,
filter_reduce
=
2
,
activation
=
"mish"
,
kernel_initializer
=
'glorot_uniform'
,
bias_initializer
=
'zeros'
,
bias_regularizer
=
None
,
weight_decay
=
None
,
# default find where is it is stated
use_bn
=
True
,
use_sync_bn
=
False
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
#layer params
self
.
_filters
=
filters
self
.
_filter_reduce
=
filter_reduce
self
.
_activation
=
activation
#convoultion params
self
.
_kernel_initializer
=
kernel_initializer
self
.
_bias_initializer
=
bias_initializer
self
.
_weight_decay
=
weight_decay
self
.
_bias_regularizer
=
bias_regularizer
self
.
_use_bn
=
use_bn
self
.
_use_sync_bn
=
use_sync_bn
self
.
_norm_moment
=
norm_momentum
self
.
_norm_epsilon
=
norm_epsilon
return
def
build
(
self
,
input_shape
):
self
.
_conv1
=
DarkConv
(
filters
=
self
.
_filters
//
self
.
_filter_reduce
,
kernel_size
=
(
1
,
1
),
strides
=
(
1
,
1
),
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
bias_regularizer
=
self
.
_bias_regularizer
,
weight_decay
=
self
.
_weight_decay
,
use_bn
=
self
.
_use_bn
,
use_sync_bn
=
self
.
_use_sync_bn
,
norm_momentum
=
self
.
_norm_moment
,
norm_epsilon
=
self
.
_norm_epsilon
,
activation
=
self
.
_activation
)
self
.
_concat
=
ks
.
layers
.
Concatenate
(
axis
=-
1
)
self
.
_conv2
=
DarkConv
(
filters
=
self
.
_filters
,
kernel_size
=
(
1
,
1
),
strides
=
(
1
,
1
),
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
bias_regularizer
=
self
.
_bias_regularizer
,
weight_decay
=
self
.
_weight_decay
,
use_bn
=
self
.
_use_bn
,
use_sync_bn
=
self
.
_use_sync_bn
,
norm_momentum
=
self
.
_norm_moment
,
norm_epsilon
=
self
.
_norm_epsilon
,
activation
=
self
.
_activation
)
return
def
call
(
self
,
inputs
):
x_prev
,
x_csp
=
inputs
x
=
self
.
_conv1
(
x_prev
)
x
=
self
.
_concat
([
x
,
x_csp
])
x
=
self
.
_conv2
(
x
)
return
x
official/vision/beta/projects/yolo/modeling/building_blocks/_CSPDownSample.py
0 → 100644
View file @
73a2818d
import
tensorflow
as
tf
import
tensorflow.keras
as
ks
from
._DarkConv
import
DarkConv
@
ks
.
utils
.
register_keras_serializable
(
package
=
'yolo'
)
class
CSPDownSample
(
ks
.
layers
.
Layer
):
def
__init__
(
self
,
filters
,
filter_reduce
=
2
,
activation
=
"mish"
,
kernel_initializer
=
'glorot_uniform'
,
bias_initializer
=
'zeros'
,
bias_regularizer
=
None
,
weight_decay
=
None
,
# default find where is it is stated
use_bn
=
True
,
use_sync_bn
=
False
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
#layer params
self
.
_filters
=
filters
self
.
_filter_reduce
=
filter_reduce
self
.
_activation
=
activation
#convoultion params
self
.
_kernel_initializer
=
kernel_initializer
self
.
_bias_initializer
=
bias_initializer
self
.
_weight_decay
=
weight_decay
self
.
_bias_regularizer
=
bias_regularizer
self
.
_use_bn
=
use_bn
self
.
_use_sync_bn
=
use_sync_bn
self
.
_norm_moment
=
norm_momentum
self
.
_norm_epsilon
=
norm_epsilon
return
def
build
(
self
,
input_shape
):
self
.
_conv1
=
DarkConv
(
filters
=
self
.
_filters
,
kernel_size
=
(
3
,
3
),
strides
=
(
2
,
2
),
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
bias_regularizer
=
self
.
_bias_regularizer
,
weight_decay
=
self
.
_weight_decay
,
use_bn
=
self
.
_use_bn
,
use_sync_bn
=
self
.
_use_sync_bn
,
norm_momentum
=
self
.
_norm_moment
,
norm_epsilon
=
self
.
_norm_epsilon
,
activation
=
self
.
_activation
)
self
.
_conv2
=
DarkConv
(
filters
=
self
.
_filters
//
self
.
_filter_reduce
,
kernel_size
=
(
1
,
1
),
strides
=
(
1
,
1
),
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
bias_regularizer
=
self
.
_bias_regularizer
,
weight_decay
=
self
.
_weight_decay
,
use_bn
=
self
.
_use_bn
,
use_sync_bn
=
self
.
_use_sync_bn
,
norm_momentum
=
self
.
_norm_moment
,
norm_epsilon
=
self
.
_norm_epsilon
,
activation
=
self
.
_activation
)
self
.
_conv3
=
DarkConv
(
filters
=
self
.
_filters
//
self
.
_filter_reduce
,
kernel_size
=
(
1
,
1
),
strides
=
(
1
,
1
),
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
bias_regularizer
=
self
.
_bias_regularizer
,
weight_decay
=
self
.
_weight_decay
,
use_bn
=
self
.
_use_bn
,
use_sync_bn
=
self
.
_use_sync_bn
,
norm_momentum
=
self
.
_norm_moment
,
norm_epsilon
=
self
.
_norm_epsilon
,
activation
=
self
.
_activation
)
return
def
call
(
self
,
inputs
):
x
=
self
.
_conv1
(
inputs
)
y
=
self
.
_conv2
(
x
)
x
=
self
.
_conv3
(
x
)
return
(
x
,
y
)
official/vision/beta/projects/yolo/modeling/building_blocks/_CSPTiny.py
0 → 100644
View file @
73a2818d
"""Contains common building blocks for yolo neural networks."""
import
tensorflow
as
tf
import
tensorflow.keras
as
ks
from
._DarkConv
import
DarkConv
@
ks
.
utils
.
register_keras_serializable
(
package
=
'yolo'
)
class
CSPTiny
(
ks
.
layers
.
Layer
):
def
__init__
(
self
,
filters
=
1
,
use_bias
=
True
,
kernel_initializer
=
'glorot_uniform'
,
bias_initializer
=
'zeros'
,
bias_regularizer
=
None
,
weight_decay
=
None
,
# default find where is it is stated
use_bn
=
True
,
use_sync_bn
=
False
,
group_id
=
1
,
groups
=
2
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
activation
=
'leaky'
,
downsample
=
True
,
leaky_alpha
=
0.1
,
**
kwargs
):
# darkconv params
self
.
_filters
=
filters
self
.
_use_bias
=
use_bias
self
.
_kernel_initializer
=
kernel_initializer
self
.
_bias_initializer
=
bias_initializer
self
.
_bias_regularizer
=
bias_regularizer
self
.
_use_bn
=
use_bn
self
.
_use_sync_bn
=
use_sync_bn
self
.
_weight_decay
=
weight_decay
self
.
_groups
=
groups
self
.
_group_id
=
group_id
self
.
_downsample
=
downsample
# normal params
self
.
_norm_moment
=
norm_momentum
self
.
_norm_epsilon
=
norm_epsilon
# activation params
self
.
_conv_activation
=
activation
self
.
_leaky_alpha
=
leaky_alpha
super
().
__init__
(
**
kwargs
)
return
def
build
(
self
,
input_shape
):
self
.
_convlayer1
=
DarkConv
(
filters
=
self
.
_filters
,
kernel_size
=
(
3
,
3
),
strides
=
(
1
,
1
),
padding
=
'same'
,
use_bias
=
self
.
_use_bias
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
bias_regularizer
=
self
.
_bias_regularizer
,
weight_decay
=
self
.
_weight_decay
,
use_bn
=
self
.
_use_bn
,
use_sync_bn
=
self
.
_use_sync_bn
,
norm_momentum
=
self
.
_norm_moment
,
norm_epsilon
=
self
.
_norm_epsilon
,
activation
=
self
.
_conv_activation
,
leaky_alpha
=
self
.
_leaky_alpha
)
self
.
_convlayer2
=
DarkConv
(
filters
=
self
.
_filters
//
2
,
kernel_size
=
(
3
,
3
),
strides
=
(
1
,
1
),
padding
=
'same'
,
use_bias
=
self
.
_use_bias
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
bias_regularizer
=
self
.
_bias_regularizer
,
weight_decay
=
self
.
_weight_decay
,
use_bn
=
self
.
_use_bn
,
use_sync_bn
=
self
.
_use_sync_bn
,
norm_momentum
=
self
.
_norm_moment
,
norm_epsilon
=
self
.
_norm_epsilon
,
activation
=
self
.
_conv_activation
,
leaky_alpha
=
self
.
_leaky_alpha
)
self
.
_convlayer3
=
DarkConv
(
filters
=
self
.
_filters
//
2
,
kernel_size
=
(
3
,
3
),
strides
=
(
1
,
1
),
padding
=
'same'
,
use_bias
=
self
.
_use_bias
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
bias_regularizer
=
self
.
_bias_regularizer
,
weight_decay
=
self
.
_weight_decay
,
use_bn
=
self
.
_use_bn
,
use_sync_bn
=
self
.
_use_sync_bn
,
norm_momentum
=
self
.
_norm_moment
,
norm_epsilon
=
self
.
_norm_epsilon
,
activation
=
self
.
_conv_activation
,
leaky_alpha
=
self
.
_leaky_alpha
)
self
.
_convlayer4
=
DarkConv
(
filters
=
self
.
_filters
,
kernel_size
=
(
1
,
1
),
strides
=
(
1
,
1
),
padding
=
'same'
,
use_bias
=
self
.
_use_bias
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
bias_regularizer
=
self
.
_bias_regularizer
,
weight_decay
=
self
.
_weight_decay
,
use_bn
=
self
.
_use_bn
,
use_sync_bn
=
self
.
_use_sync_bn
,
norm_momentum
=
self
.
_norm_moment
,
norm_epsilon
=
self
.
_norm_epsilon
,
activation
=
self
.
_conv_activation
,
leaky_alpha
=
self
.
_leaky_alpha
)
self
.
_maxpool
=
tf
.
keras
.
layers
.
MaxPool2D
(
pool_size
=
2
,
strides
=
2
,
padding
=
"same"
,
data_format
=
None
)
super
().
build
(
input_shape
)
return
def
call
(
self
,
inputs
):
x1
=
self
.
_convlayer1
(
inputs
)
x2
=
tf
.
split
(
x1
,
self
.
_groups
,
axis
=
-
1
)
x3
=
self
.
_convlayer2
(
x2
[
self
.
_group_id
])
x4
=
self
.
_convlayer3
(
x3
)
x5
=
tf
.
concat
([
x4
,
x3
],
axis
=
-
1
)
x6
=
self
.
_convlayer4
(
x5
)
x
=
tf
.
concat
([
x1
,
x6
],
axis
=
-
1
)
if
self
.
_downsample
:
x
=
self
.
_maxpool
(
x
)
return
x
,
x6
def
get_config
(
self
):
# used to store/share parameters to reconsturct the model
layer_config
=
{
"filters"
:
self
.
_filters
,
"use_bias"
:
self
.
_use_bias
,
"strides"
:
self
.
_strides
,
"kernel_initializer"
:
self
.
_kernel_initializer
,
"bias_initializer"
:
self
.
_bias_initializer
,
"weight_decay"
:
self
.
_weight_decay
,
"use_bn"
:
self
.
_use_bn
,
"use_sync_bn"
:
self
.
_use_sync_bn
,
"norm_moment"
:
self
.
_norm_moment
,
"norm_epsilon"
:
self
.
_norm_epsilon
,
"activation"
:
self
.
_conv_activation
,
"leaky_alpha"
:
self
.
_leaky_alpha
,
"sc_activation"
:
self
.
_sc_activation
,
}
layer_config
.
update
(
super
().
get_config
())
return
layer_config
official/vision/beta/projects/yolo/modeling/building_blocks/_DarkConv.py
0 → 100644
View file @
73a2818d
"""Contains common building blocks for yolo neural networks."""
from
functools
import
partial
import
tensorflow
as
tf
import
tensorflow.keras
as
ks
import
tensorflow.keras.backend
as
K
from
._Identity
import
Identity
from
yolo.modeling.functions.mish_activation
import
mish
@
ks
.
utils
.
register_keras_serializable
(
package
=
'yolo'
)
class
DarkConv
(
ks
.
layers
.
Layer
):
def
__init__
(
self
,
filters
=
1
,
kernel_size
=
(
1
,
1
),
strides
=
(
1
,
1
),
padding
=
'same'
,
dilation_rate
=
(
1
,
1
),
use_bias
=
True
,
kernel_initializer
=
'glorot_uniform'
,
bias_initializer
=
'zeros'
,
bias_regularizer
=
None
,
weight_decay
=
None
,
# default find where is it is stated
use_bn
=
True
,
use_sync_bn
=
False
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
activation
=
'leaky'
,
leaky_alpha
=
0.1
,
**
kwargs
):
'''
Modified Convolution layer to match that of the DarkNet Library
Args:
filters: integer for output depth, or the number of features to learn
kernel_size: integer or tuple for the shape of the weight matrix or kernel to learn
strides: integer of tuple how much to move the kernel after each kernel use
padding: string 'valid' or 'same', if same, then pad the image, else do not
dialtion_rate: tuple to indicate how much to modulate kernel weights and
the how many pixels ina featur map to skip
use_bias: boolean to indicate wither to use bias in convolution layer
kernel_initializer: string to indicate which function to use to initialize weigths
bias_initializer: string to indicate which function to use to initialize bias
l2_regularization: float to use as a constant for weight regularization
use_bn: boolean for wether to use batchnormalization
use_sync_bn: boolean for wether sync batch normalization statistics
of all batch norm layers to the models global statistics (across all input batches)
norm_moment: float for moment to use for batchnorm
norm_epsilon: float for batchnorm epsilon
activation: string or None for activation function to use in layer,
if None activation is replaced by linear
leaky_alpha: float to use as alpha if activation function is leaky
**kwargs: Keyword Arguments
'''
# convolution params
self
.
_filters
=
filters
self
.
_kernel_size
=
kernel_size
self
.
_strides
=
strides
self
.
_padding
=
padding
self
.
_dilation_rate
=
dilation_rate
self
.
_use_bias
=
use_bias
self
.
_kernel_initializer
=
kernel_initializer
self
.
_bias_initializer
=
bias_initializer
self
.
_weight_decay
=
weight_decay
self
.
_bias_regularizer
=
bias_regularizer
# batchnorm params
self
.
_use_bn
=
use_bn
if
self
.
_use_bn
:
self
.
_use_bias
=
False
self
.
_use_sync_bn
=
use_sync_bn
self
.
_norm_moment
=
norm_momentum
self
.
_norm_epsilon
=
norm_epsilon
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
# format: (batch_size, height, width, channels)
self
.
_bn_axis
=
-
1
else
:
# format: (batch_size, channels, width, height)
self
.
_bn_axis
=
1
# activation params
if
activation
is
None
:
self
.
_activation
=
'linear'
else
:
self
.
_activation
=
activation
self
.
_leaky_alpha
=
leaky_alpha
super
(
DarkConv
,
self
).
__init__
(
**
kwargs
)
return
def
build
(
self
,
input_shape
):
kernel_size
=
self
.
_kernel_size
if
type
(
self
.
_kernel_size
)
==
int
else
self
.
_kernel_size
[
0
]
if
self
.
_padding
==
"same"
and
kernel_size
!=
1
:
self
.
_zeropad
=
ks
.
layers
.
ZeroPadding2D
(
((
1
,
1
),
(
1
,
1
)))
# symetric padding
else
:
self
.
_zeropad
=
Identity
()
self
.
conv
=
ks
.
layers
.
Conv2D
(
filters
=
self
.
_filters
,
kernel_size
=
self
.
_kernel_size
,
strides
=
self
.
_strides
,
padding
=
"valid"
,
#self._padding,
dilation_rate
=
self
.
_dilation_rate
,
use_bias
=
self
.
_use_bias
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_weight_decay
,
bias_regularizer
=
self
.
_bias_regularizer
)
#self.conv =tf.nn.convolution(filters=self._filters, strides=self._strides, padding=self._padding
if
self
.
_use_bn
:
if
self
.
_use_sync_bn
:
self
.
bn
=
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
(
momentum
=
self
.
_norm_moment
,
epsilon
=
self
.
_norm_epsilon
,
axis
=
self
.
_bn_axis
)
else
:
self
.
bn
=
ks
.
layers
.
BatchNormalization
(
momentum
=
self
.
_norm_moment
,
epsilon
=
self
.
_norm_epsilon
,
axis
=
self
.
_bn_axis
)
else
:
self
.
bn
=
Identity
()
if
self
.
_activation
==
'leaky'
:
alpha
=
{
"alpha"
:
self
.
_leaky_alpha
}
self
.
_activation_fn
=
partial
(
tf
.
nn
.
leaky_relu
,
**
alpha
)
elif
self
.
_activation
==
'mish'
:
self
.
_activation_fn
=
mish
()
else
:
self
.
_activation_fn
=
ks
.
layers
.
Activation
(
activation
=
self
.
_activation
)
super
(
DarkConv
,
self
).
build
(
input_shape
)
return
def
call
(
self
,
inputs
):
x
=
self
.
_zeropad
(
inputs
)
x
=
self
.
conv
(
x
)
x
=
self
.
bn
(
x
)
x
=
self
.
_activation_fn
(
x
)
return
x
def
get_config
(
self
):
# used to store/share parameters to reconsturct the model
layer_config
=
{
"filters"
:
self
.
_filters
,
"kernel_size"
:
self
.
_kernel_size
,
"strides"
:
self
.
_strides
,
"padding"
:
self
.
_padding
,
"dilation_rate"
:
self
.
_dilation_rate
,
"use_bias"
:
self
.
_use_bias
,
"kernel_initializer"
:
self
.
_kernel_initializer
,
"bias_initializer"
:
self
.
_bias_initializer
,
"bias_regularizer"
:
self
.
_bias_regularizer
,
"l2_regularization"
:
self
.
_l2_regularization
,
"use_bn"
:
self
.
_use_bn
,
"use_sync_bn"
:
self
.
_use_sync_bn
,
"norm_moment"
:
self
.
_norm_moment
,
"norm_epsilon"
:
self
.
_norm_epsilon
,
"activation"
:
self
.
_activation
,
"leaky_alpha"
:
self
.
_leaky_alpha
}
layer_config
.
update
(
super
(
DarkConv
,
self
).
get_config
())
return
layer_config
def
__repr__
(
self
):
return
repr
(
self
.
get_config
())
official/vision/beta/projects/yolo/modeling/building_blocks/_DarkResidual.py
0 → 100644
View file @
73a2818d
"""Contains common building blocks for yolo neural networks."""
import
tensorflow
as
tf
import
tensorflow.keras
as
ks
from
._DarkConv
import
DarkConv
from
._Identity
import
Identity
@
ks
.
utils
.
register_keras_serializable
(
package
=
'yolo'
)
class
DarkResidual
(
ks
.
layers
.
Layer
):
def
__init__
(
self
,
filters
=
1
,
filter_scale
=
2
,
use_bias
=
True
,
kernel_initializer
=
'glorot_uniform'
,
bias_initializer
=
'zeros'
,
weight_decay
=
None
,
bias_regularizer
=
None
,
use_bn
=
True
,
use_sync_bn
=
False
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
activation
=
'leaky'
,
leaky_alpha
=
0.1
,
sc_activation
=
'linear'
,
downsample
=
False
,
**
kwargs
):
'''
DarkNet block with Residual connection for Yolo v3 Backbone
Args:
filters: integer for output depth, or the number of features to learn
use_bias: boolean to indicate wither to use bias in convolution layer
kernel_initializer: string to indicate which function to use to initialize weigths
bias_initializer: string to indicate which function to use to initialize bias
use_bn: boolean for wether to use batchnormalization
use_sync_bn: boolean for wether sync batch normalization statistics
of all batch norm layers to the models global statistics (across all input batches)
norm_moment: float for moment to use for batchnorm
norm_epsilon: float for batchnorm epsilon
conv_activation: string or None for activation function to use in layer,
if None activation is replaced by linear
leaky_alpha: float to use as alpha if activation function is leaky
sc_activation: string for activation function to use in layer
downsample: boolean for if image input is larger than layer output, set downsample to True
so the dimentions are forced to match
**kwargs: Keyword Arguments
'''
# downsample
self
.
_downsample
=
downsample
# darkconv params
self
.
_filters
=
filters
self
.
_filter_scale
=
filter_scale
self
.
_use_bias
=
use_bias
self
.
_kernel_initializer
=
kernel_initializer
self
.
_bias_initializer
=
bias_initializer
self
.
_bias_regularizer
=
bias_regularizer
self
.
_use_bn
=
use_bn
self
.
_use_sync_bn
=
use_sync_bn
self
.
_weight_decay
=
weight_decay
# normal params
self
.
_norm_moment
=
norm_momentum
self
.
_norm_epsilon
=
norm_epsilon
# activation params
self
.
_conv_activation
=
activation
self
.
_leaky_alpha
=
leaky_alpha
self
.
_sc_activation
=
sc_activation
super
().
__init__
(
**
kwargs
)
return
def
build
(
self
,
input_shape
):
if
self
.
_downsample
:
self
.
_dconv
=
DarkConv
(
filters
=
self
.
_filters
,
kernel_size
=
(
3
,
3
),
strides
=
(
2
,
2
),
padding
=
'same'
,
use_bias
=
self
.
_use_bias
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
bias_regularizer
=
self
.
_bias_regularizer
,
use_bn
=
self
.
_use_bn
,
use_sync_bn
=
self
.
_use_sync_bn
,
norm_momentum
=
self
.
_norm_moment
,
norm_epsilon
=
self
.
_norm_epsilon
,
activation
=
self
.
_conv_activation
,
weight_decay
=
self
.
_weight_decay
,
leaky_alpha
=
self
.
_leaky_alpha
)
else
:
self
.
_dconv
=
Identity
()
self
.
_conv1
=
DarkConv
(
filters
=
self
.
_filters
//
self
.
_filter_scale
,
kernel_size
=
(
1
,
1
),
strides
=
(
1
,
1
),
padding
=
'same'
,
use_bias
=
self
.
_use_bias
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
bias_regularizer
=
self
.
_bias_regularizer
,
use_bn
=
self
.
_use_bn
,
use_sync_bn
=
self
.
_use_sync_bn
,
norm_momentum
=
self
.
_norm_moment
,
norm_epsilon
=
self
.
_norm_epsilon
,
activation
=
self
.
_conv_activation
,
weight_decay
=
self
.
_weight_decay
,
leaky_alpha
=
self
.
_leaky_alpha
)
self
.
_conv2
=
DarkConv
(
filters
=
self
.
_filters
,
kernel_size
=
(
3
,
3
),
strides
=
(
1
,
1
),
padding
=
'same'
,
use_bias
=
self
.
_use_bias
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
bias_regularizer
=
self
.
_bias_regularizer
,
use_bn
=
self
.
_use_bn
,
use_sync_bn
=
self
.
_use_sync_bn
,
norm_momentum
=
self
.
_norm_moment
,
norm_epsilon
=
self
.
_norm_epsilon
,
activation
=
self
.
_conv_activation
,
weight_decay
=
self
.
_weight_decay
,
leaky_alpha
=
self
.
_leaky_alpha
)
self
.
_shortcut
=
ks
.
layers
.
Add
()
self
.
_activation_fn
=
ks
.
layers
.
Activation
(
activation
=
self
.
_sc_activation
)
super
().
build
(
input_shape
)
return
def
call
(
self
,
inputs
):
shortcut
=
self
.
_dconv
(
inputs
)
x
=
self
.
_conv1
(
shortcut
)
x
=
self
.
_conv2
(
x
)
x
=
self
.
_shortcut
([
x
,
shortcut
])
return
self
.
_activation_fn
(
x
)
def
get_config
(
self
):
# used to store/share parameters to reconsturct the model
layer_config
=
{
"filters"
:
self
.
_filters
,
"use_bias"
:
self
.
_use_bias
,
"kernel_initializer"
:
self
.
_kernel_initializer
,
"bias_initializer"
:
self
.
_bias_initializer
,
"weight_decay"
:
self
.
_weight_decay
,
"use_bn"
:
self
.
_use_bn
,
"use_sync_bn"
:
self
.
_use_sync_bn
,
"norm_moment"
:
self
.
_norm_moment
,
"norm_epsilon"
:
self
.
_norm_epsilon
,
"activation"
:
self
.
_conv_activation
,
"leaky_alpha"
:
self
.
_leaky_alpha
,
"sc_activation"
:
self
.
_sc_activation
,
"downsample"
:
self
.
_downsample
}
layer_config
.
update
(
super
().
get_config
())
return
layer_config
official/vision/beta/projects/yolo/modeling/building_blocks/_DarkTiny.py
0 → 100644
View file @
73a2818d
"""Contains common building blocks for yolo neural networks."""
import
tensorflow
as
tf
import
tensorflow.keras
as
ks
from
._DarkConv
import
DarkConv
@
ks
.
utils
.
register_keras_serializable
(
package
=
'yolo'
)
class
DarkTiny
(
ks
.
layers
.
Layer
):
def
__init__
(
self
,
filters
=
1
,
use_bias
=
True
,
strides
=
2
,
kernel_initializer
=
'glorot_uniform'
,
bias_initializer
=
'zeros'
,
bias_regularizer
=
None
,
weight_decay
=
None
,
# default find where is it is stated
use_bn
=
True
,
use_sync_bn
=
False
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
activation
=
'leaky'
,
leaky_alpha
=
0.1
,
sc_activation
=
'linear'
,
**
kwargs
):
# darkconv params
self
.
_filters
=
filters
self
.
_use_bias
=
use_bias
self
.
_kernel_initializer
=
kernel_initializer
self
.
_bias_initializer
=
bias_initializer
self
.
_bias_regularizer
=
bias_regularizer
self
.
_use_bn
=
use_bn
self
.
_use_sync_bn
=
use_sync_bn
self
.
_strides
=
strides
self
.
_weight_decay
=
weight_decay
# normal params
self
.
_norm_moment
=
norm_momentum
self
.
_norm_epsilon
=
norm_epsilon
# activation params
self
.
_conv_activation
=
activation
self
.
_leaky_alpha
=
leaky_alpha
self
.
_sc_activation
=
sc_activation
super
().
__init__
(
**
kwargs
)
return
def
build
(
self
,
input_shape
):
# if self._strides == 2:
# self._zeropad = ks.layers.ZeroPadding2D(((1,0), (1,0)))
# padding = "valid"
# else:
# self._zeropad = ks.layers.ZeroPadding2D(((0,1), (0,1)))#nn_blocks.Identity()#ks.layers.ZeroPadding2D(((1,0), (1,0)))
# padding = "valid"
self
.
_maxpool
=
tf
.
keras
.
layers
.
MaxPool2D
(
pool_size
=
2
,
strides
=
self
.
_strides
,
padding
=
"same"
,
data_format
=
None
)
self
.
_convlayer
=
DarkConv
(
filters
=
self
.
_filters
,
kernel_size
=
(
3
,
3
),
strides
=
(
1
,
1
),
padding
=
'same'
,
use_bias
=
self
.
_use_bias
,
kernel_initializer
=
self
.
_kernel_initializer
,
bias_initializer
=
self
.
_bias_initializer
,
bias_regularizer
=
self
.
_bias_regularizer
,
weight_decay
=
self
.
_weight_decay
,
use_bn
=
self
.
_use_bn
,
use_sync_bn
=
self
.
_use_sync_bn
,
norm_momentum
=
self
.
_norm_moment
,
norm_epsilon
=
self
.
_norm_epsilon
,
activation
=
self
.
_conv_activation
,
leaky_alpha
=
self
.
_leaky_alpha
)
super
().
build
(
input_shape
)
return
def
call
(
self
,
inputs
):
output
=
self
.
_maxpool
(
inputs
)
output
=
self
.
_convlayer
(
output
)
return
output
def
get_config
(
self
):
# used to store/share parameters to reconsturct the model
layer_config
=
{
"filters"
:
self
.
_filters
,
"use_bias"
:
self
.
_use_bias
,
"strides"
:
self
.
_strides
,
"kernel_initializer"
:
self
.
_kernel_initializer
,
"bias_initializer"
:
self
.
_bias_initializer
,
"l2_regularization"
:
self
.
_l2_regularization
,
"use_bn"
:
self
.
_use_bn
,
"use_sync_bn"
:
self
.
_use_sync_bn
,
"norm_moment"
:
self
.
_norm_moment
,
"norm_epsilon"
:
self
.
_norm_epsilon
,
"activation"
:
self
.
_conv_activation
,
"leaky_alpha"
:
self
.
_leaky_alpha
,
"sc_activation"
:
self
.
_sc_activation
,
}
layer_config
.
update
(
super
().
get_config
())
return
layer_config
official/vision/beta/projects/yolo/modeling/building_blocks/__init__.py
0 → 100644
View file @
73a2818d
from
._DarkConv
import
DarkConv
from
._DarkResidual
import
DarkResidual
from
._DarkTiny
import
DarkTiny
from
._CSPConnect
import
CSPConnect
from
._CSPDownSample
import
CSPDownSample
from
._CSPTiny
import
CSPTiny
official/vision/beta/projects/yolo/modeling/tests/README.md
0 → 100644
View file @
73a2818d
# Unit Tests for Yolo and Darknet models
official/vision/beta/projects/yolo/modeling/tests/__init__.py
0 → 100644
View file @
73a2818d
official/vision/beta/projects/yolo/modeling/tests/test_CSPConnect.py
0 → 100644
View file @
73a2818d
import
tensorflow
as
tf
import
tensorflow.keras
as
ks
import
numpy
as
np
from
absl.testing
import
parameterized
from
official.vision.beta.projects.yolo.modeling.building_blocks
import
CSPDownSample
as
layer
from
official.vision.beta.projects.yolo.modeling.building_blocks
import
CSPConnect
as
layer_companion
class
CSPConnect
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
named_parameters
((
"same"
,
224
,
224
,
64
,
1
),
(
"downsample"
,
224
,
224
,
64
,
2
))
def
test_pass_through
(
self
,
width
,
height
,
filters
,
mod
):
x
=
ks
.
Input
(
shape
=
(
width
,
height
,
filters
))
test_layer
=
layer
(
filters
=
filters
,
filter_reduce
=
mod
)
test_layer2
=
layer_companion
(
filters
=
filters
,
filter_reduce
=
mod
)
outx
,
px
=
test_layer
(
x
)
outx
=
test_layer2
([
outx
,
px
])
print
(
outx
)
print
(
outx
.
shape
.
as_list
())
self
.
assertAllEqual
(
outx
.
shape
.
as_list
(),
[
None
,
np
.
ceil
(
width
//
2
),
np
.
ceil
(
height
//
2
),
(
filters
)])
return
@
parameterized
.
named_parameters
((
"same"
,
224
,
224
,
64
,
1
),
(
"downsample"
,
224
,
224
,
128
,
2
))
def
test_gradient_pass_though
(
self
,
filters
,
width
,
height
,
mod
):
loss
=
ks
.
losses
.
MeanSquaredError
()
optimizer
=
ks
.
optimizers
.
SGD
()
test_layer
=
layer
(
filters
,
filter_reduce
=
mod
)
path_layer
=
layer_companion
(
filters
,
filter_reduce
=
mod
)
init
=
tf
.
random_normal_initializer
()
x
=
tf
.
Variable
(
initial_value
=
init
(
shape
=
(
1
,
width
,
height
,
filters
),
dtype
=
tf
.
float32
))
y
=
tf
.
Variable
(
initial_value
=
init
(
shape
=
(
1
,
int
(
np
.
ceil
(
width
//
2
)),
int
(
np
.
ceil
(
height
//
2
)),
filters
),
dtype
=
tf
.
float32
))
with
tf
.
GradientTape
()
as
tape
:
x_hat
,
x_prev
=
test_layer
(
x
)
x_hat
=
path_layer
([
x_hat
,
x_prev
])
grad_loss
=
loss
(
x_hat
,
y
)
grad
=
tape
.
gradient
(
grad_loss
,
test_layer
.
trainable_variables
)
optimizer
.
apply_gradients
(
zip
(
grad
,
test_layer
.
trainable_variables
))
self
.
assertNotIn
(
None
,
grad
)
return
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/vision/beta/projects/yolo/modeling/tests/test_CSPDownSample.py
0 → 100644
View file @
73a2818d
import
tensorflow
as
tf
import
tensorflow.keras
as
ks
import
numpy
as
np
from
absl.testing
import
parameterized
from
official.vision.beta.projects.yolo.modeling.building_blocks
import
CSPDownSample
as
layer
from
official.vision.beta.projects.yolo.modeling.building_blocks
import
CSPConnect
as
layer_companion
class
CSPDownSample
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
named_parameters
((
"same"
,
224
,
224
,
64
,
1
),
(
"downsample"
,
224
,
224
,
64
,
2
))
def
test_pass_through
(
self
,
width
,
height
,
filters
,
mod
):
x
=
ks
.
Input
(
shape
=
(
width
,
height
,
filters
))
test_layer
=
layer
(
filters
=
filters
,
filter_reduce
=
mod
)
outx
,
px
=
test_layer
(
x
)
print
(
outx
)
print
(
outx
.
shape
.
as_list
())
self
.
assertAllEqual
(
outx
.
shape
.
as_list
(),
[
None
,
np
.
ceil
(
width
//
2
),
np
.
ceil
(
height
//
2
),
(
filters
/
mod
)])
return
@
parameterized
.
named_parameters
((
"same"
,
224
,
224
,
64
,
1
),
(
"downsample"
,
224
,
224
,
128
,
2
))
def
test_gradient_pass_though
(
self
,
filters
,
width
,
height
,
mod
):
loss
=
ks
.
losses
.
MeanSquaredError
()
optimizer
=
ks
.
optimizers
.
SGD
()
test_layer
=
layer
(
filters
,
filter_reduce
=
mod
)
path_layer
=
layer_companion
(
filters
,
filter_reduce
=
mod
)
init
=
tf
.
random_normal_initializer
()
x
=
tf
.
Variable
(
initial_value
=
init
(
shape
=
(
1
,
width
,
height
,
filters
),
dtype
=
tf
.
float32
))
y
=
tf
.
Variable
(
initial_value
=
init
(
shape
=
(
1
,
int
(
np
.
ceil
(
width
//
2
)),
int
(
np
.
ceil
(
height
//
2
)),
filters
),
dtype
=
tf
.
float32
))
with
tf
.
GradientTape
()
as
tape
:
x_hat
,
x_prev
=
test_layer
(
x
)
x_hat
=
path_layer
([
x_hat
,
x_prev
])
grad_loss
=
loss
(
x_hat
,
y
)
grad
=
tape
.
gradient
(
grad_loss
,
test_layer
.
trainable_variables
)
optimizer
.
apply_gradients
(
zip
(
grad
,
test_layer
.
trainable_variables
))
self
.
assertNotIn
(
None
,
grad
)
return
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/vision/beta/projects/yolo/modeling/tests/test_DarkConv.py
0 → 100644
View file @
73a2818d
import
tensorflow
as
tf
import
tensorflow.keras
as
ks
import
tensorflow_datasets
as
tfds
from
absl.testing
import
parameterized
from
official.vision.beta.projects.yolo.modeling.building_blocks
import
DarkConv
class
DarkConvTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
named_parameters
((
"valid"
,
(
3
,
3
),
"valid"
,
(
1
,
1
)),
(
"same"
,
(
3
,
3
),
"same"
,
(
1
,
1
)),
(
"downsample"
,
(
3
,
3
),
"same"
,
(
2
,
2
)),
(
"test"
,
(
1
,
1
),
"valid"
,
(
1
,
1
)))
def
test_pass_through
(
self
,
kernel_size
,
padding
,
strides
):
if
padding
==
"same"
:
pad_const
=
1
else
:
pad_const
=
0
x
=
ks
.
Input
(
shape
=
(
224
,
224
,
3
))
test_layer
=
DarkConv
(
filters
=
64
,
kernel_size
=
kernel_size
,
padding
=
padding
,
strides
=
strides
,
trainable
=
False
)
outx
=
test_layer
(
x
)
print
(
outx
.
shape
.
as_list
())
test
=
[
None
,
int
((
224
-
kernel_size
[
0
]
+
(
2
*
pad_const
))
/
strides
[
0
]
+
1
),
int
((
224
-
kernel_size
[
1
]
+
(
2
*
pad_const
))
/
strides
[
1
]
+
1
),
64
]
print
(
test
)
self
.
assertAllEqual
(
outx
.
shape
.
as_list
(),
test
)
return
@
parameterized
.
named_parameters
((
"filters"
,
3
))
def
test_gradient_pass_though
(
self
,
filters
):
loss
=
ks
.
losses
.
MeanSquaredError
()
optimizer
=
ks
.
optimizers
.
SGD
()
with
tf
.
device
(
"/CPU:0"
):
test_layer
=
DarkConv
(
filters
,
kernel_size
=
(
3
,
3
),
padding
=
"same"
)
init
=
tf
.
random_normal_initializer
()
x
=
tf
.
Variable
(
initial_value
=
init
(
shape
=
(
1
,
224
,
224
,
3
),
dtype
=
tf
.
float32
))
y
=
tf
.
Variable
(
initial_value
=
init
(
shape
=
(
1
,
224
,
224
,
filters
),
dtype
=
tf
.
float32
))
with
tf
.
GradientTape
()
as
tape
:
x_hat
=
test_layer
(
x
)
grad_loss
=
loss
(
x_hat
,
y
)
grad
=
tape
.
gradient
(
grad_loss
,
test_layer
.
trainable_variables
)
optimizer
.
apply_gradients
(
zip
(
grad
,
test_layer
.
trainable_variables
))
self
.
assertNotIn
(
None
,
grad
)
return
# @parameterized.named_parameters(("filters", 3), ("filters", 20), ("filters", 512))
# def test_time(self, filters):
# # finish the test for time
# dataset = tfds.load("mnist")
# model = ks.Sequential([
# DarkConv(7, kernel_size=(3,3), strides = (2,2), activation='relu'),
# DarkConv(10, kernel_size=(3,3), strides = (2,2), activation='relu'),
# DarkConv(filters, kernel_size=(3,3), strides = (1,1), activation='relu'),
# DarkConv(9, kernel_size=(3,3), strides = (2,2), activation='relu'),
# ks.layers.GlobalAveragePooling2D(),
# ks.layers.Dense(10, activation='softmax')], name='test')
# return
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/vision/beta/projects/yolo/modeling/tests/test_DarkResidual.py
0 → 100644
View file @
73a2818d
import
tensorflow
as
tf
import
tensorflow.keras
as
ks
import
numpy
as
np
from
absl.testing
import
parameterized
from
official.vision.beta.projects.yolo.modeling.building_blocks
import
DarkResidual
as
layer
class
DarkResidualTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
named_parameters
((
"same"
,
224
,
224
,
64
,
False
),
(
"downsample"
,
223
,
223
,
32
,
True
),
(
"oddball"
,
223
,
223
,
32
,
False
))
def
test_pass_through
(
self
,
width
,
height
,
filters
,
downsample
):
mod
=
1
if
downsample
:
mod
=
2
x
=
ks
.
Input
(
shape
=
(
width
,
height
,
filters
))
test_layer
=
layer
(
filters
=
filters
,
downsample
=
downsample
)
outx
=
test_layer
(
x
)
print
(
outx
)
print
(
outx
.
shape
.
as_list
())
self
.
assertAllEqual
(
outx
.
shape
.
as_list
(),
[
None
,
np
.
ceil
(
width
/
mod
),
np
.
ceil
(
height
/
mod
),
filters
])
return
@
parameterized
.
named_parameters
((
"same"
,
64
,
224
,
224
,
False
),
(
"downsample"
,
32
,
223
,
223
,
True
),
(
"oddball"
,
32
,
223
,
223
,
False
))
def
test_gradient_pass_though
(
self
,
filters
,
width
,
height
,
downsample
):
loss
=
ks
.
losses
.
MeanSquaredError
()
optimizer
=
ks
.
optimizers
.
SGD
()
test_layer
=
layer
(
filters
,
downsample
=
downsample
)
if
downsample
:
mod
=
2
else
:
mod
=
1
init
=
tf
.
random_normal_initializer
()
x
=
tf
.
Variable
(
initial_value
=
init
(
shape
=
(
1
,
width
,
height
,
filters
),
dtype
=
tf
.
float32
))
y
=
tf
.
Variable
(
initial_value
=
init
(
shape
=
(
1
,
int
(
np
.
ceil
(
width
/
mod
)),
int
(
np
.
ceil
(
height
/
mod
)),
filters
),
dtype
=
tf
.
float32
))
with
tf
.
GradientTape
()
as
tape
:
x_hat
=
test_layer
(
x
)
grad_loss
=
loss
(
x_hat
,
y
)
grad
=
tape
.
gradient
(
grad_loss
,
test_layer
.
trainable_variables
)
optimizer
.
apply_gradients
(
zip
(
grad
,
test_layer
.
trainable_variables
))
self
.
assertNotIn
(
None
,
grad
)
return
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment