Commit fe9964a2 authored by Vishnu Banna's avatar Vishnu Banna
Browse files

Merge branch 'exp_pr' of https://github.com/PurdueCAM2Project/tf-models into exp_pr

parents f569f619 b785982f
......@@ -34,7 +34,8 @@ class MultiTaskBaseTrainer(orbit.StandardTrainer):
multi_task_model: Union[tf.keras.Model,
base_model.MultiTaskBaseModel],
optimizer: tf.optimizers.Optimizer,
trainer_options=None):
trainer_options=None,
train_datasets=None):
self._strategy = tf.distribute.get_strategy()
self._multi_task = multi_task
self._multi_task_model = multi_task_model
......@@ -55,10 +56,11 @@ class MultiTaskBaseTrainer(orbit.StandardTrainer):
global_step=self.global_step,
**checkpoint_items)
train_datasets = {}
for name, task in self.multi_task.tasks.items():
train_datasets[name] = orbit.utils.make_distributed_dataset(
self.strategy, task.build_inputs, task.task_config.train_data)
if train_datasets is None:
train_datasets = {}
for name, task in self.multi_task.tasks.items():
train_datasets[name] = orbit.utils.make_distributed_dataset(
self.strategy, task.build_inputs, task.task_config.train_data)
super().__init__(
train_dataset=train_datasets,
......
......@@ -15,7 +15,7 @@
"""Multitask training driver library."""
# pytype: disable=attribute-error
import os
from typing import List, Optional
from typing import Any, List, Optional, Tuple
from absl import logging
import orbit
import tensorflow as tf
......@@ -36,11 +36,16 @@ TRAINERS = {
}
def run_experiment(*, distribution_strategy: tf.distribute.Strategy,
task: multitask.MultiTask,
model: base_model.MultiTaskBaseModel, mode: str,
params: configs.MultiTaskExperimentConfig,
model_dir: str) -> base_model.MultiTaskBaseModel:
def run_experiment(
*,
distribution_strategy: tf.distribute.Strategy,
task: multitask.MultiTask,
model: base_model.MultiTaskBaseModel,
mode: str,
params: configs.MultiTaskExperimentConfig,
model_dir: str,
trainer: base_trainer.MultiTaskBaseTrainer = None
) -> base_model.MultiTaskBaseModel:
"""Runs train/eval configured by the experiment params.
Args:
......@@ -51,6 +56,8 @@ def run_experiment(*, distribution_strategy: tf.distribute.Strategy,
or 'continuous_eval'.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
trainer: (optional) A multi-task trainer to use. If none is provided, a
default one will be created based on `params`.
Returns:
model: `base_model.MultiTaskBaseModel` instance.
......@@ -66,8 +73,9 @@ def run_experiment(*, distribution_strategy: tf.distribute.Strategy,
sampler = task_sampler.get_task_sampler(params.trainer.task_sampler,
task.task_weights)
kwargs.update(dict(task_sampler=sampler))
trainer = TRAINERS[params.trainer.trainer_type](
**kwargs) if is_training else None
if trainer is None:
trainer = TRAINERS[params.trainer.trainer_type](
**kwargs) if is_training else None
if is_eval:
eval_steps = task.task_eval_steps
evaluator = evaluator_lib.MultiTaskEvaluator(
......@@ -145,7 +153,7 @@ def run_experiment_with_multitask_eval(
model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True,
trainer: Optional[core_lib.Trainer] = None) -> tf.keras.Model:
trainer: Optional[core_lib.Trainer] = None) -> Tuple[Any, Any]:
"""Runs train/eval configured by the experiment params.
Args:
......
......@@ -402,6 +402,40 @@ MNMultiAVG_BLOCK_SPECS = {
]
}
# Similar to MobileNetMultiAVG and used for segmentation task.
# Reduced the filters by a factor of 2 in the last block.
MNMultiAVG_SEG_BLOCK_SPECS = {
'spec_name': 'MobileNetMultiAVGSeg',
'block_spec_schema': [
'block_fn', 'kernel_size', 'strides', 'filters', 'activation',
'expand_ratio', 'use_normalization', 'use_bias', 'is_output'
],
'block_specs': [
('convbn', 3, 2, 32, 'relu', None, True, False, False),
('invertedbottleneck', 3, 2, 32, 'relu', 3., True, False, False),
('invertedbottleneck', 3, 1, 32, 'relu', 2., True, False, True),
('invertedbottleneck', 5, 2, 64, 'relu', 5., True, False, False),
('invertedbottleneck', 3, 1, 64, 'relu', 3., True, False, False),
('invertedbottleneck', 3, 1, 64, 'relu', 2., True, False, False),
('invertedbottleneck', 3, 1, 64, 'relu', 3., True, False, True),
('invertedbottleneck', 5, 2, 128, 'relu', 6., True, False, False),
('invertedbottleneck', 3, 1, 128, 'relu', 3., True, False, False),
('invertedbottleneck', 3, 1, 128, 'relu', 3., True, False, False),
('invertedbottleneck', 3, 1, 128, 'relu', 3., True, False, False),
('invertedbottleneck', 3, 1, 160, 'relu', 6., True, False, False),
('invertedbottleneck', 3, 1, 160, 'relu', 4., True, False, True),
('invertedbottleneck', 3, 2, 192, 'relu', 6., True, False, False),
('invertedbottleneck', 5, 1, 96, 'relu', 2., True, False, False),
('invertedbottleneck', 5, 1, 96, 'relu', 4., True, False, False),
('invertedbottleneck', 5, 1, 96, 'relu', 4., True, False, True),
('convbn', 1, 1, 480, 'relu', None, True, False, False),
('gpooling', None, None, None, None, None, None, None, False),
# Remove bias and add batch norm for the last layer to support QAT
# and achieve slightly better accuracy.
('convbn', 1, 1, 1280, 'relu', None, True, False, False),
]
}
SUPPORTED_SPECS_MAP = {
'MobileNetV1': MNV1_BLOCK_SPECS,
'MobileNetV2': MNV2_BLOCK_SPECS,
......@@ -410,6 +444,7 @@ SUPPORTED_SPECS_MAP = {
'MobileNetV3EdgeTPU': MNV3EdgeTPU_BLOCK_SPECS,
'MobileNetMultiMAX': MNMultiMAX_BLOCK_SPECS,
'MobileNetMultiAVG': MNMultiAVG_BLOCK_SPECS,
'MobileNetMultiAVGSeg': MNMultiAVG_SEG_BLOCK_SPECS,
}
......
......@@ -36,6 +36,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetV3EdgeTPU',
'MobileNetMultiAVG',
'MobileNetMultiMAX',
'MobileNetMultiAVGSeg',
)
def test_serialize_deserialize(self, model_id):
# Create a network object that sets all of its config options.
......@@ -80,6 +81,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetV3EdgeTPU',
'MobileNetMultiAVG',
'MobileNetMultiMAX',
'MobileNetMultiAVGSeg',
],
))
def test_input_specs(self, input_dim, model_id):
......@@ -102,6 +104,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetV3EdgeTPU',
'MobileNetMultiAVG',
'MobileNetMultiMAX',
'MobileNetMultiAVGSeg',
],
[32, 224],
))
......@@ -120,6 +123,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetV3EdgeTPU': [32, 48, 96, 192],
'MobileNetMultiMAX': [32, 64, 128, 160],
'MobileNetMultiAVG': [32, 64, 160, 192],
'MobileNetMultiAVGSeg': [32, 64, 160, 96],
}
network = mobilenet.MobileNet(model_id=model_id,
......@@ -143,6 +147,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetV3EdgeTPU',
'MobileNetMultiAVG',
'MobileNetMultiMAX',
'MobileNetMultiAVGSeg',
],
[32, 224],
))
......@@ -161,6 +166,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetV3EdgeTPU': [None, None, 384, 1280],
'MobileNetMultiMAX': [96, 128, 384, 640],
'MobileNetMultiAVG': [64, 192, 640, 768],
'MobileNetMultiAVGSeg': [64, 192, 640, 384],
}
network = mobilenet.MobileNet(model_id=model_id,
filter_size_scale=1.0,
......@@ -188,6 +194,8 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetV3EdgeTPU',
'MobileNetMultiAVG',
'MobileNetMultiMAX',
'MobileNetMultiMAX',
'MobileNetMultiAVGSeg',
],
[1.0, 0.75],
))
......@@ -209,6 +217,8 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
('MobileNetMultiAVG', 0.75): 2349704,
('MobileNetMultiMAX', 1.0): 3174560,
('MobileNetMultiMAX', 0.75): 2045816,
('MobileNetMultiAVGSeg', 1.0): 2284000,
('MobileNetMultiAVGSeg', 0.75): 1427816,
}
input_size = 224
......@@ -230,6 +240,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetV3EdgeTPU',
'MobileNetMultiAVG',
'MobileNetMultiMAX',
'MobileNetMultiAVGSeg',
],
[8, 16, 32],
))
......@@ -247,6 +258,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetV3EdgeTPU': 192,
'MobileNetMultiMAX': 160,
'MobileNetMultiAVG': 192,
'MobileNetMultiAVGSeg': 96,
}
network = mobilenet.MobileNet(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment