Commit 01db7703 authored by mashun1's avatar mashun1
Browse files

taming-transformer

parents
Pipeline #801 canceled with stages
# 模型唯一标识
modelCode = 544
# 模型名称
modelName = taming-transformers_pytorch
# 模型描述
modelDescription = taming-transformers可以用来生成及补全图像。
# 应用场景
appScenario = 训练,推理,AIGC,媒体,科研,教育
# 框架类型
frameType = pytorch
# torch==1.10.1
# torchvision==0.11.1
numpy>=1.19albumentations==0.4.3
opencv-python==4.1.2.30
pudb==2019.2
imageio==2.9.0
imageio-ffmpeg==0.4.2
pytorch-lightning==1.0.8
omegaconf==2.0.0
test-tube>=0.7.5
streamlit>=0.73.1
einops==0.3.0
more-itertools>=8.0.0
transformers
albumentations==0.4.3
packaging==21.3
pillow
timm
\ No newline at end of file
import os
import torch
import numpy as np
from tqdm import trange
from PIL import Image
def get_state(gpu):
import torch
midas = torch.hub.load("intel-isl/MiDaS", "MiDaS_small")
if gpu:
midas.cuda()
midas.eval()
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
transform = midas_transforms.small_transform
state = {"model": midas,
"transform": transform}
return state
def depth_to_rgba(x):
assert x.dtype == np.float32
assert len(x.shape) == 2
y = x.copy()
y.dtype = np.uint8
y = y.reshape(x.shape+(4,))
return np.ascontiguousarray(y)
def rgba_to_depth(x):
assert x.dtype == np.uint8
assert len(x.shape) == 3 and x.shape[2] == 4
y = x.copy()
y.dtype = np.float32
y = y.reshape(x.shape[:2])
return np.ascontiguousarray(y)
def run(x, state):
model = state["model"]
transform = state["transform"]
hw = x.shape[:2]
with torch.no_grad():
prediction = model(transform((x + 1.0) * 127.5).cuda())
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=hw,
mode="bicubic",
align_corners=False,
).squeeze()
output = prediction.cpu().numpy()
return output
def get_filename(relpath, level=-2):
# save class folder structure and filename:
fn = relpath.split(os.sep)[level:]
folder = fn[-2]
file = fn[-1].split('.')[0]
return folder, file
def save_depth(dataset, path, debug=False):
os.makedirs(path)
N = len(dset)
if debug:
N = 10
state = get_state(gpu=True)
for idx in trange(N, desc="Data"):
ex = dataset[idx]
image, relpath = ex["image"], ex["relpath"]
folder, filename = get_filename(relpath)
# prepare
folderabspath = os.path.join(path, folder)
os.makedirs(folderabspath, exist_ok=True)
savepath = os.path.join(folderabspath, filename)
# run model
xout = run(image, state)
I = depth_to_rgba(xout)
Image.fromarray(I).save("{}.png".format(savepath))
if __name__ == "__main__":
from taming.data.imagenet import ImageNetTrain, ImageNetValidation
out = "data/imagenet_depth"
if not os.path.exists(out):
print("Please create a folder or symlink '{}' to extract depth data ".format(out) +
"(be prepared that the output size will be larger than ImageNet itself).")
exit(1)
# go
dset = ImageNetValidation()
abspath = os.path.join(out, "val")
if os.path.exists(abspath):
print("{} exists - not doing anything.".format(abspath))
else:
print("preparing {}".format(abspath))
save_depth(dset, abspath)
print("done with validation split")
dset = ImageNetTrain()
abspath = os.path.join(out, "train")
if os.path.exists(abspath):
print("{} exists - not doing anything.".format(abspath))
else:
print("preparing {}".format(abspath))
save_depth(dset, abspath)
print("done with train split")
print("done done.")
import sys, os
import numpy as np
import scipy
import torch
import torch.nn as nn
from scipy import ndimage
from tqdm import tqdm, trange
from PIL import Image
import torch.hub
import torchvision
import torch.nn.functional as F
# download deeplabv2_resnet101_msc-cocostuff164k-100000.pth from
# https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/deeplabv2_resnet101_msc-cocostuff164k-100000.pth
# and put the path here
CKPT_PATH = "TODO"
rescale = lambda x: (x + 1.) / 2.
def rescale_bgr(x):
x = (x+1)*127.5
x = torch.flip(x, dims=[0])
return x
class COCOStuffSegmenter(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.n_labels = 182
model = torch.hub.load("kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", n_classes=self.n_labels)
ckpt_path = CKPT_PATH
model.load_state_dict(torch.load(ckpt_path))
self.model = model
normalize = torchvision.transforms.Normalize(mean=self.mean, std=self.std)
self.image_transform = torchvision.transforms.Compose([
torchvision.transforms.Lambda(lambda image: torch.stack(
[normalize(rescale_bgr(x)) for x in image]))
])
def forward(self, x, upsample=None):
x = self._pre_process(x)
x = self.model(x)
if upsample is not None:
x = torch.nn.functional.upsample_bilinear(x, size=upsample)
return x
def _pre_process(self, x):
x = self.image_transform(x)
return x
@property
def mean(self):
# bgr
return [104.008, 116.669, 122.675]
@property
def std(self):
return [1.0, 1.0, 1.0]
@property
def input_size(self):
return [3, 224, 224]
def run_model(img, model):
model = model.eval()
with torch.no_grad():
segmentation = model(img, upsample=(img.shape[2], img.shape[3]))
segmentation = torch.argmax(segmentation, dim=1, keepdim=True)
return segmentation.detach().cpu()
def get_input(batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
return x.float()
def save_segmentation(segmentation, path):
# --> class label to uint8, save as png
os.makedirs(os.path.dirname(path), exist_ok=True)
assert len(segmentation.shape)==4
assert segmentation.shape[0]==1
for seg in segmentation:
seg = seg.permute(1,2,0).numpy().squeeze().astype(np.uint8)
seg = Image.fromarray(seg)
seg.save(path)
def iterate_dataset(dataloader, destpath, model):
os.makedirs(destpath, exist_ok=True)
num_processed = 0
for i, batch in tqdm(enumerate(dataloader), desc="Data"):
try:
img = get_input(batch, "image")
img = img.cuda()
seg = run_model(img, model)
path = batch["relative_file_path_"][0]
path = os.path.splitext(path)[0]
path = os.path.join(destpath, path + ".png")
save_segmentation(seg, path)
num_processed += 1
except Exception as e:
print(e)
print("but anyhow..")
print("Processed {} files. Bye.".format(num_processed))
from taming.data.sflckr import Examples
from torch.utils.data import DataLoader
if __name__ == "__main__":
dest = sys.argv[1]
batchsize = 1
print("Running with batch-size {}, saving to {}...".format(batchsize, dest))
model = COCOStuffSegmenter({}).cuda()
print("Instantiated model.")
dataset = Examples()
dloader = DataLoader(dataset, batch_size=batchsize)
iterate_dataset(dataloader=dloader, destpath=dest, model=model)
print("done.")
import torch
import sys
if __name__ == "__main__":
inpath = sys.argv[1]
outpath = sys.argv[2]
submodel = "cond_stage_model"
if len(sys.argv) > 3:
submodel = sys.argv[3]
print("Extracting {} from {} to {}.".format(submodel, inpath, outpath))
sd = torch.load(inpath, map_location="cpu")
new_sd = {"state_dict": dict((k.split(".", 1)[-1],v)
for k,v in sd["state_dict"].items()
if k.startswith("cond_stage_model"))}
torch.save(new_sd, outpath)
import argparse, os, sys, glob, math, time
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from main import instantiate_from_config, DataModuleFromConfig
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from tqdm import trange
def save_image(x, path):
c,h,w = x.shape
assert c==3
x = ((x.detach().cpu().numpy().transpose(1,2,0)+1.0)*127.5).clip(0,255).astype(np.uint8)
Image.fromarray(x).save(path)
@torch.no_grad()
def run_conditional(model, dsets, outdir, top_k, temperature, batch_size=1):
if len(dsets.datasets) > 1:
split = sorted(dsets.datasets.keys())[0]
dset = dsets.datasets[split]
else:
dset = next(iter(dsets.datasets.values()))
print("Dataset: ", dset.__class__.__name__)
for start_idx in trange(0,len(dset)-batch_size+1,batch_size):
indices = list(range(start_idx, start_idx+batch_size))
example = default_collate([dset[i] for i in indices])
x = model.get_input("image", example).to(model.device)
for i in range(x.shape[0]):
save_image(x[i], os.path.join(outdir, "originals",
"{:06}.png".format(indices[i])))
cond_key = model.cond_stage_key
c = model.get_input(cond_key, example).to(model.device)
scale_factor = 1.0
quant_z, z_indices = model.encode_to_z(x)
quant_c, c_indices = model.encode_to_c(c)
cshape = quant_z.shape
xrec = model.first_stage_model.decode(quant_z)
for i in range(xrec.shape[0]):
save_image(xrec[i], os.path.join(outdir, "reconstructions",
"{:06}.png".format(indices[i])))
if cond_key == "segmentation":
# get image from segmentation mask
num_classes = c.shape[1]
c = torch.argmax(c, dim=1, keepdim=True)
c = torch.nn.functional.one_hot(c, num_classes=num_classes)
c = c.squeeze(1).permute(0, 3, 1, 2).float()
c = model.cond_stage_model.to_rgb(c)
idx = z_indices
half_sample = False
if half_sample:
start = idx.shape[1]//2
else:
start = 0
idx[:,start:] = 0
idx = idx.reshape(cshape[0],cshape[2],cshape[3])
start_i = start//cshape[3]
start_j = start %cshape[3]
cidx = c_indices
cidx = cidx.reshape(quant_c.shape[0],quant_c.shape[2],quant_c.shape[3])
sample = True
for i in range(start_i,cshape[2]-0):
if i <= 8:
local_i = i
elif cshape[2]-i < 8:
local_i = 16-(cshape[2]-i)
else:
local_i = 8
for j in range(start_j,cshape[3]-0):
if j <= 8:
local_j = j
elif cshape[3]-j < 8:
local_j = 16-(cshape[3]-j)
else:
local_j = 8
i_start = i-local_i
i_end = i_start+16
j_start = j-local_j
j_end = j_start+16
patch = idx[:,i_start:i_end,j_start:j_end]
patch = patch.reshape(patch.shape[0],-1)
cpatch = cidx[:, i_start:i_end, j_start:j_end]
cpatch = cpatch.reshape(cpatch.shape[0], -1)
patch = torch.cat((cpatch, patch), dim=1)
logits,_ = model.transformer(patch[:,:-1])
logits = logits[:, -256:, :]
logits = logits.reshape(cshape[0],16,16,-1)
logits = logits[:,local_i,local_j,:]
logits = logits/temperature
if top_k is not None:
logits = model.top_k_logits(logits, top_k)
# apply softmax to convert to probabilities
probs = torch.nn.functional.softmax(logits, dim=-1)
# sample from the distribution or take the most likely
if sample:
ix = torch.multinomial(probs, num_samples=1)
else:
_, ix = torch.topk(probs, k=1, dim=-1)
idx[:,i,j] = ix
xsample = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape)
for i in range(xsample.shape[0]):
save_image(xsample[i], os.path.join(outdir, "samples",
"{:06}.png".format(indices[i])))
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"-r",
"--resume",
type=str,
nargs="?",
help="load from logdir or checkpoint in logdir",
)
parser.add_argument(
"-b",
"--base",
nargs="*",
metavar="base_config.yaml",
help="paths to base configs. Loaded from left-to-right. "
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
default=list(),
)
parser.add_argument(
"-c",
"--config",
nargs="?",
metavar="single_config.yaml",
help="path to single config. If specified, base configs will be ignored "
"(except for the last one if left unspecified).",
const=True,
default="",
)
parser.add_argument(
"--ignore_base_data",
action="store_true",
help="Ignore data specification from base configs. Useful if you want "
"to specify a custom datasets on the command line.",
)
parser.add_argument(
"--outdir",
required=True,
type=str,
help="Where to write outputs to.",
)
parser.add_argument(
"--top_k",
type=int,
default=100,
help="Sample from among top-k predictions.",
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="Sampling temperature.",
)
return parser
def load_model_from_config(config, sd, gpu=True, eval_mode=True):
if "ckpt_path" in config.params:
print("Deleting the restore-ckpt path from the config...")
config.params.ckpt_path = None
if "downsample_cond_size" in config.params:
print("Deleting downsample-cond-size from the config and setting factor=0.5 instead...")
config.params.downsample_cond_size = -1
config.params["downsample_cond_factor"] = 0.5
try:
if "ckpt_path" in config.params.first_stage_config.params:
config.params.first_stage_config.params.ckpt_path = None
print("Deleting the first-stage restore-ckpt path from the config...")
if "ckpt_path" in config.params.cond_stage_config.params:
config.params.cond_stage_config.params.ckpt_path = None
print("Deleting the cond-stage restore-ckpt path from the config...")
except:
pass
model = instantiate_from_config(config)
if sd is not None:
missing, unexpected = model.load_state_dict(sd, strict=False)
print(f"Missing Keys in State Dict: {missing}")
print(f"Unexpected Keys in State Dict: {unexpected}")
if gpu:
model.cuda()
if eval_mode:
model.eval()
return {"model": model}
def get_data(config):
# get data
data = instantiate_from_config(config.data)
data.prepare_data()
data.setup()
return data
def load_model_and_dset(config, ckpt, gpu, eval_mode):
# get data
dsets = get_data(config) # calls data.config ...
# now load the specified checkpoint
if ckpt:
pl_sd = torch.load(ckpt, map_location="cpu")
global_step = pl_sd["global_step"]
else:
pl_sd = {"state_dict": None}
global_step = None
model = load_model_from_config(config.model,
pl_sd["state_dict"],
gpu=gpu,
eval_mode=eval_mode)["model"]
return dsets, model, global_step
if __name__ == "__main__":
sys.path.append(os.getcwd())
parser = get_parser()
opt, unknown = parser.parse_known_args()
ckpt = None
if opt.resume:
if not os.path.exists(opt.resume):
raise ValueError("Cannot find {}".format(opt.resume))
if os.path.isfile(opt.resume):
paths = opt.resume.split("/")
try:
idx = len(paths)-paths[::-1].index("logs")+1
except ValueError:
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
logdir = "/".join(paths[:idx])
ckpt = opt.resume
else:
assert os.path.isdir(opt.resume), opt.resume
logdir = opt.resume.rstrip("/")
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
print(f"logdir:{logdir}")
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
opt.base = base_configs+opt.base
if opt.config:
if type(opt.config) == str:
opt.base = [opt.config]
else:
opt.base = [opt.base[-1]]
configs = [OmegaConf.load(cfg) for cfg in opt.base]
cli = OmegaConf.from_dotlist(unknown)
if opt.ignore_base_data:
for config in configs:
if hasattr(config, "data"): del config["data"]
config = OmegaConf.merge(*configs, cli)
gpu = True
eval_mode = True
show_config = False
if show_config:
print(OmegaConf.to_container(config))
dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
print(f"Global step: {global_step}")
outdir = os.path.join(opt.outdir, "{:06}_{}_{}".format(global_step,
opt.top_k,
opt.temperature))
os.makedirs(outdir, exist_ok=True)
print("Writing samples to ", outdir)
for k in ["originals", "reconstructions", "samples"]:
os.makedirs(os.path.join(outdir, k), exist_ok=True)
run_conditional(model, dsets, outdir, opt.top_k, opt.temperature)
import glob
import os
import sys
from itertools import product
from pathlib import Path
from typing import Literal, List, Optional, Tuple
import numpy as np
import torch
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from torch import Tensor
from torchvision.utils import save_image
from tqdm import tqdm
from scripts.make_samples import get_parser, load_model_and_dset
from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
from taming.data.helper_types import BoundingBox, Annotation
from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
from taming.models.cond_transformer import Net2NetTransformer
seed_everything(42424242)
device: Literal['cuda', 'cpu'] = 'cuda'
first_stage_factor = 16
trained_on_res = 256
def _helper(coord: int, coord_max: int, coord_window: int) -> (int, int):
assert 0 <= coord < coord_max
coord_desired_center = (coord_window - 1) // 2
return np.clip(coord - coord_desired_center, 0, coord_max - coord_window)
def get_crop_coordinates(x: int, y: int) -> BoundingBox:
WIDTH, HEIGHT = desired_z_shape[1], desired_z_shape[0]
x0 = _helper(x, WIDTH, first_stage_factor) / WIDTH
y0 = _helper(y, HEIGHT, first_stage_factor) / HEIGHT
w = first_stage_factor / WIDTH
h = first_stage_factor / HEIGHT
return x0, y0, w, h
def get_z_indices_crop_out(z_indices: Tensor, predict_x: int, predict_y: int) -> Tensor:
WIDTH, HEIGHT = desired_z_shape[1], desired_z_shape[0]
x0 = _helper(predict_x, WIDTH, first_stage_factor)
y0 = _helper(predict_y, HEIGHT, first_stage_factor)
no_images = z_indices.shape[0]
cut_out_1 = z_indices[:, y0:predict_y, x0:x0+first_stage_factor].reshape((no_images, -1))
cut_out_2 = z_indices[:, predict_y, x0:predict_x]
return torch.cat((cut_out_1, cut_out_2), dim=1)
@torch.no_grad()
def sample(model: Net2NetTransformer, annotations: List[Annotation], dataset: AnnotatedObjectsDataset,
conditional_builder: ObjectsCenterPointsConditionalBuilder, no_samples: int,
temperature: float, top_k: int) -> Tensor:
x_max, y_max = desired_z_shape[1], desired_z_shape[0]
annotations = [a._replace(category_no=dataset.get_category_number(a.category_id)) for a in annotations]
recompute_conditional = any((desired_resolution[0] > trained_on_res, desired_resolution[1] > trained_on_res))
if not recompute_conditional:
crop_coordinates = get_crop_coordinates(0, 0)
conditional_indices = conditional_builder.build(annotations, crop_coordinates)
c_indices = conditional_indices.to(device).repeat(no_samples, 1)
z_indices = torch.zeros((no_samples, 0), device=device).long()
output_indices = model.sample(z_indices, c_indices, steps=x_max*y_max, temperature=temperature,
sample=True, top_k=top_k)
else:
output_indices = torch.zeros((no_samples, y_max, x_max), device=device).long()
for predict_y, predict_x in tqdm(product(range(y_max), range(x_max)), desc='sampling_image', total=x_max*y_max):
crop_coordinates = get_crop_coordinates(predict_x, predict_y)
z_indices = get_z_indices_crop_out(output_indices, predict_x, predict_y)
conditional_indices = conditional_builder.build(annotations, crop_coordinates)
c_indices = conditional_indices.to(device).repeat(no_samples, 1)
new_index = model.sample(z_indices, c_indices, steps=1, temperature=temperature, sample=True, top_k=top_k)
output_indices[:, predict_y, predict_x] = new_index[:, -1]
z_shape = (
no_samples,
model.first_stage_model.quantize.e_dim, # codebook embed_dim
desired_z_shape[0], # z_height
desired_z_shape[1] # z_width
)
x_sample = model.decode_to_img(output_indices, z_shape) * 0.5 + 0.5
x_sample = x_sample.to('cpu')
plotter = conditional_builder.plot
figure_size = (x_sample.shape[2], x_sample.shape[3])
scene_graph = conditional_builder.build(annotations, (0., 0., 1., 1.))
plot = plotter(scene_graph, dataset.get_textual_label_for_category_no, figure_size)
return torch.cat((x_sample, plot.unsqueeze(0)))
def get_resolution(resolution_str: str) -> (Tuple[int, int], Tuple[int, int]):
if not resolution_str.count(',') == 1:
raise ValueError("Give resolution as in 'height,width'")
res_h, res_w = resolution_str.split(',')
res_h = max(int(res_h), trained_on_res)
res_w = max(int(res_w), trained_on_res)
z_h = int(round(res_h/first_stage_factor))
z_w = int(round(res_w/first_stage_factor))
return (z_h, z_w), (z_h*first_stage_factor, z_w*first_stage_factor)
def add_arg_to_parser(parser):
parser.add_argument(
"-R",
"--resolution",
type=str,
default='256,256',
help=f"give resolution in multiples of {first_stage_factor}, default is '256,256'",
)
parser.add_argument(
"-C",
"--conditional",
type=str,
default='objects_bbox',
help=f"objects_bbox or objects_center_points",
)
parser.add_argument(
"-N",
"--n_samples_per_layout",
type=int,
default=4,
help=f"how many samples to generate per layout",
)
return parser
if __name__ == "__main__":
sys.path.append(os.getcwd())
parser = get_parser()
parser = add_arg_to_parser(parser)
opt, unknown = parser.parse_known_args()
ckpt = None
if opt.resume:
if not os.path.exists(opt.resume):
raise ValueError("Cannot find {}".format(opt.resume))
if os.path.isfile(opt.resume):
paths = opt.resume.split("/")
try:
idx = len(paths)-paths[::-1].index("logs")+1
except ValueError:
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
logdir = "/".join(paths[:idx])
ckpt = opt.resume
else:
assert os.path.isdir(opt.resume), opt.resume
logdir = opt.resume.rstrip("/")
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
print(f"logdir:{logdir}")
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
opt.base = base_configs+opt.base
if opt.config:
if type(opt.config) == str:
opt.base = [opt.config]
else:
opt.base = [opt.base[-1]]
configs = [OmegaConf.load(cfg) for cfg in opt.base]
cli = OmegaConf.from_dotlist(unknown)
if opt.ignore_base_data:
for config in configs:
if hasattr(config, "data"):
del config["data"]
config = OmegaConf.merge(*configs, cli)
desired_z_shape, desired_resolution = get_resolution(opt.resolution)
conditional = opt.conditional
print(ckpt)
gpu = True
eval_mode = True
show_config = False
if show_config:
print(OmegaConf.to_container(config))
dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
print(f"Global step: {global_step}")
data_loader = dsets.val_dataloader()
print(dsets.datasets["validation"].conditional_builders)
conditional_builder = dsets.datasets["validation"].conditional_builders[conditional]
outdir = Path(opt.outdir).joinpath(f"{global_step:06}_{opt.top_k}_{opt.temperature}")
outdir.mkdir(exist_ok=True, parents=True)
print("Writing samples to ", outdir)
p_bar_1 = tqdm(enumerate(iter(data_loader)), desc='batch', total=len(data_loader))
for batch_no, batch in p_bar_1:
save_img: Optional[Tensor] = None
for i, annotations in tqdm(enumerate(batch['annotations']), desc='within_batch', total=data_loader.batch_size):
imgs = sample(model, annotations, dsets.datasets["validation"], conditional_builder,
opt.n_samples_per_layout, opt.temperature, opt.top_k)
save_image(imgs, outdir.joinpath(f'{batch_no:04}_{i:02}.png'), n_row=opt.n_samples_per_layout+1)
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
from setuptools import setup, find_packages
setup(
name='taming-transformers',
version='0.0.1',
description='Taming Transformers for High-Resolution Image Synthesis',
packages=find_packages(),
install_requires=[
'torch',
'numpy',
'tqdm',
],
)
import os
import numpy as np
import cv2
import albumentations
from PIL import Image
from torch.utils.data import Dataset
from taming.data.sflckr import SegmentationBase # for examples included in repo
class Examples(SegmentationBase):
def __init__(self, size=256, random_crop=False, interpolation="bicubic"):
super().__init__(data_csv="data/ade20k_examples.txt",
data_root="data/ade20k_images",
segmentation_root="data/ade20k_segmentations",
size=size, random_crop=random_crop,
interpolation=interpolation,
n_labels=151, shift_segmentation=False)
# With semantic map and scene label
class ADE20kBase(Dataset):
def __init__(self, config=None, size=None, random_crop=False, interpolation="bicubic", crop_size=None):
self.split = self.get_split()
self.n_labels = 151 # unknown + 150
self.data_csv = {"train": "data/ade20k_train.txt",
"validation": "data/ade20k_test.txt"}[self.split]
self.data_root = "data/ade20k_root"
with open(os.path.join(self.data_root, "sceneCategories.txt"), "r") as f:
self.scene_categories = f.read().splitlines()
self.scene_categories = dict(line.split() for line in self.scene_categories)
with open(self.data_csv, "r") as f:
self.image_paths = f.read().splitlines()
self._length = len(self.image_paths)
self.labels = {
"relative_file_path_": [l for l in self.image_paths],
"file_path_": [os.path.join(self.data_root, "images", l)
for l in self.image_paths],
"relative_segmentation_path_": [l.replace(".jpg", ".png")
for l in self.image_paths],
"segmentation_path_": [os.path.join(self.data_root, "annotations",
l.replace(".jpg", ".png"))
for l in self.image_paths],
"scene_category": [self.scene_categories[l.split("/")[1].replace(".jpg", "")]
for l in self.image_paths],
}
size = None if size is not None and size<=0 else size
self.size = size
if crop_size is None:
self.crop_size = size if size is not None else None
else:
self.crop_size = crop_size
if self.size is not None:
self.interpolation = interpolation
self.interpolation = {
"nearest": cv2.INTER_NEAREST,
"bilinear": cv2.INTER_LINEAR,
"bicubic": cv2.INTER_CUBIC,
"area": cv2.INTER_AREA,
"lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
interpolation=self.interpolation)
self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
interpolation=cv2.INTER_NEAREST)
if crop_size is not None:
self.center_crop = not random_crop
if self.center_crop:
self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
else:
self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
self.preprocessor = self.cropper
def __len__(self):
return self._length
def __getitem__(self, i):
example = dict((k, self.labels[k][i]) for k in self.labels)
image = Image.open(example["file_path_"])
if not image.mode == "RGB":
image = image.convert("RGB")
image = np.array(image).astype(np.uint8)
if self.size is not None:
image = self.image_rescaler(image=image)["image"]
segmentation = Image.open(example["segmentation_path_"])
segmentation = np.array(segmentation).astype(np.uint8)
if self.size is not None:
segmentation = self.segmentation_rescaler(image=segmentation)["image"]
if self.size is not None:
processed = self.preprocessor(image=image, mask=segmentation)
else:
processed = {"image": image, "mask": segmentation}
example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
segmentation = processed["mask"]
onehot = np.eye(self.n_labels)[segmentation]
example["segmentation"] = onehot
return example
class ADE20kTrain(ADE20kBase):
# default to random_crop=True
def __init__(self, config=None, size=None, random_crop=True, interpolation="bicubic", crop_size=None):
super().__init__(config=config, size=size, random_crop=random_crop,
interpolation=interpolation, crop_size=crop_size)
def get_split(self):
return "train"
class ADE20kValidation(ADE20kBase):
def get_split(self):
return "validation"
if __name__ == "__main__":
dset = ADE20kValidation()
ex = dset[0]
for k in ["image", "scene_category", "segmentation"]:
print(type(ex[k]))
try:
print(ex[k].shape)
except:
print(ex[k])
import json
from itertools import chain
from pathlib import Path
from typing import Iterable, Dict, List, Callable, Any
from collections import defaultdict
from tqdm import tqdm
from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
from taming.data.helper_types import Annotation, ImageDescription, Category
COCO_PATH_STRUCTURE = {
'train': {
'top_level': '',
'instances_annotations': 'annotations/instances_train2017.json',
'stuff_annotations': 'annotations/stuff_train2017.json',
'files': 'train2017'
},
'validation': {
'top_level': '',
'instances_annotations': 'annotations/instances_val2017.json',
'stuff_annotations': 'annotations/stuff_val2017.json',
'files': 'val2017'
}
}
def load_image_descriptions(description_json: List[Dict]) -> Dict[str, ImageDescription]:
return {
str(img['id']): ImageDescription(
id=img['id'],
license=img.get('license'),
file_name=img['file_name'],
coco_url=img['coco_url'],
original_size=(img['width'], img['height']),
date_captured=img.get('date_captured'),
flickr_url=img.get('flickr_url')
)
for img in description_json
}
def load_categories(category_json: Iterable) -> Dict[str, Category]:
return {str(cat['id']): Category(id=str(cat['id']), super_category=cat['supercategory'], name=cat['name'])
for cat in category_json if cat['name'] != 'other'}
def load_annotations(annotations_json: List[Dict], image_descriptions: Dict[str, ImageDescription],
category_no_for_id: Callable[[str], int], split: str) -> Dict[str, List[Annotation]]:
annotations = defaultdict(list)
total = sum(len(a) for a in annotations_json)
for ann in tqdm(chain(*annotations_json), f'Loading {split} annotations', total=total):
image_id = str(ann['image_id'])
if image_id not in image_descriptions:
raise ValueError(f'image_id [{image_id}] has no image description.')
category_id = ann['category_id']
try:
category_no = category_no_for_id(str(category_id))
except KeyError:
continue
width, height = image_descriptions[image_id].original_size
bbox = (ann['bbox'][0] / width, ann['bbox'][1] / height, ann['bbox'][2] / width, ann['bbox'][3] / height)
annotations[image_id].append(
Annotation(
id=ann['id'],
area=bbox[2]*bbox[3], # use bbox area
is_group_of=ann['iscrowd'],
image_id=ann['image_id'],
bbox=bbox,
category_id=str(category_id),
category_no=category_no
)
)
return dict(annotations)
class AnnotatedObjectsCoco(AnnotatedObjectsDataset):
def __init__(self, use_things: bool = True, use_stuff: bool = True, **kwargs):
"""
@param data_path: is the path to the following folder structure:
coco/
├── annotations
│ ├── instances_train2017.json
│ ├── instances_val2017.json
│ ├── stuff_train2017.json
│ └── stuff_val2017.json
├── train2017
│ ├── 000000000009.jpg
│ ├── 000000000025.jpg
│ └── ...
├── val2017
│ ├── 000000000139.jpg
│ ├── 000000000285.jpg
│ └── ...
@param: split: one of 'train' or 'validation'
@param: desired image size (give square images)
"""
super().__init__(**kwargs)
self.use_things = use_things
self.use_stuff = use_stuff
with open(self.paths['instances_annotations']) as f:
inst_data_json = json.load(f)
with open(self.paths['stuff_annotations']) as f:
stuff_data_json = json.load(f)
category_jsons = []
annotation_jsons = []
if self.use_things:
category_jsons.append(inst_data_json['categories'])
annotation_jsons.append(inst_data_json['annotations'])
if self.use_stuff:
category_jsons.append(stuff_data_json['categories'])
annotation_jsons.append(stuff_data_json['annotations'])
self.categories = load_categories(chain(*category_jsons))
self.filter_categories()
self.setup_category_id_and_number()
self.image_descriptions = load_image_descriptions(inst_data_json['images'])
annotations = load_annotations(annotation_jsons, self.image_descriptions, self.get_category_number, self.split)
self.annotations = self.filter_object_number(annotations, self.min_object_area,
self.min_objects_per_image, self.max_objects_per_image)
self.image_ids = list(self.annotations.keys())
self.clean_up_annotations_and_image_descriptions()
def get_path_structure(self) -> Dict[str, str]:
if self.split not in COCO_PATH_STRUCTURE:
raise ValueError(f'Split [{self.split} does not exist for COCO data.]')
return COCO_PATH_STRUCTURE[self.split]
def get_image_path(self, image_id: str) -> Path:
return self.paths['files'].joinpath(self.image_descriptions[str(image_id)].file_name)
def get_image_description(self, image_id: str) -> Dict[str, Any]:
# noinspection PyProtectedMember
return self.image_descriptions[image_id]._asdict()
from pathlib import Path
from typing import Optional, List, Callable, Dict, Any, Union
import warnings
import PIL.Image as pil_image
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import transforms
from taming.data.conditional_builder.objects_bbox import ObjectsBoundingBoxConditionalBuilder
from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
from taming.data.conditional_builder.utils import load_object_from_string
from taming.data.helper_types import BoundingBox, CropMethodType, Image, Annotation, SplitType
from taming.data.image_transforms import CenterCropReturnCoordinates, RandomCrop1dReturnCoordinates, \
Random2dCropReturnCoordinates, RandomHorizontalFlipReturn, convert_pil_to_tensor
class AnnotatedObjectsDataset(Dataset):
def __init__(self, data_path: Union[str, Path], split: SplitType, keys: List[str], target_image_size: int,
min_object_area: float, min_objects_per_image: int, max_objects_per_image: int,
crop_method: CropMethodType, random_flip: bool, no_tokens: int, use_group_parameter: bool,
encode_crop: bool, category_allow_list_target: str = "", category_mapping_target: str = "",
no_object_classes: Optional[int] = None):
self.data_path = data_path
self.split = split
self.keys = keys
self.target_image_size = target_image_size
self.min_object_area = min_object_area
self.min_objects_per_image = min_objects_per_image
self.max_objects_per_image = max_objects_per_image
self.crop_method = crop_method
self.random_flip = random_flip
self.no_tokens = no_tokens
self.use_group_parameter = use_group_parameter
self.encode_crop = encode_crop
self.annotations = None
self.image_descriptions = None
self.categories = None
self.category_ids = None
self.category_number = None
self.image_ids = None
self.transform_functions: List[Callable] = self.setup_transform(target_image_size, crop_method, random_flip)
self.paths = self.build_paths(self.data_path)
self._conditional_builders = None
self.category_allow_list = None
if category_allow_list_target:
allow_list = load_object_from_string(category_allow_list_target)
self.category_allow_list = {name for name, _ in allow_list}
self.category_mapping = {}
if category_mapping_target:
self.category_mapping = load_object_from_string(category_mapping_target)
self.no_object_classes = no_object_classes
def build_paths(self, top_level: Union[str, Path]) -> Dict[str, Path]:
top_level = Path(top_level)
sub_paths = {name: top_level.joinpath(sub_path) for name, sub_path in self.get_path_structure().items()}
for path in sub_paths.values():
if not path.exists():
raise FileNotFoundError(f'{type(self).__name__} data structure error: [{path}] does not exist.')
return sub_paths
@staticmethod
def load_image_from_disk(path: Path) -> Image:
return pil_image.open(path).convert('RGB')
@staticmethod
def setup_transform(target_image_size: int, crop_method: CropMethodType, random_flip: bool):
transform_functions = []
if crop_method == 'none':
transform_functions.append(transforms.Resize((target_image_size, target_image_size)))
elif crop_method == 'center':
transform_functions.extend([
transforms.Resize(target_image_size),
CenterCropReturnCoordinates(target_image_size)
])
elif crop_method == 'random-1d':
transform_functions.extend([
transforms.Resize(target_image_size),
RandomCrop1dReturnCoordinates(target_image_size)
])
elif crop_method == 'random-2d':
transform_functions.extend([
Random2dCropReturnCoordinates(target_image_size),
transforms.Resize(target_image_size)
])
elif crop_method is None:
return None
else:
raise ValueError(f'Received invalid crop method [{crop_method}].')
if random_flip:
transform_functions.append(RandomHorizontalFlipReturn())
transform_functions.append(transforms.Lambda(lambda x: x / 127.5 - 1.))
return transform_functions
def image_transform(self, x: Tensor) -> (Optional[BoundingBox], Optional[bool], Tensor):
crop_bbox = None
flipped = None
for t in self.transform_functions:
if isinstance(t, (RandomCrop1dReturnCoordinates, CenterCropReturnCoordinates, Random2dCropReturnCoordinates)):
crop_bbox, x = t(x)
elif isinstance(t, RandomHorizontalFlipReturn):
flipped, x = t(x)
else:
x = t(x)
return crop_bbox, flipped, x
@property
def no_classes(self) -> int:
return self.no_object_classes if self.no_object_classes else len(self.categories)
@property
def conditional_builders(self) -> ObjectsCenterPointsConditionalBuilder:
# cannot set this up in init because no_classes is only known after loading data in init of superclass
if self._conditional_builders is None:
self._conditional_builders = {
'objects_center_points': ObjectsCenterPointsConditionalBuilder(
self.no_classes,
self.max_objects_per_image,
self.no_tokens,
self.encode_crop,
self.use_group_parameter,
getattr(self, 'use_additional_parameters', False)
),
'objects_bbox': ObjectsBoundingBoxConditionalBuilder(
self.no_classes,
self.max_objects_per_image,
self.no_tokens,
self.encode_crop,
self.use_group_parameter,
getattr(self, 'use_additional_parameters', False)
)
}
return self._conditional_builders
def filter_categories(self) -> None:
if self.category_allow_list:
self.categories = {id_: cat for id_, cat in self.categories.items() if cat.name in self.category_allow_list}
if self.category_mapping:
self.categories = {id_: cat for id_, cat in self.categories.items() if cat.id not in self.category_mapping}
def setup_category_id_and_number(self) -> None:
self.category_ids = list(self.categories.keys())
self.category_ids.sort()
if '/m/01s55n' in self.category_ids:
self.category_ids.remove('/m/01s55n')
self.category_ids.append('/m/01s55n')
self.category_number = {category_id: i for i, category_id in enumerate(self.category_ids)}
if self.category_allow_list is not None and self.category_mapping is None \
and len(self.category_ids) != len(self.category_allow_list):
warnings.warn('Unexpected number of categories: Mismatch with category_allow_list. '
'Make sure all names in category_allow_list exist.')
def clean_up_annotations_and_image_descriptions(self) -> None:
image_id_set = set(self.image_ids)
self.annotations = {k: v for k, v in self.annotations.items() if k in image_id_set}
self.image_descriptions = {k: v for k, v in self.image_descriptions.items() if k in image_id_set}
@staticmethod
def filter_object_number(all_annotations: Dict[str, List[Annotation]], min_object_area: float,
min_objects_per_image: int, max_objects_per_image: int) -> Dict[str, List[Annotation]]:
filtered = {}
for image_id, annotations in all_annotations.items():
annotations_with_min_area = [a for a in annotations if a.area > min_object_area]
if min_objects_per_image <= len(annotations_with_min_area) <= max_objects_per_image:
filtered[image_id] = annotations_with_min_area
return filtered
def __len__(self):
return len(self.image_ids)
def __getitem__(self, n: int) -> Dict[str, Any]:
image_id = self.get_image_id(n)
sample = self.get_image_description(image_id)
sample['annotations'] = self.get_annotation(image_id)
if 'image' in self.keys:
sample['image_path'] = str(self.get_image_path(image_id))
sample['image'] = self.load_image_from_disk(sample['image_path'])
sample['image'] = convert_pil_to_tensor(sample['image'])
sample['crop_bbox'], sample['flipped'], sample['image'] = self.image_transform(sample['image'])
sample['image'] = sample['image'].permute(1, 2, 0)
for conditional, builder in self.conditional_builders.items():
if conditional in self.keys:
sample[conditional] = builder.build(sample['annotations'], sample['crop_bbox'], sample['flipped'])
if self.keys:
# only return specified keys
sample = {key: sample[key] for key in self.keys}
return sample
def get_image_id(self, no: int) -> str:
return self.image_ids[no]
def get_annotation(self, image_id: str) -> str:
return self.annotations[image_id]
def get_textual_label_for_category_id(self, category_id: str) -> str:
return self.categories[category_id].name
def get_textual_label_for_category_no(self, category_no: int) -> str:
return self.categories[self.get_category_id(category_no)].name
def get_category_number(self, category_id: str) -> int:
return self.category_number[category_id]
def get_category_id(self, category_no: int) -> str:
return self.category_ids[category_no]
def get_image_description(self, image_id: str) -> Dict[str, Any]:
raise NotImplementedError()
def get_path_structure(self):
raise NotImplementedError
def get_image_path(self, image_id: str) -> Path:
raise NotImplementedError
from collections import defaultdict
from csv import DictReader, reader as TupleReader
from pathlib import Path
from typing import Dict, List, Any
import warnings
from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
from taming.data.helper_types import Annotation, Category
from tqdm import tqdm
OPEN_IMAGES_STRUCTURE = {
'train': {
'top_level': '',
'class_descriptions': 'class-descriptions-boxable.csv',
'annotations': 'oidv6-train-annotations-bbox.csv',
'file_list': 'train-images-boxable.csv',
'files': 'train'
},
'validation': {
'top_level': '',
'class_descriptions': 'class-descriptions-boxable.csv',
'annotations': 'validation-annotations-bbox.csv',
'file_list': 'validation-images.csv',
'files': 'validation'
},
'test': {
'top_level': '',
'class_descriptions': 'class-descriptions-boxable.csv',
'annotations': 'test-annotations-bbox.csv',
'file_list': 'test-images.csv',
'files': 'test'
}
}
def load_annotations(descriptor_path: Path, min_object_area: float, category_mapping: Dict[str, str],
category_no_for_id: Dict[str, int]) -> Dict[str, List[Annotation]]:
annotations: Dict[str, List[Annotation]] = defaultdict(list)
with open(descriptor_path) as file:
reader = DictReader(file)
for i, row in tqdm(enumerate(reader), total=14620000, desc='Loading OpenImages annotations'):
width = float(row['XMax']) - float(row['XMin'])
height = float(row['YMax']) - float(row['YMin'])
area = width * height
category_id = row['LabelName']
if category_id in category_mapping:
category_id = category_mapping[category_id]
if area >= min_object_area and category_id in category_no_for_id:
annotations[row['ImageID']].append(
Annotation(
id=i,
image_id=row['ImageID'],
source=row['Source'],
category_id=category_id,
category_no=category_no_for_id[category_id],
confidence=float(row['Confidence']),
bbox=(float(row['XMin']), float(row['YMin']), width, height),
area=area,
is_occluded=bool(int(row['IsOccluded'])),
is_truncated=bool(int(row['IsTruncated'])),
is_group_of=bool(int(row['IsGroupOf'])),
is_depiction=bool(int(row['IsDepiction'])),
is_inside=bool(int(row['IsInside']))
)
)
if 'train' in str(descriptor_path) and i < 14000000:
warnings.warn(f'Running with subset of Open Images. Train dataset has length [{len(annotations)}].')
return dict(annotations)
def load_image_ids(csv_path: Path) -> List[str]:
with open(csv_path) as file:
reader = DictReader(file)
return [row['image_name'] for row in reader]
def load_categories(csv_path: Path) -> Dict[str, Category]:
with open(csv_path) as file:
reader = TupleReader(file)
return {row[0]: Category(id=row[0], name=row[1], super_category=None) for row in reader}
class AnnotatedObjectsOpenImages(AnnotatedObjectsDataset):
def __init__(self, use_additional_parameters: bool, **kwargs):
"""
@param data_path: is the path to the following folder structure:
open_images/
│ oidv6-train-annotations-bbox.csv
├── class-descriptions-boxable.csv
├── oidv6-train-annotations-bbox.csv
├── test
│ ├── 000026e7ee790996.jpg
│ ├── 000062a39995e348.jpg
│ └── ...
├── test-annotations-bbox.csv
├── test-images.csv
├── train
│ ├── 000002b66c9c498e.jpg
│ ├── 000002b97e5471a0.jpg
│ └── ...
├── train-images-boxable.csv
├── validation
│ ├── 0001eeaf4aed83f9.jpg
│ ├── 0004886b7d043cfd.jpg
│ └── ...
├── validation-annotations-bbox.csv
└── validation-images.csv
@param: split: one of 'train', 'validation' or 'test'
@param: desired image size (returns square images)
"""
super().__init__(**kwargs)
self.use_additional_parameters = use_additional_parameters
self.categories = load_categories(self.paths['class_descriptions'])
self.filter_categories()
self.setup_category_id_and_number()
self.image_descriptions = {}
annotations = load_annotations(self.paths['annotations'], self.min_object_area, self.category_mapping,
self.category_number)
self.annotations = self.filter_object_number(annotations, self.min_object_area, self.min_objects_per_image,
self.max_objects_per_image)
self.image_ids = list(self.annotations.keys())
self.clean_up_annotations_and_image_descriptions()
def get_path_structure(self) -> Dict[str, str]:
if self.split not in OPEN_IMAGES_STRUCTURE:
raise ValueError(f'Split [{self.split} does not exist for Open Images data.]')
return OPEN_IMAGES_STRUCTURE[self.split]
def get_image_path(self, image_id: str) -> Path:
return self.paths['files'].joinpath(f'{image_id:0>16}.jpg')
def get_image_description(self, image_id: str) -> Dict[str, Any]:
image_path = self.get_image_path(image_id)
return {'file_path': str(image_path), 'file_name': image_path.name}
import bisect
import numpy as np
import albumentations
from PIL import Image
from torch.utils.data import Dataset, ConcatDataset
class ConcatDatasetWithIndex(ConcatDataset):
"""Modified from original pytorch code to return dataset idx"""
def __getitem__(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError("absolute value of index should not exceed dataset length")
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx][sample_idx], dataset_idx
class ImagePaths(Dataset):
def __init__(self, paths, size=None, random_crop=False, labels=None):
self.size = size
self.random_crop = random_crop
self.labels = dict() if labels is None else labels
self.labels["file_path_"] = paths
self._length = len(paths)
if self.size is not None and self.size > 0:
self.rescaler = albumentations.SmallestMaxSize(max_size = self.size)
if not self.random_crop:
self.cropper = albumentations.CenterCrop(height=self.size,width=self.size)
else:
self.cropper = albumentations.RandomCrop(height=self.size,width=self.size)
self.preprocessor = albumentations.Compose([self.rescaler, self.cropper])
else:
self.preprocessor = lambda **kwargs: kwargs
def __len__(self):
return self._length
def preprocess_image(self, image_path):
image = Image.open(image_path)
if not image.mode == "RGB":
image = image.convert("RGB")
image = np.array(image).astype(np.uint8)
image = self.preprocessor(image=image)["image"]
image = (image/127.5 - 1.0).astype(np.float32)
return image
def __getitem__(self, i):
example = dict()
example["image"] = self.preprocess_image(self.labels["file_path_"][i])
for k in self.labels:
example[k] = self.labels[k][i]
return example
class NumpyPaths(ImagePaths):
def preprocess_image(self, image_path):
image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024
image = np.transpose(image, (1,2,0))
image = Image.fromarray(image, mode="RGB")
image = np.array(image).astype(np.uint8)
image = self.preprocessor(image=image)["image"]
image = (image/127.5 - 1.0).astype(np.float32)
return image
import os
import json
import albumentations
import numpy as np
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset
from taming.data.sflckr import SegmentationBase # for examples included in repo
class Examples(SegmentationBase):
def __init__(self, size=256, random_crop=False, interpolation="bicubic"):
super().__init__(data_csv="data/coco_examples.txt",
data_root="data/coco_images",
segmentation_root="data/coco_segmentations",
size=size, random_crop=random_crop,
interpolation=interpolation,
n_labels=183, shift_segmentation=True)
class CocoBase(Dataset):
"""needed for (image, caption, segmentation) pairs"""
def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False,
crop_size=None, force_no_crop=False, given_files=None):
self.split = self.get_split()
self.size = size
if crop_size is None:
self.crop_size = size
else:
self.crop_size = crop_size
self.onehot = onehot_segmentation # return segmentation as rgb or one hot
self.stuffthing = use_stuffthing # include thing in segmentation
if self.onehot and not self.stuffthing:
raise NotImplemented("One hot mode is only supported for the "
"stuffthings version because labels are stored "
"a bit different.")
data_json = datajson
with open(data_json) as json_file:
self.json_data = json.load(json_file)
self.img_id_to_captions = dict()
self.img_id_to_filepath = dict()
self.img_id_to_segmentation_filepath = dict()
assert data_json.split("/")[-1] in ["captions_train2017.json",
"captions_val2017.json"]
if self.stuffthing:
self.segmentation_prefix = (
"data/cocostuffthings/val2017" if
data_json.endswith("captions_val2017.json") else
"data/cocostuffthings/train2017")
else:
self.segmentation_prefix = (
"data/coco/annotations/stuff_val2017_pixelmaps" if
data_json.endswith("captions_val2017.json") else
"data/coco/annotations/stuff_train2017_pixelmaps")
imagedirs = self.json_data["images"]
self.labels = {"image_ids": list()}
for imgdir in tqdm(imagedirs, desc="ImgToPath"):
self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"])
self.img_id_to_captions[imgdir["id"]] = list()
pngfilename = imgdir["file_name"].replace("jpg", "png")
self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join(
self.segmentation_prefix, pngfilename)
if given_files is not None:
if pngfilename in given_files:
self.labels["image_ids"].append(imgdir["id"])
else:
self.labels["image_ids"].append(imgdir["id"])
capdirs = self.json_data["annotations"]
for capdir in tqdm(capdirs, desc="ImgToCaptions"):
# there are in average 5 captions per image
self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]]))
self.rescaler = albumentations.SmallestMaxSize(max_size=self.size)
if self.split=="validation":
self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
else:
self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
self.preprocessor = albumentations.Compose(
[self.rescaler, self.cropper],
additional_targets={"segmentation": "image"})
if force_no_crop:
self.rescaler = albumentations.Resize(height=self.size, width=self.size)
self.preprocessor = albumentations.Compose(
[self.rescaler],
additional_targets={"segmentation": "image"})
def __len__(self):
return len(self.labels["image_ids"])
def preprocess_image(self, image_path, segmentation_path):
image = Image.open(image_path)
if not image.mode == "RGB":
image = image.convert("RGB")
image = np.array(image).astype(np.uint8)
segmentation = Image.open(segmentation_path)
if not self.onehot and not segmentation.mode == "RGB":
segmentation = segmentation.convert("RGB")
segmentation = np.array(segmentation).astype(np.uint8)
if self.onehot:
assert self.stuffthing
# stored in caffe format: unlabeled==255. stuff and thing from
# 0-181. to be compatible with the labels in
# https://github.com/nightrome/cocostuff/blob/master/labels.txt
# we shift stuffthing one to the right and put unlabeled in zero
# as long as segmentation is uint8 shifting to right handles the
# latter too
assert segmentation.dtype == np.uint8
segmentation = segmentation + 1
processed = self.preprocessor(image=image, segmentation=segmentation)
image, segmentation = processed["image"], processed["segmentation"]
image = (image / 127.5 - 1.0).astype(np.float32)
if self.onehot:
assert segmentation.dtype == np.uint8
# make it one hot
n_labels = 183
flatseg = np.ravel(segmentation)
onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool)
onehot[np.arange(flatseg.size), flatseg] = True
onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int)
segmentation = onehot
else:
segmentation = (segmentation / 127.5 - 1.0).astype(np.float32)
return image, segmentation
def __getitem__(self, i):
img_path = self.img_id_to_filepath[self.labels["image_ids"][i]]
seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]]
image, segmentation = self.preprocess_image(img_path, seg_path)
captions = self.img_id_to_captions[self.labels["image_ids"][i]]
# randomly draw one of all available captions per image
caption = captions[np.random.randint(0, len(captions))]
example = {"image": image,
"caption": [str(caption[0])],
"segmentation": segmentation,
"img_path": img_path,
"seg_path": seg_path,
"filename_": img_path.split(os.sep)[-1]
}
return example
class CocoImagesAndCaptionsTrain(CocoBase):
"""returns a pair of (image, caption)"""
def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False):
super().__init__(size=size,
dataroot="data/coco/train2017",
datajson="data/coco/annotations/captions_train2017.json",
onehot_segmentation=onehot_segmentation,
use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop)
def get_split(self):
return "train"
class CocoImagesAndCaptionsValidation(CocoBase):
"""returns a pair of (image, caption)"""
def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,
given_files=None):
super().__init__(size=size,
dataroot="data/coco/val2017",
datajson="data/coco/annotations/captions_val2017.json",
onehot_segmentation=onehot_segmentation,
use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
given_files=given_files)
def get_split(self):
return "validation"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment