Commit b6c93a74 authored by Albert Pumarola's avatar Albert Pumarola Committed by Facebook GitHub Bot
Browse files

Extend Pix2Pix to allow for input extra data

Summary: Extended Pix2Pix to allow for input extra data

Reviewed By: tax313

Differential Revision: D31469054

fbshipit-source-id: 790543f214ea9fa0158e509acb27193916bf17ce
parent 980d614b
......@@ -81,6 +81,7 @@ def load_pix2pix_json(
input_folder,
gt_folder,
mask_folder,
input_extras_folder,
real_json_path=None,
real_folder=None,
max_num=1e10,
......@@ -91,6 +92,7 @@ def load_pix2pix_json(
input_folder (str): the directory for the input/source images
input_folder (str): the directory for the ground_truth/target images
mask_folder (str): the directory for the masks
input_extras_folder (str): the directory for the input extras
Returns:
list[dict]: a list of dicts
"""
......@@ -126,6 +128,7 @@ def load_pix2pix_json(
"input_folder": input_folder,
"gt_folder": gt_folder,
"mask_folder": mask_folder,
"input_extras_folder": input_extras_folder,
"input_label": input_label,
"real_folder": real_folder,
}
......@@ -148,9 +151,11 @@ def register_folder_dataset(
input_folder,
gt_folder=None,
mask_folder=None,
input_extras_folder=None,
input_src_path=None,
gt_src_path=None,
mask_src_path=None,
input_extras_src_path=None,
real_json_path=None,
real_folder=None,
real_src_path=None,
......@@ -163,6 +168,7 @@ def register_folder_dataset(
input_folder,
gt_folder,
mask_folder,
input_extras_folder,
real_json_path,
real_folder,
max_num,
......@@ -172,10 +178,12 @@ def register_folder_dataset(
"input_src_path": input_src_path,
"gt_src_path": gt_src_path,
"mask_src_path": mask_src_path,
"input_extras_src_path": input_extras_src_path,
"real_src_path": real_src_path,
"input_folder": input_folder,
"gt_folder": gt_folder,
"mask_folder": mask_folder,
"input_extras_folder": input_extras_folder,
"real_folder": real_folder,
}
MetadataCatalog.get(name).set(**metadata)
......@@ -251,6 +259,15 @@ def inject_gan_datasets(cfg):
mask_src_path = None
mask_folder = None
input_extras_src_path = (
cfg.D2GO_DATA.DATASETS.GAN_INJECTION.INPUT_EXTRAS_SRC_DIR
)
if PathManager.isfile(input_extras_src_path):
input_extras_folder = os.path.join(image_dir, name, "input_extras")
else:
input_extras_src_path = None
input_extras_folder = None
real_src_path = cfg.D2GO_DATA.DATASETS.GAN_INJECTION.REAL_SRC_DIR
if PathManager.isfile(real_src_path):
real_folder = os.path.join(image_dir, name, "real")
......@@ -269,9 +286,11 @@ def inject_gan_datasets(cfg):
input_folder,
gt_folder,
mask_folder,
input_extras_folder,
input_src_path,
gt_src_path,
mask_src_path,
input_extras_src_path,
real_json_path,
real_folder,
real_src_path,
......@@ -283,9 +302,11 @@ def inject_gan_datasets(cfg):
input_folder,
gt_folder,
mask_folder,
input_extras_folder,
input_src_path,
gt_src_path,
mask_src_path,
input_extras_src_path,
real_json_path,
real_folder,
real_src_path,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment