Commit d0b78926 authored by Pengchong Jin's avatar Pengchong Jin Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 341443730
parent 49c66406
...@@ -24,6 +24,6 @@ from official.modeling import hyperparams ...@@ -24,6 +24,6 @@ from official.modeling import hyperparams
@dataclasses.dataclass @dataclasses.dataclass
class NormActivation(hyperparams.Config): class NormActivation(hyperparams.Config):
activation: str = 'relu' activation: str = 'relu'
use_sync_bn: bool = False use_sync_bn: bool = True
norm_momentum: float = 0.99 norm_momentum: float = 0.99
norm_epsilon: float = 0.001 norm_epsilon: float = 0.001
...@@ -38,12 +38,14 @@ class DataConfig(cfg.DataConfig): ...@@ -38,12 +38,14 @@ class DataConfig(cfg.DataConfig):
@dataclasses.dataclass @dataclasses.dataclass
class ImageClassificationModel(hyperparams.Config): class ImageClassificationModel(hyperparams.Config):
"""The model config."""
num_classes: int = 0 num_classes: int = 0
input_size: List[int] = dataclasses.field(default_factory=list) input_size: List[int] = dataclasses.field(default_factory=list)
backbone: backbones.Backbone = backbones.Backbone( backbone: backbones.Backbone = backbones.Backbone(
type='resnet', resnet=backbones.ResNet()) type='resnet', resnet=backbones.ResNet())
dropout_rate: float = 0.0 dropout_rate: float = 0.0
norm_activation: common.NormActivation = common.NormActivation() norm_activation: common.NormActivation = common.NormActivation(
use_sync_bn=False)
# Adds a BatchNormalization layer pre-GlobalAveragePooling in classification # Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
add_head_batch_norm: bool = False add_head_batch_norm: bool = False
...@@ -57,7 +59,7 @@ class Losses(hyperparams.Config): ...@@ -57,7 +59,7 @@ class Losses(hyperparams.Config):
@dataclasses.dataclass @dataclasses.dataclass
class ImageClassificationTask(cfg.TaskConfig): class ImageClassificationTask(cfg.TaskConfig):
"""The model config.""" """The task config."""
model: ImageClassificationModel = ImageClassificationModel() model: ImageClassificationModel = ImageClassificationModel()
train_data: DataConfig = DataConfig(is_training=True) train_data: DataConfig = DataConfig(is_training=True)
validation_data: DataConfig = DataConfig(is_training=False) validation_data: DataConfig = DataConfig(is_training=False)
...@@ -98,7 +100,7 @@ def image_classification_imagenet() -> cfg.ExperimentConfig: ...@@ -98,7 +100,7 @@ def image_classification_imagenet() -> cfg.ExperimentConfig:
backbone=backbones.Backbone( backbone=backbones.Backbone(
type='resnet', resnet=backbones.ResNet(model_id=50)), type='resnet', resnet=backbones.ResNet(model_id=50)),
norm_activation=common.NormActivation( norm_activation=common.NormActivation(
norm_momentum=0.9, norm_epsilon=1e-5)), norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False)),
losses=Losses(l2_weight_decay=1e-4), losses=Losses(l2_weight_decay=1e-4),
train_data=DataConfig( train_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'), input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
...@@ -168,7 +170,7 @@ def image_classification_imagenet_revnet() -> cfg.ExperimentConfig: ...@@ -168,7 +170,7 @@ def image_classification_imagenet_revnet() -> cfg.ExperimentConfig:
backbone=backbones.Backbone( backbone=backbones.Backbone(
type='revnet', revnet=backbones.RevNet(model_id=56)), type='revnet', revnet=backbones.RevNet(model_id=56)),
norm_activation=common.NormActivation( norm_activation=common.NormActivation(
norm_momentum=0.9, norm_epsilon=1e-5), norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False),
add_head_batch_norm=True), add_head_batch_norm=True),
losses=Losses(l2_weight_decay=1e-4), losses=Losses(l2_weight_decay=1e-4),
train_data=DataConfig( train_data=DataConfig(
...@@ -236,7 +238,7 @@ def image_classification_imagenet_mobilenet() -> cfg.ExperimentConfig: ...@@ -236,7 +238,7 @@ def image_classification_imagenet_mobilenet() -> cfg.ExperimentConfig:
mobilenet=backbones.MobileNet( mobilenet=backbones.MobileNet(
model_id='MobileNetV2', filter_size_scale=1.0)), model_id='MobileNetV2', filter_size_scale=1.0)),
norm_activation=common.NormActivation( norm_activation=common.NormActivation(
norm_momentum=0.997, norm_epsilon=1e-3)), norm_momentum=0.997, norm_epsilon=1e-3, use_sync_bn=False)),
losses=Losses(l2_weight_decay=1e-5, label_smoothing=0.1), losses=Losses(l2_weight_decay=1e-5, label_smoothing=0.1),
train_data=DataConfig( train_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'), input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
......
...@@ -41,7 +41,7 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -41,7 +41,7 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
type='resnet', type='resnet',
resnet=backbones_cfg.ResNet(model_id=model_id)) resnet=backbones_cfg.ResNet(model_id=model_id))
norm_activation_config = common_cfg.NormActivation( norm_activation_config = common_cfg.NormActivation(
norm_momentum=0.99, norm_epsilon=1e-5) norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False)
model_config = retinanet_cfg.RetinaNet( model_config = retinanet_cfg.RetinaNet(
backbone=backbone_config, norm_activation=norm_activation_config) backbone=backbone_config, norm_activation=norm_activation_config)
...@@ -73,7 +73,7 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -73,7 +73,7 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
efficientnet=backbones_cfg.EfficientNet( efficientnet=backbones_cfg.EfficientNet(
model_id=model_id, se_ratio=se_ratio)) model_id=model_id, se_ratio=se_ratio))
norm_activation_config = common_cfg.NormActivation( norm_activation_config = common_cfg.NormActivation(
norm_momentum=0.99, norm_epsilon=1e-5) norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False)
model_config = retinanet_cfg.RetinaNet( model_config = retinanet_cfg.RetinaNet(
backbone=backbone_config, norm_activation=norm_activation_config) backbone=backbone_config, norm_activation=norm_activation_config)
...@@ -107,7 +107,7 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -107,7 +107,7 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
mobilenet=backbones_cfg.MobileNet( mobilenet=backbones_cfg.MobileNet(
model_id=model_id, filter_size_scale=filter_size_scale)) model_id=model_id, filter_size_scale=filter_size_scale))
norm_activation_config = common_cfg.NormActivation( norm_activation_config = common_cfg.NormActivation(
norm_momentum=0.99, norm_epsilon=1e-5) norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False)
model_config = retinanet_cfg.RetinaNet( model_config = retinanet_cfg.RetinaNet(
backbone=backbone_config, norm_activation=norm_activation_config) backbone=backbone_config, norm_activation=norm_activation_config)
...@@ -140,7 +140,7 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -140,7 +140,7 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
type='spinenet', type='spinenet',
spinenet=backbones_cfg.SpineNet(model_id=model_id)) spinenet=backbones_cfg.SpineNet(model_id=model_id))
norm_activation_config = common_cfg.NormActivation( norm_activation_config = common_cfg.NormActivation(
norm_momentum=0.99, norm_epsilon=1e-5) norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False)
model_config = retinanet_cfg.RetinaNet( model_config = retinanet_cfg.RetinaNet(
backbone=backbone_config, norm_activation=norm_activation_config) backbone=backbone_config, norm_activation=norm_activation_config)
...@@ -165,7 +165,7 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -165,7 +165,7 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
type='revnet', type='revnet',
revnet=backbones_cfg.RevNet(model_id=model_id)) revnet=backbones_cfg.RevNet(model_id=model_id))
norm_activation_config = common_cfg.NormActivation( norm_activation_config = common_cfg.NormActivation(
norm_momentum=0.99, norm_epsilon=1e-5) norm_momentum=0.99, norm_epsilon=1e-5, use_sync_bn=False)
model_config = retinanet_cfg.RetinaNet( model_config = retinanet_cfg.RetinaNet(
backbone=backbone_config, norm_activation=norm_activation_config) backbone=backbone_config, norm_activation=norm_activation_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment