# -------------------------------------------------------- # Images Speak in Images: A Generalist Painter for In-Context Visual Learning (https://arxiv.org/abs/2212.02499) # Github source: https://github.com/baaivision/Painter # Copyright (c) 2022 Beijing Academy of Artificial Intelligence (BAAI) # Licensed under The MIT License [see LICENSE for details] # By Xinlong Wang, Wen Wang # Based on MAE, BEiT, detectron2, Mask2Former, bts, mmcv, mmdetetection, mmpose, MIRNet, MPRNet, and Uformer codebases # --------------------------------------------------------' import sys import os import warnings import requests import argparse import torch import torch.nn.functional as F import numpy as np import glob import tqdm import matplotlib.pyplot as plt from PIL import Image import torch.distributed as dist from torch.utils.data import DataLoader, DistributedSampler sys.path.append('.') import models_painter from util.ddp_utils import DatasetTest from util import ddp_utils imagenet_mean = np.array([0.485, 0.456, 0.406]) imagenet_std = np.array([0.229, 0.224, 0.225]) def get_args_parser(): parser = argparse.ArgumentParser('COCO Pose Estimation', add_help=False) parser.add_argument('--ckpt_path', type=str, help='path to ckpt', default='') parser.add_argument('--model', type=str, help='dir to ckpt', default='painter_vit_large_patch16_input896x448_win_dec64_8glb_sl1') parser.add_argument('--prompt', type=str, help='prompt image in train set', default='000000000165_box0') parser.add_argument('--input_size', type=int, default=448) parser.add_argument('--flip_test', action='store_true', help='use offline bbox') # distributed training parameters parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') return parser.parse_args() def prepare_model(chkpt_dir, arch, args=None): # build model model = getattr(models_painter, arch)() model.to("cuda") model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model_without_ddp = model.module # load model checkpoint = torch.load(chkpt_dir, map_location='cpu') msg = model_without_ddp.load_state_dict(checkpoint['model'], strict=False) print(msg) return model def run_one_image(img, tgt, size, model, out_path, device): x = torch.tensor(img) x = x.unsqueeze(dim=0) x = torch.einsum('nhwc->nchw', x) tgt = torch.tensor(tgt) tgt = tgt.unsqueeze(dim=0) tgt = torch.einsum('nhwc->nchw', tgt) bool_masked_pos = torch.zeros(model.module.patch_embed.num_patches) bool_masked_pos[model.module.patch_embed.num_patches//2:] = 1 bool_masked_pos = bool_masked_pos.unsqueeze(dim=0) valid = torch.ones_like(tgt) loss, y, mask = model(x.float().to(device), tgt.float().to(device), bool_masked_pos.to(device), valid.float().to(device)) y = model.module.unpatchify(y) y = torch.einsum('nchw->nhwc', y).detach().cpu() output = y[0, y.shape[1]//2:, :, :] output = torch.clip((output * imagenet_std + imagenet_mean) * 255, 0, 255) output = F.interpolate(output[None, ...].permute(0, 3, 1, 2), size=[size[1], size[0]], mode='nearest').permute(0, 2, 3, 1)[0] output = output.int() output = Image.fromarray(output.numpy().astype(np.uint8)) output.save(out_path) if __name__ == '__main__': dataset_dir = "datasets/" args = get_args_parser() args = ddp_utils.init_distributed_mode(args) device = torch.device("cuda") ckpt_path = args.ckpt_path model = args.model prompt = args.prompt input_size = args.input_size path_splits = ckpt_path.split('/') ckpt_dir, ckpt_file = path_splits[-2], path_splits[-1] dst_dir = os.path.join('models_inference', ckpt_dir.split('/')[-1], "coco_pose_inference_{}_{}".format(ckpt_path, os.path.basename(prompt).split(".")[0])) if args.flip_test: dst_dir = dst_dir + "_flip" if ddp_utils.get_rank() == 0: if not os.path.exists(dst_dir): os.makedirs(dst_dir) print("output_dir: {}".format(dst_dir)) model_painter = prepare_model(ckpt_path, model, args) print('Model loaded.') img_src_dir = dataset_dir + "coco_pose/data_pair/test_256x192" if args.flip_test: img_src_dir += "_flip" dataset_val = DatasetTest(img_src_dir, input_size, ext_list=('*.png',)) sampler_val = DistributedSampler(dataset_val, shuffle=False) data_loader_val = DataLoader(dataset_val, batch_size=1, sampler=sampler_val, drop_last=False, collate_fn=ddp_utils.collate_fn, num_workers=2) img2_path = dataset_dir + "coco_pose/data_pair/train_256x192_aug0/{}_image.png".format(prompt) tgt2_path = dataset_dir + "coco_pose/data_pair/train_256x192_aug0/{}_label.png".format(prompt) # load the shared prompt image pair img2 = Image.open(img2_path).convert("RGB") img2 = img2.resize((input_size, input_size)) img2 = np.array(img2) / 255. tgt2 = Image.open(tgt2_path) tgt2 = tgt2.resize((input_size, input_size)) tgt2 = np.array(tgt2) / 255. model_painter.eval() for data in tqdm.tqdm(data_loader_val): """ Load an image """ assert len(data) == 1 img, img_path, size = data[0] img_name = os.path.basename(img_path) out_path = os.path.join(dst_dir, img_name.replace('.jpg', '.png')) img = np.concatenate((img2, img), axis=0) assert img.shape == (input_size * 2, input_size, 3) # normalize by ImageNet mean and std img = img - imagenet_mean img = img / imagenet_std tgt = tgt2 # tgt is not available tgt = np.concatenate((tgt2, tgt), axis=0) assert tgt.shape == (input_size * 2, input_size, 3) # normalize by ImageNet mean and std tgt = tgt - imagenet_mean tgt = tgt / imagenet_std # make random mask reproducible (comment out to make it change) torch.manual_seed(2) run_one_image(img, tgt, size, model_painter, out_path, device)