Commit 1ad55bb4 authored by mashun1's avatar mashun1
Browse files

i2vgen-xl

parents
Pipeline #819 canceled with stages
import os
import sys
import json
import torch
import imageio
import numpy as np
import os.path as osp
sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2]))
from thop import profile
from ptflops import get_model_complexity_info
import artist.data as data
from tools.modules.config import cfg
from tools.modules.unet.util import *
from utils.config import Config as pConfig
from utils.registry_class import ENGINE, MODEL
def save_temporal_key():
cfg_update = pConfig(load=True)
for k, v in cfg_update.cfg_dict.items():
if isinstance(v, dict) and k in cfg:
cfg[k].update(v)
else:
cfg[k] = v
model = MODEL.build(cfg.UNet)
temp_name = ''
temp_key_list = []
spth = 'workspace/module_list/UNetSD_I2V_vs_Text_temporal_key_list.json'
for name, module in model.named_modules():
if isinstance(module, (TemporalTransformer, TemporalTransformer_attemask, TemporalAttentionBlock, TemporalAttentionMultiBlock, TemporalConvBlock_v2, TemporalConvBlock)):
temp_name = name
print(f'Model: {name}')
elif isinstance(module, (ResidualBlock, ResBlock, SpatialTransformer, Upsample, Downsample)):
temp_name = ''
if hasattr(module, 'weight'):
if temp_name != '' and (temp_name in name):
temp_key_list.append(name)
print(f'{name}')
# print(name)
save_module_list = []
for k, p in model.named_parameters():
for item in temp_key_list:
if item in k:
print(f'{item} --> {k}')
save_module_list.append(k)
print(int(sum(p.numel() for k, p in model.named_parameters()) / (1024 ** 2)), 'M parameters')
# spth = 'workspace/module_list/{}'
json.dump(save_module_list, open(spth, 'w'))
a = 0
def save_spatial_key():
cfg_update = pConfig(load=True)
for k, v in cfg_update.cfg_dict.items():
if isinstance(v, dict) and k in cfg:
cfg[k].update(v)
else:
cfg[k] = v
model = MODEL.build(cfg.UNet)
temp_name = ''
temp_key_list = []
spth = 'workspace/module_list/UNetSD_I2V_HQ_P_spatial_key_list.json'
for name, module in model.named_modules():
if isinstance(module, (ResidualBlock, ResBlock, SpatialTransformer, Upsample, Downsample)):
temp_name = name
print(f'Model: {name}')
elif isinstance(module, (TemporalTransformer, TemporalTransformer_attemask, TemporalAttentionBlock, TemporalAttentionMultiBlock, TemporalConvBlock_v2, TemporalConvBlock)):
temp_name = ''
if hasattr(module, 'weight'):
if temp_name != '' and (temp_name in name):
temp_key_list.append(name)
print(f'{name}')
# print(name)
save_module_list = []
for k, p in model.named_parameters():
for item in temp_key_list:
if item in k:
print(f'{item} --> {k}')
save_module_list.append(k)
print(int(sum(p.numel() for k, p in model.named_parameters()) / (1024 ** 2)), 'M parameters')
# spth = 'workspace/module_list/{}'
json.dump(save_module_list, open(spth, 'w'))
a = 0
if __name__ == '__main__':
# save_temporal_key()
save_spatial_key()
# print([k for (k, _) in self.input_blocks.named_parameters()])
import os
import sys
import torch
import imageio
import numpy as np
import os.path as osp
sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2]))
from PIL import Image, ImageDraw, ImageFont
from einops import rearrange
from tools import *
import utils.transforms as data
from utils.seed import setup_seed
from tools.modules.config import cfg
from utils.config import Config as pConfig
from utils.registry_class import ENGINE, DATASETS, AUTO_ENCODER
def test_enc_dec(gpu=0):
setup_seed(0)
cfg_update = pConfig(load=True)
for k, v in cfg_update.cfg_dict.items():
if isinstance(v, dict) and k in cfg:
cfg[k].update(v)
else:
cfg[k] = v
save_dir = os.path.join('workspace/test_data/autoencoder', cfg.auto_encoder['type'])
os.system('rm -rf %s' % (save_dir))
os.makedirs(save_dir, exist_ok=True)
train_trans = data.Compose([
data.CenterCropWide(size=cfg.resolution),
data.ToTensor(),
data.Normalize(mean=cfg.mean, std=cfg.std)])
vit_trans = data.Compose([
data.CenterCropWide(size=(cfg.resolution[0], cfg.resolution[0])) if cfg.resolution[0]>cfg.vit_resolution[0] else data.CenterCropWide(size=cfg.vit_resolution),
data.Resize(cfg.vit_resolution),
data.ToTensor(),
data.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)])
video_mean = torch.tensor(cfg.mean).view(1, -1, 1, 1) #n c f h w
video_std = torch.tensor(cfg.std).view(1, -1, 1, 1) #n c f h w
txt_size = cfg.resolution[1]
nc = int(38 * (txt_size / 256))
font = ImageFont.truetype('data/font/DejaVuSans.ttf', size=13)
dataset = DATASETS.build(cfg.vid_dataset, sample_fps=4, transforms=train_trans, vit_transforms=vit_trans)
print('There are %d videos' % (len(dataset)))
autoencoder = AUTO_ENCODER.build(cfg.auto_encoder)
autoencoder.eval() # freeze
for param in autoencoder.parameters():
param.requires_grad = False
autoencoder.to(gpu)
for idx, item in enumerate(dataset):
local_path = os.path.join(save_dir, '%04d.mp4' % idx)
# ref_frame, video_data, caption = item
ref_frame, vit_frame, video_data = item[:3]
video_data = video_data.to(gpu)
image_list = []
video_data_list = torch.chunk(video_data, video_data.shape[0]//cfg.chunk_size,dim=0)
with torch.no_grad():
decode_data = []
for chunk_data in video_data_list:
latent_z = autoencoder.encode_firsr_stage(chunk_data).detach()
# latent_z = get_first_stage_encoding(encoder_posterior).detach()
kwargs = {"timesteps": chunk_data.shape[0]}
recons_data = autoencoder.decode(latent_z, **kwargs)
vis_data = torch.cat([chunk_data, recons_data], dim=2).cpu()
vis_data = vis_data.mul_(video_std).add_(video_mean) # 8x3x16x256x384
vis_data = vis_data.cpu()
vis_data.clamp_(0, 1)
vis_data = vis_data.permute(0, 2, 3, 1)
vis_data = [(image.numpy() * 255).astype('uint8') for image in vis_data]
image_list.extend(vis_data)
num_image = len(image_list)
frame_dir = os.path.join(save_dir, 'temp')
os.makedirs(frame_dir, exist_ok=True)
for idx in range(num_image):
tpth = os.path.join(frame_dir, '%04d.png' % (idx+1))
cv2.imwrite(tpth, image_list[idx][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100])
cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate 8 -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}'
os.system(cmd); os.system(f'rm -rf {frame_dir}')
if __name__ == '__main__':
test_enc_dec()
import os
import sys
import imageio
import numpy as np
import os.path as osp
sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2]))
from PIL import Image, ImageDraw, ImageFont
import torchvision.transforms as T
import utils.transforms as data
from tools.modules.config import cfg
from utils.config import Config as pConfig
from utils.registry_class import ENGINE, DATASETS
from tools import *
def test_video_dataset():
cfg_update = pConfig(load=True)
for k, v in cfg_update.cfg_dict.items():
if isinstance(v, dict) and k in cfg:
cfg[k].update(v)
else:
cfg[k] = v
exp_name = os.path.basename(cfg.cfg_file).split('.')[0]
save_dir = os.path.join('workspace', 'test_data/datasets', cfg.vid_dataset['type'], exp_name)
os.system('rm -rf %s' % (save_dir))
os.makedirs(save_dir, exist_ok=True)
train_trans = data.Compose([
data.CenterCropWide(size=cfg.resolution),
data.ToTensor(),
data.Normalize(mean=cfg.mean, std=cfg.std)])
vit_trans = T.Compose([
data.CenterCropWide(cfg.vit_resolution),
T.ToTensor(),
T.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)])
video_mean = torch.tensor(cfg.mean).view(1, -1, 1, 1) #n c f h w
video_std = torch.tensor(cfg.std).view(1, -1, 1, 1) #n c f h w
img_mean = torch.tensor(cfg.mean).view(-1, 1, 1) # c f h w
img_std = torch.tensor(cfg.std).view(-1, 1, 1) # c f h w
vit_mean = torch.tensor(cfg.vit_mean).view(-1, 1, 1) # c f h w
vit_std = torch.tensor(cfg.vit_std).view(-1, 1, 1) # c f h w
txt_size = cfg.resolution[1]
nc = int(38 * (txt_size / 256))
font = ImageFont.truetype('data/font/DejaVuSans.ttf', size=13)
dataset = DATASETS.build(cfg.vid_dataset, sample_fps=cfg.sample_fps[0], transforms=train_trans, vit_transforms=vit_trans)
print('There are %d videos' % (len(dataset)))
for idx, item in enumerate(dataset):
ref_frame, vit_frame, video_data, caption, video_key = item
video_data = video_data.mul_(video_std).add_(video_mean)
video_data.clamp_(0, 1)
video_data = video_data.permute(0, 2, 3, 1)
video_data = [(image.numpy() * 255).astype('uint8') for image in video_data]
# Single Image
ref_frame = ref_frame.mul_(img_mean).add_(img_std)
ref_frame.clamp_(0, 1)
ref_frame = ref_frame.permute(1, 2, 0)
ref_frame = (ref_frame.numpy() * 255).astype('uint8')
# Text image
txt_img = Image.new("RGB", (txt_size, txt_size), color="white")
draw = ImageDraw.Draw(txt_img)
lines = "\n".join(caption[start:start + nc] for start in range(0, len(caption), nc))
draw.text((0, 0), lines, fill="black", font=font)
txt_img = np.array(txt_img)
video_data = [np.concatenate([ref_frame, u, txt_img], axis=1) for u in video_data]
spath = os.path.join(save_dir, '%04d.gif' % (idx))
imageio.mimwrite(spath, video_data, fps =8)
# if idx > 100: break
def test_vit_image(test_video_flag=True):
cfg_update = pConfig(load=True)
for k, v in cfg_update.cfg_dict.items():
if isinstance(v, dict) and k in cfg:
cfg[k].update(v)
else:
cfg[k] = v
exp_name = os.path.basename(cfg.cfg_file).split('.')[0]
save_dir = os.path.join('workspace', 'test_data/datasets', cfg.img_dataset['type'], exp_name)
os.system('rm -rf %s' % (save_dir))
os.makedirs(save_dir, exist_ok=True)
train_trans = data.Compose([
data.CenterCropWide(size=cfg.resolution),
data.ToTensor(),
data.Normalize(mean=cfg.mean, std=cfg.std)])
vit_trans = data.Compose([
data.CenterCropWide(cfg.resolution),
data.Resize(cfg.vit_resolution),
data.ToTensor(),
data.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)])
img_mean = torch.tensor(cfg.mean).view(-1, 1, 1) # c f h w
img_std = torch.tensor(cfg.std).view(-1, 1, 1) # c f h w
vit_mean = torch.tensor(cfg.vit_mean).view(-1, 1, 1) # c f h w
vit_std = torch.tensor(cfg.vit_std).view(-1, 1, 1) # c f h w
txt_size = cfg.resolution[1]
nc = int(38 * (txt_size / 256))
font = ImageFont.truetype('artist/font/DejaVuSans.ttf', size=13)
dataset = DATASETS.build(cfg.img_dataset, transforms=train_trans, vit_transforms=vit_trans)
print('There are %d videos' % (len(dataset)))
for idx, item in enumerate(dataset):
ref_frame, vit_frame, video_data, caption, video_key = item
video_data = video_data.mul_(img_std).add_(img_mean)
video_data.clamp_(0, 1)
video_data = video_data.permute(0, 2, 3, 1)
video_data = [(image.numpy() * 255).astype('uint8') for image in video_data]
# Single Image
vit_frame = vit_frame.mul_(vit_std).add_(vit_mean)
vit_frame.clamp_(0, 1)
vit_frame = vit_frame.permute(1, 2, 0)
vit_frame = (vit_frame.numpy() * 255).astype('uint8')
zero_frame = np.zeros((cfg.resolution[1], cfg.resolution[1], 3), dtype=np.uint8)
zero_frame[:vit_frame.shape[0], :vit_frame.shape[1], :] = vit_frame
# Text image
txt_img = Image.new("RGB", (txt_size, txt_size), color="white")
draw = ImageDraw.Draw(txt_img)
lines = "\n".join(caption[start:start + nc] for start in range(0, len(caption), nc))
draw.text((0, 0), lines, fill="black", font=font)
txt_img = np.array(txt_img)
video_data = [np.concatenate([zero_frame, u, txt_img], axis=1) for u in video_data]
spath = os.path.join(save_dir, '%04d.gif' % (idx))
imageio.mimwrite(spath, video_data, fps =8)
# if idx > 100: break
if __name__ == '__main__':
# test_video_dataset()
test_vit_image()
import os
import sys
import torch
import imageio
import numpy as np
import os.path as osp
sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2]))
from thop import profile
from ptflops import get_model_complexity_info
import artist.data as data
from tools.modules.config import cfg
from utils.config import Config as pConfig
from utils.registry_class import ENGINE, MODEL
def test_model():
cfg_update = pConfig(load=True)
for k, v in cfg_update.cfg_dict.items():
if isinstance(v, dict) and k in cfg:
cfg[k].update(v)
else:
cfg[k] = v
model = MODEL.build(cfg.UNet)
print(int(sum(p.numel() for k, p in model.named_parameters()) / (1024 ** 2)), 'M parameters')
# state_dict = torch.load('cache/pretrain_model/jiuniu_0600000.pth', map_location='cpu')
# model.load_state_dict(state_dict, strict=False)
model = model.cuda()
x = torch.Tensor(1, 4, 16, 32, 56).cuda()
t = torch.Tensor(1).cuda()
sims = torch.Tensor(1, 32).cuda()
fps = torch.Tensor([8]).cuda()
y = torch.Tensor(1, 1, 1024).cuda()
image = torch.Tensor(1, 3, 256, 448).cuda()
ret = model(x=x, t=t, y=y, ori_img=image, sims=sims, fps=fps)
print('Out shape if {}'.format(ret.shape))
# flops, params = profile(model=model, inputs=(x, t, y, image, sims, fps))
# print('Model: {:.2f} GFLOPs and {:.2f}M parameters'.format(flops/1e9, params/1e6))
def prepare_input(resolution):
return dict(x=[x, t, y, image, sims, fps])
flops, params = get_model_complexity_info(model, (1, 4, 16, 32, 56),
input_constructor = prepare_input,
as_strings=True, print_per_layer_stat=True)
print(' - Flops: ' + flops)
print(' - Params: ' + params)
if __name__ == '__main__':
test_model()
import numpy as np
import cv2
cap = cv2.VideoCapture('workspace/img_dir/tst.mp4')
fourcc = cv2.VideoWriter_fourcc(*'H264')
ret, frame = cap.read()
vid_size = frame.shape[:2][::-1]
out = cv2.VideoWriter('workspace/img_dir/testwrite.mp4',fourcc, 8, vid_size)
out.write(frame)
while(cap.isOpened()):
ret, frame = cap.read()
if not ret: break
out.write(frame)
cap.release()
out.release()
from .annotator import *
from .datasets import *
from .modules import *
from .train import *
from .hooks import *
from .inferences import *
# from .prior import *
from .basic_funcs import *
import cv2
import torch
import numpy as np
from tools.annotator.util import HWC3
# import gradio as gr
class CannyDetector:
def __call__(self, img, low_threshold = None, high_threshold = None, random_threshold = True):
### GPT-4 suggestions
# In the cv2.Canny() function, the low threshold and high threshold are used to determine the edges based on the gradient values in the image.
# There isn't a one-size-fits-all solution for these threshold values, as the optimal values depend on the specific image and the application.
# However, there are some general guidelines and empirical values you can use as a starting point:
# 1. Ratio: A common recommendation is to use a ratio of 1:2 or 1:3 between the low threshold and the high threshold.
# This means if your low threshold is 50, the high threshold should be around 100 or 150.
# 2. Empirical values: As a starting point, you can use low threshold values in the range of 50-100 and high threshold values in the range of 100-200.
# You may need to fine-tune these values based on the specific image and desired edge detection results.
# 3. Automatic threshold calculation: To automatically calculate the threshold values, you can use the median or mean value of the image's pixel intensities as the low threshold,
# and the high threshold can be set as twice or three times the low threshold.
### Convert to numpy
if isinstance(img, torch.Tensor): # (h, w, c)
img = img.cpu().numpy()
img_np = cv2.convertScaleAbs((img * 255.))
elif isinstance(img, np.ndarray): # (h, w, c)
img_np = img # we assume values are in the range from 0 to 255.
else:
assert False
### Select the threshold
if (low_threshold is None) and (high_threshold is None):
median_intensity = np.median(img_np)
if random_threshold is False:
low_threshold = int(max(0, (1 - 0.33) * median_intensity))
high_threshold = int(min(255, (1 + 0.33) * median_intensity))
else:
random_canny = np.random.uniform(0.1, 0.4)
# Might try other values
low_threshold = int(max(0, (1 - random_canny) * median_intensity))
high_threshold = 2 * low_threshold
### Detect canny edge
canny_edge = cv2.Canny(img_np, low_threshold, high_threshold)
### Convert to 3 channels
# canny_edge = HWC3(canny_edge)
canny_condition = torch.from_numpy(canny_edge.copy()).unsqueeze(dim = -1).float().cuda() / 255.0
# canny_condition = torch.stack([canny_condition for _ in range(num_samples)], dim=0)
# canny_condition = einops.rearrange(canny_condition, 'h w c -> b c h w').clone()
# return cv2.Canny(img, low_threshold, high_threshold)
return canny_condition
\ No newline at end of file
from .palette import *
\ No newline at end of file
r"""Modified from ``https://github.com/sergeyk/rayleigh''.
"""
import os
import os.path as osp
import numpy as np
from skimage.color import hsv2rgb, rgb2lab, lab2rgb
from skimage.io import imsave
from sklearn.metrics import euclidean_distances
__all__ = ['Palette']
def rgb2hex(rgb):
return '#%02x%02x%02x' % tuple([int(round(255.0 * u)) for u in rgb])
def hex2rgb(hex):
rgb = hex.strip('#')
fn = lambda u: round(int(u, 16) / 255.0, 5)
return fn(rgb[:2]), fn(rgb[2:4]), fn(rgb[4:6])
class Palette(object):
r"""Create a color palette (codebook) in the form of a 2D grid of colors.
Further, the rightmost column has num_hues gradations from black to white.
Parameters:
num_hues: number of colors with full lightness and saturation, in the middle.
num_sat: number of rows above middle row that show the same hues with decreasing saturation.
"""
def __init__(self, num_hues=11, num_sat=5, num_light=4):
n = num_sat + 2 * num_light
# hues
if num_hues == 8:
hues = np.tile(np.array([0., 0.10, 0.15, 0.28, 0.51, 0.58, 0.77, 0.85]), (n, 1))
elif num_hues == 9:
hues = np.tile(np.array([0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.7, 0.87]), (n, 1))
elif num_hues == 10:
hues = np.tile(np.array([0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.66, 0.76, 0.87]), (n, 1))
elif num_hues == 11:
hues = np.tile(np.array([0.0, 0.0833, 0.166, 0.25, 0.333, 0.5, 0.56333, 0.666, 0.73, 0.803, 0.916]), (n, 1))
else:
hues = np.tile(np.linspace(0, 1, num_hues + 1)[:-1], (n, 1))
# saturations
sats = np.hstack((
np.linspace(0, 1, num_sat + 2)[1:-1],
1,
[1] * num_light,
[0.4] * (num_light - 1)))
sats = np.tile(np.atleast_2d(sats).T, (1, num_hues))
# lights
lights = np.hstack((
[1] * num_sat,
1,
np.linspace(1, 0.2, num_light + 2)[1:-1],
np.linspace(1, 0.2, num_light + 2)[1:-2]))
lights = np.tile(np.atleast_2d(lights).T, (1, num_hues))
# colors
rgb = hsv2rgb(np.dstack([hues, sats, lights]))
gray = np.tile(np.linspace(1, 0, n)[:, np.newaxis, np.newaxis], (1, 1, 3))
self.thumbnail = np.hstack([rgb, gray])
# flatten
rgb = rgb.T.reshape(3, -1).T
gray = gray.T.reshape(3, -1).T
self.rgb = np.vstack((rgb, gray))
self.lab = rgb2lab(self.rgb[np.newaxis, :, :]).squeeze()
self.hex = [rgb2hex(u) for u in self.rgb]
self.lab_dists = euclidean_distances(self.lab, squared=True)
def histogram(self, rgb_img, sigma=20):
# compute histogram
lab = rgb2lab(rgb_img).reshape((-1, 3))
min_ind = np.argmin(euclidean_distances(lab, self.lab, squared=True), axis=1)
hist = 1.0 * np.bincount(min_ind, minlength=self.lab.shape[0]) / lab.shape[0]
# smooth histogram
if sigma > 0:
weight = np.exp(-self.lab_dists / (2.0 * sigma ** 2))
weight = weight / weight.sum(1)[:, np.newaxis]
hist = (weight * hist).sum(1)
hist[hist < 1e-5] = 0
return hist
def get_palette_image(self, hist, percentile=90, width=200, height=50):
# curate histogram
ind = np.argsort(-hist)
ind = ind[hist[ind] > np.percentile(hist, percentile)]
hist = hist[ind] / hist[ind].sum()
# draw palette
nums = np.array(hist * width, dtype=int)
array = np.vstack([np.tile(np.array(u), (v, 1)) for u, v in zip(self.rgb[ind], nums)])
array = np.tile(array[np.newaxis, :, :], (height, 1, 1))
if array.shape[1] < width:
array = np.concatenate([array, np.zeros((height, width - array.shape[1], 3))], axis=1)
return array
def quantize_image(self, rgb_img):
lab = rgb2lab(rgb_img).reshape((-1, 3))
min_ind = np.argmin(euclidean_distances(lab, self.lab, squared=True), axis=1)
quantized_lab = self.lab[min_ind]
img = lab2rgb(quantized_lab.reshape(rgb_img.shape))
return img
def export(self, dirname):
if not osp.exists(dirname):
os.makedirs(dirname)
# save thumbnail
imsave(osp.join(dirname, 'palette.png'), self.thumbnail)
# save html
with open(osp.join(dirname, 'palette.html'), 'w') as f:
html = '''
<style>
span {
width: 20px;
height: 20px;
margin: 2px;
padding: 0px;
display: inline-block;
}
</style>
'''
for row in self.thumbnail:
for col in row:
html += '<a id="{0}"><span style="background-color: {0}" /></a>\n'.format(rgb2hex(col))
html += '<br />\n'
f.write(html)
from .pidinet import *
from .sketch_simplification import *
\ No newline at end of file
r"""Modified from ``https://github.com/zhuoinoulu/pidinet''.
Image augmentation: T.Compose([
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# from canvas import DOWNLOAD_TO_CACHE
from artist import DOWNLOAD_TO_CACHE
__all__ = ['PiDiNet', 'pidinet_bsd_tiny', 'pidinet_bsd_small', 'pidinet_bsd',
'pidinet_nyud', 'pidinet_multicue']
CONFIGS = {
'baseline': {
'layer0': 'cv',
'layer1': 'cv',
'layer2': 'cv',
'layer3': 'cv',
'layer4': 'cv',
'layer5': 'cv',
'layer6': 'cv',
'layer7': 'cv',
'layer8': 'cv',
'layer9': 'cv',
'layer10': 'cv',
'layer11': 'cv',
'layer12': 'cv',
'layer13': 'cv',
'layer14': 'cv',
'layer15': 'cv',
},
'c-v15': {
'layer0': 'cd',
'layer1': 'cv',
'layer2': 'cv',
'layer3': 'cv',
'layer4': 'cv',
'layer5': 'cv',
'layer6': 'cv',
'layer7': 'cv',
'layer8': 'cv',
'layer9': 'cv',
'layer10': 'cv',
'layer11': 'cv',
'layer12': 'cv',
'layer13': 'cv',
'layer14': 'cv',
'layer15': 'cv',
},
'a-v15': {
'layer0': 'ad',
'layer1': 'cv',
'layer2': 'cv',
'layer3': 'cv',
'layer4': 'cv',
'layer5': 'cv',
'layer6': 'cv',
'layer7': 'cv',
'layer8': 'cv',
'layer9': 'cv',
'layer10': 'cv',
'layer11': 'cv',
'layer12': 'cv',
'layer13': 'cv',
'layer14': 'cv',
'layer15': 'cv',
},
'r-v15': {
'layer0': 'rd',
'layer1': 'cv',
'layer2': 'cv',
'layer3': 'cv',
'layer4': 'cv',
'layer5': 'cv',
'layer6': 'cv',
'layer7': 'cv',
'layer8': 'cv',
'layer9': 'cv',
'layer10': 'cv',
'layer11': 'cv',
'layer12': 'cv',
'layer13': 'cv',
'layer14': 'cv',
'layer15': 'cv',
},
'cvvv4': {
'layer0': 'cd',
'layer1': 'cv',
'layer2': 'cv',
'layer3': 'cv',
'layer4': 'cd',
'layer5': 'cv',
'layer6': 'cv',
'layer7': 'cv',
'layer8': 'cd',
'layer9': 'cv',
'layer10': 'cv',
'layer11': 'cv',
'layer12': 'cd',
'layer13': 'cv',
'layer14': 'cv',
'layer15': 'cv',
},
'avvv4': {
'layer0': 'ad',
'layer1': 'cv',
'layer2': 'cv',
'layer3': 'cv',
'layer4': 'ad',
'layer5': 'cv',
'layer6': 'cv',
'layer7': 'cv',
'layer8': 'ad',
'layer9': 'cv',
'layer10': 'cv',
'layer11': 'cv',
'layer12': 'ad',
'layer13': 'cv',
'layer14': 'cv',
'layer15': 'cv',
},
'rvvv4': {
'layer0': 'rd',
'layer1': 'cv',
'layer2': 'cv',
'layer3': 'cv',
'layer4': 'rd',
'layer5': 'cv',
'layer6': 'cv',
'layer7': 'cv',
'layer8': 'rd',
'layer9': 'cv',
'layer10': 'cv',
'layer11': 'cv',
'layer12': 'rd',
'layer13': 'cv',
'layer14': 'cv',
'layer15': 'cv',
},
'cccv4': {
'layer0': 'cd',
'layer1': 'cd',
'layer2': 'cd',
'layer3': 'cv',
'layer4': 'cd',
'layer5': 'cd',
'layer6': 'cd',
'layer7': 'cv',
'layer8': 'cd',
'layer9': 'cd',
'layer10': 'cd',
'layer11': 'cv',
'layer12': 'cd',
'layer13': 'cd',
'layer14': 'cd',
'layer15': 'cv',
},
'aaav4': {
'layer0': 'ad',
'layer1': 'ad',
'layer2': 'ad',
'layer3': 'cv',
'layer4': 'ad',
'layer5': 'ad',
'layer6': 'ad',
'layer7': 'cv',
'layer8': 'ad',
'layer9': 'ad',
'layer10': 'ad',
'layer11': 'cv',
'layer12': 'ad',
'layer13': 'ad',
'layer14': 'ad',
'layer15': 'cv',
},
'rrrv4': {
'layer0': 'rd',
'layer1': 'rd',
'layer2': 'rd',
'layer3': 'cv',
'layer4': 'rd',
'layer5': 'rd',
'layer6': 'rd',
'layer7': 'cv',
'layer8': 'rd',
'layer9': 'rd',
'layer10': 'rd',
'layer11': 'cv',
'layer12': 'rd',
'layer13': 'rd',
'layer14': 'rd',
'layer15': 'cv',
},
'c16': {
'layer0': 'cd',
'layer1': 'cd',
'layer2': 'cd',
'layer3': 'cd',
'layer4': 'cd',
'layer5': 'cd',
'layer6': 'cd',
'layer7': 'cd',
'layer8': 'cd',
'layer9': 'cd',
'layer10': 'cd',
'layer11': 'cd',
'layer12': 'cd',
'layer13': 'cd',
'layer14': 'cd',
'layer15': 'cd',
},
'a16': {
'layer0': 'ad',
'layer1': 'ad',
'layer2': 'ad',
'layer3': 'ad',
'layer4': 'ad',
'layer5': 'ad',
'layer6': 'ad',
'layer7': 'ad',
'layer8': 'ad',
'layer9': 'ad',
'layer10': 'ad',
'layer11': 'ad',
'layer12': 'ad',
'layer13': 'ad',
'layer14': 'ad',
'layer15': 'ad',
},
'r16': {
'layer0': 'rd',
'layer1': 'rd',
'layer2': 'rd',
'layer3': 'rd',
'layer4': 'rd',
'layer5': 'rd',
'layer6': 'rd',
'layer7': 'rd',
'layer8': 'rd',
'layer9': 'rd',
'layer10': 'rd',
'layer11': 'rd',
'layer12': 'rd',
'layer13': 'rd',
'layer14': 'rd',
'layer15': 'rd',
},
'carv4': {
'layer0': 'cd',
'layer1': 'ad',
'layer2': 'rd',
'layer3': 'cv',
'layer4': 'cd',
'layer5': 'ad',
'layer6': 'rd',
'layer7': 'cv',
'layer8': 'cd',
'layer9': 'ad',
'layer10': 'rd',
'layer11': 'cv',
'layer12': 'cd',
'layer13': 'ad',
'layer14': 'rd',
'layer15': 'cv'
}
}
def create_conv_func(op_type):
assert op_type in ['cv', 'cd', 'ad', 'rd'], 'unknown op type: %s' % str(op_type)
if op_type == 'cv':
return F.conv2d
if op_type == 'cd':
def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
assert dilation in [1, 2], 'dilation for cd_conv should be in 1 or 2'
assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for cd_conv should be 3x3'
assert padding == dilation, 'padding for cd_conv set wrong'
weights_c = weights.sum(dim=[2, 3], keepdim=True)
yc = F.conv2d(x, weights_c, stride=stride, padding=0, groups=groups)
y = F.conv2d(x, weights, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
return y - yc
return func
elif op_type == 'ad':
def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
assert dilation in [1, 2], 'dilation for ad_conv should be in 1 or 2'
assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for ad_conv should be 3x3'
assert padding == dilation, 'padding for ad_conv set wrong'
shape = weights.shape
weights = weights.view(shape[0], shape[1], -1)
weights_conv = (weights - weights[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view(shape) # clock-wise
y = F.conv2d(x, weights_conv, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
return y
return func
elif op_type == 'rd':
def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
assert dilation in [1, 2], 'dilation for rd_conv should be in 1 or 2'
assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for rd_conv should be 3x3'
padding = 2 * dilation
shape = weights.shape
if weights.is_cuda:
buffer = torch.cuda.FloatTensor(shape[0], shape[1], 5 * 5).fill_(0)
else:
buffer = torch.zeros(shape[0], shape[1], 5 * 5)
weights = weights.view(shape[0], shape[1], -1)
buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weights[:, :, 1:]
buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weights[:, :, 1:]
buffer[:, :, 12] = 0
buffer = buffer.view(shape[0], shape[1], 5, 5)
y = F.conv2d(x, buffer, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
return y
return func
else:
print('impossible to be here unless you force that', flush=True)
return None
def config_model(model):
model_options = list(CONFIGS.keys())
assert model in model_options, \
'unrecognized model, please choose from %s' % str(model_options)
pdcs = []
for i in range(16):
layer_name = 'layer%d' % i
op = CONFIGS[model][layer_name]
pdcs.append(create_conv_func(op))
return pdcs
def config_model_converted(model):
model_options = list(CONFIGS.keys())
assert model in model_options, \
'unrecognized model, please choose from %s' % str(model_options)
pdcs = []
for i in range(16):
layer_name = 'layer%d' % i
op = CONFIGS[model][layer_name]
pdcs.append(op)
return pdcs
def convert_pdc(op, weight):
if op == 'cv':
return weight
elif op == 'cd':
shape = weight.shape
weight_c = weight.sum(dim=[2, 3])
weight = weight.view(shape[0], shape[1], -1)
weight[:, :, 4] = weight[:, :, 4] - weight_c
weight = weight.view(shape)
return weight
elif op == 'ad':
shape = weight.shape
weight = weight.view(shape[0], shape[1], -1)
weight_conv = (weight - weight[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view(shape)
return weight_conv
elif op == 'rd':
shape = weight.shape
buffer = torch.zeros(shape[0], shape[1], 5 * 5, device=weight.device)
weight = weight.view(shape[0], shape[1], -1)
buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weight[:, :, 1:]
buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weight[:, :, 1:]
buffer = buffer.view(shape[0], shape[1], 5, 5)
return buffer
raise ValueError("wrong op {}".format(str(op)))
def convert_pidinet(state_dict, config):
pdcs = config_model_converted(config)
new_dict = {}
for pname, p in state_dict.items():
if 'init_block.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[0], p)
elif 'block1_1.conv1.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[1], p)
elif 'block1_2.conv1.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[2], p)
elif 'block1_3.conv1.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[3], p)
elif 'block2_1.conv1.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[4], p)
elif 'block2_2.conv1.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[5], p)
elif 'block2_3.conv1.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[6], p)
elif 'block2_4.conv1.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[7], p)
elif 'block3_1.conv1.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[8], p)
elif 'block3_2.conv1.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[9], p)
elif 'block3_3.conv1.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[10], p)
elif 'block3_4.conv1.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[11], p)
elif 'block4_1.conv1.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[12], p)
elif 'block4_2.conv1.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[13], p)
elif 'block4_3.conv1.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[14], p)
elif 'block4_4.conv1.weight' in pname:
new_dict[pname] = convert_pdc(pdcs[15], p)
else:
new_dict[pname] = p
return new_dict
class Conv2d(nn.Module):
def __init__(self, pdc, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False):
super(Conv2d, self).__init__()
if in_channels % groups != 0:
raise ValueError('in_channels must be divisible by groups')
if out_channels % groups != 0:
raise ValueError('out_channels must be divisible by groups')
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
self.pdc = pdc
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, input):
return self.pdc(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
class CSAM(nn.Module):
r"""
Compact Spatial Attention Module
"""
def __init__(self, channels):
super(CSAM, self).__init__()
mid_channels = 4
self.relu1 = nn.ReLU()
self.conv1 = nn.Conv2d(channels, mid_channels, kernel_size=1, padding=0)
self.conv2 = nn.Conv2d(mid_channels, 1, kernel_size=3, padding=1, bias=False)
self.sigmoid = nn.Sigmoid()
nn.init.constant_(self.conv1.bias, 0)
def forward(self, x):
y = self.relu1(x)
y = self.conv1(y)
y = self.conv2(y)
y = self.sigmoid(y)
return x * y
class CDCM(nn.Module):
r"""
Compact Dilation Convolution based Module
"""
def __init__(self, in_channels, out_channels):
super(CDCM, self).__init__()
self.relu1 = nn.ReLU()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
self.conv2_1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=5, padding=5, bias=False)
self.conv2_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=7, padding=7, bias=False)
self.conv2_3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=9, padding=9, bias=False)
self.conv2_4 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=11, padding=11, bias=False)
nn.init.constant_(self.conv1.bias, 0)
def forward(self, x):
x = self.relu1(x)
x = self.conv1(x)
x1 = self.conv2_1(x)
x2 = self.conv2_2(x)
x3 = self.conv2_3(x)
x4 = self.conv2_4(x)
return x1 + x2 + x3 + x4
class MapReduce(nn.Module):
r"""
Reduce feature maps into a single edge map
"""
def __init__(self, channels):
super(MapReduce, self).__init__()
self.conv = nn.Conv2d(channels, 1, kernel_size=1, padding=0)
nn.init.constant_(self.conv.bias, 0)
def forward(self, x):
return self.conv(x)
class PDCBlock(nn.Module):
def __init__(self, pdc, inplane, ouplane, stride=1):
super(PDCBlock, self).__init__()
self.stride=stride
self.stride=stride
if self.stride > 1:
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0)
self.conv1 = Conv2d(pdc, inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False)
self.relu2 = nn.ReLU()
self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False)
def forward(self, x):
if self.stride > 1:
x = self.pool(x)
y = self.conv1(x)
y = self.relu2(y)
y = self.conv2(y)
if self.stride > 1:
x = self.shortcut(x)
y = y + x
return y
class PDCBlock_converted(nn.Module):
r"""
CPDC, APDC can be converted to vanilla 3x3 convolution
RPDC can be converted to vanilla 5x5 convolution
"""
def __init__(self, pdc, inplane, ouplane, stride=1):
super(PDCBlock_converted, self).__init__()
self.stride=stride
if self.stride > 1:
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0)
if pdc == 'rd':
self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=5, padding=2, groups=inplane, bias=False)
else:
self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False)
self.relu2 = nn.ReLU()
self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False)
def forward(self, x):
if self.stride > 1:
x = self.pool(x)
y = self.conv1(x)
y = self.relu2(y)
y = self.conv2(y)
if self.stride > 1:
x = self.shortcut(x)
y = y + x
return y
class PiDiNet(nn.Module):
def __init__(self, inplane, pdcs, dil=None, sa=False, convert=False):
super(PiDiNet, self).__init__()
self.sa = sa
if dil is not None:
assert isinstance(dil, int), 'dil should be an int'
self.dil = dil
self.fuseplanes = []
self.inplane = inplane
if convert:
if pdcs[0] == 'rd':
init_kernel_size = 5
init_padding = 2
else:
init_kernel_size = 3
init_padding = 1
self.init_block = nn.Conv2d(3, self.inplane,
kernel_size=init_kernel_size, padding=init_padding, bias=False)
block_class = PDCBlock_converted
else:
self.init_block = Conv2d(pdcs[0], 3, self.inplane, kernel_size=3, padding=1)
block_class = PDCBlock
self.block1_1 = block_class(pdcs[1], self.inplane, self.inplane)
self.block1_2 = block_class(pdcs[2], self.inplane, self.inplane)
self.block1_3 = block_class(pdcs[3], self.inplane, self.inplane)
self.fuseplanes.append(self.inplane) # C
inplane = self.inplane
self.inplane = self.inplane * 2
self.block2_1 = block_class(pdcs[4], inplane, self.inplane, stride=2)
self.block2_2 = block_class(pdcs[5], self.inplane, self.inplane)
self.block2_3 = block_class(pdcs[6], self.inplane, self.inplane)
self.block2_4 = block_class(pdcs[7], self.inplane, self.inplane)
self.fuseplanes.append(self.inplane) # 2C
inplane = self.inplane
self.inplane = self.inplane * 2
self.block3_1 = block_class(pdcs[8], inplane, self.inplane, stride=2)
self.block3_2 = block_class(pdcs[9], self.inplane, self.inplane)
self.block3_3 = block_class(pdcs[10], self.inplane, self.inplane)
self.block3_4 = block_class(pdcs[11], self.inplane, self.inplane)
self.fuseplanes.append(self.inplane) # 4C
self.block4_1 = block_class(pdcs[12], self.inplane, self.inplane, stride=2)
self.block4_2 = block_class(pdcs[13], self.inplane, self.inplane)
self.block4_3 = block_class(pdcs[14], self.inplane, self.inplane)
self.block4_4 = block_class(pdcs[15], self.inplane, self.inplane)
self.fuseplanes.append(self.inplane) # 4C
self.conv_reduces = nn.ModuleList()
if self.sa and self.dil is not None:
self.attentions = nn.ModuleList()
self.dilations = nn.ModuleList()
for i in range(4):
self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
self.attentions.append(CSAM(self.dil))
self.conv_reduces.append(MapReduce(self.dil))
elif self.sa:
self.attentions = nn.ModuleList()
for i in range(4):
self.attentions.append(CSAM(self.fuseplanes[i]))
self.conv_reduces.append(MapReduce(self.fuseplanes[i]))
elif self.dil is not None:
self.dilations = nn.ModuleList()
for i in range(4):
self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
self.conv_reduces.append(MapReduce(self.dil))
else:
for i in range(4):
self.conv_reduces.append(MapReduce(self.fuseplanes[i]))
self.classifier = nn.Conv2d(4, 1, kernel_size=1) # has bias
nn.init.constant_(self.classifier.weight, 0.25)
nn.init.constant_(self.classifier.bias, 0)
def get_weights(self):
conv_weights = []
bn_weights = []
relu_weights = []
for pname, p in self.named_parameters():
if 'bn' in pname:
bn_weights.append(p)
elif 'relu' in pname:
relu_weights.append(p)
else:
conv_weights.append(p)
return conv_weights, bn_weights, relu_weights
def forward(self, x):
H, W = x.size()[2:]
x = self.init_block(x)
x1 = self.block1_1(x)
x1 = self.block1_2(x1)
x1 = self.block1_3(x1)
x2 = self.block2_1(x1)
x2 = self.block2_2(x2)
x2 = self.block2_3(x2)
x2 = self.block2_4(x2)
x3 = self.block3_1(x2)
x3 = self.block3_2(x3)
x3 = self.block3_3(x3)
x3 = self.block3_4(x3)
x4 = self.block4_1(x3)
x4 = self.block4_2(x4)
x4 = self.block4_3(x4)
x4 = self.block4_4(x4)
x_fuses = []
if self.sa and self.dil is not None:
for i, xi in enumerate([x1, x2, x3, x4]):
x_fuses.append(self.attentions[i](self.dilations[i](xi)))
elif self.sa:
for i, xi in enumerate([x1, x2, x3, x4]):
x_fuses.append(self.attentions[i](xi))
elif self.dil is not None:
for i, xi in enumerate([x1, x2, x3, x4]):
x_fuses.append(self.dilations[i](xi))
else:
x_fuses = [x1, x2, x3, x4]
e1 = self.conv_reduces[0](x_fuses[0])
e1 = F.interpolate(e1, (H, W), mode="bilinear", align_corners=False)
e2 = self.conv_reduces[1](x_fuses[1])
e2 = F.interpolate(e2, (H, W), mode="bilinear", align_corners=False)
e3 = self.conv_reduces[2](x_fuses[2])
e3 = F.interpolate(e3, (H, W), mode="bilinear", align_corners=False)
e4 = self.conv_reduces[3](x_fuses[3])
e4 = F.interpolate(e4, (H, W), mode="bilinear", align_corners=False)
outputs = [e1, e2, e3, e4]
output = self.classifier(torch.cat(outputs, dim=1))
outputs.append(output)
outputs = [torch.sigmoid(r) for r in outputs]
return outputs[-1]
def pidinet_bsd_tiny(pretrained=False, vanilla_cnn=True):
pdcs = config_model_converted('carv4') if vanilla_cnn else config_model('carv4')
model = PiDiNet(20, pdcs, dil=8, sa=True, convert=vanilla_cnn)
if pretrained:
state = torch.load(
DOWNLOAD_TO_CACHE(f'models/pidinet/table5_pidinet-tiny.pth'),
map_location='cpu')['state_dict']
if vanilla_cnn:
state = convert_pidinet(state, 'carv4')
state = {k[len('module.'):] if k.startswith('module.') else k: v for k, v in state.items()}
model.load_state_dict(state)
return model
def pidinet_bsd_small(pretrained=False, vanilla_cnn=True):
pdcs = config_model_converted('carv4') if vanilla_cnn else config_model('carv4')
model = PiDiNet(30, pdcs, dil=12, sa=True, convert=vanilla_cnn)
if pretrained:
state = torch.load(
DOWNLOAD_TO_CACHE(f'models/pidinet/table5_pidinet-small.pth'),
map_location='cpu')['state_dict']
if vanilla_cnn:
state = convert_pidinet(state, 'carv4')
state = {k[len('module.'):] if k.startswith('module.') else k: v for k, v in state.items()}
model.load_state_dict(state)
return model
def pidinet_bsd(pretrained=False, vanilla_cnn=True):
pdcs = config_model_converted('carv4') if vanilla_cnn else config_model('carv4')
model = PiDiNet(60, pdcs, dil=24, sa=True, convert=vanilla_cnn)
if pretrained:
# state = torch.load(
# DOWNLOAD_TO_CACHE(f'models/pidinet/table5_pidinet.pth'),
# map_location='cpu')['state_dict']
state = torch.load(
DOWNLOAD_TO_CACHE(f'VideoComposer/Hangjie/models/pidinet/table5_pidinet.pth'),
map_location='cpu')['state_dict']
if vanilla_cnn:
state = convert_pidinet(state, 'carv4')
state = {k[len('module.'):] if k.startswith('module.') else k: v for k, v in state.items()}
model.load_state_dict(state)
return model
def pidinet_nyud(pretrained=False, vanilla_cnn=True):
pdcs = config_model_converted('carv4') if vanilla_cnn else config_model('carv4')
model = PiDiNet(60, pdcs, dil=24, sa=True, convert=vanilla_cnn)
if pretrained:
state = torch.load(
DOWNLOAD_TO_CACHE(f'models/pidinet/table6_pidinet.pth'),
map_location='cpu')['state_dict']
if vanilla_cnn:
state = convert_pidinet(state, 'carv4')
state = {k[len('module.'):] if k.startswith('module.') else k: v for k, v in state.items()}
model.load_state_dict(state)
return model
def pidinet_multicue(pretrained=False, vanilla_cnn=True):
pdcs = config_model_converted('carv4') if vanilla_cnn else config_model('carv4')
model = PiDiNet(60, pdcs, dil=24, sa=True, convert=vanilla_cnn)
if pretrained:
state = torch.load(
DOWNLOAD_TO_CACHE(f'models/pidinet/table7_pidinet.pth'),
map_location='cpu')['state_dict']
if vanilla_cnn:
state = convert_pidinet(state, 'carv4')
state = {k[len('module.'):] if k.startswith('module.') else k: v for k, v in state.items()}
model.load_state_dict(state)
return model
r"""PyTorch re-implementation adapted from the Lua code in ``https://github.com/bobbens/sketch_simplification''.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# from canvas import DOWNLOAD_TO_CACHE
from artist import DOWNLOAD_TO_CACHE
__all__ = ['SketchSimplification', 'sketch_simplification_gan', 'sketch_simplification_mse',
'sketch_to_pencil_v1', 'sketch_to_pencil_v2']
class SketchSimplification(nn.Module):
r"""NOTE:
1. Input image should has only one gray channel.
2. Input image size should be divisible by 8.
3. Sketch in the input/output image is in dark color while background in light color.
"""
def __init__(self, mean, std):
assert isinstance(mean, float) and isinstance(std, float)
super(SketchSimplification, self).__init__()
self.mean = mean
self.std = std
# layers
self.layers = nn.Sequential(
nn.Conv2d(1, 48, 5, 2, 2),
nn.ReLU(inplace=True),
nn.Conv2d(48, 128, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, 3, 2, 1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, 3, 2, 1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 512, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 1024, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(1024, 1024, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(1024, 1024, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(1024, 1024, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(1024, 512, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 256, 3, 1, 1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(256, 256, 4, 2, 1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 128, 3, 1, 1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(128, 128, 4, 2, 1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 48, 3, 1, 1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(48, 48, 4, 2, 1),
nn.ReLU(inplace=True),
nn.Conv2d(48, 24, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(24, 1, 3, 1, 1),
nn.Sigmoid())
def forward(self, x):
r"""x: [B, 1, H, W] within range [0, 1]. Sketch pixels in dark color.
"""
x = (x - self.mean) / self.std
return self.layers(x)
def sketch_simplification_gan(pretrained=False):
model = SketchSimplification(mean=0.9664114577640158, std=0.0858381272736797)
if pretrained:
# model.load_state_dict(torch.load(
# DOWNLOAD_TO_CACHE('models/sketch_simplification/sketch_simplification_gan.pth'),
# map_location='cpu'))
model.load_state_dict(torch.load(
DOWNLOAD_TO_CACHE('VideoComposer/Hangjie/models/sketch_simplification/sketch_simplification_gan.pth'),
map_location='cpu'))
return model
def sketch_simplification_mse(pretrained=False):
model = SketchSimplification(mean=0.9664423107454593, std=0.08583666033640507)
if pretrained:
model.load_state_dict(torch.load(
DOWNLOAD_TO_CACHE('models/sketch_simplification/sketch_simplification_mse.pth'),
map_location='cpu'))
return model
def sketch_to_pencil_v1(pretrained=False):
model = SketchSimplification(mean=0.9817833515894078, std=0.0925009022585048)
if pretrained:
model.load_state_dict(torch.load(
DOWNLOAD_TO_CACHE('models/sketch_simplification/sketch_to_pencil_v1.pth'),
map_location='cpu'))
return model
def sketch_to_pencil_v2(pretrained=False):
model = SketchSimplification(mean=0.9851298627337799, std=0.07418377454883571)
if pretrained:
model.load_state_dict(torch.load(
DOWNLOAD_TO_CACHE('models/sketch_simplification/sketch_to_pencil_v2.pth'),
map_location='cpu'))
return model
import numpy as np
import cv2
import os
annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
def HWC3(x):
assert x.dtype == np.uint8
if x.ndim == 2:
x = x[:, :, None]
assert x.ndim == 3
H, W, C = x.shape
assert C == 1 or C == 3 or C == 4
if C == 3:
return x
if C == 1:
return np.concatenate([x, x, x], axis=2)
if C == 4:
color = x[:, :, 0:3].astype(np.float32)
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
y = color * alpha + 255.0 * (1.0 - alpha)
y = y.clip(0, 255).astype(np.uint8)
return y
def resize_image(input_image, resolution):
H, W, C = input_image.shape
H = float(H)
W = float(W)
k = float(resolution) / min(H, W)
H *= k
W *= k
H = int(np.round(H / 64.0)) * 64
W = int(np.round(W / 64.0)) * 64
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
return img
\ No newline at end of file
from .pretrain_functions import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment