Commit 1ad55bb4 authored by mashun1's avatar mashun1
Browse files

i2vgen-xl

parents
Pipeline #819 canceled with stages
import os
import sys
import json
import torch
import imageio
import numpy as np
import os.path as osp
sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2]))
from thop import profile
from ptflops import get_model_complexity_info
import artist.data as data
from tools.modules.config import cfg
from tools.modules.unet.util import *
from utils.config import Config as pConfig
from utils.registry_class import ENGINE, MODEL
def save_temporal_key():
cfg_update = pConfig(load=True)
for k, v in cfg_update.cfg_dict.items():
if isinstance(v, dict) and k in cfg:
cfg[k].update(v)
else:
cfg[k] = v
model = MODEL.build(cfg.UNet)
temp_name = ''
temp_key_list = []
spth = 'workspace/module_list/UNetSD_I2V_vs_Text_temporal_key_list.json'
for name, module in model.named_modules():
if isinstance(module, (TemporalTransformer, TemporalTransformer_attemask, TemporalAttentionBlock, TemporalAttentionMultiBlock, TemporalConvBlock_v2, TemporalConvBlock)):
temp_name = name
print(f'Model: {name}')
elif isinstance(module, (ResidualBlock, ResBlock, SpatialTransformer, Upsample, Downsample)):
temp_name = ''
if hasattr(module, 'weight'):
if temp_name != '' and (temp_name in name):
temp_key_list.append(name)
print(f'{name}')
# print(name)
save_module_list = []
for k, p in model.named_parameters():
for item in temp_key_list:
if item in k:
print(f'{item} --> {k}')
save_module_list.append(k)
print(int(sum(p.numel() for k, p in model.named_parameters()) / (1024 ** 2)), 'M parameters')
# spth = 'workspace/module_list/{}'
json.dump(save_module_list, open(spth, 'w'))
a = 0
def save_spatial_key():
cfg_update = pConfig(load=True)
for k, v in cfg_update.cfg_dict.items():
if isinstance(v, dict) and k in cfg:
cfg[k].update(v)
else:
cfg[k] = v
model = MODEL.build(cfg.UNet)
temp_name = ''
temp_key_list = []
spth = 'workspace/module_list/UNetSD_I2V_HQ_P_spatial_key_list.json'
for name, module in model.named_modules():
if isinstance(module, (ResidualBlock, ResBlock, SpatialTransformer, Upsample, Downsample)):
temp_name = name
print(f'Model: {name}')
elif isinstance(module, (TemporalTransformer, TemporalTransformer_attemask, TemporalAttentionBlock, TemporalAttentionMultiBlock, TemporalConvBlock_v2, TemporalConvBlock)):
temp_name = ''
if hasattr(module, 'weight'):
if temp_name != '' and (temp_name in name):
temp_key_list.append(name)
print(f'{name}')
# print(name)
save_module_list = []
for k, p in model.named_parameters():
for item in temp_key_list:
if item in k:
print(f'{item} --> {k}')
save_module_list.append(k)
print(int(sum(p.numel() for k, p in model.named_parameters()) / (1024 ** 2)), 'M parameters')
# spth = 'workspace/module_list/{}'
json.dump(save_module_list, open(spth, 'w'))
a = 0
if __name__ == '__main__':
# save_temporal_key()
save_spatial_key()
# print([k for (k, _) in self.input_blocks.named_parameters()])
import os
import sys
import torch
import imageio
import numpy as np
import os.path as osp
sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2]))
from PIL import Image, ImageDraw, ImageFont
from einops import rearrange
from tools import *
import utils.transforms as data
from utils.seed import setup_seed
from tools.modules.config import cfg
from utils.config import Config as pConfig
from utils.registry_class import ENGINE, DATASETS, AUTO_ENCODER
def test_enc_dec(gpu=0):
setup_seed(0)
cfg_update = pConfig(load=True)
for k, v in cfg_update.cfg_dict.items():
if isinstance(v, dict) and k in cfg:
cfg[k].update(v)
else:
cfg[k] = v
save_dir = os.path.join('workspace/test_data/autoencoder', cfg.auto_encoder['type'])
os.system('rm -rf %s' % (save_dir))
os.makedirs(save_dir, exist_ok=True)
train_trans = data.Compose([
data.CenterCropWide(size=cfg.resolution),
data.ToTensor(),
data.Normalize(mean=cfg.mean, std=cfg.std)])
vit_trans = data.Compose([
data.CenterCropWide(size=(cfg.resolution[0], cfg.resolution[0])) if cfg.resolution[0]>cfg.vit_resolution[0] else data.CenterCropWide(size=cfg.vit_resolution),
data.Resize(cfg.vit_resolution),
data.ToTensor(),
data.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)])
video_mean = torch.tensor(cfg.mean).view(1, -1, 1, 1) #n c f h w
video_std = torch.tensor(cfg.std).view(1, -1, 1, 1) #n c f h w
txt_size = cfg.resolution[1]
nc = int(38 * (txt_size / 256))
font = ImageFont.truetype('data/font/DejaVuSans.ttf', size=13)
dataset = DATASETS.build(cfg.vid_dataset, sample_fps=4, transforms=train_trans, vit_transforms=vit_trans)
print('There are %d videos' % (len(dataset)))
autoencoder = AUTO_ENCODER.build(cfg.auto_encoder)
autoencoder.eval() # freeze
for param in autoencoder.parameters():
param.requires_grad = False
autoencoder.to(gpu)
for idx, item in enumerate(dataset):
local_path = os.path.join(save_dir, '%04d.mp4' % idx)
# ref_frame, video_data, caption = item
ref_frame, vit_frame, video_data = item[:3]
video_data = video_data.to(gpu)
image_list = []
video_data_list = torch.chunk(video_data, video_data.shape[0]//cfg.chunk_size,dim=0)
with torch.no_grad():
decode_data = []
for chunk_data in video_data_list:
latent_z = autoencoder.encode_firsr_stage(chunk_data).detach()
# latent_z = get_first_stage_encoding(encoder_posterior).detach()
kwargs = {"timesteps": chunk_data.shape[0]}
recons_data = autoencoder.decode(latent_z, **kwargs)
vis_data = torch.cat([chunk_data, recons_data], dim=2).cpu()
vis_data = vis_data.mul_(video_std).add_(video_mean) # 8x3x16x256x384
vis_data = vis_data.cpu()
vis_data.clamp_(0, 1)
vis_data = vis_data.permute(0, 2, 3, 1)
vis_data = [(image.numpy() * 255).astype('uint8') for image in vis_data]
image_list.extend(vis_data)
num_image = len(image_list)
frame_dir = os.path.join(save_dir, 'temp')
os.makedirs(frame_dir, exist_ok=True)
for idx in range(num_image):
tpth = os.path.join(frame_dir, '%04d.png' % (idx+1))
cv2.imwrite(tpth, image_list[idx][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100])
cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate 8 -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}'
os.system(cmd); os.system(f'rm -rf {frame_dir}')
if __name__ == '__main__':
test_enc_dec()
import os
import sys
import imageio
import numpy as np
import os.path as osp
sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2]))
from PIL import Image, ImageDraw, ImageFont
import torchvision.transforms as T
import utils.transforms as data
from tools.modules.config import cfg
from utils.config import Config as pConfig
from utils.registry_class import ENGINE, DATASETS
from tools import *
def test_video_dataset():
cfg_update = pConfig(load=True)
for k, v in cfg_update.cfg_dict.items():
if isinstance(v, dict) and k in cfg:
cfg[k].update(v)
else:
cfg[k] = v
exp_name = os.path.basename(cfg.cfg_file).split('.')[0]
save_dir = os.path.join('workspace', 'test_data/datasets', cfg.vid_dataset['type'], exp_name)
os.system('rm -rf %s' % (save_dir))
os.makedirs(save_dir, exist_ok=True)
train_trans = data.Compose([
data.CenterCropWide(size=cfg.resolution),
data.ToTensor(),
data.Normalize(mean=cfg.mean, std=cfg.std)])
vit_trans = T.Compose([
data.CenterCropWide(cfg.vit_resolution),
T.ToTensor(),
T.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)])
video_mean = torch.tensor(cfg.mean).view(1, -1, 1, 1) #n c f h w
video_std = torch.tensor(cfg.std).view(1, -1, 1, 1) #n c f h w
img_mean = torch.tensor(cfg.mean).view(-1, 1, 1) # c f h w
img_std = torch.tensor(cfg.std).view(-1, 1, 1) # c f h w
vit_mean = torch.tensor(cfg.vit_mean).view(-1, 1, 1) # c f h w
vit_std = torch.tensor(cfg.vit_std).view(-1, 1, 1) # c f h w
txt_size = cfg.resolution[1]
nc = int(38 * (txt_size / 256))
font = ImageFont.truetype('data/font/DejaVuSans.ttf', size=13)
dataset = DATASETS.build(cfg.vid_dataset, sample_fps=cfg.sample_fps[0], transforms=train_trans, vit_transforms=vit_trans)
print('There are %d videos' % (len(dataset)))
for idx, item in enumerate(dataset):
ref_frame, vit_frame, video_data, caption, video_key = item
video_data = video_data.mul_(video_std).add_(video_mean)
video_data.clamp_(0, 1)
video_data = video_data.permute(0, 2, 3, 1)
video_data = [(image.numpy() * 255).astype('uint8') for image in video_data]
# Single Image
ref_frame = ref_frame.mul_(img_mean).add_(img_std)
ref_frame.clamp_(0, 1)
ref_frame = ref_frame.permute(1, 2, 0)
ref_frame = (ref_frame.numpy() * 255).astype('uint8')
# Text image
txt_img = Image.new("RGB", (txt_size, txt_size), color="white")
draw = ImageDraw.Draw(txt_img)
lines = "\n".join(caption[start:start + nc] for start in range(0, len(caption), nc))
draw.text((0, 0), lines, fill="black", font=font)
txt_img = np.array(txt_img)
video_data = [np.concatenate([ref_frame, u, txt_img], axis=1) for u in video_data]
spath = os.path.join(save_dir, '%04d.gif' % (idx))
imageio.mimwrite(spath, video_data, fps =8)
# if idx > 100: break
def test_vit_image(test_video_flag=True):
cfg_update = pConfig(load=True)
for k, v in cfg_update.cfg_dict.items():
if isinstance(v, dict) and k in cfg:
cfg[k].update(v)
else:
cfg[k] = v
exp_name = os.path.basename(cfg.cfg_file).split('.')[0]
save_dir = os.path.join('workspace', 'test_data/datasets', cfg.img_dataset['type'], exp_name)
os.system('rm -rf %s' % (save_dir))
os.makedirs(save_dir, exist_ok=True)
train_trans = data.Compose([
data.CenterCropWide(size=cfg.resolution),
data.ToTensor(),
data.Normalize(mean=cfg.mean, std=cfg.std)])
vit_trans = data.Compose([
data.CenterCropWide(cfg.resolution),
data.Resize(cfg.vit_resolution),
data.ToTensor(),
data.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)])
img_mean = torch.tensor(cfg.mean).view(-1, 1, 1) # c f h w
img_std = torch.tensor(cfg.std).view(-1, 1, 1) # c f h w
vit_mean = torch.tensor(cfg.vit_mean).view(-1, 1, 1) # c f h w
vit_std = torch.tensor(cfg.vit_std).view(-1, 1, 1) # c f h w
txt_size = cfg.resolution[1]
nc = int(38 * (txt_size / 256))
font = ImageFont.truetype('artist/font/DejaVuSans.ttf', size=13)
dataset = DATASETS.build(cfg.img_dataset, transforms=train_trans, vit_transforms=vit_trans)
print('There are %d videos' % (len(dataset)))
for idx, item in enumerate(dataset):
ref_frame, vit_frame, video_data, caption, video_key = item
video_data = video_data.mul_(img_std).add_(img_mean)
video_data.clamp_(0, 1)
video_data = video_data.permute(0, 2, 3, 1)
video_data = [(image.numpy() * 255).astype('uint8') for image in video_data]
# Single Image
vit_frame = vit_frame.mul_(vit_std).add_(vit_mean)
vit_frame.clamp_(0, 1)
vit_frame = vit_frame.permute(1, 2, 0)
vit_frame = (vit_frame.numpy() * 255).astype('uint8')
zero_frame = np.zeros((cfg.resolution[1], cfg.resolution[1], 3), dtype=np.uint8)
zero_frame[:vit_frame.shape[0], :vit_frame.shape[1], :] = vit_frame
# Text image
txt_img = Image.new("RGB", (txt_size, txt_size), color="white")
draw = ImageDraw.Draw(txt_img)
lines = "\n".join(caption[start:start + nc] for start in range(0, len(caption), nc))
draw.text((0, 0), lines, fill="black", font=font)
txt_img = np.array(txt_img)
video_data = [np.concatenate([zero_frame, u, txt_img], axis=1) for u in video_data]
spath = os.path.join(save_dir, '%04d.gif' % (idx))
imageio.mimwrite(spath, video_data, fps =8)
# if idx > 100: break
if __name__ == '__main__':
# test_video_dataset()
test_vit_image()
import os
import sys
import torch
import imageio
import numpy as np
import os.path as osp
sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2]))
from thop import profile
from ptflops import get_model_complexity_info
import artist.data as data
from tools.modules.config import cfg
from utils.config import Config as pConfig
from utils.registry_class import ENGINE, MODEL
def test_model():
cfg_update = pConfig(load=True)
for k, v in cfg_update.cfg_dict.items():
if isinstance(v, dict) and k in cfg:
cfg[k].update(v)
else:
cfg[k] = v
model = MODEL.build(cfg.UNet)
print(int(sum(p.numel() for k, p in model.named_parameters()) / (1024 ** 2)), 'M parameters')
# state_dict = torch.load('cache/pretrain_model/jiuniu_0600000.pth', map_location='cpu')
# model.load_state_dict(state_dict, strict=False)
model = model.cuda()
x = torch.Tensor(1, 4, 16, 32, 56).cuda()
t = torch.Tensor(1).cuda()
sims = torch.Tensor(1, 32).cuda()
fps = torch.Tensor([8]).cuda()
y = torch.Tensor(1, 1, 1024).cuda()
image = torch.Tensor(1, 3, 256, 448).cuda()
ret = model(x=x, t=t, y=y, ori_img=image, sims=sims, fps=fps)
print('Out shape if {}'.format(ret.shape))
# flops, params = profile(model=model, inputs=(x, t, y, image, sims, fps))
# print('Model: {:.2f} GFLOPs and {:.2f}M parameters'.format(flops/1e9, params/1e6))
def prepare_input(resolution):
return dict(x=[x, t, y, image, sims, fps])
flops, params = get_model_complexity_info(model, (1, 4, 16, 32, 56),
input_constructor = prepare_input,
as_strings=True, print_per_layer_stat=True)
print(' - Flops: ' + flops)
print(' - Params: ' + params)
if __name__ == '__main__':
test_model()
import numpy as np
import cv2
cap = cv2.VideoCapture('workspace/img_dir/tst.mp4')
fourcc = cv2.VideoWriter_fourcc(*'H264')
ret, frame = cap.read()
vid_size = frame.shape[:2][::-1]
out = cv2.VideoWriter('workspace/img_dir/testwrite.mp4',fourcc, 8, vid_size)
out.write(frame)
while(cap.isOpened()):
ret, frame = cap.read()
if not ret: break
out.write(frame)
cap.release()
out.release()
from .annotator import *
from .datasets import *
from .modules import *
from .train import *
from .hooks import *
from .inferences import *
# from .prior import *
from .basic_funcs import *
import cv2
import torch
import numpy as np
from tools.annotator.util import HWC3
# import gradio as gr
class CannyDetector:
def __call__(self, img, low_threshold = None, high_threshold = None, random_threshold = True):
### GPT-4 suggestions
# In the cv2.Canny() function, the low threshold and high threshold are used to determine the edges based on the gradient values in the image.
# There isn't a one-size-fits-all solution for these threshold values, as the optimal values depend on the specific image and the application.
# However, there are some general guidelines and empirical values you can use as a starting point:
# 1. Ratio: A common recommendation is to use a ratio of 1:2 or 1:3 between the low threshold and the high threshold.
# This means if your low threshold is 50, the high threshold should be around 100 or 150.
# 2. Empirical values: As a starting point, you can use low threshold values in the range of 50-100 and high threshold values in the range of 100-200.
# You may need to fine-tune these values based on the specific image and desired edge detection results.
# 3. Automatic threshold calculation: To automatically calculate the threshold values, you can use the median or mean value of the image's pixel intensities as the low threshold,
# and the high threshold can be set as twice or three times the low threshold.
### Convert to numpy
if isinstance(img, torch.Tensor): # (h, w, c)
img = img.cpu().numpy()
img_np = cv2.convertScaleAbs((img * 255.))
elif isinstance(img, np.ndarray): # (h, w, c)
img_np = img # we assume values are in the range from 0 to 255.
else:
assert False
### Select the threshold
if (low_threshold is None) and (high_threshold is None):
median_intensity = np.median(img_np)
if random_threshold is False:
low_threshold = int(max(0, (1 - 0.33) * median_intensity))
high_threshold = int(min(255, (1 + 0.33) * median_intensity))
else:
random_canny = np.random.uniform(0.1, 0.4)
# Might try other values
low_threshold = int(max(0, (1 - random_canny) * median_intensity))
high_threshold = 2 * low_threshold
### Detect canny edge
canny_edge = cv2.Canny(img_np, low_threshold, high_threshold)
### Convert to 3 channels
# canny_edge = HWC3(canny_edge)
canny_condition = torch.from_numpy(canny_edge.copy()).unsqueeze(dim = -1).float().cuda() / 255.0
# canny_condition = torch.stack([canny_condition for _ in range(num_samples)], dim=0)
# canny_condition = einops.rearrange(canny_condition, 'h w c -> b c h w').clone()
# return cv2.Canny(img, low_threshold, high_threshold)
return canny_condition
\ No newline at end of file
from .palette import *
\ No newline at end of file
r"""Modified from ``https://github.com/sergeyk/rayleigh''.
"""
import os
import os.path as osp
import numpy as np
from skimage.color import hsv2rgb, rgb2lab, lab2rgb
from skimage.io import imsave
from sklearn.metrics import euclidean_distances
__all__ = ['Palette']
def rgb2hex(rgb):
return '#%02x%02x%02x' % tuple([int(round(255.0 * u)) for u in rgb])
def hex2rgb(hex):
rgb = hex.strip('#')
fn = lambda u: round(int(u, 16) / 255.0, 5)
return fn(rgb[:2]), fn(rgb[2:4]), fn(rgb[4:6])
class Palette(object):
r"""Create a color palette (codebook) in the form of a 2D grid of colors.
Further, the rightmost column has num_hues gradations from black to white.
Parameters:
num_hues: number of colors with full lightness and saturation, in the middle.
num_sat: number of rows above middle row that show the same hues with decreasing saturation.
"""
def __init__(self, num_hues=11, num_sat=5, num_light=4):
n = num_sat + 2 * num_light
# hues
if num_hues == 8:
hues = np.tile(np.array([0., 0.10, 0.15, 0.28, 0.51, 0.58, 0.77, 0.85]), (n, 1))
elif num_hues == 9:
hues = np.tile(np.array([0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.7, 0.87]), (n, 1))
elif num_hues == 10:
hues = np.tile(np.array([0., 0.10, 0.15, 0.28, 0.49, 0.54, 0.60, 0.66, 0.76, 0.87]), (n, 1))
elif num_hues == 11:
hues = np.tile(np.array([0.0, 0.0833, 0.166, 0.25, 0.333, 0.5, 0.56333, 0.666, 0.73, 0.803, 0.916]), (n, 1))
else:
hues = np.tile(np.linspace(0, 1, num_hues + 1)[:-1], (n, 1))
# saturations
sats = np.hstack((
np.linspace(0, 1, num_sat + 2)[1:-1],
1,
[1] * num_light,
[0.4] * (num_light - 1)))
sats = np.tile(np.atleast_2d(sats).T, (1, num_hues))
# lights
lights = np.hstack((
[1] * num_sat,
1,
np.linspace(1, 0.2, num_light + 2)[1:-1],
np.linspace(1, 0.2, num_light + 2)[1:-2]))
lights = np.tile(np.atleast_2d(lights).T, (1, num_hues))
# colors
rgb = hsv2rgb(np.dstack([hues, sats, lights]))
gray = np.tile(np.linspace(1, 0, n)[:, np.newaxis, np.newaxis], (1, 1, 3))
self.thumbnail = np.hstack([rgb, gray])
# flatten
rgb = rgb.T.reshape(3, -1).T
gray = gray.T.reshape(3, -1).T
self.rgb = np.vstack((rgb, gray))
self.lab = rgb2lab(self.rgb[np.newaxis, :, :]).squeeze()
self.hex = [rgb2hex(u) for u in self.rgb]
self.lab_dists = euclidean_distances(self.lab, squared=True)
def histogram(self, rgb_img, sigma=20):
# compute histogram
lab = rgb2lab(rgb_img).reshape((-1, 3))
min_ind = np.argmin(euclidean_distances(lab, self.lab, squared=True), axis=1)
hist = 1.0 * np.bincount(min_ind, minlength=self.lab.shape[0]) / lab.shape[0]
# smooth histogram
if sigma > 0:
weight = np.exp(-self.lab_dists / (2.0 * sigma ** 2))
weight = weight / weight.sum(1)[:, np.newaxis]
hist = (weight * hist).sum(1)
hist[hist < 1e-5] = 0
return hist
def get_palette_image(self, hist, percentile=90, width=200, height=50):
# curate histogram
ind = np.argsort(-hist)
ind = ind[hist[ind] > np.percentile(hist, percentile)]
hist = hist[ind] / hist[ind].sum()
# draw palette
nums = np.array(hist * width, dtype=int)
array = np.vstack([np.tile(np.array(u), (v, 1)) for u, v in zip(self.rgb[ind], nums)])
array = np.tile(array[np.newaxis, :, :], (height, 1, 1))
if array.shape[1] < width:
array = np.concatenate([array, np.zeros((height, width - array.shape[1], 3))], axis=1)
return array
def quantize_image(self, rgb_img):
lab = rgb2lab(rgb_img).reshape((-1, 3))
min_ind = np.argmin(euclidean_distances(lab, self.lab, squared=True), axis=1)
quantized_lab = self.lab[min_ind]
img = lab2rgb(quantized_lab.reshape(rgb_img.shape))
return img
def export(self, dirname):
if not osp.exists(dirname):
os.makedirs(dirname)
# save thumbnail
imsave(osp.join(dirname, 'palette.png'), self.thumbnail)
# save html
with open(osp.join(dirname, 'palette.html'), 'w') as f:
html = '''
<style>
span {
width: 20px;
height: 20px;
margin: 2px;
padding: 0px;
display: inline-block;
}
</style>
'''
for row in self.thumbnail:
for col in row:
html += '<a id="{0}"><span style="background-color: {0}" /></a>\n'.format(rgb2hex(col))
html += '<br />\n'
f.write(html)
from .pidinet import *
from .sketch_simplification import *
\ No newline at end of file
This diff is collapsed.
r"""PyTorch re-implementation adapted from the Lua code in ``https://github.com/bobbens/sketch_simplification''.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# from canvas import DOWNLOAD_TO_CACHE
from artist import DOWNLOAD_TO_CACHE
__all__ = ['SketchSimplification', 'sketch_simplification_gan', 'sketch_simplification_mse',
'sketch_to_pencil_v1', 'sketch_to_pencil_v2']
class SketchSimplification(nn.Module):
r"""NOTE:
1. Input image should has only one gray channel.
2. Input image size should be divisible by 8.
3. Sketch in the input/output image is in dark color while background in light color.
"""
def __init__(self, mean, std):
assert isinstance(mean, float) and isinstance(std, float)
super(SketchSimplification, self).__init__()
self.mean = mean
self.std = std
# layers
self.layers = nn.Sequential(
nn.Conv2d(1, 48, 5, 2, 2),
nn.ReLU(inplace=True),
nn.Conv2d(48, 128, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, 3, 2, 1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, 3, 2, 1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 512, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 1024, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(1024, 1024, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(1024, 1024, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(1024, 1024, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(1024, 512, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 256, 3, 1, 1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(256, 256, 4, 2, 1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 128, 3, 1, 1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(128, 128, 4, 2, 1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 48, 3, 1, 1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(48, 48, 4, 2, 1),
nn.ReLU(inplace=True),
nn.Conv2d(48, 24, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(24, 1, 3, 1, 1),
nn.Sigmoid())
def forward(self, x):
r"""x: [B, 1, H, W] within range [0, 1]. Sketch pixels in dark color.
"""
x = (x - self.mean) / self.std
return self.layers(x)
def sketch_simplification_gan(pretrained=False):
model = SketchSimplification(mean=0.9664114577640158, std=0.0858381272736797)
if pretrained:
# model.load_state_dict(torch.load(
# DOWNLOAD_TO_CACHE('models/sketch_simplification/sketch_simplification_gan.pth'),
# map_location='cpu'))
model.load_state_dict(torch.load(
DOWNLOAD_TO_CACHE('VideoComposer/Hangjie/models/sketch_simplification/sketch_simplification_gan.pth'),
map_location='cpu'))
return model
def sketch_simplification_mse(pretrained=False):
model = SketchSimplification(mean=0.9664423107454593, std=0.08583666033640507)
if pretrained:
model.load_state_dict(torch.load(
DOWNLOAD_TO_CACHE('models/sketch_simplification/sketch_simplification_mse.pth'),
map_location='cpu'))
return model
def sketch_to_pencil_v1(pretrained=False):
model = SketchSimplification(mean=0.9817833515894078, std=0.0925009022585048)
if pretrained:
model.load_state_dict(torch.load(
DOWNLOAD_TO_CACHE('models/sketch_simplification/sketch_to_pencil_v1.pth'),
map_location='cpu'))
return model
def sketch_to_pencil_v2(pretrained=False):
model = SketchSimplification(mean=0.9851298627337799, std=0.07418377454883571)
if pretrained:
model.load_state_dict(torch.load(
DOWNLOAD_TO_CACHE('models/sketch_simplification/sketch_to_pencil_v2.pth'),
map_location='cpu'))
return model
import numpy as np
import cv2
import os
annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
def HWC3(x):
assert x.dtype == np.uint8
if x.ndim == 2:
x = x[:, :, None]
assert x.ndim == 3
H, W, C = x.shape
assert C == 1 or C == 3 or C == 4
if C == 3:
return x
if C == 1:
return np.concatenate([x, x, x], axis=2)
if C == 4:
color = x[:, :, 0:3].astype(np.float32)
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
y = color * alpha + 255.0 * (1.0 - alpha)
y = y.clip(0, 255).astype(np.uint8)
return y
def resize_image(input_image, resolution):
H, W, C = input_image.shape
H = float(H)
W = float(W)
k = float(resolution) / min(H, W)
H *= k
W *= k
H = int(np.round(H / 64.0)) * 64
W = int(np.round(W / 64.0)) * 64
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
return img
\ No newline at end of file
from .pretrain_functions import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment