"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "a800db23b8a7f69a95f972dc94d6f3ced29631f0"
Commit ba772461 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Makes the params override strict.

PiperOrigin-RevId: 307473476
parent 0cc16aa5
...@@ -187,8 +187,7 @@ def _get_params_from_flags(flags_obj: flags.FlagValues): ...@@ -187,8 +187,7 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
for param in overriding_configs: for param in overriding_configs:
logging.info('Overriding params: %s', param) logging.info('Overriding params: %s', param)
# Set is_strict to false because we can have dynamic dict parameters. params = params_dict.override_params_dict(params, param, is_strict=True)
params = params_dict.override_params_dict(params, param, is_strict=False)
params.validate() params.validate()
params.lock() params.lock()
......
...@@ -207,7 +207,7 @@ class ModelConfig(base_config.Config): ...@@ -207,7 +207,7 @@ class ModelConfig(base_config.Config):
""" """
name: str = None name: str = None
model_params: Mapping[str, Any] = None model_params: base_config.Config = None
num_classes: int = None num_classes: int = None
loss: LossConfig = None loss: LossConfig = None
optimizer: OptimizerConfig = None optimizer: OptimizerConfig = None
......
...@@ -22,6 +22,7 @@ from typing import Any, Mapping ...@@ -22,6 +22,7 @@ from typing import Any, Mapping
import dataclasses import dataclasses
from official.modeling.hyperparams import base_config
from official.vision.image_classification.configs import base_configs from official.vision.image_classification.configs import base_configs
...@@ -43,23 +44,24 @@ class EfficientNetModelConfig(base_configs.ModelConfig): ...@@ -43,23 +44,24 @@ class EfficientNetModelConfig(base_configs.ModelConfig):
configuration. configuration.
learning_rate: The configuration for learning rate. Defaults to an learning_rate: The configuration for learning rate. Defaults to an
exponential configuration. exponential configuration.
""" """
name: str = 'EfficientNet' name: str = 'EfficientNet'
num_classes: int = 1000 num_classes: int = 1000
model_params: Mapping[str, Any] = dataclasses.field(default_factory=lambda: { model_params: base_config.Config = dataclasses.field(
'model_name': 'efficientnet-b0', default_factory=lambda: {
'model_weights_path': '', 'model_name': 'efficientnet-b0',
'weights_format': 'saved_model', 'model_weights_path': '',
'overrides': { 'weights_format': 'saved_model',
'batch_norm': 'default', 'overrides': {
'rescale_input': True, 'batch_norm': 'default',
'num_classes': 1000, 'rescale_input': True,
} 'num_classes': 1000,
}) 'activation': 'swish',
'dtype': 'float32',
}
})
loss: base_configs.LossConfig = base_configs.LossConfig( loss: base_configs.LossConfig = base_configs.LossConfig(
name='categorical_crossentropy', name='categorical_crossentropy', label_smoothing=0.1)
label_smoothing=0.1)
optimizer: base_configs.OptimizerConfig = base_configs.OptimizerConfig( optimizer: base_configs.OptimizerConfig = base_configs.OptimizerConfig(
name='rmsprop', name='rmsprop',
decay=0.9, decay=0.9,
......
...@@ -22,6 +22,7 @@ from typing import Any, Mapping ...@@ -22,6 +22,7 @@ from typing import Any, Mapping
import dataclasses import dataclasses
from official.modeling.hyperparams import base_config
from official.vision.image_classification.configs import base_configs from official.vision.image_classification.configs import base_configs
...@@ -38,12 +39,13 @@ class ResNetModelConfig(base_configs.ModelConfig): ...@@ -38,12 +39,13 @@ class ResNetModelConfig(base_configs.ModelConfig):
"""Configuration for the ResNet model.""" """Configuration for the ResNet model."""
name: str = 'ResNet' name: str = 'ResNet'
num_classes: int = 1000 num_classes: int = 1000
model_params: Mapping[str, Any] = dataclasses.field(default_factory=lambda: { model_params: base_config.Config = dataclasses.field(
'num_classes': 1000, default_factory=lambda: {
'batch_size': None, 'num_classes': 1000,
'use_l2_regularizer': True, 'batch_size': None,
'rescale_inputs': False, 'use_l2_regularizer': True,
}) 'rescale_inputs': False,
})
loss: base_configs.LossConfig = base_configs.LossConfig( loss: base_configs.LossConfig = base_configs.LossConfig(
name='sparse_categorical_crossentropy') name='sparse_categorical_crossentropy')
optimizer: base_configs.OptimizerConfig = base_configs.OptimizerConfig( optimizer: base_configs.OptimizerConfig = base_configs.OptimizerConfig(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment