Commit ba772461 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Makes the params override strict.

PiperOrigin-RevId: 307473476
parent 0cc16aa5
......@@ -187,8 +187,7 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
for param in overriding_configs:
logging.info('Overriding params: %s', param)
# Set is_strict to false because we can have dynamic dict parameters.
params = params_dict.override_params_dict(params, param, is_strict=False)
params = params_dict.override_params_dict(params, param, is_strict=True)
params.validate()
params.lock()
......
......@@ -207,7 +207,7 @@ class ModelConfig(base_config.Config):
"""
name: str = None
model_params: Mapping[str, Any] = None
model_params: base_config.Config = None
num_classes: int = None
loss: LossConfig = None
optimizer: OptimizerConfig = None
......
......@@ -22,6 +22,7 @@ from typing import Any, Mapping
import dataclasses
from official.modeling.hyperparams import base_config
from official.vision.image_classification.configs import base_configs
......@@ -43,23 +44,24 @@ class EfficientNetModelConfig(base_configs.ModelConfig):
configuration.
learning_rate: The configuration for learning rate. Defaults to an
exponential configuration.
"""
name: str = 'EfficientNet'
num_classes: int = 1000
model_params: Mapping[str, Any] = dataclasses.field(default_factory=lambda: {
'model_name': 'efficientnet-b0',
'model_weights_path': '',
'weights_format': 'saved_model',
'overrides': {
'batch_norm': 'default',
'rescale_input': True,
'num_classes': 1000,
}
})
model_params: base_config.Config = dataclasses.field(
default_factory=lambda: {
'model_name': 'efficientnet-b0',
'model_weights_path': '',
'weights_format': 'saved_model',
'overrides': {
'batch_norm': 'default',
'rescale_input': True,
'num_classes': 1000,
'activation': 'swish',
'dtype': 'float32',
}
})
loss: base_configs.LossConfig = base_configs.LossConfig(
name='categorical_crossentropy',
label_smoothing=0.1)
name='categorical_crossentropy', label_smoothing=0.1)
optimizer: base_configs.OptimizerConfig = base_configs.OptimizerConfig(
name='rmsprop',
decay=0.9,
......
......@@ -22,6 +22,7 @@ from typing import Any, Mapping
import dataclasses
from official.modeling.hyperparams import base_config
from official.vision.image_classification.configs import base_configs
......@@ -38,12 +39,13 @@ class ResNetModelConfig(base_configs.ModelConfig):
"""Configuration for the ResNet model."""
name: str = 'ResNet'
num_classes: int = 1000
model_params: Mapping[str, Any] = dataclasses.field(default_factory=lambda: {
'num_classes': 1000,
'batch_size': None,
'use_l2_regularizer': True,
'rescale_inputs': False,
})
model_params: base_config.Config = dataclasses.field(
default_factory=lambda: {
'num_classes': 1000,
'batch_size': None,
'use_l2_regularizer': True,
'rescale_inputs': False,
})
loss: base_configs.LossConfig = base_configs.LossConfig(
name='sparse_categorical_crossentropy')
optimizer: base_configs.OptimizerConfig = base_configs.OptimizerConfig(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment