Commit d0b78926 authored by Pengchong Jin's avatar Pengchong Jin Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 341443730
parent 49c66406
......@@ -24,6 +24,6 @@ from official.modeling import hyperparams
@dataclasses.dataclass
class NormActivation(hyperparams.Config):
activation: str = 'relu'
use_sync_bn: bool = False
use_sync_bn: bool = True
norm_momentum: float = 0.99
norm_epsilon: float = 0.001
......@@ -38,12 +38,14 @@ class DataConfig(cfg.DataConfig):
@dataclasses.dataclass
class ImageClassificationModel(hyperparams.Config):
"""The model config."""
num_classes: int = 0
input_size: List[int] = dataclasses.field(default_factory=list)
backbone: backbones.Backbone = backbones.Backbone(
type='resnet', resnet=backbones.ResNet())
dropout_rate: float = 0.0
norm_activation: common.NormActivation = common.NormActivation()
norm_activation: common.NormActivation = common.NormActivation(
use_sync_bn=False)
# Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
add_head_batch_norm: bool = False
......@@ -57,7 +59,7 @@ class Losses(hyperparams.Config):
@dataclasses.dataclass
class ImageClassificationTask(cfg.TaskConfig):
"""The model config."""
"""The task config."""
model: ImageClassificationModel = ImageClassificationModel()
train_data: DataConfig = DataConfig(is_training=True)
validation_data: DataConfig = DataConfig(is_training=False)
......@@ -98,7 +100,7 @@ def image_classification_imagenet() -> cfg.ExperimentConfig:
backbone=backbones.Backbone(
type='resnet', resnet=backbones.ResNet(model_id=50)),
norm_activation=common.NormActivation(
norm_momentum=0.9, norm_epsilon=1e-5)),
norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False)),
losses=Losses(l2_weight_decay=1e-4),
train_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
......@@ -168,7 +170,7 @@ def image_classification_imagenet_revnet() -> cfg.ExperimentConfig:
backbone=backbones.Backbone(
type='revnet', revnet=backbones.RevNet(model_id=56)),
norm_activation=common.NormActivation(
norm_momentum=0.9, norm_epsilon=1e-5),
norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False),
add_head_batch_norm=True),
losses=Losses(l2_weight_decay=1e-4),
train_data=DataConfig(
......@@ -236,7 +238,7 @@ def image_classification_imagenet_mobilenet() -> cfg.ExperimentConfig:
mobilenet=backbones.MobileNet(
model_id='MobileNetV2', filter_size_scale=1.0)),
norm_activation=common.NormActivation(
norm_momentum=0.997, norm_epsilon=1e-3)),
norm_momentum=0.997, norm_epsilon=1e-3, use_sync_bn=False)),
losses=Losses(l2_weight_decay=1e-5, label_smoothing=0.1),
train_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
......
......@@ -41,7 +41,7 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
type='resnet',
resnet=backbones_cfg.ResNet(model_id=model_id))
norm_activation_config = common_cfg.NormActivation(
norm_momentum=0.99, norm_epsilon=1e-5)
norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False)
model_config = retinanet_cfg.RetinaNet(
backbone=backbone_config, norm_activation=norm_activation_config)
......@@ -73,7 +73,7 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
efficientnet=backbones_cfg.EfficientNet(
model_id=model_id, se_ratio=se_ratio))
norm_activation_config = common_cfg.NormActivation(
norm_momentum=0.99, norm_epsilon=1e-5)
norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False)
model_config = retinanet_cfg.RetinaNet(
backbone=backbone_config, norm_activation=norm_activation_config)
......@@ -107,7 +107,7 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
mobilenet=backbones_cfg.MobileNet(
model_id=model_id, filter_size_scale=filter_size_scale))
norm_activation_config = common_cfg.NormActivation(
norm_momentum=0.99, norm_epsilon=1e-5)
norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False)
model_config = retinanet_cfg.RetinaNet(
backbone=backbone_config, norm_activation=norm_activation_config)
......@@ -140,7 +140,7 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
type='spinenet',
spinenet=backbones_cfg.SpineNet(model_id=model_id))
norm_activation_config = common_cfg.NormActivation(
norm_momentum=0.99, norm_epsilon=1e-5)
norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False)
model_config = retinanet_cfg.RetinaNet(
backbone=backbone_config, norm_activation=norm_activation_config)
......@@ -165,7 +165,7 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
type='revnet',
revnet=backbones_cfg.RevNet(model_id=model_id))
norm_activation_config = common_cfg.NormActivation(
norm_momentum=0.99, norm_epsilon=1e-5)
norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False)
model_config = retinanet_cfg.RetinaNet(
backbone=backbone_config, norm_activation=norm_activation_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment