from official.vision.beta.projects.yolo.common import registry_imports  # pylint: disable=unused-import
from official.vision.beta.projects.yolo.tasks import image_classification as imc
from official.vision.beta.projects.yolo.configs import darknet_classification as dcfg

import os
import tensorflow as tf
from official.core import train_utils
from official.core import task_factory
from absl.testing import parameterized

PATH_TO_COCO = '/media/vbanna/DATA_SHARE/CV/datasets/COCO_raw/records/'

def test_yolo_input_task(scaled_pipeline = False, batch_size = 1):
  if not scaled_pipeline:
    experiment = "yolo_darknet"
    config_path = [
      "official/vision/beta/projects/yolo/configs/experiments/yolov4/tpu/512.yaml"]
  else:
    experiment = "scaled_yolo"
    config_path = [
      "official/vision/beta/projects/yolo/configs/experiments/yolov4-csp/tpu/640.yaml"]

  config = train_utils.ParseConfigOptions(experiment=experiment, config_file=config_path)
  params = train_utils.parse_configuration(config)
  config = params.task
  task = task_factory.get_task(params.task)

  config.train_data.global_batch_size = batch_size
  config.validation_data.global_batch_size = 1
  config.train_data.dtype = 'float32'
  config.validation_data.dtype = 'float32'
  config.validation_data.shuffle_buffer_size = 1
  config.train_data.shuffle_buffer_size = 1
  config.train_data.input_path = os.path.join(PATH_TO_COCO, 'train*')
  config.validation_data.input_path = os.path.join(PATH_TO_COCO, 'val*')

  with tf.device('/CPU:0'):
    train_data = task.build_inputs(config.train_data)
    test_data = task.build_inputs(config.validation_data)
  return train_data, test_data, config

def test_yolo_pipeline_visually(is_training=True, num=30):
  # visualize the datapipeline
  import matplotlib.pyplot as plt
  dataset, testing, _ = test_yolo_input_task()

  data = dataset if is_training else testing
  data = data.take(num)
  for l, (image, label) in enumerate(data):
    image = tf.image.draw_bounding_boxes(image, label['bbox'], [[1.0, 0.0, 1.0]])

    gt = label['true_conf']

    obj3 = tf.clip_by_value(gt['3'][..., 0], 0.0, 1.0)
    obj4 = tf.clip_by_value(gt['4'][..., 0], 0.0, 1.0)
    obj5 = tf.clip_by_value(gt['5'][..., 0], 0.0, 1.0)

    for shind in range(1):
      fig, axe = plt.subplots(1, 4)

      image = image[shind]

      axe[0].imshow(image)
      axe[1].imshow(obj3[shind].numpy())
      axe[2].imshow(obj4[shind].numpy())
      axe[3].imshow(obj5[shind].numpy())

      fig.set_size_inches(18.5, 6.5, forward=True)
      plt.tight_layout()
      plt.show()

class YoloDetectionInputTest(tf.test.TestCase, parameterized.TestCase):

  @parameterized.named_parameters(('scaled', True), ('darknet', False))
  def test_yolo_input(self, scaled_pipeline):
    # builds a pipline forom the config and tests the datapipline shapes
    dataset, _, params = test_yolo_input_task(
        scaled_pipeline=scaled_pipeline, 
        batch_size=1)

    dataset = dataset.take(1)

    for image, label in dataset:
      self.assertAllEqual(image.shape, ([1] + params.model.input_size))
      self.assertTrue(
          tf.reduce_all(tf.math.logical_and(image >= 0, image <= 1)))


if __name__ == '__main__':
  tf.test.main()
  # test_yolo_pipeline_visually(is_training=True, num=20)
