Commit d4176a95 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 406422641
parent fa46f548
......@@ -33,7 +33,7 @@ def dummy_ade20k_dataset(image_width, image_height):
def dummy_data(_):
dummy_image = tf.zeros((1, image_width, image_height, 3), dtype=tf.float32)
dummy_masks = tf.zeros((1, image_width, image_height, 1), dtype=tf.float32)
dummy_valid_masks = dummy_masks
dummy_valid_masks = tf.cast(dummy_masks, dtype=tf.bool)
dummy_image_info = tf.zeros((1, 4, 2), dtype=tf.float32)
return (dummy_image, {
'masks': dummy_masks,
......@@ -99,20 +99,18 @@ class AutosegEdgeTPUTaskTest(tf.test.TestCase, parameterized.TestCase):
}
config = autoseg_cfg.autoseg_edgetpu_experiment_config(
config_to_backbone_mapping[config_name], init_backbone=False)
config.task.train_data.global_batch_size = 2
config.task.train_data.global_batch_size = 1
config.task.train_data.shuffle_buffer_size = 2
config.task.validation_data.shuffle_buffer_size = 2
config.task.validation_data.global_batch_size = 2
config.task.validation_data.global_batch_size = 1
config.task.train_data.output_size = [512, 512]
config.task.validation_data.output_size = [512, 512]
task = img_seg_task.AutosegEdgeTPUTask(config.task)
model = task.build_model()
metrics = task.build_metrics()
strategy = tf.distribute.get_strategy()
dataset = orbit.utils.make_distributed_dataset(strategy, task.build_inputs,
config.task.train_data)
dataset = dummy_ade20k_dataset(512, 512)
iterator = iter(dataset)
opt_factory = optimization.OptimizerFactory(config.trainer.optimizer_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment