import argparse, os, sys, glob, math, time import torch import numpy as np from omegaconf import OmegaConf from pathlib import Path import sys parent_dir = Path(__file__).resolve() parent_dir = parent_dir.parent.parent sys.path.append(str(parent_dir)) import streamlit as st # from streamlit import caching from PIL import Image from main import instantiate_from_config, DataModuleFromConfig from torch.utils.data import DataLoader from torch.utils.data.dataloader import default_collate rescale = lambda x: (x + 1.) / 2. def bchw_to_st(x): return rescale(x.detach().cpu().numpy().transpose(0,2,3,1)) def save_img(xstart, fname): I = (xstart.clip(0,1)[0]*255).astype(np.uint8) Image.fromarray(I).save(fname) def get_interactive_image(resize=False): image = st.file_uploader("Input", type=["jpg", "JPEG", "png"]) if image is not None: image = Image.open(image) if not image.mode == "RGB": image = image.convert("RGB") image = np.array(image).astype(np.uint8) print("upload image shape: {}".format(image.shape)) img = Image.fromarray(image) if resize: img = img.resize((256, 256)) image = np.array(img) return image def single_image_to_torch(x, permute=True): assert x is not None, "Please provide an image through the upload function" x = np.array(x) x = torch.FloatTensor(x/255.*2. - 1.)[None,...] if permute: x = x.permute(0, 3, 1, 2) return x def pad_to_M(x, M): hp = math.ceil(x.shape[2]/M)*M-x.shape[2] wp = math.ceil(x.shape[3]/M)*M-x.shape[3] x = torch.nn.functional.pad(x, (0,wp,0,hp,0,0,0,0)) return x @torch.no_grad() def run_conditional(model, dsets): if len(dsets.datasets) > 1: split = st.sidebar.radio("Split", sorted(dsets.datasets.keys())) dset = dsets.datasets[split] else: dset = next(iter(dsets.datasets.values())) batch_size = 1 start_index = st.sidebar.number_input("Example Index (Size: {})".format(len(dset)), value=0, min_value=0, max_value=len(dset)-batch_size) indices = list(range(start_index, start_index+batch_size)) example = default_collate([dset[i] for i in indices]) x = model.get_input("image", example).to(model.device) cond_key = model.cond_stage_key c = model.get_input(cond_key, example).to(model.device) scale_factor = st.sidebar.slider("Scale Factor", min_value=0.5, max_value=4.0, step=0.25, value=1.00) if scale_factor != 1.0: x = torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="bicubic") c = torch.nn.functional.interpolate(c, scale_factor=scale_factor, mode="bicubic") quant_z, z_indices = model.encode_to_z(x) quant_c, c_indices = model.encode_to_c(c) cshape = quant_z.shape xrec = model.first_stage_model.decode(quant_z) st.write("image: {}".format(x.shape)) st.image(bchw_to_st(x), clamp=True, output_format="PNG") st.write("image reconstruction: {}".format(xrec.shape)) st.image(bchw_to_st(xrec), clamp=True, output_format="PNG") if cond_key == "segmentation": # get image from segmentation mask num_classes = c.shape[1] c = torch.argmax(c, dim=1, keepdim=True) c = torch.nn.functional.one_hot(c, num_classes=num_classes) c = c.squeeze(1).permute(0, 3, 1, 2).float() c = model.cond_stage_model.to_rgb(c) st.write(f"{cond_key}: {tuple(c.shape)}") st.image(bchw_to_st(c), clamp=True, output_format="PNG") idx = z_indices half_sample = st.sidebar.checkbox("Image Completion", value=False) if half_sample: start = idx.shape[1]//2 else: start = 0 idx[:,start:] = 0 idx = idx.reshape(cshape[0],cshape[2],cshape[3]) start_i = start//cshape[3] start_j = start %cshape[3] if not half_sample and quant_z.shape == quant_c.shape: st.info("Setting idx to c_indices") idx = c_indices.clone().reshape(cshape[0],cshape[2],cshape[3]) cidx = c_indices cidx = cidx.reshape(quant_c.shape[0],quant_c.shape[2],quant_c.shape[3]) xstart = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape) st.image(bchw_to_st(xstart), clamp=True, output_format="PNG") temperature = st.number_input("Temperature", value=1.0) top_k = st.number_input("Top k", value=100) sample = st.checkbox("Sample", value=True) update_every = st.number_input("Update every", value=75) st.text(f"Sampling shape ({cshape[2]},{cshape[3]})") animate = st.checkbox("animate") if animate: import imageio outvid = "sampling.mp4" writer = imageio.get_writer(outvid, fps=25) elapsed_t = st.empty() info = st.empty() st.text("Sampled") if st.button("Sample"): output = st.empty() start_t = time.time() for i in range(start_i,cshape[2]-0): if i <= 8: local_i = i elif cshape[2]-i < 8: local_i = 16-(cshape[2]-i) else: local_i = 8 for j in range(start_j,cshape[3]-0): if j <= 8: local_j = j elif cshape[3]-j < 8: local_j = 16-(cshape[3]-j) else: local_j = 8 i_start = i-local_i i_end = i_start+16 j_start = j-local_j j_end = j_start+16 elapsed_t.text(f"Time: {time.time() - start_t} seconds") info.text(f"Step: ({i},{j}) | Local: ({local_i},{local_j}) | Crop: ({i_start}:{i_end},{j_start}:{j_end})") patch = idx[:,i_start:i_end,j_start:j_end] patch = patch.reshape(patch.shape[0],-1) cpatch = cidx[:, i_start:i_end, j_start:j_end] cpatch = cpatch.reshape(cpatch.shape[0], -1) patch = torch.cat((cpatch, patch), dim=1) logits,_ = model.transformer(patch[:,:-1]) logits = logits[:, -256:, :] logits = logits.reshape(cshape[0],16,16,-1) logits = logits[:,local_i,local_j,:] logits = logits/temperature if top_k is not None: logits = model.top_k_logits(logits, top_k) # apply softmax to convert to probabilities probs = torch.nn.functional.softmax(logits, dim=-1) # sample from the distribution or take the most likely if sample: ix = torch.multinomial(probs, num_samples=1) else: _, ix = torch.topk(probs, k=1, dim=-1) idx[:,i,j] = ix if (i*cshape[3]+j)%update_every==0: xstart = model.decode_to_img(idx[:, :cshape[2], :cshape[3]], cshape,) xstart = bchw_to_st(xstart) output.image(xstart, clamp=True, output_format="PNG") if animate: writer.append_data((xstart[0]*255).clip(0, 255).astype(np.uint8)) xstart = model.decode_to_img(idx[:,:cshape[2],:cshape[3]], cshape) xstart = bchw_to_st(xstart) output.image(xstart, clamp=True, output_format="PNG") #save_img(xstart, "full_res_sample.png") if animate: writer.close() st.video(outvid) def get_parser(): parser = argparse.ArgumentParser() parser.add_argument( "-r", "--resume", type=str, nargs="?", help="load from logdir or checkpoint in logdir", ) parser.add_argument( "-b", "--base", nargs="*", metavar="base_config.yaml", help="paths to base configs. Loaded from left-to-right. " "Parameters can be overwritten or added with command-line options of the form `--key value`.", default=list(), ) parser.add_argument( "-c", "--config", nargs="?", metavar="single_config.yaml", help="path to single config. If specified, base configs will be ignored " "(except for the last one if left unspecified).", const=True, default="", ) parser.add_argument( "--ignore_base_data", action="store_true", help="Ignore data specification from base configs. Useful if you want " "to specify a custom datasets on the command line.", ) return parser def load_model_from_config(config, sd, gpu=True, eval_mode=True): if "ckpt_path" in config.params: st.warning("Deleting the restore-ckpt path from the config...") config.params.ckpt_path = None if "downsample_cond_size" in config.params: st.warning("Deleting downsample-cond-size from the config and setting factor=0.5 instead...") config.params.downsample_cond_size = -1 config.params["downsample_cond_factor"] = 0.5 try: if "ckpt_path" in config.params.first_stage_config.params: config.params.first_stage_config.params.ckpt_path = None st.warning("Deleting the first-stage restore-ckpt path from the config...") if "ckpt_path" in config.params.cond_stage_config.params: config.params.cond_stage_config.params.ckpt_path = None st.warning("Deleting the cond-stage restore-ckpt path from the config...") except: pass model = instantiate_from_config(config) if sd is not None: missing, unexpected = model.load_state_dict(sd, strict=False) st.info(f"Missing Keys in State Dict: {missing}") st.info(f"Unexpected Keys in State Dict: {unexpected}") if gpu: model.cuda() if eval_mode: model.eval() return {"model": model} def get_data(config): # get data data = instantiate_from_config(config.data) data.prepare_data() data.setup() return data @st.cache(allow_output_mutation=True, suppress_st_warning=True) def load_model_and_dset(config, ckpt, gpu, eval_mode): # get data dsets = get_data(config) # calls data.config ... # now load the specified checkpoint if ckpt: pl_sd = torch.load(ckpt, map_location="cpu") global_step = pl_sd["global_step"] else: pl_sd = {"state_dict": None} global_step = None model = load_model_from_config(config.model, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode)["model"] return dsets, model, global_step if __name__ == "__main__": sys.path.append(os.getcwd()) parser = get_parser() opt, unknown = parser.parse_known_args() ckpt = None if opt.resume: if not os.path.exists(opt.resume): raise ValueError("Cannot find {}".format(opt.resume)) if os.path.isfile(opt.resume): paths = opt.resume.split("/") try: idx = len(paths)-paths[::-1].index("logs")+1 except ValueError: idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt logdir = "/".join(paths[:idx]) ckpt = opt.resume else: assert os.path.isdir(opt.resume), opt.resume logdir = opt.resume.rstrip("/") ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") print(f"logdir:{logdir}") base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml"))) opt.base = base_configs+opt.base if opt.config: if type(opt.config) == str: opt.base = [opt.config] else: opt.base = [opt.base[-1]] configs = [OmegaConf.load(cfg) for cfg in opt.base] cli = OmegaConf.from_dotlist(unknown) if opt.ignore_base_data: for config in configs: if hasattr(config, "data"): del config["data"] config = OmegaConf.merge(*configs, cli) st.sidebar.text(ckpt) gs = st.sidebar.empty() gs.text(f"Global step: ?") st.sidebar.text("Options") #gpu = st.sidebar.checkbox("GPU", value=True) gpu = True #eval_mode = st.sidebar.checkbox("Eval Mode", value=True) eval_mode = True #show_config = st.sidebar.checkbox("Show Config", value=False) show_config = False if show_config: st.info("Checkpoint: {}".format(ckpt)) st.json(OmegaConf.to_container(config)) dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode) gs.text(f"Global step: {global_step}") run_conditional(model, dsets)