Commit 2e9bb539 authored by stephenwu's avatar stephenwu
Browse files

Merge branch 'master' of https://github.com/tensorflow/models into RTESuperGLUE

parents 7bae5317 8fba84f8
...@@ -26,9 +26,12 @@ from official.vision.beta.configs import image_classification as exp_cfg ...@@ -26,9 +26,12 @@ from official.vision.beta.configs import image_classification as exp_cfg
class ImageClassificationConfigTest(tf.test.TestCase, parameterized.TestCase): class ImageClassificationConfigTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(('resnet_imagenet',), @parameterized.parameters(
('revnet_imagenet',), ('resnet_imagenet',),
('mobilenet_imagenet'),) ('resnet_rs_imagenet',),
('revnet_imagenet',),
('mobilenet_imagenet'),
)
def test_image_classification_configs(self, config_name): def test_image_classification_configs(self, config_name):
config = exp_factory.get_exp_config(config_name) config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig) self.assertIsInstance(config, cfg.ExperimentConfig)
......
...@@ -90,6 +90,12 @@ class Losses(hyperparams.Config): ...@@ -90,6 +90,12 @@ class Losses(hyperparams.Config):
top_k_percent_pixels: float = 1.0 top_k_percent_pixels: float = 1.0
@dataclasses.dataclass
class Evaluation(hyperparams.Config):
report_per_class_iou: bool = True
report_train_mean_iou: bool = True # Turning this off can speed up training.
@dataclasses.dataclass @dataclasses.dataclass
class SemanticSegmentationTask(cfg.TaskConfig): class SemanticSegmentationTask(cfg.TaskConfig):
"""The model config.""" """The model config."""
...@@ -97,6 +103,7 @@ class SemanticSegmentationTask(cfg.TaskConfig): ...@@ -97,6 +103,7 @@ class SemanticSegmentationTask(cfg.TaskConfig):
train_data: DataConfig = DataConfig(is_training=True) train_data: DataConfig = DataConfig(is_training=True)
validation_data: DataConfig = DataConfig(is_training=False) validation_data: DataConfig = DataConfig(is_training=False)
losses: Losses = Losses() losses: Losses = Losses()
evaluation: Evaluation = Evaluation()
train_input_partition_dims: List[int] = dataclasses.field( train_input_partition_dims: List[int] = dataclasses.field(
default_factory=list) default_factory=list)
eval_input_partition_dims: List[int] = dataclasses.field( eval_input_partition_dims: List[int] = dataclasses.field(
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -182,6 +182,7 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -182,6 +182,7 @@ class SegmentationHead(tf.keras.layers.Layer):
x = tf.image.resize( x = tf.image.resize(
x, tf.shape(y)[1:3], method=tf.image.ResizeMethod.BILINEAR) x, tf.shape(y)[1:3], method=tf.image.ResizeMethod.BILINEAR)
x = tf.cast(x, dtype=y.dtype)
x = tf.concat([x, y], axis=self._bn_axis) x = tf.concat([x, y], axis=self._bn_axis)
elif self._config_dict['feature_fusion'] == 'pyramid_fusion': elif self._config_dict['feature_fusion'] == 'pyramid_fusion':
x = nn_layers.pyramid_feature_fusion(decoder_output, x = nn_layers.pyramid_feature_fusion(decoder_output,
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment