Commit 315506c9 authored by ghpark's avatar ghpark
Browse files

Edit tests

parent 02f5599f
...@@ -27,14 +27,24 @@ from official.projects.detr.dataloaders import coco ...@@ -27,14 +27,24 @@ from official.projects.detr.dataloaders import coco
class DetrTest(tf.test.TestCase, parameterized.TestCase): class DetrTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(('detr_coco',)) @parameterized.parameters(('detr_coco',))
def test_detr_configs(self, config_name): def test_detr_configs_tfds(self, config_name):
config = exp_factory.get_exp_config(config_name) config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig) self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.DetectionConfig) self.assertIsInstance(config.task, exp_cfg.DetrTask)
self.assertIsInstance(config.task.train_data, coco.COCODataConfig) self.assertIsInstance(config.task.train_data, coco.COCODataConfig)
config.task.train_data.is_training = None config.task.train_data.is_training = None
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
config.validate() config.validate()
@parameterized.parameters(('detr_coco_tfrecord'),('detr_coco_tfds'))
def test_detr_configs(self, config_name):
config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.DetrTask)
self.assertIsInstance(config.task.train_data, cfg.DataConfig)
config.task.train_data.is_training = None
with self.assertRaises(KeyError):
config.validate()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -58,9 +58,11 @@ def _as_dataset(self, *args, **kwargs): ...@@ -58,9 +58,11 @@ def _as_dataset(self, *args, **kwargs):
class DetectionTest(tf.test.TestCase): class DetectionTest(tf.test.TestCase):
def test_train_step(self): def test_train_step(self):
config = detr_cfg.DetectionConfig( config = detr_cfg.DetrTask(
num_encoder_layers=1, model=detr_cfg.Detr(
num_decoder_layers=1, input_size=[1333, 1333, 3],
num_encoder_layers=1,
num_decoder_layers=1,),
train_data=coco.COCODataConfig( train_data=coco.COCODataConfig(
tfds_name='coco/2017', tfds_name='coco/2017',
tfds_split='validation', tfds_split='validation',
...@@ -92,9 +94,11 @@ class DetectionTest(tf.test.TestCase): ...@@ -92,9 +94,11 @@ class DetectionTest(tf.test.TestCase):
task.train_step(next(iterator), model, optimizer) task.train_step(next(iterator), model, optimizer)
def test_validation_step(self): def test_validation_step(self):
config = detr_cfg.DetectionConfig( config = detr_cfg.DetrTask(
num_encoder_layers=1, model=detr_cfg.Detr(
num_decoder_layers=1, input_size=[1333, 1333, 3],
num_encoder_layers=1,
num_decoder_layers=1,),
validation_data=coco.COCODataConfig( validation_data=coco.COCODataConfig(
tfds_name='coco/2017', tfds_name='coco/2017',
tfds_split='validation', tfds_split='validation',
...@@ -112,5 +116,127 @@ class DetectionTest(tf.test.TestCase): ...@@ -112,5 +116,127 @@ class DetectionTest(tf.test.TestCase):
state = task.aggregate_logs(step_outputs=logs) state = task.aggregate_logs(step_outputs=logs)
task.reduce_aggregated_logs(state) task.reduce_aggregated_logs(state)
class DetectionTest_tfds(tf.test.TestCase):
def test_train_step(self):
config = detr_cfg.DetrTask(
model=detr_cfg.Detr(
input_size=[1333, 1333, 3],
num_encoder_layers=1,
num_decoder_layers=1,),
train_data=detr_cfg.DataConfig(
tfds_name='coco/2017',
tfds_split='validation',
is_training=True,
global_batch_size=2,
))
with tfds.testing.mock_data(as_dataset_fn=_as_dataset):
task = detection.DectectionTask(config)
model = task.build_model()
dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)
opt_cfg = optimization.OptimizationConfig({
'optimizer': {
'type': 'detr_adamw',
'detr_adamw': {
'weight_decay_rate': 1e-4,
'global_clipnorm': 0.1,
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {
'boundaries': [120000],
'values': [0.0001, 1.0e-05]
}
},
})
optimizer = detection.DectectionTask.create_optimizer(opt_cfg)
task.train_step(next(iterator), model, optimizer)
def test_validation_step(self):
config = detr_cfg.DetrTask(
model=detr_cfg.Detr(
input_size=[1333, 1333, 3],
num_encoder_layers=1,
num_decoder_layers=1,),
validation_data=detr_cfg.DataConfig(
tfds_name='coco/2017',
tfds_split='validation',
is_training=False,
global_batch_size=2,
))
with tfds.testing.mock_data(as_dataset_fn=_as_dataset):
task = detection.DectectionTask(config)
model = task.build_model()
metrics = task.build_metrics(training=False)
dataset = task.build_inputs(config.validation_data)
iterator = iter(dataset)
logs = task.validation_step(next(iterator), model, metrics)
state = task.aggregate_logs(step_outputs=logs)
task.reduce_aggregated_logs(state)
class DetectionTest_tfrecord(tf.test.TestCase):
def test_train_step(self):
config = detr_cfg.DetrTask(
model=detr_cfg.Detr(
input_size=[1333, 1333, 3],
num_encoder_layers=1,
num_decoder_layers=1,),
train_data=detr_cfg.DataConfig(
input_path='/data/MS_COCO/tfrecords/train*',
tfds_name='',
is_training=True,
global_batch_size=2,
))
with tfds.testing.mock_data(as_dataset_fn=_as_dataset):
task = detection.DectectionTask(config)
model = task.build_model()
dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)
opt_cfg = optimization.OptimizationConfig({
'optimizer': {
'type': 'detr_adamw',
'detr_adamw': {
'weight_decay_rate': 1e-4,
'global_clipnorm': 0.1,
}
},
'learning_rate': {
'type': 'stepwise',
'stepwise': {
'boundaries': [120000],
'values': [0.0001, 1.0e-05]
}
},
})
optimizer = detection.DectectionTask.create_optimizer(opt_cfg)
task.train_step(next(iterator), model, optimizer)
def test_validation_step(self):
config = detr_cfg.DetrTask(
model=detr_cfg.Detr(
input_size=[1333, 1333, 3],
num_encoder_layers=1,
num_decoder_layers=1,),
validation_data=detr_cfg.DataConfig(
input_path='/data/MS_COCO/tfrecords/val*',
tfds_name='',
is_training=False,
global_batch_size=2,
))
with tfds.testing.mock_data(as_dataset_fn=_as_dataset):
task = detection.DectectionTask(config)
model = task.build_model()
metrics = task.build_metrics(training=False)
dataset = task.build_inputs(config.validation_data)
iterator = iter(dataset)
logs = task.validation_step(next(iterator), model, metrics)
state = task.aggregate_logs(step_outputs=logs)
task.reduce_aggregated_logs(state)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment