Commit 9cfc6603 authored by dongchy920's avatar dongchy920
Browse files

instruct first commit

parents
Pipeline #1969 canceled with stages
607f94fc7d3ef6d8d1627017215476d9dfc7ddc4
\ No newline at end of file
import argparse, os, sys, glob, datetime, yaml
import torch
import time
import numpy as np
from tqdm import trange
from omegaconf import OmegaConf
from PIL import Image
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config
rescale = lambda x: (x + 1.) / 2.
def custom_to_pil(x):
x = x.detach().cpu()
x = torch.clamp(x, -1., 1.)
x = (x + 1.) / 2.
x = x.permute(1, 2, 0).numpy()
x = (255 * x).astype(np.uint8)
x = Image.fromarray(x)
if not x.mode == "RGB":
x = x.convert("RGB")
return x
def custom_to_np(x):
# saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py
sample = x.detach().cpu()
sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
sample = sample.permute(0, 2, 3, 1)
sample = sample.contiguous()
return sample
def logs2pil(logs, keys=["sample"]):
imgs = dict()
for k in logs:
try:
if len(logs[k].shape) == 4:
img = custom_to_pil(logs[k][0, ...])
elif len(logs[k].shape) == 3:
img = custom_to_pil(logs[k])
else:
print(f"Unknown format for key {k}. ")
img = None
except:
img = None
imgs[k] = img
return imgs
@torch.no_grad()
def convsample(model, shape, return_intermediates=True,
verbose=True,
make_prog_row=False):
if not make_prog_row:
return model.p_sample_loop(None, shape,
return_intermediates=return_intermediates, verbose=verbose)
else:
return model.progressive_denoising(
None, shape, verbose=True
)
@torch.no_grad()
def convsample_ddim(model, steps, shape, eta=1.0
):
ddim = DDIMSampler(model)
bs = shape[0]
shape = shape[1:]
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,)
return samples, intermediates
@torch.no_grad()
def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0,):
log = dict()
shape = [batch_size,
model.model.diffusion_model.in_channels,
model.model.diffusion_model.image_size,
model.model.diffusion_model.image_size]
with model.ema_scope("Plotting"):
t0 = time.time()
if vanilla:
sample, progrow = convsample(model, shape,
make_prog_row=True)
else:
sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape,
eta=eta)
t1 = time.time()
x_sample = model.decode_first_stage(sample)
log["sample"] = x_sample
log["time"] = t1 - t0
log['throughput'] = sample.shape[0] / (t1 - t0)
print(f'Throughput for this batch: {log["throughput"]}')
return log
def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None):
if vanilla:
print(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.')
else:
print(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}')
tstart = time.time()
n_saved = len(glob.glob(os.path.join(logdir,'*.png')))-1
# path = logdir
if model.cond_stage_model is None:
all_images = []
print(f"Running unconditional sampling for {n_samples} samples")
for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"):
logs = make_convolutional_sample(model, batch_size=batch_size,
vanilla=vanilla, custom_steps=custom_steps,
eta=eta)
n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample")
all_images.extend([custom_to_np(logs["sample"])])
if n_saved >= n_samples:
print(f'Finish after generating {n_saved} samples')
break
all_img = np.concatenate(all_images, axis=0)
all_img = all_img[:n_samples]
shape_str = "x".join([str(x) for x in all_img.shape])
nppath = os.path.join(nplog, f"{shape_str}-samples.npz")
np.savez(nppath, all_img)
else:
raise NotImplementedError('Currently only sampling for unconditional models supported.')
print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.")
def save_logs(logs, path, n_saved=0, key="sample", np_path=None):
for k in logs:
if k == key:
batch = logs[key]
if np_path is None:
for x in batch:
img = custom_to_pil(x)
imgpath = os.path.join(path, f"{key}_{n_saved:06}.png")
img.save(imgpath)
n_saved += 1
else:
npbatch = custom_to_np(batch)
shape_str = "x".join([str(x) for x in npbatch.shape])
nppath = os.path.join(np_path, f"{n_saved}-{shape_str}-samples.npz")
np.savez(nppath, npbatch)
n_saved += npbatch.shape[0]
return n_saved
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"-r",
"--resume",
type=str,
nargs="?",
help="load from logdir or checkpoint in logdir",
)
parser.add_argument(
"-n",
"--n_samples",
type=int,
nargs="?",
help="number of samples to draw",
default=50000
)
parser.add_argument(
"-e",
"--eta",
type=float,
nargs="?",
help="eta for ddim sampling (0.0 yields deterministic sampling)",
default=1.0
)
parser.add_argument(
"-v",
"--vanilla_sample",
default=False,
action='store_true',
help="vanilla sampling (default option is DDIM sampling)?",
)
parser.add_argument(
"-l",
"--logdir",
type=str,
nargs="?",
help="extra logdir",
default="none"
)
parser.add_argument(
"-c",
"--custom_steps",
type=int,
nargs="?",
help="number of steps for ddim and fastdpm sampling",
default=50
)
parser.add_argument(
"--batch_size",
type=int,
nargs="?",
help="the bs",
default=10
)
return parser
def load_model_from_config(config, sd):
model = instantiate_from_config(config)
model.load_state_dict(sd,strict=False)
model.cuda()
model.eval()
return model
def load_model(config, ckpt, gpu, eval_mode):
if ckpt:
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
global_step = pl_sd["global_step"]
else:
pl_sd = {"state_dict": None}
global_step = None
model = load_model_from_config(config.model,
pl_sd["state_dict"])
return model, global_step
if __name__ == "__main__":
now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
sys.path.append(os.getcwd())
command = " ".join(sys.argv)
parser = get_parser()
opt, unknown = parser.parse_known_args()
ckpt = None
if not os.path.exists(opt.resume):
raise ValueError("Cannot find {}".format(opt.resume))
if os.path.isfile(opt.resume):
# paths = opt.resume.split("/")
try:
logdir = '/'.join(opt.resume.split('/')[:-1])
# idx = len(paths)-paths[::-1].index("logs")+1
print(f'Logdir is {logdir}')
except ValueError:
paths = opt.resume.split("/")
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
logdir = "/".join(paths[:idx])
ckpt = opt.resume
else:
assert os.path.isdir(opt.resume), f"{opt.resume} is not a directory"
logdir = opt.resume.rstrip("/")
ckpt = os.path.join(logdir, "model.ckpt")
base_configs = sorted(glob.glob(os.path.join(logdir, "config.yaml")))
opt.base = base_configs
configs = [OmegaConf.load(cfg) for cfg in opt.base]
cli = OmegaConf.from_dotlist(unknown)
config = OmegaConf.merge(*configs, cli)
gpu = True
eval_mode = True
if opt.logdir != "none":
locallog = logdir.split(os.sep)[-1]
if locallog == "": locallog = logdir.split(os.sep)[-2]
print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'")
logdir = os.path.join(opt.logdir, locallog)
print(config)
model, global_step = load_model(config, ckpt, gpu, eval_mode)
print(f"global step: {global_step}")
print(75 * "=")
print("logging to:")
logdir = os.path.join(logdir, "samples", f"{global_step:08}", now)
imglogdir = os.path.join(logdir, "img")
numpylogdir = os.path.join(logdir, "numpy")
os.makedirs(imglogdir)
os.makedirs(numpylogdir)
print(logdir)
print(75 * "=")
# write config out
sampling_file = os.path.join(logdir, "sampling_config.yaml")
sampling_conf = vars(opt)
with open(sampling_file, 'w') as f:
yaml.dump(sampling_conf, f, default_flow_style=False)
print(sampling_conf)
run(model, imglogdir, eta=opt.eta,
vanilla=opt.vanilla_sample, n_samples=opt.n_samples, custom_steps=opt.custom_steps,
batch_size=opt.batch_size, nplog=numpylogdir)
print("done.")
import cv2
import fire
from imwatermark import WatermarkDecoder
def testit(img_path):
bgr = cv2.imread(img_path)
decoder = WatermarkDecoder('bytes', 136)
watermark = decoder.decode(bgr, 'dwtDct')
try:
dec = watermark.decode('utf-8')
except:
dec = "null"
print(dec)
if __name__ == "__main__":
fire.Fire(testit)
\ No newline at end of file
import os, sys
import numpy as np
import scann
import argparse
import glob
from multiprocessing import cpu_count
from tqdm import tqdm
from ldm.util import parallel_data_prefetch
def search_bruteforce(searcher):
return searcher.score_brute_force().build()
def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k,
partioning_trainsize, num_leaves, num_leaves_to_search):
return searcher.tree(num_leaves=num_leaves,
num_leaves_to_search=num_leaves_to_search,
training_sample_size=partioning_trainsize). \
score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build()
def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k):
return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(
reorder_k).build()
def load_datapool(dpath):
def load_single_file(saved_embeddings):
compressed = np.load(saved_embeddings)
database = {key: compressed[key] for key in compressed.files}
return database
def load_multi_files(data_archive):
database = {key: [] for key in data_archive[0].files}
for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'):
for key in d.files:
database[key].append(d[key])
return database
print(f'Load saved patch embedding from "{dpath}"')
file_content = glob.glob(os.path.join(dpath, '*.npz'))
if len(file_content) == 1:
data_pool = load_single_file(file_content[0])
elif len(file_content) > 1:
data = [np.load(f) for f in file_content]
prefetched_data = parallel_data_prefetch(load_multi_files, data,
n_proc=min(len(data), cpu_count()), target_data_type='dict')
data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()}
else:
raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?')
print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.')
return data_pool
def train_searcher(opt,
metric='dot_product',
partioning_trainsize=None,
reorder_k=None,
# todo tune
aiq_thld=0.2,
dims_per_block=2,
num_leaves=None,
num_leaves_to_search=None,):
data_pool = load_datapool(opt.database)
k = opt.knn
if not reorder_k:
reorder_k = 2 * k
# normalize
# embeddings =
searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric)
pool_size = data_pool['embedding'].shape[0]
print(*(['#'] * 100))
print('Initializing scaNN searcher with the following values:')
print(f'k: {k}')
print(f'metric: {metric}')
print(f'reorder_k: {reorder_k}')
print(f'anisotropic_quantization_threshold: {aiq_thld}')
print(f'dims_per_block: {dims_per_block}')
print(*(['#'] * 100))
print('Start training searcher....')
print(f'N samples in pool is {pool_size}')
# this reflects the recommended design choices proposed at
# https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md
if pool_size < 2e4:
print('Using brute force search.')
searcher = search_bruteforce(searcher)
elif 2e4 <= pool_size and pool_size < 1e5:
print('Using asymmetric hashing search and reordering.')
searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
else:
print('Using using partioning, asymmetric hashing search and reordering.')
if not partioning_trainsize:
partioning_trainsize = data_pool['embedding'].shape[0] // 10
if not num_leaves:
num_leaves = int(np.sqrt(pool_size))
if not num_leaves_to_search:
num_leaves_to_search = max(num_leaves // 20, 1)
print('Partitioning params:')
print(f'num_leaves: {num_leaves}')
print(f'num_leaves_to_search: {num_leaves_to_search}')
# self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k,
partioning_trainsize, num_leaves, num_leaves_to_search)
print('Finish training searcher')
searcher_savedir = opt.target_path
os.makedirs(searcher_savedir, exist_ok=True)
searcher.serialize(searcher_savedir)
print(f'Saved trained searcher under "{searcher_savedir}"')
if __name__ == '__main__':
sys.path.append(os.getcwd())
parser = argparse.ArgumentParser()
parser.add_argument('--database',
'-d',
default='data/rdm/retrieval_databases/openimages',
type=str,
help='path to folder containing the clip feature of the database')
parser.add_argument('--target_path',
'-t',
default='data/rdm/searchers/openimages',
type=str,
help='path to the target folder where the searcher shall be stored.')
parser.add_argument('--knn',
'-k',
default=20,
type=int,
help='number of nearest neighbors, for which the searcher shall be optimized')
opt, _ = parser.parse_known_args()
train_searcher(opt,)
\ No newline at end of file
import argparse, os, sys, glob
import cv2
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from imwatermark import WatermarkEncoder
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid
import time
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.dpm_solver import DPMSolverSampler
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
# load safety model
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
model.cuda()
model.eval()
return model
def put_watermark(img, wm_encoder=None):
if wm_encoder is not None:
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
img = wm_encoder.encode(img, 'dwtDct')
img = Image.fromarray(img[:, :, ::-1])
return img
def load_replacement(x):
try:
hwc = x.shape
y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
y = (np.array(y)/255.0).astype(x.dtype)
assert y.shape == x.shape
return y
except Exception:
return x
def check_safety(x_image):
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
assert x_checked_image.shape[0] == len(has_nsfw_concept)
for i in range(len(has_nsfw_concept)):
if has_nsfw_concept[i]:
x_checked_image[i] = load_replacement(x_checked_image[i])
return x_checked_image, has_nsfw_concept
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--prompt",
type=str,
nargs="?",
default="a painting of a virus monster playing guitar",
help="the prompt to render"
)
parser.add_argument(
"--outdir",
type=str,
nargs="?",
help="dir to write results to",
default="outputs/txt2img-samples"
)
parser.add_argument(
"--skip_grid",
action='store_true',
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
)
parser.add_argument(
"--skip_save",
action='store_true',
help="do not save individual samples. For speed measurements.",
)
parser.add_argument(
"--ddim_steps",
type=int,
default=50,
help="number of ddim sampling steps",
)
parser.add_argument(
"--plms",
action='store_true',
help="use plms sampling",
)
parser.add_argument(
"--dpm_solver",
action='store_true',
help="use dpm_solver sampling",
)
parser.add_argument(
"--laion400m",
action='store_true',
help="uses the LAION400M model",
)
parser.add_argument(
"--fixed_code",
action='store_true',
help="if enabled, uses the same starting code across samples ",
)
parser.add_argument(
"--ddim_eta",
type=float,
default=0.0,
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
)
parser.add_argument(
"--n_iter",
type=int,
default=2,
help="sample this often",
)
parser.add_argument(
"--H",
type=int,
default=512,
help="image height, in pixel space",
)
parser.add_argument(
"--W",
type=int,
default=512,
help="image width, in pixel space",
)
parser.add_argument(
"--C",
type=int,
default=4,
help="latent channels",
)
parser.add_argument(
"--f",
type=int,
default=8,
help="downsampling factor",
)
parser.add_argument(
"--n_samples",
type=int,
default=3,
help="how many samples to produce for each given prompt. A.k.a. batch size",
)
parser.add_argument(
"--n_rows",
type=int,
default=0,
help="rows in the grid (default: n_samples)",
)
parser.add_argument(
"--scale",
type=float,
default=7.5,
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
)
parser.add_argument(
"--from-file",
type=str,
help="if specified, load prompts from this file",
)
parser.add_argument(
"--config",
type=str,
default="configs/stable-diffusion/v1-inference.yaml",
help="path to config which constructs model",
)
parser.add_argument(
"--ckpt",
type=str,
default="models/ldm/stable-diffusion-v1/model.ckpt",
help="path to checkpoint of model",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="the seed (for reproducible sampling)",
)
parser.add_argument(
"--precision",
type=str,
help="evaluate at this precision",
choices=["full", "autocast"],
default="autocast"
)
opt = parser.parse_args()
if opt.laion400m:
print("Falling back to LAION 400M model...")
opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
opt.ckpt = "models/ldm/text2img-large/model.ckpt"
opt.outdir = "outputs/txt2img-samples-laion400m"
seed_everything(opt.seed)
config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
if opt.dpm_solver:
sampler = DPMSolverSampler(model)
elif opt.plms:
sampler = PLMSSampler(model)
else:
sampler = DDIMSampler(model)
os.makedirs(opt.outdir, exist_ok=True)
outpath = opt.outdir
print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
wm = "StableDiffusionV1"
wm_encoder = WatermarkEncoder()
wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
batch_size = opt.n_samples
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
if not opt.from_file:
prompt = opt.prompt
assert prompt is not None
data = [batch_size * [prompt]]
else:
print(f"reading prompts from {opt.from_file}")
with open(opt.from_file, "r") as f:
data = f.read().splitlines()
data = list(chunk(data, batch_size))
sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))
grid_count = len(os.listdir(outpath)) - 1
start_code = None
if opt.fixed_code:
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
precision_scope = autocast if opt.precision=="autocast" else nullcontext
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
tic = time.time()
all_samples = list()
for n in trange(opt.n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"):
uc = None
if opt.scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
conditioning=c,
batch_size=opt.n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
eta=opt.ddim_eta,
x_T=start_code)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim)
x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
if not opt.skip_save:
for x_sample in x_checked_image_torch:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
img = Image.fromarray(x_sample.astype(np.uint8))
img = put_watermark(img, wm_encoder)
img.save(os.path.join(sample_path, f"{base_count:05}.png"))
base_count += 1
if not opt.skip_grid:
all_samples.append(x_checked_image_torch)
if not opt.skip_grid:
# additionally, save as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=n_rows)
# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
img = Image.fromarray(grid.astype(np.uint8))
img = put_watermark(img, wm_encoder)
img.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
grid_count += 1
toc = time.time()
print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
f" \nEnjoy.")
if __name__ == "__main__":
main()
from setuptools import setup, find_packages
setup(
name='latent-diffusion',
version='0.0.1',
description='',
packages=find_packages(),
install_requires=[
'torch',
'numpy',
'tqdm',
],
)
\ No newline at end of file
import os
import numpy as np
import cv2
import albumentations
from PIL import Image
from torch.utils.data import Dataset
from taming.data.sflckr import SegmentationBase # for examples included in repo
class Examples(SegmentationBase):
def __init__(self, size=256, random_crop=False, interpolation="bicubic"):
super().__init__(data_csv="data/ade20k_examples.txt",
data_root="data/ade20k_images",
segmentation_root="data/ade20k_segmentations",
size=size, random_crop=random_crop,
interpolation=interpolation,
n_labels=151, shift_segmentation=False)
# With semantic map and scene label
class ADE20kBase(Dataset):
def __init__(self, config=None, size=None, random_crop=False, interpolation="bicubic", crop_size=None):
self.split = self.get_split()
self.n_labels = 151 # unknown + 150
self.data_csv = {"train": "data/ade20k_train.txt",
"validation": "data/ade20k_test.txt"}[self.split]
self.data_root = "data/ade20k_root"
with open(os.path.join(self.data_root, "sceneCategories.txt"), "r") as f:
self.scene_categories = f.read().splitlines()
self.scene_categories = dict(line.split() for line in self.scene_categories)
with open(self.data_csv, "r") as f:
self.image_paths = f.read().splitlines()
self._length = len(self.image_paths)
self.labels = {
"relative_file_path_": [l for l in self.image_paths],
"file_path_": [os.path.join(self.data_root, "images", l)
for l in self.image_paths],
"relative_segmentation_path_": [l.replace(".jpg", ".png")
for l in self.image_paths],
"segmentation_path_": [os.path.join(self.data_root, "annotations",
l.replace(".jpg", ".png"))
for l in self.image_paths],
"scene_category": [self.scene_categories[l.split("/")[1].replace(".jpg", "")]
for l in self.image_paths],
}
size = None if size is not None and size<=0 else size
self.size = size
if crop_size is None:
self.crop_size = size if size is not None else None
else:
self.crop_size = crop_size
if self.size is not None:
self.interpolation = interpolation
self.interpolation = {
"nearest": cv2.INTER_NEAREST,
"bilinear": cv2.INTER_LINEAR,
"bicubic": cv2.INTER_CUBIC,
"area": cv2.INTER_AREA,
"lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
interpolation=self.interpolation)
self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
interpolation=cv2.INTER_NEAREST)
if crop_size is not None:
self.center_crop = not random_crop
if self.center_crop:
self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
else:
self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
self.preprocessor = self.cropper
def __len__(self):
return self._length
def __getitem__(self, i):
example = dict((k, self.labels[k][i]) for k in self.labels)
image = Image.open(example["file_path_"])
if not image.mode == "RGB":
image = image.convert("RGB")
image = np.array(image).astype(np.uint8)
if self.size is not None:
image = self.image_rescaler(image=image)["image"]
segmentation = Image.open(example["segmentation_path_"])
segmentation = np.array(segmentation).astype(np.uint8)
if self.size is not None:
segmentation = self.segmentation_rescaler(image=segmentation)["image"]
if self.size is not None:
processed = self.preprocessor(image=image, mask=segmentation)
else:
processed = {"image": image, "mask": segmentation}
example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
segmentation = processed["mask"]
onehot = np.eye(self.n_labels)[segmentation]
example["segmentation"] = onehot
return example
class ADE20kTrain(ADE20kBase):
# default to random_crop=True
def __init__(self, config=None, size=None, random_crop=True, interpolation="bicubic", crop_size=None):
super().__init__(config=config, size=size, random_crop=random_crop,
interpolation=interpolation, crop_size=crop_size)
def get_split(self):
return "train"
class ADE20kValidation(ADE20kBase):
def get_split(self):
return "validation"
if __name__ == "__main__":
dset = ADE20kValidation()
ex = dset[0]
for k in ["image", "scene_category", "segmentation"]:
print(type(ex[k]))
try:
print(ex[k].shape)
except:
print(ex[k])
import json
from itertools import chain
from pathlib import Path
from typing import Iterable, Dict, List, Callable, Any
from collections import defaultdict
from tqdm import tqdm
from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
from taming.data.helper_types import Annotation, ImageDescription, Category
COCO_PATH_STRUCTURE = {
'train': {
'top_level': '',
'instances_annotations': 'annotations/instances_train2017.json',
'stuff_annotations': 'annotations/stuff_train2017.json',
'files': 'train2017'
},
'validation': {
'top_level': '',
'instances_annotations': 'annotations/instances_val2017.json',
'stuff_annotations': 'annotations/stuff_val2017.json',
'files': 'val2017'
}
}
def load_image_descriptions(description_json: List[Dict]) -> Dict[str, ImageDescription]:
return {
str(img['id']): ImageDescription(
id=img['id'],
license=img.get('license'),
file_name=img['file_name'],
coco_url=img['coco_url'],
original_size=(img['width'], img['height']),
date_captured=img.get('date_captured'),
flickr_url=img.get('flickr_url')
)
for img in description_json
}
def load_categories(category_json: Iterable) -> Dict[str, Category]:
return {str(cat['id']): Category(id=str(cat['id']), super_category=cat['supercategory'], name=cat['name'])
for cat in category_json if cat['name'] != 'other'}
def load_annotations(annotations_json: List[Dict], image_descriptions: Dict[str, ImageDescription],
category_no_for_id: Callable[[str], int], split: str) -> Dict[str, List[Annotation]]:
annotations = defaultdict(list)
total = sum(len(a) for a in annotations_json)
for ann in tqdm(chain(*annotations_json), f'Loading {split} annotations', total=total):
image_id = str(ann['image_id'])
if image_id not in image_descriptions:
raise ValueError(f'image_id [{image_id}] has no image description.')
category_id = ann['category_id']
try:
category_no = category_no_for_id(str(category_id))
except KeyError:
continue
width, height = image_descriptions[image_id].original_size
bbox = (ann['bbox'][0] / width, ann['bbox'][1] / height, ann['bbox'][2] / width, ann['bbox'][3] / height)
annotations[image_id].append(
Annotation(
id=ann['id'],
area=bbox[2]*bbox[3], # use bbox area
is_group_of=ann['iscrowd'],
image_id=ann['image_id'],
bbox=bbox,
category_id=str(category_id),
category_no=category_no
)
)
return dict(annotations)
class AnnotatedObjectsCoco(AnnotatedObjectsDataset):
def __init__(self, use_things: bool = True, use_stuff: bool = True, **kwargs):
"""
@param data_path: is the path to the following folder structure:
coco/
├── annotations
│ ├── instances_train2017.json
│ ├── instances_val2017.json
│ ├── stuff_train2017.json
│ └── stuff_val2017.json
├── train2017
│ ├── 000000000009.jpg
│ ├── 000000000025.jpg
│ └── ...
├── val2017
│ ├── 000000000139.jpg
│ ├── 000000000285.jpg
│ └── ...
@param: split: one of 'train' or 'validation'
@param: desired image size (give square images)
"""
super().__init__(**kwargs)
self.use_things = use_things
self.use_stuff = use_stuff
with open(self.paths['instances_annotations']) as f:
inst_data_json = json.load(f)
with open(self.paths['stuff_annotations']) as f:
stuff_data_json = json.load(f)
category_jsons = []
annotation_jsons = []
if self.use_things:
category_jsons.append(inst_data_json['categories'])
annotation_jsons.append(inst_data_json['annotations'])
if self.use_stuff:
category_jsons.append(stuff_data_json['categories'])
annotation_jsons.append(stuff_data_json['annotations'])
self.categories = load_categories(chain(*category_jsons))
self.filter_categories()
self.setup_category_id_and_number()
self.image_descriptions = load_image_descriptions(inst_data_json['images'])
annotations = load_annotations(annotation_jsons, self.image_descriptions, self.get_category_number, self.split)
self.annotations = self.filter_object_number(annotations, self.min_object_area,
self.min_objects_per_image, self.max_objects_per_image)
self.image_ids = list(self.annotations.keys())
self.clean_up_annotations_and_image_descriptions()
def get_path_structure(self) -> Dict[str, str]:
if self.split not in COCO_PATH_STRUCTURE:
raise ValueError(f'Split [{self.split} does not exist for COCO data.]')
return COCO_PATH_STRUCTURE[self.split]
def get_image_path(self, image_id: str) -> Path:
return self.paths['files'].joinpath(self.image_descriptions[str(image_id)].file_name)
def get_image_description(self, image_id: str) -> Dict[str, Any]:
# noinspection PyProtectedMember
return self.image_descriptions[image_id]._asdict()
from pathlib import Path
from typing import Optional, List, Callable, Dict, Any, Union
import warnings
import PIL.Image as pil_image
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import transforms
from taming.data.conditional_builder.objects_bbox import ObjectsBoundingBoxConditionalBuilder
from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
from taming.data.conditional_builder.utils import load_object_from_string
from taming.data.helper_types import BoundingBox, CropMethodType, Image, Annotation, SplitType
from taming.data.image_transforms import CenterCropReturnCoordinates, RandomCrop1dReturnCoordinates, \
Random2dCropReturnCoordinates, RandomHorizontalFlipReturn, convert_pil_to_tensor
class AnnotatedObjectsDataset(Dataset):
def __init__(self, data_path: Union[str, Path], split: SplitType, keys: List[str], target_image_size: int,
min_object_area: float, min_objects_per_image: int, max_objects_per_image: int,
crop_method: CropMethodType, random_flip: bool, no_tokens: int, use_group_parameter: bool,
encode_crop: bool, category_allow_list_target: str = "", category_mapping_target: str = "",
no_object_classes: Optional[int] = None):
self.data_path = data_path
self.split = split
self.keys = keys
self.target_image_size = target_image_size
self.min_object_area = min_object_area
self.min_objects_per_image = min_objects_per_image
self.max_objects_per_image = max_objects_per_image
self.crop_method = crop_method
self.random_flip = random_flip
self.no_tokens = no_tokens
self.use_group_parameter = use_group_parameter
self.encode_crop = encode_crop
self.annotations = None
self.image_descriptions = None
self.categories = None
self.category_ids = None
self.category_number = None
self.image_ids = None
self.transform_functions: List[Callable] = self.setup_transform(target_image_size, crop_method, random_flip)
self.paths = self.build_paths(self.data_path)
self._conditional_builders = None
self.category_allow_list = None
if category_allow_list_target:
allow_list = load_object_from_string(category_allow_list_target)
self.category_allow_list = {name for name, _ in allow_list}
self.category_mapping = {}
if category_mapping_target:
self.category_mapping = load_object_from_string(category_mapping_target)
self.no_object_classes = no_object_classes
def build_paths(self, top_level: Union[str, Path]) -> Dict[str, Path]:
top_level = Path(top_level)
sub_paths = {name: top_level.joinpath(sub_path) for name, sub_path in self.get_path_structure().items()}
for path in sub_paths.values():
if not path.exists():
raise FileNotFoundError(f'{type(self).__name__} data structure error: [{path}] does not exist.')
return sub_paths
@staticmethod
def load_image_from_disk(path: Path) -> Image:
return pil_image.open(path).convert('RGB')
@staticmethod
def setup_transform(target_image_size: int, crop_method: CropMethodType, random_flip: bool):
transform_functions = []
if crop_method == 'none':
transform_functions.append(transforms.Resize((target_image_size, target_image_size)))
elif crop_method == 'center':
transform_functions.extend([
transforms.Resize(target_image_size),
CenterCropReturnCoordinates(target_image_size)
])
elif crop_method == 'random-1d':
transform_functions.extend([
transforms.Resize(target_image_size),
RandomCrop1dReturnCoordinates(target_image_size)
])
elif crop_method == 'random-2d':
transform_functions.extend([
Random2dCropReturnCoordinates(target_image_size),
transforms.Resize(target_image_size)
])
elif crop_method is None:
return None
else:
raise ValueError(f'Received invalid crop method [{crop_method}].')
if random_flip:
transform_functions.append(RandomHorizontalFlipReturn())
transform_functions.append(transforms.Lambda(lambda x: x / 127.5 - 1.))
return transform_functions
def image_transform(self, x: Tensor) -> (Optional[BoundingBox], Optional[bool], Tensor):
crop_bbox = None
flipped = None
for t in self.transform_functions:
if isinstance(t, (RandomCrop1dReturnCoordinates, CenterCropReturnCoordinates, Random2dCropReturnCoordinates)):
crop_bbox, x = t(x)
elif isinstance(t, RandomHorizontalFlipReturn):
flipped, x = t(x)
else:
x = t(x)
return crop_bbox, flipped, x
@property
def no_classes(self) -> int:
return self.no_object_classes if self.no_object_classes else len(self.categories)
@property
def conditional_builders(self) -> ObjectsCenterPointsConditionalBuilder:
# cannot set this up in init because no_classes is only known after loading data in init of superclass
if self._conditional_builders is None:
self._conditional_builders = {
'objects_center_points': ObjectsCenterPointsConditionalBuilder(
self.no_classes,
self.max_objects_per_image,
self.no_tokens,
self.encode_crop,
self.use_group_parameter,
getattr(self, 'use_additional_parameters', False)
),
'objects_bbox': ObjectsBoundingBoxConditionalBuilder(
self.no_classes,
self.max_objects_per_image,
self.no_tokens,
self.encode_crop,
self.use_group_parameter,
getattr(self, 'use_additional_parameters', False)
)
}
return self._conditional_builders
def filter_categories(self) -> None:
if self.category_allow_list:
self.categories = {id_: cat for id_, cat in self.categories.items() if cat.name in self.category_allow_list}
if self.category_mapping:
self.categories = {id_: cat for id_, cat in self.categories.items() if cat.id not in self.category_mapping}
def setup_category_id_and_number(self) -> None:
self.category_ids = list(self.categories.keys())
self.category_ids.sort()
if '/m/01s55n' in self.category_ids:
self.category_ids.remove('/m/01s55n')
self.category_ids.append('/m/01s55n')
self.category_number = {category_id: i for i, category_id in enumerate(self.category_ids)}
if self.category_allow_list is not None and self.category_mapping is None \
and len(self.category_ids) != len(self.category_allow_list):
warnings.warn('Unexpected number of categories: Mismatch with category_allow_list. '
'Make sure all names in category_allow_list exist.')
def clean_up_annotations_and_image_descriptions(self) -> None:
image_id_set = set(self.image_ids)
self.annotations = {k: v for k, v in self.annotations.items() if k in image_id_set}
self.image_descriptions = {k: v for k, v in self.image_descriptions.items() if k in image_id_set}
@staticmethod
def filter_object_number(all_annotations: Dict[str, List[Annotation]], min_object_area: float,
min_objects_per_image: int, max_objects_per_image: int) -> Dict[str, List[Annotation]]:
filtered = {}
for image_id, annotations in all_annotations.items():
annotations_with_min_area = [a for a in annotations if a.area > min_object_area]
if min_objects_per_image <= len(annotations_with_min_area) <= max_objects_per_image:
filtered[image_id] = annotations_with_min_area
return filtered
def __len__(self):
return len(self.image_ids)
def __getitem__(self, n: int) -> Dict[str, Any]:
image_id = self.get_image_id(n)
sample = self.get_image_description(image_id)
sample['annotations'] = self.get_annotation(image_id)
if 'image' in self.keys:
sample['image_path'] = str(self.get_image_path(image_id))
sample['image'] = self.load_image_from_disk(sample['image_path'])
sample['image'] = convert_pil_to_tensor(sample['image'])
sample['crop_bbox'], sample['flipped'], sample['image'] = self.image_transform(sample['image'])
sample['image'] = sample['image'].permute(1, 2, 0)
for conditional, builder in self.conditional_builders.items():
if conditional in self.keys:
sample[conditional] = builder.build(sample['annotations'], sample['crop_bbox'], sample['flipped'])
if self.keys:
# only return specified keys
sample = {key: sample[key] for key in self.keys}
return sample
def get_image_id(self, no: int) -> str:
return self.image_ids[no]
def get_annotation(self, image_id: str) -> str:
return self.annotations[image_id]
def get_textual_label_for_category_id(self, category_id: str) -> str:
return self.categories[category_id].name
def get_textual_label_for_category_no(self, category_no: int) -> str:
return self.categories[self.get_category_id(category_no)].name
def get_category_number(self, category_id: str) -> int:
return self.category_number[category_id]
def get_category_id(self, category_no: int) -> str:
return self.category_ids[category_no]
def get_image_description(self, image_id: str) -> Dict[str, Any]:
raise NotImplementedError()
def get_path_structure(self):
raise NotImplementedError
def get_image_path(self, image_id: str) -> Path:
raise NotImplementedError
from collections import defaultdict
from csv import DictReader, reader as TupleReader
from pathlib import Path
from typing import Dict, List, Any
import warnings
from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
from taming.data.helper_types import Annotation, Category
from tqdm import tqdm
OPEN_IMAGES_STRUCTURE = {
'train': {
'top_level': '',
'class_descriptions': 'class-descriptions-boxable.csv',
'annotations': 'oidv6-train-annotations-bbox.csv',
'file_list': 'train-images-boxable.csv',
'files': 'train'
},
'validation': {
'top_level': '',
'class_descriptions': 'class-descriptions-boxable.csv',
'annotations': 'validation-annotations-bbox.csv',
'file_list': 'validation-images.csv',
'files': 'validation'
},
'test': {
'top_level': '',
'class_descriptions': 'class-descriptions-boxable.csv',
'annotations': 'test-annotations-bbox.csv',
'file_list': 'test-images.csv',
'files': 'test'
}
}
def load_annotations(descriptor_path: Path, min_object_area: float, category_mapping: Dict[str, str],
category_no_for_id: Dict[str, int]) -> Dict[str, List[Annotation]]:
annotations: Dict[str, List[Annotation]] = defaultdict(list)
with open(descriptor_path) as file:
reader = DictReader(file)
for i, row in tqdm(enumerate(reader), total=14620000, desc='Loading OpenImages annotations'):
width = float(row['XMax']) - float(row['XMin'])
height = float(row['YMax']) - float(row['YMin'])
area = width * height
category_id = row['LabelName']
if category_id in category_mapping:
category_id = category_mapping[category_id]
if area >= min_object_area and category_id in category_no_for_id:
annotations[row['ImageID']].append(
Annotation(
id=i,
image_id=row['ImageID'],
source=row['Source'],
category_id=category_id,
category_no=category_no_for_id[category_id],
confidence=float(row['Confidence']),
bbox=(float(row['XMin']), float(row['YMin']), width, height),
area=area,
is_occluded=bool(int(row['IsOccluded'])),
is_truncated=bool(int(row['IsTruncated'])),
is_group_of=bool(int(row['IsGroupOf'])),
is_depiction=bool(int(row['IsDepiction'])),
is_inside=bool(int(row['IsInside']))
)
)
if 'train' in str(descriptor_path) and i < 14000000:
warnings.warn(f'Running with subset of Open Images. Train dataset has length [{len(annotations)}].')
return dict(annotations)
def load_image_ids(csv_path: Path) -> List[str]:
with open(csv_path) as file:
reader = DictReader(file)
return [row['image_name'] for row in reader]
def load_categories(csv_path: Path) -> Dict[str, Category]:
with open(csv_path) as file:
reader = TupleReader(file)
return {row[0]: Category(id=row[0], name=row[1], super_category=None) for row in reader}
class AnnotatedObjectsOpenImages(AnnotatedObjectsDataset):
def __init__(self, use_additional_parameters: bool, **kwargs):
"""
@param data_path: is the path to the following folder structure:
open_images/
│ oidv6-train-annotations-bbox.csv
├── class-descriptions-boxable.csv
├── oidv6-train-annotations-bbox.csv
├── test
│ ├── 000026e7ee790996.jpg
│ ├── 000062a39995e348.jpg
│ └── ...
├── test-annotations-bbox.csv
├── test-images.csv
├── train
│ ├── 000002b66c9c498e.jpg
│ ├── 000002b97e5471a0.jpg
│ └── ...
├── train-images-boxable.csv
├── validation
│ ├── 0001eeaf4aed83f9.jpg
│ ├── 0004886b7d043cfd.jpg
│ └── ...
├── validation-annotations-bbox.csv
└── validation-images.csv
@param: split: one of 'train', 'validation' or 'test'
@param: desired image size (returns square images)
"""
super().__init__(**kwargs)
self.use_additional_parameters = use_additional_parameters
self.categories = load_categories(self.paths['class_descriptions'])
self.filter_categories()
self.setup_category_id_and_number()
self.image_descriptions = {}
annotations = load_annotations(self.paths['annotations'], self.min_object_area, self.category_mapping,
self.category_number)
self.annotations = self.filter_object_number(annotations, self.min_object_area, self.min_objects_per_image,
self.max_objects_per_image)
self.image_ids = list(self.annotations.keys())
self.clean_up_annotations_and_image_descriptions()
def get_path_structure(self) -> Dict[str, str]:
if self.split not in OPEN_IMAGES_STRUCTURE:
raise ValueError(f'Split [{self.split} does not exist for Open Images data.]')
return OPEN_IMAGES_STRUCTURE[self.split]
def get_image_path(self, image_id: str) -> Path:
return self.paths['files'].joinpath(f'{image_id:0>16}.jpg')
def get_image_description(self, image_id: str) -> Dict[str, Any]:
image_path = self.get_image_path(image_id)
return {'file_path': str(image_path), 'file_name': image_path.name}
import bisect
import numpy as np
import albumentations
from PIL import Image
from torch.utils.data import Dataset, ConcatDataset
class ConcatDatasetWithIndex(ConcatDataset):
"""Modified from original pytorch code to return dataset idx"""
def __getitem__(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError("absolute value of index should not exceed dataset length")
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx][sample_idx], dataset_idx
class ImagePaths(Dataset):
def __init__(self, paths, size=None, random_crop=False, labels=None):
self.size = size
self.random_crop = random_crop
self.labels = dict() if labels is None else labels
self.labels["file_path_"] = paths
self._length = len(paths)
if self.size is not None and self.size > 0:
self.rescaler = albumentations.SmallestMaxSize(max_size = self.size)
if not self.random_crop:
self.cropper = albumentations.CenterCrop(height=self.size,width=self.size)
else:
self.cropper = albumentations.RandomCrop(height=self.size,width=self.size)
self.preprocessor = albumentations.Compose([self.rescaler, self.cropper])
else:
self.preprocessor = lambda **kwargs: kwargs
def __len__(self):
return self._length
def preprocess_image(self, image_path):
image = Image.open(image_path)
if not image.mode == "RGB":
image = image.convert("RGB")
image = np.array(image).astype(np.uint8)
image = self.preprocessor(image=image)["image"]
image = (image/127.5 - 1.0).astype(np.float32)
return image
def __getitem__(self, i):
example = dict()
example["image"] = self.preprocess_image(self.labels["file_path_"][i])
for k in self.labels:
example[k] = self.labels[k][i]
return example
class NumpyPaths(ImagePaths):
def preprocess_image(self, image_path):
image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024
image = np.transpose(image, (1,2,0))
image = Image.fromarray(image, mode="RGB")
image = np.array(image).astype(np.uint8)
image = self.preprocessor(image=image)["image"]
image = (image/127.5 - 1.0).astype(np.float32)
return image
import os
import json
import albumentations
import numpy as np
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset
from taming.data.sflckr import SegmentationBase # for examples included in repo
class Examples(SegmentationBase):
def __init__(self, size=256, random_crop=False, interpolation="bicubic"):
super().__init__(data_csv="data/coco_examples.txt",
data_root="data/coco_images",
segmentation_root="data/coco_segmentations",
size=size, random_crop=random_crop,
interpolation=interpolation,
n_labels=183, shift_segmentation=True)
class CocoBase(Dataset):
"""needed for (image, caption, segmentation) pairs"""
def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False,
crop_size=None, force_no_crop=False, given_files=None):
self.split = self.get_split()
self.size = size
if crop_size is None:
self.crop_size = size
else:
self.crop_size = crop_size
self.onehot = onehot_segmentation # return segmentation as rgb or one hot
self.stuffthing = use_stuffthing # include thing in segmentation
if self.onehot and not self.stuffthing:
raise NotImplemented("One hot mode is only supported for the "
"stuffthings version because labels are stored "
"a bit different.")
data_json = datajson
with open(data_json) as json_file:
self.json_data = json.load(json_file)
self.img_id_to_captions = dict()
self.img_id_to_filepath = dict()
self.img_id_to_segmentation_filepath = dict()
assert data_json.split("/")[-1] in ["captions_train2017.json",
"captions_val2017.json"]
if self.stuffthing:
self.segmentation_prefix = (
"data/cocostuffthings/val2017" if
data_json.endswith("captions_val2017.json") else
"data/cocostuffthings/train2017")
else:
self.segmentation_prefix = (
"data/coco/annotations/stuff_val2017_pixelmaps" if
data_json.endswith("captions_val2017.json") else
"data/coco/annotations/stuff_train2017_pixelmaps")
imagedirs = self.json_data["images"]
self.labels = {"image_ids": list()}
for imgdir in tqdm(imagedirs, desc="ImgToPath"):
self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"])
self.img_id_to_captions[imgdir["id"]] = list()
pngfilename = imgdir["file_name"].replace("jpg", "png")
self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join(
self.segmentation_prefix, pngfilename)
if given_files is not None:
if pngfilename in given_files:
self.labels["image_ids"].append(imgdir["id"])
else:
self.labels["image_ids"].append(imgdir["id"])
capdirs = self.json_data["annotations"]
for capdir in tqdm(capdirs, desc="ImgToCaptions"):
# there are in average 5 captions per image
self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]]))
self.rescaler = albumentations.SmallestMaxSize(max_size=self.size)
if self.split=="validation":
self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
else:
self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
self.preprocessor = albumentations.Compose(
[self.rescaler, self.cropper],
additional_targets={"segmentation": "image"})
if force_no_crop:
self.rescaler = albumentations.Resize(height=self.size, width=self.size)
self.preprocessor = albumentations.Compose(
[self.rescaler],
additional_targets={"segmentation": "image"})
def __len__(self):
return len(self.labels["image_ids"])
def preprocess_image(self, image_path, segmentation_path):
image = Image.open(image_path)
if not image.mode == "RGB":
image = image.convert("RGB")
image = np.array(image).astype(np.uint8)
segmentation = Image.open(segmentation_path)
if not self.onehot and not segmentation.mode == "RGB":
segmentation = segmentation.convert("RGB")
segmentation = np.array(segmentation).astype(np.uint8)
if self.onehot:
assert self.stuffthing
# stored in caffe format: unlabeled==255. stuff and thing from
# 0-181. to be compatible with the labels in
# https://github.com/nightrome/cocostuff/blob/master/labels.txt
# we shift stuffthing one to the right and put unlabeled in zero
# as long as segmentation is uint8 shifting to right handles the
# latter too
assert segmentation.dtype == np.uint8
segmentation = segmentation + 1
processed = self.preprocessor(image=image, segmentation=segmentation)
image, segmentation = processed["image"], processed["segmentation"]
image = (image / 127.5 - 1.0).astype(np.float32)
if self.onehot:
assert segmentation.dtype == np.uint8
# make it one hot
n_labels = 183
flatseg = np.ravel(segmentation)
onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool)
onehot[np.arange(flatseg.size), flatseg] = True
onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int)
segmentation = onehot
else:
segmentation = (segmentation / 127.5 - 1.0).astype(np.float32)
return image, segmentation
def __getitem__(self, i):
img_path = self.img_id_to_filepath[self.labels["image_ids"][i]]
seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]]
image, segmentation = self.preprocess_image(img_path, seg_path)
captions = self.img_id_to_captions[self.labels["image_ids"][i]]
# randomly draw one of all available captions per image
caption = captions[np.random.randint(0, len(captions))]
example = {"image": image,
"caption": [str(caption[0])],
"segmentation": segmentation,
"img_path": img_path,
"seg_path": seg_path,
"filename_": img_path.split(os.sep)[-1]
}
return example
class CocoImagesAndCaptionsTrain(CocoBase):
"""returns a pair of (image, caption)"""
def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False):
super().__init__(size=size,
dataroot="data/coco/train2017",
datajson="data/coco/annotations/captions_train2017.json",
onehot_segmentation=onehot_segmentation,
use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop)
def get_split(self):
return "train"
class CocoImagesAndCaptionsValidation(CocoBase):
"""returns a pair of (image, caption)"""
def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,
given_files=None):
super().__init__(size=size,
dataroot="data/coco/val2017",
datajson="data/coco/annotations/captions_val2017.json",
onehot_segmentation=onehot_segmentation,
use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
given_files=given_files)
def get_split(self):
return "validation"
from itertools import cycle
from typing import List, Tuple, Callable, Optional
from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
from more_itertools.recipes import grouper
from taming.data.image_transforms import convert_pil_to_tensor
from torch import LongTensor, Tensor
from taming.data.helper_types import BoundingBox, Annotation
from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \
pad_list, get_plot_font_size, absolute_bbox
class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder):
@property
def object_descriptor_length(self) -> int:
return 3
def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
object_triples = [
(self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox))
for ann in annotations
]
empty_triple = (self.none, self.none, self.none)
object_triples = pad_list(object_triples, empty_triple, self.no_max_objects)
return object_triples
def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]:
conditional_list = conditional.tolist()
crop_coordinates = None
if self.encode_crop:
crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
conditional_list = conditional_list[:-2]
object_triples = grouper(conditional_list, 3)
assert conditional.shape[0] == self.embedding_dim
return [
(object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2]))
for object_triple in object_triples if object_triple[0] != self.none
], crop_coordinates
def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
plot = pil_image.new('RGB', figure_size, WHITE)
draw = pil_img_draw.Draw(plot)
font = ImageFont.truetype(
"/usr/share/fonts/truetype/lato/Lato-Regular.ttf",
size=get_plot_font_size(font_size, figure_size)
)
width, height = plot.size
description, crop_coordinates = self.inverse_build(conditional)
for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)):
annotation = self.representation_to_annotation(representation)
class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation)
bbox = absolute_bbox(bbox, width, height)
draw.rectangle(bbox, outline=color, width=line_width)
draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font)
if crop_coordinates is not None:
draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
return convert_pil_to_tensor(plot) / 127.5 - 1.
import math
import random
import warnings
from itertools import cycle
from typing import List, Optional, Tuple, Callable
from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
from more_itertools.recipes import grouper
from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, FULL_CROP, filter_annotations, \
additional_parameters_string, horizontally_flip_bbox, pad_list, get_circle_size, get_plot_font_size, \
absolute_bbox, rescale_annotations
from taming.data.helper_types import BoundingBox, Annotation
from taming.data.image_transforms import convert_pil_to_tensor
from torch import LongTensor, Tensor
class ObjectsCenterPointsConditionalBuilder:
def __init__(self, no_object_classes: int, no_max_objects: int, no_tokens: int, encode_crop: bool,
use_group_parameter: bool, use_additional_parameters: bool):
self.no_object_classes = no_object_classes
self.no_max_objects = no_max_objects
self.no_tokens = no_tokens
self.encode_crop = encode_crop
self.no_sections = int(math.sqrt(self.no_tokens))
self.use_group_parameter = use_group_parameter
self.use_additional_parameters = use_additional_parameters
@property
def none(self) -> int:
return self.no_tokens - 1
@property
def object_descriptor_length(self) -> int:
return 2
@property
def embedding_dim(self) -> int:
extra_length = 2 if self.encode_crop else 0
return self.no_max_objects * self.object_descriptor_length + extra_length
def tokenize_coordinates(self, x: float, y: float) -> int:
"""
Express 2d coordinates with one number.
Example: assume self.no_tokens = 16, then no_sections = 4:
0 0 0 0
0 0 # 0
0 0 0 0
0 0 0 x
Then the # position corresponds to token 6, the x position to token 15.
@param x: float in [0, 1]
@param y: float in [0, 1]
@return: discrete tokenized coordinate
"""
x_discrete = int(round(x * (self.no_sections - 1)))
y_discrete = int(round(y * (self.no_sections - 1)))
return y_discrete * self.no_sections + x_discrete
def coordinates_from_token(self, token: int) -> (float, float):
x = token % self.no_sections
y = token // self.no_sections
return x / (self.no_sections - 1), y / (self.no_sections - 1)
def bbox_from_token_pair(self, token1: int, token2: int) -> BoundingBox:
x0, y0 = self.coordinates_from_token(token1)
x1, y1 = self.coordinates_from_token(token2)
return x0, y0, x1 - x0, y1 - y0
def token_pair_from_bbox(self, bbox: BoundingBox) -> Tuple[int, int]:
return self.tokenize_coordinates(bbox[0], bbox[1]), \
self.tokenize_coordinates(bbox[0] + bbox[2], bbox[1] + bbox[3])
def inverse_build(self, conditional: LongTensor) \
-> Tuple[List[Tuple[int, Tuple[float, float]]], Optional[BoundingBox]]:
conditional_list = conditional.tolist()
crop_coordinates = None
if self.encode_crop:
crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
conditional_list = conditional_list[:-2]
table_of_content = grouper(conditional_list, self.object_descriptor_length)
assert conditional.shape[0] == self.embedding_dim
return [
(object_tuple[0], self.coordinates_from_token(object_tuple[1]))
for object_tuple in table_of_content if object_tuple[0] != self.none
], crop_coordinates
def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
plot = pil_image.new('RGB', figure_size, WHITE)
draw = pil_img_draw.Draw(plot)
circle_size = get_circle_size(figure_size)
font = ImageFont.truetype('/usr/share/fonts/truetype/lato/Lato-Regular.ttf',
size=get_plot_font_size(font_size, figure_size))
width, height = plot.size
description, crop_coordinates = self.inverse_build(conditional)
for (representation, (x, y)), color in zip(description, cycle(COLOR_PALETTE)):
x_abs, y_abs = x * width, y * height
ann = self.representation_to_annotation(representation)
label = label_for_category_no(ann.category_no) + ' ' + additional_parameters_string(ann)
ellipse_bbox = [x_abs - circle_size, y_abs - circle_size, x_abs + circle_size, y_abs + circle_size]
draw.ellipse(ellipse_bbox, fill=color, width=0)
draw.text((x_abs, y_abs), label, anchor='md', fill=BLACK, font=font)
if crop_coordinates is not None:
draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
return convert_pil_to_tensor(plot) / 127.5 - 1.
def object_representation(self, annotation: Annotation) -> int:
modifier = 0
if self.use_group_parameter:
modifier |= 1 * (annotation.is_group_of is True)
if self.use_additional_parameters:
modifier |= 2 * (annotation.is_occluded is True)
modifier |= 4 * (annotation.is_depiction is True)
modifier |= 8 * (annotation.is_inside is True)
return annotation.category_no + self.no_object_classes * modifier
def representation_to_annotation(self, representation: int) -> Annotation:
category_no = representation % self.no_object_classes
modifier = representation // self.no_object_classes
# noinspection PyTypeChecker
return Annotation(
area=None, image_id=None, bbox=None, category_id=None, id=None, source=None, confidence=None,
category_no=category_no,
is_group_of=bool((modifier & 1) * self.use_group_parameter),
is_occluded=bool((modifier & 2) * self.use_additional_parameters),
is_depiction=bool((modifier & 4) * self.use_additional_parameters),
is_inside=bool((modifier & 8) * self.use_additional_parameters)
)
def _crop_encoder(self, crop_coordinates: BoundingBox) -> List[int]:
return list(self.token_pair_from_bbox(crop_coordinates))
def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
object_tuples = [
(self.object_representation(a),
self.tokenize_coordinates(a.bbox[0] + a.bbox[2] / 2, a.bbox[1] + a.bbox[3] / 2))
for a in annotations
]
empty_tuple = (self.none, self.none)
object_tuples = pad_list(object_tuples, empty_tuple, self.no_max_objects)
return object_tuples
def build(self, annotations: List, crop_coordinates: Optional[BoundingBox] = None, horizontal_flip: bool = False) \
-> LongTensor:
if len(annotations) == 0:
warnings.warn('Did not receive any annotations.')
if len(annotations) > self.no_max_objects:
warnings.warn('Received more annotations than allowed.')
annotations = annotations[:self.no_max_objects]
if not crop_coordinates:
crop_coordinates = FULL_CROP
random.shuffle(annotations)
annotations = filter_annotations(annotations, crop_coordinates)
if self.encode_crop:
annotations = rescale_annotations(annotations, FULL_CROP, horizontal_flip)
if horizontal_flip:
crop_coordinates = horizontally_flip_bbox(crop_coordinates)
extra = self._crop_encoder(crop_coordinates)
else:
annotations = rescale_annotations(annotations, crop_coordinates, horizontal_flip)
extra = []
object_tuples = self._make_object_descriptors(annotations)
flattened = [token for tuple_ in object_tuples for token in tuple_] + extra
assert len(flattened) == self.embedding_dim
assert all(0 <= value < self.no_tokens for value in flattened)
return LongTensor(flattened)
import importlib
from typing import List, Any, Tuple, Optional
from taming.data.helper_types import BoundingBox, Annotation
# source: seaborn, color palette tab10
COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188),
(139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)]
BLACK = (0, 0, 0)
GRAY_75 = (63, 63, 63)
GRAY_50 = (127, 127, 127)
GRAY_25 = (191, 191, 191)
WHITE = (255, 255, 255)
FULL_CROP = (0., 0., 1., 1.)
def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float:
"""
Give intersection area of two rectangles.
@param rectangle1: (x0, y0, w, h) of first rectangle
@param rectangle2: (x0, y0, w, h) of second rectangle
"""
rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3]
rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3]
x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0]))
y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1]))
return x_overlap * y_overlap
def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox:
return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3]
def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]:
bbox = relative_bbox
bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height
return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List:
return list_ + [pad_element for _ in range(pad_to_length - len(list_))]
def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \
List[Annotation]:
def clamp(x: float):
return max(min(x, 1.), 0.)
def rescale_bbox(bbox: BoundingBox) -> BoundingBox:
x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
w = min(bbox[2] / crop_coordinates[2], 1 - x0)
h = min(bbox[3] / crop_coordinates[3], 1 - y0)
if flip:
x0 = 1 - (x0 + w)
return x0, y0, w, h
return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations]
def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List:
return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0]
def additional_parameters_string(annotation: Annotation, short: bool = True) -> str:
sl = slice(1) if short else slice(None)
string = ''
if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside):
return string
if annotation.is_group_of:
string += 'group'[sl] + ','
if annotation.is_occluded:
string += 'occluded'[sl] + ','
if annotation.is_depiction:
string += 'depiction'[sl] + ','
if annotation.is_inside:
string += 'inside'[sl]
return '(' + string.strip(",") + ')'
def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int:
if font_size is None:
font_size = 10
if max(figure_size) >= 256:
font_size = 12
if max(figure_size) >= 512:
font_size = 15
return font_size
def get_circle_size(figure_size: Tuple[int, int]) -> int:
circle_size = 2
if max(figure_size) >= 256:
circle_size = 3
if max(figure_size) >= 512:
circle_size = 4
return circle_size
def load_object_from_string(object_string: str) -> Any:
"""
Source: https://stackoverflow.com/a/10773699
"""
module_name, class_name = object_string.rsplit(".", 1)
return getattr(importlib.import_module(module_name), class_name)
import os
import numpy as np
import albumentations
from torch.utils.data import Dataset
from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
class CustomBase(Dataset):
def __init__(self, *args, **kwargs):
super().__init__()
self.data = None
def __len__(self):
return len(self.data)
def __getitem__(self, i):
example = self.data[i]
return example
class CustomTrain(CustomBase):
def __init__(self, size, training_images_list_file):
super().__init__()
with open(training_images_list_file, "r") as f:
paths = f.read().splitlines()
self.data = ImagePaths(paths=paths, size=size, random_crop=False)
class CustomTest(CustomBase):
def __init__(self, size, test_images_list_file):
super().__init__()
with open(test_images_list_file, "r") as f:
paths = f.read().splitlines()
self.data = ImagePaths(paths=paths, size=size, random_crop=False)
import os
import numpy as np
import albumentations
from torch.utils.data import Dataset
from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
class FacesBase(Dataset):
def __init__(self, *args, **kwargs):
super().__init__()
self.data = None
self.keys = None
def __len__(self):
return len(self.data)
def __getitem__(self, i):
example = self.data[i]
ex = {}
if self.keys is not None:
for k in self.keys:
ex[k] = example[k]
else:
ex = example
return ex
class CelebAHQTrain(FacesBase):
def __init__(self, size, keys=None):
super().__init__()
root = "data/celebahq"
with open("data/celebahqtrain.txt", "r") as f:
relpaths = f.read().splitlines()
paths = [os.path.join(root, relpath) for relpath in relpaths]
self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
self.keys = keys
class CelebAHQValidation(FacesBase):
def __init__(self, size, keys=None):
super().__init__()
root = "data/celebahq"
with open("data/celebahqvalidation.txt", "r") as f:
relpaths = f.read().splitlines()
paths = [os.path.join(root, relpath) for relpath in relpaths]
self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
self.keys = keys
class FFHQTrain(FacesBase):
def __init__(self, size, keys=None):
super().__init__()
root = "data/ffhq"
with open("data/ffhqtrain.txt", "r") as f:
relpaths = f.read().splitlines()
paths = [os.path.join(root, relpath) for relpath in relpaths]
self.data = ImagePaths(paths=paths, size=size, random_crop=False)
self.keys = keys
class FFHQValidation(FacesBase):
def __init__(self, size, keys=None):
super().__init__()
root = "data/ffhq"
with open("data/ffhqvalidation.txt", "r") as f:
relpaths = f.read().splitlines()
paths = [os.path.join(root, relpath) for relpath in relpaths]
self.data = ImagePaths(paths=paths, size=size, random_crop=False)
self.keys = keys
class FacesHQTrain(Dataset):
# CelebAHQ [0] + FFHQ [1]
def __init__(self, size, keys=None, crop_size=None, coord=False):
d1 = CelebAHQTrain(size=size, keys=keys)
d2 = FFHQTrain(size=size, keys=keys)
self.data = ConcatDatasetWithIndex([d1, d2])
self.coord = coord
if crop_size is not None:
self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size)
if self.coord:
self.cropper = albumentations.Compose([self.cropper],
additional_targets={"coord": "image"})
def __len__(self):
return len(self.data)
def __getitem__(self, i):
ex, y = self.data[i]
if hasattr(self, "cropper"):
if not self.coord:
out = self.cropper(image=ex["image"])
ex["image"] = out["image"]
else:
h,w,_ = ex["image"].shape
coord = np.arange(h*w).reshape(h,w,1)/(h*w)
out = self.cropper(image=ex["image"], coord=coord)
ex["image"] = out["image"]
ex["coord"] = out["coord"]
ex["class"] = y
return ex
class FacesHQValidation(Dataset):
# CelebAHQ [0] + FFHQ [1]
def __init__(self, size, keys=None, crop_size=None, coord=False):
d1 = CelebAHQValidation(size=size, keys=keys)
d2 = FFHQValidation(size=size, keys=keys)
self.data = ConcatDatasetWithIndex([d1, d2])
self.coord = coord
if crop_size is not None:
self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size)
if self.coord:
self.cropper = albumentations.Compose([self.cropper],
additional_targets={"coord": "image"})
def __len__(self):
return len(self.data)
def __getitem__(self, i):
ex, y = self.data[i]
if hasattr(self, "cropper"):
if not self.coord:
out = self.cropper(image=ex["image"])
ex["image"] = out["image"]
else:
h,w,_ = ex["image"].shape
coord = np.arange(h*w).reshape(h,w,1)/(h*w)
out = self.cropper(image=ex["image"], coord=coord)
ex["image"] = out["image"]
ex["coord"] = out["coord"]
ex["class"] = y
return ex
from typing import Dict, Tuple, Optional, NamedTuple, Union
from PIL.Image import Image as pil_image
from torch import Tensor
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
Image = Union[Tensor, pil_image]
BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h
CropMethodType = Literal['none', 'random', 'center', 'random-2d']
SplitType = Literal['train', 'validation', 'test']
class ImageDescription(NamedTuple):
id: int
file_name: str
original_size: Tuple[int, int] # w, h
url: Optional[str] = None
license: Optional[int] = None
coco_url: Optional[str] = None
date_captured: Optional[str] = None
flickr_url: Optional[str] = None
flickr_id: Optional[str] = None
coco_id: Optional[str] = None
class Category(NamedTuple):
id: str
super_category: Optional[str]
name: str
class Annotation(NamedTuple):
area: float
image_id: str
bbox: BoundingBox
category_no: int
category_id: str
id: Optional[int] = None
source: Optional[str] = None
confidence: Optional[float] = None
is_group_of: Optional[bool] = None
is_truncated: Optional[bool] = None
is_occluded: Optional[bool] = None
is_depiction: Optional[bool] = None
is_inside: Optional[bool] = None
segmentation: Optional[Dict] = None
import random
import warnings
from typing import Union
import torch
from torch import Tensor
from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor
from torchvision.transforms.functional import _get_image_size as get_image_size
from taming.data.helper_types import BoundingBox, Image
pil_to_tensor = PILToTensor()
def convert_pil_to_tensor(image: Image) -> Tensor:
with warnings.catch_warnings():
# to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194
warnings.simplefilter("ignore")
return pil_to_tensor(image)
class RandomCrop1dReturnCoordinates(RandomCrop):
def forward(self, img: Image) -> (BoundingBox, Image):
"""
Additionally to cropping, returns the relative coordinates of the crop bounding box.
Args:
img (PIL Image or Tensor): Image to be cropped.
Returns:
Bounding box: x0, y0, w, h
PIL Image or Tensor: Cropped image.
Based on:
torchvision.transforms.RandomCrop, torchvision 1.7.0
"""
if self.padding is not None:
img = F.pad(img, self.padding, self.fill, self.padding_mode)
width, height = get_image_size(img)
# pad the width if needed
if self.pad_if_needed and width < self.size[1]:
padding = [self.size[1] - width, 0]
img = F.pad(img, padding, self.fill, self.padding_mode)
# pad the height if needed
if self.pad_if_needed and height < self.size[0]:
padding = [0, self.size[0] - height]
img = F.pad(img, padding, self.fill, self.padding_mode)
i, j, h, w = self.get_params(img, self.size)
bbox = (j / width, i / height, w / width, h / height) # x0, y0, w, h
return bbox, F.crop(img, i, j, h, w)
class Random2dCropReturnCoordinates(torch.nn.Module):
"""
Additionally to cropping, returns the relative coordinates of the crop bounding box.
Args:
img (PIL Image or Tensor): Image to be cropped.
Returns:
Bounding box: x0, y0, w, h
PIL Image or Tensor: Cropped image.
Based on:
torchvision.transforms.RandomCrop, torchvision 1.7.0
"""
def __init__(self, min_size: int):
super().__init__()
self.min_size = min_size
def forward(self, img: Image) -> (BoundingBox, Image):
width, height = get_image_size(img)
max_size = min(width, height)
if max_size <= self.min_size:
size = max_size
else:
size = random.randint(self.min_size, max_size)
top = random.randint(0, height - size)
left = random.randint(0, width - size)
bbox = left / width, top / height, size / width, size / height
return bbox, F.crop(img, top, left, size, size)
class CenterCropReturnCoordinates(CenterCrop):
@staticmethod
def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox:
if width > height:
w = height / width
h = 1.0
x0 = 0.5 - w / 2
y0 = 0.
else:
w = 1.0
h = width / height
x0 = 0.
y0 = 0.5 - h / 2
return x0, y0, w, h
def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tensor]):
"""
Additionally to cropping, returns the relative coordinates of the crop bounding box.
Args:
img (PIL Image or Tensor): Image to be cropped.
Returns:
Bounding box: x0, y0, w, h
PIL Image or Tensor: Cropped image.
Based on:
torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
"""
width, height = get_image_size(img)
return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size)
class RandomHorizontalFlipReturn(RandomHorizontalFlip):
def forward(self, img: Image) -> (bool, Image):
"""
Additionally to flipping, returns a boolean whether it was flipped or not.
Args:
img (PIL Image or Tensor): Image to be flipped.
Returns:
flipped: whether the image was flipped or not
PIL Image or Tensor: Randomly flipped image.
Based on:
torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
"""
if torch.rand(1) < self.p:
return True, F.hflip(img)
return False, img
import os, tarfile, glob, shutil
import yaml
import numpy as np
from tqdm import tqdm
from PIL import Image
import albumentations
from omegaconf import OmegaConf
from torch.utils.data import Dataset
from taming.data.base import ImagePaths
from taming.util import download, retrieve
import taming.data.utils as bdu
def give_synsets_from_indices(indices, path_to_yaml="data/imagenet_idx_to_synset.yaml"):
synsets = []
with open(path_to_yaml) as f:
di2s = yaml.load(f)
for idx in indices:
synsets.append(str(di2s[idx]))
print("Using {} different synsets for construction of Restriced Imagenet.".format(len(synsets)))
return synsets
def str_to_indices(string):
"""Expects a string in the format '32-123, 256, 280-321'"""
assert not string.endswith(","), "provided string '{}' ends with a comma, pls remove it".format(string)
subs = string.split(",")
indices = []
for sub in subs:
subsubs = sub.split("-")
assert len(subsubs) > 0
if len(subsubs) == 1:
indices.append(int(subsubs[0]))
else:
rang = [j for j in range(int(subsubs[0]), int(subsubs[1]))]
indices.extend(rang)
return sorted(indices)
class ImageNetBase(Dataset):
def __init__(self, config=None):
self.config = config or OmegaConf.create()
if not type(self.config)==dict:
self.config = OmegaConf.to_container(self.config)
self._prepare()
self._prepare_synset_to_human()
self._prepare_idx_to_synset()
self._load()
def __len__(self):
return len(self.data)
def __getitem__(self, i):
return self.data[i]
def _prepare(self):
raise NotImplementedError()
def _filter_relpaths(self, relpaths):
ignore = set([
"n06596364_9591.JPEG",
])
relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
if "sub_indices" in self.config:
indices = str_to_indices(self.config["sub_indices"])
synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
files = []
for rpath in relpaths:
syn = rpath.split("/")[0]
if syn in synsets:
files.append(rpath)
return files
else:
return relpaths
def _prepare_synset_to_human(self):
SIZE = 2655750
URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
self.human_dict = os.path.join(self.root, "synset_human.txt")
if (not os.path.exists(self.human_dict) or
not os.path.getsize(self.human_dict)==SIZE):
download(URL, self.human_dict)
def _prepare_idx_to_synset(self):
URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
self.idx2syn = os.path.join(self.root, "index_synset.yaml")
if (not os.path.exists(self.idx2syn)):
download(URL, self.idx2syn)
def _load(self):
with open(self.txt_filelist, "r") as f:
self.relpaths = f.read().splitlines()
l1 = len(self.relpaths)
self.relpaths = self._filter_relpaths(self.relpaths)
print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
self.synsets = [p.split("/")[0] for p in self.relpaths]
self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
unique_synsets = np.unique(self.synsets)
class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
self.class_labels = [class_dict[s] for s in self.synsets]
with open(self.human_dict, "r") as f:
human_dict = f.read().splitlines()
human_dict = dict(line.split(maxsplit=1) for line in human_dict)
self.human_labels = [human_dict[s] for s in self.synsets]
labels = {
"relpath": np.array(self.relpaths),
"synsets": np.array(self.synsets),
"class_label": np.array(self.class_labels),
"human_label": np.array(self.human_labels),
}
self.data = ImagePaths(self.abspaths,
labels=labels,
size=retrieve(self.config, "size", default=0),
random_crop=self.random_crop)
class ImageNetTrain(ImageNetBase):
NAME = "ILSVRC2012_train"
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
FILES = [
"ILSVRC2012_img_train.tar",
]
SIZES = [
147897477120,
]
def _prepare(self):
self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
default=True)
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
self.datadir = os.path.join(self.root, "data")
self.txt_filelist = os.path.join(self.root, "filelist.txt")
self.expected_length = 1281167
if not bdu.is_prepared(self.root):
# prep
print("Preparing dataset {} in {}".format(self.NAME, self.root))
datadir = self.datadir
if not os.path.exists(datadir):
path = os.path.join(self.root, self.FILES[0])
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
import academictorrents as at
atpath = at.get(self.AT_HASH, datastore=self.root)
assert atpath == path
print("Extracting {} to {}".format(path, datadir))
os.makedirs(datadir, exist_ok=True)
with tarfile.open(path, "r:") as tar:
tar.extractall(path=datadir)
print("Extracting sub-tars.")
subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
for subpath in tqdm(subpaths):
subdir = subpath[:-len(".tar")]
os.makedirs(subdir, exist_ok=True)
with tarfile.open(subpath, "r:") as tar:
tar.extractall(path=subdir)
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
filelist = sorted(filelist)
filelist = "\n".join(filelist)+"\n"
with open(self.txt_filelist, "w") as f:
f.write(filelist)
bdu.mark_prepared(self.root)
class ImageNetValidation(ImageNetBase):
NAME = "ILSVRC2012_validation"
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
FILES = [
"ILSVRC2012_img_val.tar",
"validation_synset.txt",
]
SIZES = [
6744924160,
1950000,
]
def _prepare(self):
self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
default=False)
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
self.datadir = os.path.join(self.root, "data")
self.txt_filelist = os.path.join(self.root, "filelist.txt")
self.expected_length = 50000
if not bdu.is_prepared(self.root):
# prep
print("Preparing dataset {} in {}".format(self.NAME, self.root))
datadir = self.datadir
if not os.path.exists(datadir):
path = os.path.join(self.root, self.FILES[0])
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
import academictorrents as at
atpath = at.get(self.AT_HASH, datastore=self.root)
assert atpath == path
print("Extracting {} to {}".format(path, datadir))
os.makedirs(datadir, exist_ok=True)
with tarfile.open(path, "r:") as tar:
tar.extractall(path=datadir)
vspath = os.path.join(self.root, self.FILES[1])
if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
download(self.VS_URL, vspath)
with open(vspath, "r") as f:
synset_dict = f.read().splitlines()
synset_dict = dict(line.split() for line in synset_dict)
print("Reorganizing into synset folders")
synsets = np.unique(list(synset_dict.values()))
for s in synsets:
os.makedirs(os.path.join(datadir, s), exist_ok=True)
for k, v in synset_dict.items():
src = os.path.join(datadir, k)
dst = os.path.join(datadir, v)
shutil.move(src, dst)
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
filelist = sorted(filelist)
filelist = "\n".join(filelist)+"\n"
with open(self.txt_filelist, "w") as f:
f.write(filelist)
bdu.mark_prepared(self.root)
def get_preprocessor(size=None, random_crop=False, additional_targets=None,
crop_size=None):
if size is not None and size > 0:
transforms = list()
rescaler = albumentations.SmallestMaxSize(max_size = size)
transforms.append(rescaler)
if not random_crop:
cropper = albumentations.CenterCrop(height=size,width=size)
transforms.append(cropper)
else:
cropper = albumentations.RandomCrop(height=size,width=size)
transforms.append(cropper)
flipper = albumentations.HorizontalFlip()
transforms.append(flipper)
preprocessor = albumentations.Compose(transforms,
additional_targets=additional_targets)
elif crop_size is not None and crop_size > 0:
if not random_crop:
cropper = albumentations.CenterCrop(height=crop_size,width=crop_size)
else:
cropper = albumentations.RandomCrop(height=crop_size,width=crop_size)
transforms = [cropper]
preprocessor = albumentations.Compose(transforms,
additional_targets=additional_targets)
else:
preprocessor = lambda **kwargs: kwargs
return preprocessor
def rgba_to_depth(x):
assert x.dtype == np.uint8
assert len(x.shape) == 3 and x.shape[2] == 4
y = x.copy()
y.dtype = np.float32
y = y.reshape(x.shape[:2])
return np.ascontiguousarray(y)
class BaseWithDepth(Dataset):
DEFAULT_DEPTH_ROOT="data/imagenet_depth"
def __init__(self, config=None, size=None, random_crop=False,
crop_size=None, root=None):
self.config = config
self.base_dset = self.get_base_dset()
self.preprocessor = get_preprocessor(
size=size,
crop_size=crop_size,
random_crop=random_crop,
additional_targets={"depth": "image"})
self.crop_size = crop_size
if self.crop_size is not None:
self.rescaler = albumentations.Compose(
[albumentations.SmallestMaxSize(max_size = self.crop_size)],
additional_targets={"depth": "image"})
if root is not None:
self.DEFAULT_DEPTH_ROOT = root
def __len__(self):
return len(self.base_dset)
def preprocess_depth(self, path):
rgba = np.array(Image.open(path))
depth = rgba_to_depth(rgba)
depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min())
depth = 2.0*depth-1.0
return depth
def __getitem__(self, i):
e = self.base_dset[i]
e["depth"] = self.preprocess_depth(self.get_depth_path(e))
# up if necessary
h,w,c = e["image"].shape
if self.crop_size and min(h,w) < self.crop_size:
# have to upscale to be able to crop - this just uses bilinear
out = self.rescaler(image=e["image"], depth=e["depth"])
e["image"] = out["image"]
e["depth"] = out["depth"]
transformed = self.preprocessor(image=e["image"], depth=e["depth"])
e["image"] = transformed["image"]
e["depth"] = transformed["depth"]
return e
class ImageNetTrainWithDepth(BaseWithDepth):
# default to random_crop=True
def __init__(self, random_crop=True, sub_indices=None, **kwargs):
self.sub_indices = sub_indices
super().__init__(random_crop=random_crop, **kwargs)
def get_base_dset(self):
if self.sub_indices is None:
return ImageNetTrain()
else:
return ImageNetTrain({"sub_indices": self.sub_indices})
def get_depth_path(self, e):
fid = os.path.splitext(e["relpath"])[0]+".png"
fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "train", fid)
return fid
class ImageNetValidationWithDepth(BaseWithDepth):
def __init__(self, sub_indices=None, **kwargs):
self.sub_indices = sub_indices
super().__init__(**kwargs)
def get_base_dset(self):
if self.sub_indices is None:
return ImageNetValidation()
else:
return ImageNetValidation({"sub_indices": self.sub_indices})
def get_depth_path(self, e):
fid = os.path.splitext(e["relpath"])[0]+".png"
fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "val", fid)
return fid
class RINTrainWithDepth(ImageNetTrainWithDepth):
def __init__(self, config=None, size=None, random_crop=True, crop_size=None):
sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319"
super().__init__(config=config, size=size, random_crop=random_crop,
sub_indices=sub_indices, crop_size=crop_size)
class RINValidationWithDepth(ImageNetValidationWithDepth):
def __init__(self, config=None, size=None, random_crop=False, crop_size=None):
sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319"
super().__init__(config=config, size=size, random_crop=random_crop,
sub_indices=sub_indices, crop_size=crop_size)
class DRINExamples(Dataset):
def __init__(self):
self.preprocessor = get_preprocessor(size=256, additional_targets={"depth": "image"})
with open("data/drin_examples.txt", "r") as f:
relpaths = f.read().splitlines()
self.image_paths = [os.path.join("data/drin_images",
relpath) for relpath in relpaths]
self.depth_paths = [os.path.join("data/drin_depth",
relpath.replace(".JPEG", ".png")) for relpath in relpaths]
def __len__(self):
return len(self.image_paths)
def preprocess_image(self, image_path):
image = Image.open(image_path)
if not image.mode == "RGB":
image = image.convert("RGB")
image = np.array(image).astype(np.uint8)
image = self.preprocessor(image=image)["image"]
image = (image/127.5 - 1.0).astype(np.float32)
return image
def preprocess_depth(self, path):
rgba = np.array(Image.open(path))
depth = rgba_to_depth(rgba)
depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min())
depth = 2.0*depth-1.0
return depth
def __getitem__(self, i):
e = dict()
e["image"] = self.preprocess_image(self.image_paths[i])
e["depth"] = self.preprocess_depth(self.depth_paths[i])
transformed = self.preprocessor(image=e["image"], depth=e["depth"])
e["image"] = transformed["image"]
e["depth"] = transformed["depth"]
return e
def imscale(x, factor, keepshapes=False, keepmode="bicubic"):
if factor is None or factor==1:
return x
dtype = x.dtype
assert dtype in [np.float32, np.float64]
assert x.min() >= -1
assert x.max() <= 1
keepmode = {"nearest": Image.NEAREST, "bilinear": Image.BILINEAR,
"bicubic": Image.BICUBIC}[keepmode]
lr = (x+1.0)*127.5
lr = lr.clip(0,255).astype(np.uint8)
lr = Image.fromarray(lr)
h, w, _ = x.shape
nh = h//factor
nw = w//factor
assert nh > 0 and nw > 0, (nh, nw)
lr = lr.resize((nw,nh), Image.BICUBIC)
if keepshapes:
lr = lr.resize((w,h), keepmode)
lr = np.array(lr)/127.5-1.0
lr = lr.astype(dtype)
return lr
class ImageNetScale(Dataset):
def __init__(self, size=None, crop_size=None, random_crop=False,
up_factor=None, hr_factor=None, keep_mode="bicubic"):
self.base = self.get_base()
self.size = size
self.crop_size = crop_size if crop_size is not None else self.size
self.random_crop = random_crop
self.up_factor = up_factor
self.hr_factor = hr_factor
self.keep_mode = keep_mode
transforms = list()
if self.size is not None and self.size > 0:
rescaler = albumentations.SmallestMaxSize(max_size = self.size)
self.rescaler = rescaler
transforms.append(rescaler)
if self.crop_size is not None and self.crop_size > 0:
if len(transforms) == 0:
self.rescaler = albumentations.SmallestMaxSize(max_size = self.crop_size)
if not self.random_crop:
cropper = albumentations.CenterCrop(height=self.crop_size,width=self.crop_size)
else:
cropper = albumentations.RandomCrop(height=self.crop_size,width=self.crop_size)
transforms.append(cropper)
if len(transforms) > 0:
if self.up_factor is not None:
additional_targets = {"lr": "image"}
else:
additional_targets = None
self.preprocessor = albumentations.Compose(transforms,
additional_targets=additional_targets)
else:
self.preprocessor = lambda **kwargs: kwargs
def __len__(self):
return len(self.base)
def __getitem__(self, i):
example = self.base[i]
image = example["image"]
# adjust resolution
image = imscale(image, self.hr_factor, keepshapes=False)
h,w,c = image.shape
if self.crop_size and min(h,w) < self.crop_size:
# have to upscale to be able to crop - this just uses bilinear
image = self.rescaler(image=image)["image"]
if self.up_factor is None:
image = self.preprocessor(image=image)["image"]
example["image"] = image
else:
lr = imscale(image, self.up_factor, keepshapes=True,
keepmode=self.keep_mode)
out = self.preprocessor(image=image, lr=lr)
example["image"] = out["image"]
example["lr"] = out["lr"]
return example
class ImageNetScaleTrain(ImageNetScale):
def __init__(self, random_crop=True, **kwargs):
super().__init__(random_crop=random_crop, **kwargs)
def get_base(self):
return ImageNetTrain()
class ImageNetScaleValidation(ImageNetScale):
def get_base(self):
return ImageNetValidation()
from skimage.feature import canny
from skimage.color import rgb2gray
class ImageNetEdges(ImageNetScale):
def __init__(self, up_factor=1, **kwargs):
super().__init__(up_factor=1, **kwargs)
def __getitem__(self, i):
example = self.base[i]
image = example["image"]
h,w,c = image.shape
if self.crop_size and min(h,w) < self.crop_size:
# have to upscale to be able to crop - this just uses bilinear
image = self.rescaler(image=image)["image"]
lr = canny(rgb2gray(image), sigma=2)
lr = lr.astype(np.float32)
lr = lr[:,:,None][:,:,[0,0,0]]
out = self.preprocessor(image=image, lr=lr)
example["image"] = out["image"]
example["lr"] = out["lr"]
return example
class ImageNetEdgesTrain(ImageNetEdges):
def __init__(self, random_crop=True, **kwargs):
super().__init__(random_crop=random_crop, **kwargs)
def get_base(self):
return ImageNetTrain()
class ImageNetEdgesValidation(ImageNetEdges):
def get_base(self):
return ImageNetValidation()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment