Commit 95e1fa6e authored by Tao Xu's avatar Tao Xu Committed by Facebook GitHub Bot
Browse files

add upperbound-model based evaluation dataset

Summary: Add upperbound-model based video evaluation dataset

Reviewed By: yc-fb

Differential Revision: D28122534

fbshipit-source-id: 37cf5ece1c6a5d537b91ae2f7db74a35b97c210d
parent d37031b4
...@@ -16,6 +16,8 @@ raw data with fields such as: ...@@ -16,6 +16,8 @@ raw data with fields such as:
import os import os
import json import json
import logging import logging
import tempfile
from pathlib import Path
from detectron2.utils.file_io import PathManager from detectron2.utils.file_io import PathManager
from detectron2.data import DatasetCatalog, MetadataCatalog from detectron2.data import DatasetCatalog, MetadataCatalog
...@@ -106,12 +108,19 @@ def load_pix2pix_json( ...@@ -106,12 +108,19 @@ def load_pix2pix_json(
# for fname in filenames.keys(): # for fname in filenames.keys():
while cnt < total_len: while cnt < total_len:
fname = in_keys[cnt % in_len] fname = in_keys[cnt % in_len]
input_label = filenames[fname]
if isinstance(input_label, tuple) or isinstance(input_label, list):
assert len(input_label) == 2, (
"Save (real_name, label) as the value of the json dict for resampling"
)
fname, input_label = input_label
f = { f = {
"file_name": fname, "file_name": fname,
"input_folder": input_folder, "input_folder": input_folder,
"gt_folder": gt_folder, "gt_folder": gt_folder,
"mask_folder": mask_folder, "mask_folder": mask_folder,
"input_label": filenames[fname], "input_label": input_label,
"real_folder": real_folder "real_folder": real_folder
} }
if real_len > 0: if real_len > 0:
...@@ -195,34 +204,42 @@ def register_lmdb_dataset( ...@@ -195,34 +204,42 @@ def register_lmdb_dataset(
def inject_gan_datasets(cfg): def inject_gan_datasets(cfg):
if cfg.D2GO_DATA.DATASETS.GAN_INJECTION.ENABLE: if cfg.D2GO_DATA.DATASETS.GAN_INJECTION.ENABLE:
name = cfg.D2GO_DATA.DATASETS.GAN_INJECTION.NAME name = cfg.D2GO_DATA.DATASETS.GAN_INJECTION.NAME
cfg.merge_from_list(["DATASETS.TRAIN", [name + "_train"], "DATASETS.TEST", [name + "_test"]]) cfg.merge_from_list(
["DATASETS.TRAIN",
list(cfg.DATASETS.TRAIN) + [name + "_train"],
"DATASETS.TEST",
list(cfg.DATASETS.TEST) + [name + "_test"]
]
)
json_path = cfg.D2GO_DATA.DATASETS.GAN_INJECTION.JSON_PATH json_path = cfg.D2GO_DATA.DATASETS.GAN_INJECTION.JSON_PATH
assert PathManager.isfile(json_path), ( assert PathManager.isfile(json_path), (
"{} is not valid!".format(json_path)) "{} is not valid!".format(json_path))
image_dir = Path(tempfile.mkdtemp())
input_src_path = cfg.D2GO_DATA.DATASETS.GAN_INJECTION.INPUT_SRC_DIR input_src_path = cfg.D2GO_DATA.DATASETS.GAN_INJECTION.INPUT_SRC_DIR
assert PathManager.isfile(input_src_path), ( assert PathManager.isfile(input_src_path), (
"{} is not valid!".format(input_src_path)) "{} is not valid!".format(input_src_path))
input_folder = "/tmp/{}/input".format(name) input_folder = os.path.join(image_dir, name, "input")
gt_src_path = cfg.D2GO_DATA.DATASETS.GAN_INJECTION.GT_SRC_DIR gt_src_path = cfg.D2GO_DATA.DATASETS.GAN_INJECTION.GT_SRC_DIR
if PathManager.isfile(gt_src_path): if PathManager.isfile(gt_src_path):
gt_folder = "/tmp/{}/gt".format(name) gt_folder = os.path.join(image_dir, name, "gt")
else: else:
gt_src_path = None gt_src_path = None
gt_folder=None gt_folder=None
mask_src_path = cfg.D2GO_DATA.DATASETS.GAN_INJECTION.MASK_SRC_DIR mask_src_path = cfg.D2GO_DATA.DATASETS.GAN_INJECTION.MASK_SRC_DIR
if PathManager.isfile(mask_src_path): if PathManager.isfile(mask_src_path):
mask_folder = "/tmp/{}/mask".format(name) mask_folder = os.path.join(image_dir, name, "mask")
else: else:
mask_src_path = None mask_src_path = None
mask_folder=None mask_folder=None
real_src_path = cfg.D2GO_DATA.DATASETS.GAN_INJECTION.REAL_SRC_DIR real_src_path = cfg.D2GO_DATA.DATASETS.GAN_INJECTION.REAL_SRC_DIR
if PathManager.isfile(real_src_path): if PathManager.isfile(real_src_path):
real_folder = "/tmp/{}/mask".format(name) real_folder = os.path.join(image_dir, name, "real")
real_json_path = cfg.D2GO_DATA.DATASETS.GAN_INJECTION.REAL_JSON_PATH real_json_path = cfg.D2GO_DATA.DATASETS.GAN_INJECTION.REAL_JSON_PATH
assert PathManager.isfile(real_json_path), ( assert PathManager.isfile(real_json_path), (
"{} is not valid!".format(real_json_path)) "{} is not valid!".format(real_json_path))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment