Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
0b7646e2
Commit
0b7646e2
authored
Nov 09, 2020
by
Pengchong Jin
Committed by
A. Unique TensorFlower
Nov 09, 2020
Browse files
Internal change
PiperOrigin-RevId: 341443730
parent
773ec44d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
11 deletions
+13
-11
official/vision/beta/configs/common.py
official/vision/beta/configs/common.py
+1
-1
official/vision/beta/configs/image_classification.py
official/vision/beta/configs/image_classification.py
+7
-5
official/vision/beta/modeling/backbones/factory_test.py
official/vision/beta/modeling/backbones/factory_test.py
+5
-5
No files found.
official/vision/beta/configs/common.py
View file @
0b7646e2
...
@@ -24,6 +24,6 @@ from official.modeling import hyperparams
...
@@ -24,6 +24,6 @@ from official.modeling import hyperparams
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
NormActivation
(
hyperparams
.
Config
):
class
NormActivation
(
hyperparams
.
Config
):
activation
:
str
=
'relu'
activation
:
str
=
'relu'
use_sync_bn
:
bool
=
Fals
e
use_sync_bn
:
bool
=
Tru
e
norm_momentum
:
float
=
0.99
norm_momentum
:
float
=
0.99
norm_epsilon
:
float
=
0.001
norm_epsilon
:
float
=
0.001
official/vision/beta/configs/image_classification.py
View file @
0b7646e2
...
@@ -38,12 +38,14 @@ class DataConfig(cfg.DataConfig):
...
@@ -38,12 +38,14 @@ class DataConfig(cfg.DataConfig):
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
ImageClassificationModel
(
hyperparams
.
Config
):
class
ImageClassificationModel
(
hyperparams
.
Config
):
"""The model config."""
num_classes
:
int
=
0
num_classes
:
int
=
0
input_size
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
input_size
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
backbone
:
backbones
.
Backbone
=
backbones
.
Backbone
(
backbone
:
backbones
.
Backbone
=
backbones
.
Backbone
(
type
=
'resnet'
,
resnet
=
backbones
.
ResNet
())
type
=
'resnet'
,
resnet
=
backbones
.
ResNet
())
dropout_rate
:
float
=
0.0
dropout_rate
:
float
=
0.0
norm_activation
:
common
.
NormActivation
=
common
.
NormActivation
()
norm_activation
:
common
.
NormActivation
=
common
.
NormActivation
(
use_sync_bn
=
False
)
# Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
# Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
add_head_batch_norm
:
bool
=
False
add_head_batch_norm
:
bool
=
False
...
@@ -57,7 +59,7 @@ class Losses(hyperparams.Config):
...
@@ -57,7 +59,7 @@ class Losses(hyperparams.Config):
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
ImageClassificationTask
(
cfg
.
TaskConfig
):
class
ImageClassificationTask
(
cfg
.
TaskConfig
):
"""The
model
config."""
"""The
task
config."""
model
:
ImageClassificationModel
=
ImageClassificationModel
()
model
:
ImageClassificationModel
=
ImageClassificationModel
()
train_data
:
DataConfig
=
DataConfig
(
is_training
=
True
)
train_data
:
DataConfig
=
DataConfig
(
is_training
=
True
)
validation_data
:
DataConfig
=
DataConfig
(
is_training
=
False
)
validation_data
:
DataConfig
=
DataConfig
(
is_training
=
False
)
...
@@ -98,7 +100,7 @@ def image_classification_imagenet() -> cfg.ExperimentConfig:
...
@@ -98,7 +100,7 @@ def image_classification_imagenet() -> cfg.ExperimentConfig:
backbone
=
backbones
.
Backbone
(
backbone
=
backbones
.
Backbone
(
type
=
'resnet'
,
resnet
=
backbones
.
ResNet
(
model_id
=
50
)),
type
=
'resnet'
,
resnet
=
backbones
.
ResNet
(
model_id
=
50
)),
norm_activation
=
common
.
NormActivation
(
norm_activation
=
common
.
NormActivation
(
norm_momentum
=
0.9
,
norm_epsilon
=
1e-5
)),
norm_momentum
=
0.9
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
False
)),
losses
=
Losses
(
l2_weight_decay
=
1e-4
),
losses
=
Losses
(
l2_weight_decay
=
1e-4
),
train_data
=
DataConfig
(
train_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'train*'
),
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'train*'
),
...
@@ -168,7 +170,7 @@ def image_classification_imagenet_revnet() -> cfg.ExperimentConfig:
...
@@ -168,7 +170,7 @@ def image_classification_imagenet_revnet() -> cfg.ExperimentConfig:
backbone
=
backbones
.
Backbone
(
backbone
=
backbones
.
Backbone
(
type
=
'revnet'
,
revnet
=
backbones
.
RevNet
(
model_id
=
56
)),
type
=
'revnet'
,
revnet
=
backbones
.
RevNet
(
model_id
=
56
)),
norm_activation
=
common
.
NormActivation
(
norm_activation
=
common
.
NormActivation
(
norm_momentum
=
0.9
,
norm_epsilon
=
1e-5
),
norm_momentum
=
0.9
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
False
),
add_head_batch_norm
=
True
),
add_head_batch_norm
=
True
),
losses
=
Losses
(
l2_weight_decay
=
1e-4
),
losses
=
Losses
(
l2_weight_decay
=
1e-4
),
train_data
=
DataConfig
(
train_data
=
DataConfig
(
...
@@ -236,7 +238,7 @@ def image_classification_imagenet_mobilenet() -> cfg.ExperimentConfig:
...
@@ -236,7 +238,7 @@ def image_classification_imagenet_mobilenet() -> cfg.ExperimentConfig:
mobilenet
=
backbones
.
MobileNet
(
mobilenet
=
backbones
.
MobileNet
(
model_id
=
'MobileNetV2'
,
filter_size_scale
=
1.0
)),
model_id
=
'MobileNetV2'
,
filter_size_scale
=
1.0
)),
norm_activation
=
common
.
NormActivation
(
norm_activation
=
common
.
NormActivation
(
norm_momentum
=
0.997
,
norm_epsilon
=
1e-3
)),
norm_momentum
=
0.997
,
norm_epsilon
=
1e-3
,
use_sync_bn
=
False
)),
losses
=
Losses
(
l2_weight_decay
=
1e-5
,
label_smoothing
=
0.1
),
losses
=
Losses
(
l2_weight_decay
=
1e-5
,
label_smoothing
=
0.1
),
train_data
=
DataConfig
(
train_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'train*'
),
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'train*'
),
...
...
official/vision/beta/modeling/backbones/factory_test.py
View file @
0b7646e2
...
@@ -41,7 +41,7 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -41,7 +41,7 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
type
=
'resnet'
,
type
=
'resnet'
,
resnet
=
backbones_cfg
.
ResNet
(
model_id
=
model_id
))
resnet
=
backbones_cfg
.
ResNet
(
model_id
=
model_id
))
norm_activation_config
=
common_cfg
.
NormActivation
(
norm_activation_config
=
common_cfg
.
NormActivation
(
norm_momentum
=
0.99
,
norm_epsilon
=
1e-5
)
norm_momentum
=
0.99
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
False
)
model_config
=
retinanet_cfg
.
RetinaNet
(
model_config
=
retinanet_cfg
.
RetinaNet
(
backbone
=
backbone_config
,
norm_activation
=
norm_activation_config
)
backbone
=
backbone_config
,
norm_activation
=
norm_activation_config
)
...
@@ -73,7 +73,7 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -73,7 +73,7 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
efficientnet
=
backbones_cfg
.
EfficientNet
(
efficientnet
=
backbones_cfg
.
EfficientNet
(
model_id
=
model_id
,
se_ratio
=
se_ratio
))
model_id
=
model_id
,
se_ratio
=
se_ratio
))
norm_activation_config
=
common_cfg
.
NormActivation
(
norm_activation_config
=
common_cfg
.
NormActivation
(
norm_momentum
=
0.99
,
norm_epsilon
=
1e-5
)
norm_momentum
=
0.99
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
False
)
model_config
=
retinanet_cfg
.
RetinaNet
(
model_config
=
retinanet_cfg
.
RetinaNet
(
backbone
=
backbone_config
,
norm_activation
=
norm_activation_config
)
backbone
=
backbone_config
,
norm_activation
=
norm_activation_config
)
...
@@ -107,7 +107,7 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -107,7 +107,7 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
mobilenet
=
backbones_cfg
.
MobileNet
(
mobilenet
=
backbones_cfg
.
MobileNet
(
model_id
=
model_id
,
filter_size_scale
=
filter_size_scale
))
model_id
=
model_id
,
filter_size_scale
=
filter_size_scale
))
norm_activation_config
=
common_cfg
.
NormActivation
(
norm_activation_config
=
common_cfg
.
NormActivation
(
norm_momentum
=
0.99
,
norm_epsilon
=
1e-5
)
norm_momentum
=
0.99
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
False
)
model_config
=
retinanet_cfg
.
RetinaNet
(
model_config
=
retinanet_cfg
.
RetinaNet
(
backbone
=
backbone_config
,
norm_activation
=
norm_activation_config
)
backbone
=
backbone_config
,
norm_activation
=
norm_activation_config
)
...
@@ -140,7 +140,7 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -140,7 +140,7 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
type
=
'spinenet'
,
type
=
'spinenet'
,
spinenet
=
backbones_cfg
.
SpineNet
(
model_id
=
model_id
))
spinenet
=
backbones_cfg
.
SpineNet
(
model_id
=
model_id
))
norm_activation_config
=
common_cfg
.
NormActivation
(
norm_activation_config
=
common_cfg
.
NormActivation
(
norm_momentum
=
0.99
,
norm_epsilon
=
1e-5
)
norm_momentum
=
0.99
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
False
)
model_config
=
retinanet_cfg
.
RetinaNet
(
model_config
=
retinanet_cfg
.
RetinaNet
(
backbone
=
backbone_config
,
norm_activation
=
norm_activation_config
)
backbone
=
backbone_config
,
norm_activation
=
norm_activation_config
)
...
@@ -165,7 +165,7 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -165,7 +165,7 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
type
=
'revnet'
,
type
=
'revnet'
,
revnet
=
backbones_cfg
.
RevNet
(
model_id
=
model_id
))
revnet
=
backbones_cfg
.
RevNet
(
model_id
=
model_id
))
norm_activation_config
=
common_cfg
.
NormActivation
(
norm_activation_config
=
common_cfg
.
NormActivation
(
norm_momentum
=
0.99
,
norm_epsilon
=
1e-5
)
norm_momentum
=
0.99
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
False
)
model_config
=
retinanet_cfg
.
RetinaNet
(
model_config
=
retinanet_cfg
.
RetinaNet
(
backbone
=
backbone_config
,
norm_activation
=
norm_activation_config
)
backbone
=
backbone_config
,
norm_activation
=
norm_activation_config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment