Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
091da63d
Commit
091da63d
authored
May 10, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 373084407
parent
3091fb64
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
130 additions
and
89 deletions
+130
-89
official/vision/beta/configs/backbones.py
official/vision/beta/configs/backbones.py
+4
-0
official/vision/beta/configs/maskrcnn.py
official/vision/beta/configs/maskrcnn.py
+9
-2
official/vision/beta/configs/retinanet.py
official/vision/beta/configs/retinanet.py
+14
-4
official/vision/beta/modeling/backbones/efficientnet.py
official/vision/beta/modeling/backbones/efficientnet.py
+4
-4
official/vision/beta/modeling/backbones/factory.py
official/vision/beta/modeling/backbones/factory.py
+20
-9
official/vision/beta/modeling/backbones/factory_test.py
official/vision/beta/modeling/backbones/factory_test.py
+10
-16
official/vision/beta/modeling/backbones/mobilenet.py
official/vision/beta/modeling/backbones/mobilenet.py
+4
-4
official/vision/beta/modeling/backbones/resnet.py
official/vision/beta/modeling/backbones/resnet.py
+4
-4
official/vision/beta/modeling/backbones/resnet_3d.py
official/vision/beta/modeling/backbones/resnet_3d.py
+6
-6
official/vision/beta/modeling/backbones/resnet_deeplab.py
official/vision/beta/modeling/backbones/resnet_deeplab.py
+5
-4
official/vision/beta/modeling/backbones/revnet.py
official/vision/beta/modeling/backbones/revnet.py
+5
-4
official/vision/beta/modeling/backbones/spinenet.py
official/vision/beta/modeling/backbones/spinenet.py
+7
-6
official/vision/beta/modeling/backbones/spinenet_mobile.py
official/vision/beta/modeling/backbones/spinenet_mobile.py
+7
-6
official/vision/beta/modeling/factory.py
official/vision/beta/modeling/factory.py
+12
-8
official/vision/beta/modeling/factory_3d.py
official/vision/beta/modeling/factory_3d.py
+3
-1
official/vision/beta/projects/assemblenet/modeling/assemblenet.py
.../vision/beta/projects/assemblenet/modeling/assemblenet.py
+7
-5
official/vision/beta/projects/deepmac_maskrcnn/tasks/deep_mask_head_rcnn.py
...ta/projects/deepmac_maskrcnn/tasks/deep_mask_head_rcnn.py
+3
-2
official/vision/beta/projects/simclr/tasks/simclr.py
official/vision/beta/projects/simclr/tasks/simclr.py
+2
-1
official/vision/beta/projects/yolo/modeling/backbones/darknet.py
...l/vision/beta/projects/yolo/modeling/backbones/darknet.py
+4
-3
No files found.
official/vision/beta/configs/backbones.py
View file @
091da63d
...
@@ -67,6 +67,8 @@ class SpineNet(hyperparams.Config):
...
@@ -67,6 +67,8 @@ class SpineNet(hyperparams.Config):
"""SpineNet config."""
"""SpineNet config."""
model_id
:
str
=
'49'
model_id
:
str
=
'49'
stochastic_depth_drop_rate
:
float
=
0.0
stochastic_depth_drop_rate
:
float
=
0.0
min_level
:
int
=
3
max_level
:
int
=
7
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -76,6 +78,8 @@ class SpineNetMobile(hyperparams.Config):
...
@@ -76,6 +78,8 @@ class SpineNetMobile(hyperparams.Config):
stochastic_depth_drop_rate
:
float
=
0.0
stochastic_depth_drop_rate
:
float
=
0.0
se_ratio
:
float
=
0.2
se_ratio
:
float
=
0.2
expand_ratio
:
int
=
6
expand_ratio
:
int
=
6
min_level
:
int
=
3
max_level
:
int
=
7
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
official/vision/beta/configs/maskrcnn.py
View file @
091da63d
...
@@ -437,7 +437,12 @@ def maskrcnn_spinenet_coco() -> cfg.ExperimentConfig:
...
@@ -437,7 +437,12 @@ def maskrcnn_spinenet_coco() -> cfg.ExperimentConfig:
'instances_val2017.json'
),
'instances_val2017.json'
),
model
=
MaskRCNN
(
model
=
MaskRCNN
(
backbone
=
backbones
.
Backbone
(
backbone
=
backbones
.
Backbone
(
type
=
'spinenet'
,
spinenet
=
backbones
.
SpineNet
(
model_id
=
'49'
)),
type
=
'spinenet'
,
spinenet
=
backbones
.
SpineNet
(
model_id
=
'49'
,
min_level
=
3
,
max_level
=
7
,
)),
decoder
=
decoders
.
Decoder
(
decoder
=
decoders
.
Decoder
(
type
=
'identity'
,
identity
=
decoders
.
Identity
()),
type
=
'identity'
,
identity
=
decoders
.
Identity
()),
anchor
=
Anchor
(
anchor_size
=
3
),
anchor
=
Anchor
(
anchor_size
=
3
),
...
@@ -491,6 +496,8 @@ def maskrcnn_spinenet_coco() -> cfg.ExperimentConfig:
...
@@ -491,6 +496,8 @@ def maskrcnn_spinenet_coco() -> cfg.ExperimentConfig:
})),
})),
restrictions
=
[
restrictions
=
[
'task.train_data.is_training != None'
,
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
'task.validation_data.is_training != None'
,
'task.model.min_level == task,model.backbone.spinenet.min_level'
,
'task.model.max_level == task,model.backbone.spinenet.max_level'
,
])
])
return
config
return
config
official/vision/beta/configs/retinanet.py
View file @
091da63d
...
@@ -248,7 +248,10 @@ def retinanet_spinenet_coco() -> cfg.ExperimentConfig:
...
@@ -248,7 +248,10 @@ def retinanet_spinenet_coco() -> cfg.ExperimentConfig:
backbone
=
backbones
.
Backbone
(
backbone
=
backbones
.
Backbone
(
type
=
'spinenet'
,
type
=
'spinenet'
,
spinenet
=
backbones
.
SpineNet
(
spinenet
=
backbones
.
SpineNet
(
model_id
=
'49'
,
stochastic_depth_drop_rate
=
0.2
)),
model_id
=
'49'
,
stochastic_depth_drop_rate
=
0.2
,
min_level
=
3
,
max_level
=
7
)),
decoder
=
decoders
.
Decoder
(
decoder
=
decoders
.
Decoder
(
type
=
'identity'
,
identity
=
decoders
.
Identity
()),
type
=
'identity'
,
identity
=
decoders
.
Identity
()),
anchor
=
Anchor
(
anchor_size
=
3
),
anchor
=
Anchor
(
anchor_size
=
3
),
...
@@ -306,7 +309,9 @@ def retinanet_spinenet_coco() -> cfg.ExperimentConfig:
...
@@ -306,7 +309,9 @@ def retinanet_spinenet_coco() -> cfg.ExperimentConfig:
})),
})),
restrictions
=
[
restrictions
=
[
'task.train_data.is_training != None'
,
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
'task.validation_data.is_training != None'
,
'task.model.min_level == task,model.backbone.spinenet.min_level'
,
'task.model.max_level == task,model.backbone.spinenet.max_level'
,
])
])
return
config
return
config
...
@@ -329,7 +334,10 @@ def retinanet_spinenet_mobile_coco() -> cfg.ExperimentConfig:
...
@@ -329,7 +334,10 @@ def retinanet_spinenet_mobile_coco() -> cfg.ExperimentConfig:
backbone
=
backbones
.
Backbone
(
backbone
=
backbones
.
Backbone
(
type
=
'spinenet_mobile'
,
type
=
'spinenet_mobile'
,
spinenet_mobile
=
backbones
.
SpineNetMobile
(
spinenet_mobile
=
backbones
.
SpineNetMobile
(
model_id
=
'49'
,
stochastic_depth_drop_rate
=
0.2
)),
model_id
=
'49'
,
stochastic_depth_drop_rate
=
0.2
,
min_level
=
3
,
max_level
=
7
)),
decoder
=
decoders
.
Decoder
(
decoder
=
decoders
.
Decoder
(
type
=
'identity'
,
identity
=
decoders
.
Identity
()),
type
=
'identity'
,
identity
=
decoders
.
Identity
()),
head
=
RetinaNetHead
(
num_filters
=
48
,
use_separable_conv
=
True
),
head
=
RetinaNetHead
(
num_filters
=
48
,
use_separable_conv
=
True
),
...
@@ -388,7 +396,9 @@ def retinanet_spinenet_mobile_coco() -> cfg.ExperimentConfig:
...
@@ -388,7 +396,9 @@ def retinanet_spinenet_mobile_coco() -> cfg.ExperimentConfig:
})),
})),
restrictions
=
[
restrictions
=
[
'task.train_data.is_training != None'
,
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
'task.validation_data.is_training != None'
,
'task.model.min_level == task,model.backbone.spinenet_mobile.min_level'
,
'task.model.max_level == task,model.backbone.spinenet_mobile.max_level'
,
])
])
return
config
return
config
official/vision/beta/modeling/backbones/efficientnet.py
View file @
091da63d
...
@@ -297,12 +297,12 @@ class EfficientNet(tf.keras.Model):
...
@@ -297,12 +297,12 @@ class EfficientNet(tf.keras.Model):
@
factory
.
register_backbone_builder
(
'efficientnet'
)
@
factory
.
register_backbone_builder
(
'efficientnet'
)
def
build_efficientnet
(
def
build_efficientnet
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
model_config
:
hyperparams
.
Config
,
backbone_config
:
hyperparams
.
Config
,
norm_activation_config
:
hyperparams
.
Config
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
"""Builds EfficientNet backbone from a config."""
"""Builds EfficientNet backbone from a config."""
backbone_type
=
model_config
.
backbone
.
type
backbone_type
=
backbone_config
.
type
backbone_cfg
=
model_config
.
backbone
.
get
()
backbone_cfg
=
backbone_config
.
get
()
norm_activation_config
=
model_config
.
norm_activation
assert
backbone_type
==
'efficientnet'
,
(
f
'Inconsistent backbone type '
assert
backbone_type
==
'efficientnet'
,
(
f
'Inconsistent backbone type '
f
'
{
backbone_type
}
'
)
f
'
{
backbone_type
}
'
)
...
...
official/vision/beta/modeling/backbones/factory.py
View file @
091da63d
...
@@ -42,6 +42,8 @@ in place that uses it.
...
@@ -42,6 +42,8 @@ in place that uses it.
"""
"""
from
typing
import
Sequence
,
Union
# Import libraries
# Import libraries
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -81,22 +83,31 @@ def register_backbone_builder(key: str):
...
@@ -81,22 +83,31 @@ def register_backbone_builder(key: str):
return
registry
.
register
(
_REGISTERED_BACKBONE_CLS
,
key
)
return
registry
.
register
(
_REGISTERED_BACKBONE_CLS
,
key
)
def
build_backbone
(
def
build_backbone
(
input_specs
:
Union
[
tf
.
keras
.
layers
.
InputSpec
,
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
Sequence
[
tf
.
keras
.
layers
.
InputSpec
]],
model_config
:
hyperparams
.
Config
,
backbone_config
:
hyperparams
.
Config
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
norm_activation_config
:
hyperparams
.
Config
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
,
**
kwargs
)
->
tf
.
keras
.
Model
:
"""Builds backbone from a config.
"""Builds backbone from a config.
Args:
Args:
input_specs: A `tf.keras.layers.InputSpec` of input.
input_specs: A (sequence of) `tf.keras.layers.InputSpec` of input.
model_config: A `OneOfConfig` of model config.
backbone_config: A `OneOfConfig` of backbone config.
norm_activation_config: A config for normalization/activation layer.
l2_regularizer: A `tf.keras.regularizers.Regularizer` object. Default to
l2_regularizer: A `tf.keras.regularizers.Regularizer` object. Default to
None.
None.
**kwargs: Additional keyword args to be passed to backbone builder.
Returns:
Returns:
A `tf.keras.Model` instance of the backbone.
A `tf.keras.Model` instance of the backbone.
"""
"""
backbone_builder
=
registry
.
lookup
(
_REGISTERED_BACKBONE_CLS
,
backbone_builder
=
registry
.
lookup
(
_REGISTERED_BACKBONE_CLS
,
model_config
.
backbone
.
type
)
backbone_config
.
type
)
return
backbone_builder
(
input_specs
,
model_config
,
l2_regularizer
)
return
backbone_builder
(
input_specs
=
input_specs
,
backbone_config
=
backbone_config
,
norm_activation_config
=
norm_activation_config
,
l2_regularizer
=
l2_regularizer
,
**
kwargs
)
official/vision/beta/modeling/backbones/factory_test.py
View file @
091da63d
...
@@ -22,7 +22,6 @@ from tensorflow.python.distribute import combinations
...
@@ -22,7 +22,6 @@ from tensorflow.python.distribute import combinations
from
official.vision.beta.configs
import
backbones
as
backbones_cfg
from
official.vision.beta.configs
import
backbones
as
backbones_cfg
from
official.vision.beta.configs
import
backbones_3d
as
backbones_3d_cfg
from
official.vision.beta.configs
import
backbones_3d
as
backbones_3d_cfg
from
official.vision.beta.configs
import
common
as
common_cfg
from
official.vision.beta.configs
import
common
as
common_cfg
from
official.vision.beta.configs
import
retinanet
as
retinanet_cfg
from
official.vision.beta.modeling
import
backbones
from
official.vision.beta.modeling
import
backbones
from
official.vision.beta.modeling.backbones
import
factory
from
official.vision.beta.modeling.backbones
import
factory
...
@@ -42,12 +41,11 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -42,12 +41,11 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
resnet
=
backbones_cfg
.
ResNet
(
model_id
=
model_id
,
se_ratio
=
0.0
))
resnet
=
backbones_cfg
.
ResNet
(
model_id
=
model_id
,
se_ratio
=
0.0
))
norm_activation_config
=
common_cfg
.
NormActivation
(
norm_activation_config
=
common_cfg
.
NormActivation
(
norm_momentum
=
0.99
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
False
)
norm_momentum
=
0.99
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
False
)
model_config
=
retinanet_cfg
.
RetinaNet
(
backbone
=
backbone_config
,
norm_activation
=
norm_activation_config
)
factory_network
=
factory
.
build_backbone
(
factory_network
=
factory
.
build_backbone
(
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
]),
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
]),
model_config
=
model_config
)
backbone_config
=
backbone_config
,
norm_activation_config
=
norm_activation_config
)
network_config
=
network
.
get_config
()
network_config
=
network
.
get_config
()
factory_network_config
=
factory_network
.
get_config
()
factory_network_config
=
factory_network
.
get_config
()
...
@@ -74,12 +72,11 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -74,12 +72,11 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
model_id
=
model_id
,
se_ratio
=
se_ratio
))
model_id
=
model_id
,
se_ratio
=
se_ratio
))
norm_activation_config
=
common_cfg
.
NormActivation
(
norm_activation_config
=
common_cfg
.
NormActivation
(
norm_momentum
=
0.99
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
False
)
norm_momentum
=
0.99
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
False
)
model_config
=
retinanet_cfg
.
RetinaNet
(
backbone
=
backbone_config
,
norm_activation
=
norm_activation_config
)
factory_network
=
factory
.
build_backbone
(
factory_network
=
factory
.
build_backbone
(
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
]),
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
]),
model_config
=
model_config
)
backbone_config
=
backbone_config
,
norm_activation_config
=
norm_activation_config
)
network_config
=
network
.
get_config
()
network_config
=
network
.
get_config
()
factory_network_config
=
factory_network
.
get_config
()
factory_network_config
=
factory_network
.
get_config
()
...
@@ -108,12 +105,11 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -108,12 +105,11 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
model_id
=
model_id
,
filter_size_scale
=
filter_size_scale
))
model_id
=
model_id
,
filter_size_scale
=
filter_size_scale
))
norm_activation_config
=
common_cfg
.
NormActivation
(
norm_activation_config
=
common_cfg
.
NormActivation
(
norm_momentum
=
0.99
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
False
)
norm_momentum
=
0.99
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
False
)
model_config
=
retinanet_cfg
.
RetinaNet
(
backbone
=
backbone_config
,
norm_activation
=
norm_activation_config
)
factory_network
=
factory
.
build_backbone
(
factory_network
=
factory
.
build_backbone
(
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
]),
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
]),
model_config
=
model_config
)
backbone_config
=
backbone_config
,
norm_activation_config
=
norm_activation_config
)
network_config
=
network
.
get_config
()
network_config
=
network
.
get_config
()
factory_network_config
=
factory_network
.
get_config
()
factory_network_config
=
factory_network
.
get_config
()
...
@@ -141,13 +137,12 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -141,13 +137,12 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
spinenet
=
backbones_cfg
.
SpineNet
(
model_id
=
model_id
))
spinenet
=
backbones_cfg
.
SpineNet
(
model_id
=
model_id
))
norm_activation_config
=
common_cfg
.
NormActivation
(
norm_activation_config
=
common_cfg
.
NormActivation
(
norm_momentum
=
0.99
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
False
)
norm_momentum
=
0.99
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
False
)
model_config
=
retinanet_cfg
.
RetinaNet
(
backbone
=
backbone_config
,
norm_activation
=
norm_activation_config
)
factory_network
=
factory
.
build_backbone
(
factory_network
=
factory
.
build_backbone
(
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
input_size
,
input_size
,
3
]),
shape
=
[
None
,
input_size
,
input_size
,
3
]),
model_config
=
model_config
)
backbone_config
=
backbone_config
,
norm_activation_config
=
norm_activation_config
)
network_config
=
network
.
get_config
()
network_config
=
network
.
get_config
()
factory_network_config
=
factory_network
.
get_config
()
factory_network_config
=
factory_network
.
get_config
()
...
@@ -166,12 +161,11 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -166,12 +161,11 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
revnet
=
backbones_cfg
.
RevNet
(
model_id
=
model_id
))
revnet
=
backbones_cfg
.
RevNet
(
model_id
=
model_id
))
norm_activation_config
=
common_cfg
.
NormActivation
(
norm_activation_config
=
common_cfg
.
NormActivation
(
norm_momentum
=
0.99
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
False
)
norm_momentum
=
0.99
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
False
)
model_config
=
retinanet_cfg
.
RetinaNet
(
backbone
=
backbone_config
,
norm_activation
=
norm_activation_config
)
factory_network
=
factory
.
build_backbone
(
factory_network
=
factory
.
build_backbone
(
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
]),
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
]),
model_config
=
model_config
)
backbone_config
=
backbone_config
,
norm_activation_config
=
norm_activation_config
)
network_config
=
network
.
get_config
()
network_config
=
network
.
get_config
()
factory_network_config
=
factory_network
.
get_config
()
factory_network_config
=
factory_network
.
get_config
()
...
...
official/vision/beta/modeling/backbones/mobilenet.py
View file @
091da63d
...
@@ -766,12 +766,12 @@ class MobileNet(tf.keras.Model):
...
@@ -766,12 +766,12 @@ class MobileNet(tf.keras.Model):
@
factory
.
register_backbone_builder
(
'mobilenet'
)
@
factory
.
register_backbone_builder
(
'mobilenet'
)
def
build_mobilenet
(
def
build_mobilenet
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
model_config
:
hyperparams
.
Config
,
backbone_config
:
hyperparams
.
Config
,
norm_activation_config
:
hyperparams
.
Config
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
"""Builds MobileNet backbone from a config."""
"""Builds MobileNet backbone from a config."""
backbone_type
=
model_config
.
backbone
.
type
backbone_type
=
backbone_config
.
type
backbone_cfg
=
model_config
.
backbone
.
get
()
backbone_cfg
=
backbone_config
.
get
()
norm_activation_config
=
model_config
.
norm_activation
assert
backbone_type
==
'mobilenet'
,
(
f
'Inconsistent backbone type '
assert
backbone_type
==
'mobilenet'
,
(
f
'Inconsistent backbone type '
f
'
{
backbone_type
}
'
)
f
'
{
backbone_type
}
'
)
...
...
official/vision/beta/modeling/backbones/resnet.py
View file @
091da63d
...
@@ -372,12 +372,12 @@ class ResNet(tf.keras.Model):
...
@@ -372,12 +372,12 @@ class ResNet(tf.keras.Model):
@
factory
.
register_backbone_builder
(
'resnet'
)
@
factory
.
register_backbone_builder
(
'resnet'
)
def
build_resnet
(
def
build_resnet
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
model_config
:
hyperparams
.
Config
,
backbone_config
:
hyperparams
.
Config
,
norm_activation_config
:
hyperparams
.
Config
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
"""Builds ResNet backbone from a config."""
"""Builds ResNet backbone from a config."""
backbone_type
=
model_config
.
backbone
.
type
backbone_type
=
backbone_config
.
type
backbone_cfg
=
model_config
.
backbone
.
get
()
backbone_cfg
=
backbone_config
.
get
()
norm_activation_config
=
model_config
.
norm_activation
assert
backbone_type
==
'resnet'
,
(
f
'Inconsistent backbone type '
assert
backbone_type
==
'resnet'
,
(
f
'Inconsistent backbone type '
f
'
{
backbone_type
}
'
)
f
'
{
backbone_type
}
'
)
...
...
official/vision/beta/modeling/backbones/resnet_3d.py
View file @
091da63d
...
@@ -378,11 +378,11 @@ class ResNet3D(tf.keras.Model):
...
@@ -378,11 +378,11 @@ class ResNet3D(tf.keras.Model):
@
factory
.
register_backbone_builder
(
'resnet_3d'
)
@
factory
.
register_backbone_builder
(
'resnet_3d'
)
def
build_resnet3d
(
def
build_resnet3d
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
model_config
,
backbone_config
:
hyperparams
.
Config
,
norm_activation_config
:
hyperparams
.
Config
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
"""Builds ResNet 3d backbone from a config."""
"""Builds ResNet 3d backbone from a config."""
backbone_cfg
=
model_config
.
backbone
.
get
()
backbone_cfg
=
backbone_config
.
get
()
norm_activation_config
=
model_config
.
norm_activation
# Flatten configs before passing to the backbone.
# Flatten configs before passing to the backbone.
temporal_strides
=
[]
temporal_strides
=
[]
...
@@ -416,11 +416,11 @@ def build_resnet3d(
...
@@ -416,11 +416,11 @@ def build_resnet3d(
@
factory
.
register_backbone_builder
(
'resnet_3d_rs'
)
@
factory
.
register_backbone_builder
(
'resnet_3d_rs'
)
def
build_resnet3d_rs
(
def
build_resnet3d_rs
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
model_config
:
hyperparams
.
Config
,
backbone_config
:
hyperparams
.
Config
,
norm_activation_config
:
hyperparams
.
Config
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
"""Builds ResNet-3D-RS backbone from a config."""
"""Builds ResNet-3D-RS backbone from a config."""
backbone_cfg
=
model_config
.
backbone
.
get
()
backbone_cfg
=
backbone_config
.
get
()
norm_activation_config
=
model_config
.
norm_activation
# Flatten configs before passing to the backbone.
# Flatten configs before passing to the backbone.
temporal_strides
=
[]
temporal_strides
=
[]
...
...
official/vision/beta/modeling/backbones/resnet_deeplab.py
View file @
091da63d
...
@@ -18,6 +18,7 @@ from typing import Callable, Optional, Tuple, List
...
@@ -18,6 +18,7 @@ from typing import Callable, Optional, Tuple, List
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
hyperparams
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.vision.beta.modeling.backbones
import
factory
from
official.vision.beta.modeling.backbones
import
factory
from
official.vision.beta.modeling.layers
import
nn_blocks
from
official.vision.beta.modeling.layers
import
nn_blocks
...
@@ -340,12 +341,12 @@ class DilatedResNet(tf.keras.Model):
...
@@ -340,12 +341,12 @@ class DilatedResNet(tf.keras.Model):
@
factory
.
register_backbone_builder
(
'dilated_resnet'
)
@
factory
.
register_backbone_builder
(
'dilated_resnet'
)
def
build_dilated_resnet
(
def
build_dilated_resnet
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
model_config
,
backbone_config
:
hyperparams
.
Config
,
norm_activation_config
:
hyperparams
.
Config
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
"""Builds ResNet backbone from a config."""
"""Builds ResNet backbone from a config."""
backbone_type
=
model_config
.
backbone
.
type
backbone_type
=
backbone_config
.
type
backbone_cfg
=
model_config
.
backbone
.
get
()
backbone_cfg
=
backbone_config
.
get
()
norm_activation_config
=
model_config
.
norm_activation
assert
backbone_type
==
'dilated_resnet'
,
(
f
'Inconsistent backbone type '
assert
backbone_type
==
'dilated_resnet'
,
(
f
'Inconsistent backbone type '
f
'
{
backbone_type
}
'
)
f
'
{
backbone_type
}
'
)
...
...
official/vision/beta/modeling/backbones/revnet.py
View file @
091da63d
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
from
typing
import
Any
,
Callable
,
Dict
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
Optional
# Import libraries
# Import libraries
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
hyperparams
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.vision.beta.modeling.backbones
import
factory
from
official.vision.beta.modeling.backbones
import
factory
from
official.vision.beta.modeling.layers
import
nn_blocks
from
official.vision.beta.modeling.layers
import
nn_blocks
...
@@ -213,12 +214,12 @@ class RevNet(tf.keras.Model):
...
@@ -213,12 +214,12 @@ class RevNet(tf.keras.Model):
@
factory
.
register_backbone_builder
(
'revnet'
)
@
factory
.
register_backbone_builder
(
'revnet'
)
def
build_revnet
(
def
build_revnet
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
model_config
,
backbone_config
:
hyperparams
.
Config
,
norm_activation_config
:
hyperparams
.
Config
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
"""Builds RevNet backbone from a config."""
"""Builds RevNet backbone from a config."""
backbone_type
=
model_config
.
backbone
.
type
backbone_type
=
backbone_config
.
type
backbone_cfg
=
model_config
.
backbone
.
get
()
backbone_cfg
=
backbone_config
.
get
()
norm_activation_config
=
model_config
.
norm_activation
assert
backbone_type
==
'revnet'
,
(
f
'Inconsistent backbone type '
assert
backbone_type
==
'revnet'
,
(
f
'Inconsistent backbone type '
f
'
{
backbone_type
}
'
)
f
'
{
backbone_type
}
'
)
...
...
official/vision/beta/modeling/backbones/spinenet.py
View file @
091da63d
...
@@ -22,6 +22,7 @@ from typing import Any, List, Optional, Tuple
...
@@ -22,6 +22,7 @@ from typing import Any, List, Optional, Tuple
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
hyperparams
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.vision.beta.modeling.backbones
import
factory
from
official.vision.beta.modeling.backbones
import
factory
from
official.vision.beta.modeling.layers
import
nn_blocks
from
official.vision.beta.modeling.layers
import
nn_blocks
...
@@ -527,12 +528,12 @@ class SpineNet(tf.keras.Model):
...
@@ -527,12 +528,12 @@ class SpineNet(tf.keras.Model):
@
factory
.
register_backbone_builder
(
'spinenet'
)
@
factory
.
register_backbone_builder
(
'spinenet'
)
def
build_spinenet
(
def
build_spinenet
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
model_config
,
backbone_config
:
hyperparams
.
Config
,
norm_activation_config
:
hyperparams
.
Config
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
"""Builds SpineNet backbone from a config."""
"""Builds SpineNet backbone from a config."""
backbone_type
=
model_config
.
backbone
.
type
backbone_type
=
backbone_config
.
type
backbone_cfg
=
model_config
.
backbone
.
get
()
backbone_cfg
=
backbone_config
.
get
()
norm_activation_config
=
model_config
.
norm_activation
assert
backbone_type
==
'spinenet'
,
(
f
'Inconsistent backbone type '
assert
backbone_type
==
'spinenet'
,
(
f
'Inconsistent backbone type '
f
'
{
backbone_type
}
'
)
f
'
{
backbone_type
}
'
)
...
@@ -544,8 +545,8 @@ def build_spinenet(
...
@@ -544,8 +545,8 @@ def build_spinenet(
return
SpineNet
(
return
SpineNet
(
input_specs
=
input_specs
,
input_specs
=
input_specs
,
min_level
=
model_confi
g
.
min_level
,
min_level
=
backbone_cf
g
.
min_level
,
max_level
=
model_confi
g
.
max_level
,
max_level
=
backbone_cf
g
.
max_level
,
endpoints_num_filters
=
scaling_params
[
'endpoints_num_filters'
],
endpoints_num_filters
=
scaling_params
[
'endpoints_num_filters'
],
resample_alpha
=
scaling_params
[
'resample_alpha'
],
resample_alpha
=
scaling_params
[
'resample_alpha'
],
block_repeats
=
scaling_params
[
'block_repeats'
],
block_repeats
=
scaling_params
[
'block_repeats'
],
...
...
official/vision/beta/modeling/backbones/spinenet_mobile.py
View file @
091da63d
...
@@ -36,6 +36,7 @@ from typing import Any, List, Optional, Tuple
...
@@ -36,6 +36,7 @@ from typing import Any, List, Optional, Tuple
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
hyperparams
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.vision.beta.modeling.backbones
import
factory
from
official.vision.beta.modeling.backbones
import
factory
from
official.vision.beta.modeling.layers
import
nn_blocks
from
official.vision.beta.modeling.layers
import
nn_blocks
...
@@ -501,12 +502,12 @@ class SpineNetMobile(tf.keras.Model):
...
@@ -501,12 +502,12 @@ class SpineNetMobile(tf.keras.Model):
@
factory
.
register_backbone_builder
(
'spinenet_mobile'
)
@
factory
.
register_backbone_builder
(
'spinenet_mobile'
)
def
build_spinenet_mobile
(
def
build_spinenet_mobile
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
model_config
,
backbone_config
:
hyperparams
.
Config
,
norm_activation_config
:
hyperparams
.
Config
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
"""Builds Mobile SpineNet backbone from a config."""
"""Builds Mobile SpineNet backbone from a config."""
backbone_type
=
model_config
.
backbone
.
type
backbone_type
=
backbone_config
.
type
backbone_cfg
=
model_config
.
backbone
.
get
()
backbone_cfg
=
backbone_config
.
get
()
norm_activation_config
=
model_config
.
norm_activation
assert
backbone_type
==
'spinenet_mobile'
,
(
f
'Inconsistent backbone type '
assert
backbone_type
==
'spinenet_mobile'
,
(
f
'Inconsistent backbone type '
f
'
{
backbone_type
}
'
)
f
'
{
backbone_type
}
'
)
...
@@ -518,8 +519,8 @@ def build_spinenet_mobile(
...
@@ -518,8 +519,8 @@ def build_spinenet_mobile(
return
SpineNetMobile
(
return
SpineNetMobile
(
input_specs
=
input_specs
,
input_specs
=
input_specs
,
min_level
=
model_confi
g
.
min_level
,
min_level
=
backbone_cf
g
.
min_level
,
max_level
=
model_confi
g
.
max_level
,
max_level
=
backbone_cf
g
.
max_level
,
endpoints_num_filters
=
scaling_params
[
'endpoints_num_filters'
],
endpoints_num_filters
=
scaling_params
[
'endpoints_num_filters'
],
block_repeats
=
scaling_params
[
'block_repeats'
],
block_repeats
=
scaling_params
[
'block_repeats'
],
filter_size_scale
=
scaling_params
[
'filter_size_scale'
],
filter_size_scale
=
scaling_params
[
'filter_size_scale'
],
...
...
official/vision/beta/modeling/factory.py
View file @
091da63d
...
@@ -44,12 +44,13 @@ def build_classification_model(
...
@@ -44,12 +44,13 @@ def build_classification_model(
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
,
skip_logits_layer
:
bool
=
False
)
->
tf
.
keras
.
Model
:
skip_logits_layer
:
bool
=
False
)
->
tf
.
keras
.
Model
:
"""Builds the classification model."""
"""Builds the classification model."""
norm_activation_config
=
model_config
.
norm_activation
backbone
=
backbones
.
factory
.
build_backbone
(
backbone
=
backbones
.
factory
.
build_backbone
(
input_specs
=
input_specs
,
input_specs
=
input_specs
,
model_config
=
model_config
,
backbone_config
=
model_config
.
backbone
,
norm_activation_config
=
norm_activation_config
,
l2_regularizer
=
l2_regularizer
)
l2_regularizer
=
l2_regularizer
)
norm_activation_config
=
model_config
.
norm_activation
model
=
classification_model
.
ClassificationModel
(
model
=
classification_model
.
ClassificationModel
(
backbone
=
backbone
,
backbone
=
backbone
,
num_classes
=
model_config
.
num_classes
,
num_classes
=
model_config
.
num_classes
,
...
@@ -69,9 +70,11 @@ def build_maskrcnn(
...
@@ -69,9 +70,11 @@ def build_maskrcnn(
model_config
:
maskrcnn_cfg
.
MaskRCNN
,
model_config
:
maskrcnn_cfg
.
MaskRCNN
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
"""Builds Mask R-CNN model."""
"""Builds Mask R-CNN model."""
norm_activation_config
=
model_config
.
norm_activation
backbone
=
backbones
.
factory
.
build_backbone
(
backbone
=
backbones
.
factory
.
build_backbone
(
input_specs
=
input_specs
,
input_specs
=
input_specs
,
model_config
=
model_config
,
backbone_config
=
model_config
.
backbone
,
norm_activation_config
=
norm_activation_config
,
l2_regularizer
=
l2_regularizer
)
l2_regularizer
=
l2_regularizer
)
decoder
=
decoder_factory
.
build_decoder
(
decoder
=
decoder_factory
.
build_decoder
(
...
@@ -85,7 +88,6 @@ def build_maskrcnn(
...
@@ -85,7 +88,6 @@ def build_maskrcnn(
roi_aligner_config
=
model_config
.
roi_aligner
roi_aligner_config
=
model_config
.
roi_aligner
detection_head_config
=
model_config
.
detection_head
detection_head_config
=
model_config
.
detection_head
generator_config
=
model_config
.
detection_generator
generator_config
=
model_config
.
detection_generator
norm_activation_config
=
model_config
.
norm_activation
num_anchors_per_location
=
(
num_anchors_per_location
=
(
len
(
model_config
.
anchor
.
aspect_ratios
)
*
model_config
.
anchor
.
num_scales
)
len
(
model_config
.
anchor
.
aspect_ratios
)
*
model_config
.
anchor
.
num_scales
)
...
@@ -242,9 +244,11 @@ def build_retinanet(
...
@@ -242,9 +244,11 @@ def build_retinanet(
model_config
:
retinanet_cfg
.
RetinaNet
,
model_config
:
retinanet_cfg
.
RetinaNet
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
"""Builds RetinaNet model."""
"""Builds RetinaNet model."""
norm_activation_config
=
model_config
.
norm_activation
backbone
=
backbones
.
factory
.
build_backbone
(
backbone
=
backbones
.
factory
.
build_backbone
(
input_specs
=
input_specs
,
input_specs
=
input_specs
,
model_config
=
model_config
,
backbone_config
=
model_config
.
backbone
,
norm_activation_config
=
norm_activation_config
,
l2_regularizer
=
l2_regularizer
)
l2_regularizer
=
l2_regularizer
)
decoder
=
decoder_factory
.
build_decoder
(
decoder
=
decoder_factory
.
build_decoder
(
...
@@ -254,7 +258,6 @@ def build_retinanet(
...
@@ -254,7 +258,6 @@ def build_retinanet(
head_config
=
model_config
.
head
head_config
=
model_config
.
head
generator_config
=
model_config
.
detection_generator
generator_config
=
model_config
.
detection_generator
norm_activation_config
=
model_config
.
norm_activation
num_anchors_per_location
=
(
num_anchors_per_location
=
(
len
(
model_config
.
anchor
.
aspect_ratios
)
*
model_config
.
anchor
.
num_scales
)
len
(
model_config
.
anchor
.
aspect_ratios
)
*
model_config
.
anchor
.
num_scales
)
...
@@ -301,9 +304,11 @@ def build_segmentation_model(
...
@@ -301,9 +304,11 @@ def build_segmentation_model(
model_config
:
segmentation_cfg
.
SemanticSegmentationModel
,
model_config
:
segmentation_cfg
.
SemanticSegmentationModel
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
"""Builds Segmentation model."""
"""Builds Segmentation model."""
norm_activation_config
=
model_config
.
norm_activation
backbone
=
backbones
.
factory
.
build_backbone
(
backbone
=
backbones
.
factory
.
build_backbone
(
input_specs
=
input_specs
,
input_specs
=
input_specs
,
model_config
=
model_config
,
backbone_config
=
model_config
.
backbone
,
norm_activation_config
=
norm_activation_config
,
l2_regularizer
=
l2_regularizer
)
l2_regularizer
=
l2_regularizer
)
decoder
=
decoder_factory
.
build_decoder
(
decoder
=
decoder_factory
.
build_decoder
(
...
@@ -312,7 +317,6 @@ def build_segmentation_model(
...
@@ -312,7 +317,6 @@ def build_segmentation_model(
l2_regularizer
=
l2_regularizer
)
l2_regularizer
=
l2_regularizer
)
head_config
=
model_config
.
head
head_config
=
model_config
.
head
norm_activation_config
=
model_config
.
norm_activation
head
=
segmentation_heads
.
SegmentationHead
(
head
=
segmentation_heads
.
SegmentationHead
(
num_classes
=
model_config
.
num_classes
,
num_classes
=
model_config
.
num_classes
,
...
...
official/vision/beta/modeling/factory_3d.py
View file @
091da63d
...
@@ -85,9 +85,11 @@ def build_video_classification_model(
...
@@ -85,9 +85,11 @@ def build_video_classification_model(
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
"""Builds the video classification model."""
"""Builds the video classification model."""
input_specs_dict
=
{
'image'
:
input_specs
}
input_specs_dict
=
{
'image'
:
input_specs
}
norm_activation_config
=
model_config
.
norm_activation
backbone
=
backbones
.
factory
.
build_backbone
(
backbone
=
backbones
.
factory
.
build_backbone
(
input_specs
=
input_specs
,
input_specs
=
input_specs
,
model_config
=
model_config
,
backbone_config
=
model_config
.
backbone
,
norm_activation_config
=
norm_activation_config
,
l2_regularizer
=
l2_regularizer
)
l2_regularizer
=
l2_regularizer
)
model
=
video_classification_model
.
VideoClassificationModel
(
model
=
video_classification_model
.
VideoClassificationModel
(
...
...
official/vision/beta/projects/assemblenet/modeling/assemblenet.py
View file @
091da63d
...
@@ -54,6 +54,7 @@ from absl import logging
...
@@ -54,6 +54,7 @@ from absl import logging
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
hyperparams
from
official.vision.beta.modeling
import
factory_3d
as
model_factory
from
official.vision.beta.modeling
import
factory_3d
as
model_factory
from
official.vision.beta.modeling.backbones
import
factory
as
backbone_factory
from
official.vision.beta.modeling.backbones
import
factory
as
backbone_factory
from
official.vision.beta.projects.assemblenet.configs
import
assemblenet
as
cfg
from
official.vision.beta.projects.assemblenet.configs
import
assemblenet
as
cfg
...
@@ -1015,14 +1016,14 @@ def assemblenet_v1(assemblenet_depth: int,
...
@@ -1015,14 +1016,14 @@ def assemblenet_v1(assemblenet_depth: int,
@
backbone_factory
.
register_backbone_builder
(
'assemblenet'
)
@
backbone_factory
.
register_backbone_builder
(
'assemblenet'
)
def
build_assemblenet_v1
(
def
build_assemblenet_v1
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
model_config
:
cfg
.
Backbone3D
,
backbone_config
:
hyperparams
.
Config
,
norm_activation_config
:
hyperparams
.
Config
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
"""Builds assemblenet backbone."""
"""Builds assemblenet backbone."""
del
l2_regularizer
del
l2_regularizer
backbone_type
=
model_config
.
backbone
.
type
backbone_type
=
backbone_config
.
type
backbone_cfg
=
model_config
.
backbone
.
get
()
backbone_cfg
=
backbone_config
.
get
()
norm_activation_config
=
model_config
.
norm_activation
assert
backbone_type
==
'assemblenet'
assert
backbone_type
==
'assemblenet'
assemblenet_depth
=
int
(
backbone_cfg
.
model_id
)
assemblenet_depth
=
int
(
backbone_cfg
.
model_id
)
...
@@ -1060,7 +1061,8 @@ def build_assemblenet_model(
...
@@ -1060,7 +1061,8 @@ def build_assemblenet_model(
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
):
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
):
"""Builds assemblenet model."""
"""Builds assemblenet model."""
input_specs_dict
=
{
'image'
:
input_specs
}
input_specs_dict
=
{
'image'
:
input_specs
}
backbone
=
build_assemblenet_v1
(
input_specs
,
model_config
,
l2_regularizer
)
backbone
=
build_assemblenet_v1
(
input_specs
,
model_config
.
backbone
,
model_config
.
norm_activation
,
l2_regularizer
)
backbone_cfg
=
model_config
.
backbone
.
get
()
backbone_cfg
=
model_config
.
backbone
.
get
()
model_structure
,
_
=
cfg
.
blocks_to_flat_lists
(
backbone_cfg
.
blocks
)
model_structure
,
_
=
cfg
.
blocks_to_flat_lists
(
backbone_cfg
.
blocks
)
model
=
AssembleNetModel
(
model
=
AssembleNetModel
(
...
...
official/vision/beta/projects/deepmac_maskrcnn/tasks/deep_mask_head_rcnn.py
View file @
091da63d
...
@@ -37,9 +37,11 @@ def build_maskrcnn(input_specs: tf.keras.layers.InputSpec,
...
@@ -37,9 +37,11 @@ def build_maskrcnn(input_specs: tf.keras.layers.InputSpec,
model_config
:
deep_mask_head_rcnn_config
.
DeepMaskHeadRCNN
,
model_config
:
deep_mask_head_rcnn_config
.
DeepMaskHeadRCNN
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
):
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
):
"""Builds Mask R-CNN model."""
"""Builds Mask R-CNN model."""
norm_activation_config
=
model_config
.
norm_activation
backbone
=
backbones
.
factory
.
build_backbone
(
backbone
=
backbones
.
factory
.
build_backbone
(
input_specs
=
input_specs
,
input_specs
=
input_specs
,
model_config
=
model_config
,
backbone_config
=
model_config
.
backbone
,
norm_activation_config
=
norm_activation_config
,
l2_regularizer
=
l2_regularizer
)
l2_regularizer
=
l2_regularizer
)
decoder
=
decoder_factory
.
build_decoder
(
decoder
=
decoder_factory
.
build_decoder
(
...
@@ -53,7 +55,6 @@ def build_maskrcnn(input_specs: tf.keras.layers.InputSpec,
...
@@ -53,7 +55,6 @@ def build_maskrcnn(input_specs: tf.keras.layers.InputSpec,
roi_aligner_config
=
model_config
.
roi_aligner
roi_aligner_config
=
model_config
.
roi_aligner
detection_head_config
=
model_config
.
detection_head
detection_head_config
=
model_config
.
detection_head
generator_config
=
model_config
.
detection_generator
generator_config
=
model_config
.
detection_generator
norm_activation_config
=
model_config
.
norm_activation
num_anchors_per_location
=
(
num_anchors_per_location
=
(
len
(
model_config
.
anchor
.
aspect_ratios
)
*
model_config
.
anchor
.
num_scales
)
len
(
model_config
.
anchor
.
aspect_ratios
)
*
model_config
.
anchor
.
num_scales
)
...
...
official/vision/beta/projects/simclr/tasks/simclr.py
View file @
091da63d
...
@@ -110,7 +110,8 @@ class SimCLRPretrainTask(base_task.Task):
...
@@ -110,7 +110,8 @@ class SimCLRPretrainTask(base_task.Task):
# Build backbone
# Build backbone
backbone
=
backbones
.
factory
.
build_backbone
(
backbone
=
backbones
.
factory
.
build_backbone
(
input_specs
=
input_specs
,
input_specs
=
input_specs
,
model_config
=
model_config
,
backbone_config
=
model_config
.
backbone
,
norm_activation_config
=
model_config
.
norm_activation
,
l2_regularizer
=
l2_regularizer
)
l2_regularizer
=
l2_regularizer
)
# Build projection head
# Build projection head
...
...
official/vision/beta/projects/yolo/modeling/backbones/darknet.py
View file @
091da63d
...
@@ -40,6 +40,7 @@ import collections
...
@@ -40,6 +40,7 @@ import collections
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
hyperparams
from
official.vision.beta.modeling.backbones
import
factory
from
official.vision.beta.modeling.backbones
import
factory
from
official.vision.beta.projects.yolo.modeling.layers
import
nn_blocks
from
official.vision.beta.projects.yolo.modeling.layers
import
nn_blocks
...
@@ -428,12 +429,12 @@ class Darknet(tf.keras.Model):
...
@@ -428,12 +429,12 @@ class Darknet(tf.keras.Model):
@
factory
.
register_backbone_builder
(
"darknet"
)
@
factory
.
register_backbone_builder
(
"darknet"
)
def
build_darknet
(
def
build_darknet
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
model_config
,
backbone_config
:
hyperparams
.
Config
,
norm_activation_config
:
hyperparams
.
Config
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
"""Builds darknet backbone."""
"""Builds darknet backbone."""
backbone_cfg
=
model_config
.
backbone
.
get
()
backbone_cfg
=
backbone_config
.
get
()
norm_activation_config
=
model_config
.
norm_activation
model
=
Darknet
(
model
=
Darknet
(
model_id
=
backbone_cfg
.
model_id
,
model_id
=
backbone_cfg
.
model_id
,
input_shape
=
input_specs
,
input_shape
=
input_specs
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment