Commit 57463d8d authored by suily's avatar suily
Browse files

init

parents
Pipeline #1918 canceled with stages
from face3d.options.base_options import BaseOptions
class InferenceOptions(BaseOptions):
"""This class includes test options.
It also includes shared options defined in BaseOptions.
"""
def initialize(self, parser):
parser = BaseOptions.initialize(self, parser) # define shared options
parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]')
parser.add_argument('--input_dir', type=str, help='the folder of the input files')
parser.add_argument('--keypoint_dir', type=str, help='the folder of the keypoint files')
parser.add_argument('--output_dir', type=str, default='mp4', help='the output dir to save the extracted coefficients')
parser.add_argument('--save_split_files', action='store_true', help='save split files or not')
parser.add_argument('--inference_batch_size', type=int, default=8)
# Dropout and Batchnorm has different behavior during training and test.
self.isTrain = False
return parser
"""This script contains the test options for Deep3DFaceRecon_pytorch
"""
from .base_options import BaseOptions
class TestOptions(BaseOptions):
"""This class includes test options.
It also includes shared options defined in BaseOptions.
"""
def initialize(self, parser):
parser = BaseOptions.initialize(self, parser) # define shared options
parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]')
parser.add_argument('--img_folder', type=str, default='examples', help='folder for test images.')
# Dropout and Batchnorm has different behavior during training and test.
self.isTrain = False
return parser
"""This script contains the training options for Deep3DFaceRecon_pytorch
"""
from .base_options import BaseOptions
from util import util
class TrainOptions(BaseOptions):
"""This class includes training options.
It also includes shared options defined in BaseOptions.
"""
def initialize(self, parser):
parser = BaseOptions.initialize(self, parser)
# dataset parameters
# for train
parser.add_argument('--data_root', type=str, default='./', help='dataset root')
parser.add_argument('--flist', type=str, default='datalist/train/masks.txt', help='list of mask names of training set')
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--dataset_mode', type=str, default='flist', help='chooses how datasets are loaded. [None | flist]')
parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
parser.add_argument('--preprocess', type=str, default='shift_scale_rot_flip', help='scaling and cropping of images at load time [shift_scale_rot_flip | shift_scale | shift | shift_rot_flip ]')
parser.add_argument('--use_aug', type=util.str2bool, nargs='?', const=True, default=True, help='whether use data augmentation')
# for val
parser.add_argument('--flist_val', type=str, default='datalist/val/masks.txt', help='list of mask names of val set')
parser.add_argument('--batch_size_val', type=int, default=32)
# visualization parameters
parser.add_argument('--display_freq', type=int, default=1000, help='frequency of showing training results on screen')
parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
# network saving and loading parameters
parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs')
parser.add_argument('--evaluation_freq', type=int, default=5000, help='evaluation freq')
parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint')
# training parameters
parser.add_argument('--n_epochs', type=int, default=20, help='number of epochs with the initial learning rate')
parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam')
parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]')
parser.add_argument('--lr_decay_epochs', type=int, default=10, help='multiply by a gamma every lr_decay_epochs epoches')
self.isTrain = True
return parser
"""This package includes a miscellaneous collection of useful helper functions."""
from src.face3d.util import *
import os
import cv2
import numpy as np
from scipy.io import loadmat
import tensorflow as tf
from util.preprocess import align_for_lm
from shutil import move
mean_face = np.loadtxt('util/test_mean_face.txt')
mean_face = mean_face.reshape([68, 2])
def save_label(labels, save_path):
np.savetxt(save_path, labels)
def draw_landmarks(img, landmark, save_name):
landmark = landmark
lm_img = np.zeros([img.shape[0], img.shape[1], 3])
lm_img[:] = img.astype(np.float32)
landmark = np.round(landmark).astype(np.int32)
for i in range(len(landmark)):
for j in range(-1, 1):
for k in range(-1, 1):
if img.shape[0] - 1 - landmark[i, 1]+j > 0 and \
img.shape[0] - 1 - landmark[i, 1]+j < img.shape[0] and \
landmark[i, 0]+k > 0 and \
landmark[i, 0]+k < img.shape[1]:
lm_img[img.shape[0] - 1 - landmark[i, 1]+j, landmark[i, 0]+k,
:] = np.array([0, 0, 255])
lm_img = lm_img.astype(np.uint8)
cv2.imwrite(save_name, lm_img)
def load_data(img_name, txt_name):
return cv2.imread(img_name), np.loadtxt(txt_name)
# create tensorflow graph for landmark detector
def load_lm_graph(graph_filename):
with tf.gfile.GFile(graph_filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name='net')
img_224 = graph.get_tensor_by_name('net/input_imgs:0')
output_lm = graph.get_tensor_by_name('net/lm:0')
lm_sess = tf.Session(graph=graph)
return lm_sess,img_224,output_lm
# landmark detection
def detect_68p(img_path,sess,input_op,output_op):
print('detecting landmarks......')
names = [i for i in sorted(os.listdir(
img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i]
vis_path = os.path.join(img_path, 'vis')
remove_path = os.path.join(img_path, 'remove')
save_path = os.path.join(img_path, 'landmarks')
if not os.path.isdir(vis_path):
os.makedirs(vis_path)
if not os.path.isdir(remove_path):
os.makedirs(remove_path)
if not os.path.isdir(save_path):
os.makedirs(save_path)
for i in range(0, len(names)):
name = names[i]
print('%05d' % (i), ' ', name)
full_image_name = os.path.join(img_path, name)
txt_name = '.'.join(name.split('.')[:-1]) + '.txt'
full_txt_name = os.path.join(img_path, 'detections', txt_name) # 5 facial landmark path for each image
# if an image does not have detected 5 facial landmarks, remove it from the training list
if not os.path.isfile(full_txt_name):
move(full_image_name, os.path.join(remove_path, name))
continue
# load data
img, five_points = load_data(full_image_name, full_txt_name)
input_img, scale, bbox = align_for_lm(img, five_points) # align for 68 landmark detection
# if the alignment fails, remove corresponding image from the training list
if scale == 0:
move(full_txt_name, os.path.join(
remove_path, txt_name))
move(full_image_name, os.path.join(remove_path, name))
continue
# detect landmarks
input_img = np.reshape(
input_img, [1, 224, 224, 3]).astype(np.float32)
landmark = sess.run(
output_op, feed_dict={input_op: input_img})
# transform back to original image coordinate
landmark = landmark.reshape([68, 2]) + mean_face
landmark[:, 1] = 223 - landmark[:, 1]
landmark = landmark / scale
landmark[:, 0] = landmark[:, 0] + bbox[0]
landmark[:, 1] = landmark[:, 1] + bbox[1]
landmark[:, 1] = img.shape[0] - 1 - landmark[:, 1]
if i % 100 == 0:
draw_landmarks(img, landmark, os.path.join(vis_path, name))
save_label(landmark, os.path.join(save_path, txt_name))
"""This script is to generate training list files for Deep3DFaceRecon_pytorch
"""
import os
# save path to training data
def write_list(lms_list, imgs_list, msks_list, mode='train',save_folder='datalist', save_name=''):
save_path = os.path.join(save_folder, mode)
if not os.path.isdir(save_path):
os.makedirs(save_path)
with open(os.path.join(save_path, save_name + 'landmarks.txt'), 'w') as fd:
fd.writelines([i + '\n' for i in lms_list])
with open(os.path.join(save_path, save_name + 'images.txt'), 'w') as fd:
fd.writelines([i + '\n' for i in imgs_list])
with open(os.path.join(save_path, save_name + 'masks.txt'), 'w') as fd:
fd.writelines([i + '\n' for i in msks_list])
# check if the path is valid
def check_list(rlms_list, rimgs_list, rmsks_list):
lms_list, imgs_list, msks_list = [], [], []
for i in range(len(rlms_list)):
flag = 'false'
lm_path = rlms_list[i]
im_path = rimgs_list[i]
msk_path = rmsks_list[i]
if os.path.isfile(lm_path) and os.path.isfile(im_path) and os.path.isfile(msk_path):
flag = 'true'
lms_list.append(rlms_list[i])
imgs_list.append(rimgs_list[i])
msks_list.append(rmsks_list[i])
print(i, rlms_list[i], flag)
return lms_list, imgs_list, msks_list
import dominate
from dominate.tags import meta, h3, table, tr, td, p, a, img, br
import os
class HTML:
"""This HTML class allows us to save images and write texts into a single HTML file.
It consists of functions such as <add_header> (add a text header to the HTML file),
<add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).
It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
"""
def __init__(self, web_dir, title, refresh=0):
"""Initialize the HTML classes
Parameters:
web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/
title (str) -- the webpage name
refresh (int) -- how often the website refresh itself; if 0; no refreshing
"""
self.title = title
self.web_dir = web_dir
self.img_dir = os.path.join(self.web_dir, 'images')
if not os.path.exists(self.web_dir):
os.makedirs(self.web_dir)
if not os.path.exists(self.img_dir):
os.makedirs(self.img_dir)
self.doc = dominate.document(title=title)
if refresh > 0:
with self.doc.head:
meta(http_equiv="refresh", content=str(refresh))
def get_image_dir(self):
"""Return the directory that stores images"""
return self.img_dir
def add_header(self, text):
"""Insert a header to the HTML file
Parameters:
text (str) -- the header text
"""
with self.doc:
h3(text)
def add_images(self, ims, txts, links, width=400):
"""add images to the HTML file
Parameters:
ims (str list) -- a list of image paths
txts (str list) -- a list of image names shown on the website
links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
"""
self.t = table(border=1, style="table-layout: fixed;") # Insert a table
self.doc.add(self.t)
with self.t:
with tr():
for im, txt, link in zip(ims, txts, links):
with td(style="word-wrap: break-word;", halign="center", valign="top"):
with p():
with a(href=os.path.join('images', link)):
img(style="width:%dpx" % width, src=os.path.join('images', im))
br()
p(txt)
def save(self):
"""save the current content to the HMTL file"""
html_file = '%s/index.html' % self.web_dir
f = open(html_file, 'wt')
f.write(self.doc.render())
f.close()
if __name__ == '__main__': # we show an example usage here.
html = HTML('web/', 'test_html')
html.add_header('hello world')
ims, txts, links = [], [], []
for n in range(4):
ims.append('image_%d.png' % n)
txts.append('text_%d' % n)
links.append('image_%d.png' % n)
html.add_images(ims, txts, links)
html.save()
"""This script is to load 3D face model for Deep3DFaceRecon_pytorch
"""
import numpy as np
from PIL import Image
from scipy.io import loadmat, savemat
from array import array
import os.path as osp
# load expression basis
def LoadExpBasis(bfm_folder='BFM'):
n_vertex = 53215
Expbin = open(osp.join(bfm_folder, 'Exp_Pca.bin'), 'rb')
exp_dim = array('i')
exp_dim.fromfile(Expbin, 1)
expMU = array('f')
expPC = array('f')
expMU.fromfile(Expbin, 3*n_vertex)
expPC.fromfile(Expbin, 3*exp_dim[0]*n_vertex)
Expbin.close()
expPC = np.array(expPC)
expPC = np.reshape(expPC, [exp_dim[0], -1])
expPC = np.transpose(expPC)
expEV = np.loadtxt(osp.join(bfm_folder, 'std_exp.txt'))
return expPC, expEV
# transfer original BFM09 to our face model
def transferBFM09(bfm_folder='BFM'):
print('Transfer BFM09 to BFM_model_front......')
original_BFM = loadmat(osp.join(bfm_folder, '01_MorphableModel.mat'))
shapePC = original_BFM['shapePC'] # shape basis
shapeEV = original_BFM['shapeEV'] # corresponding eigen value
shapeMU = original_BFM['shapeMU'] # mean face
texPC = original_BFM['texPC'] # texture basis
texEV = original_BFM['texEV'] # eigen value
texMU = original_BFM['texMU'] # mean texture
expPC, expEV = LoadExpBasis(bfm_folder)
# transfer BFM09 to our face model
idBase = shapePC*np.reshape(shapeEV, [-1, 199])
idBase = idBase/1e5 # unify the scale to decimeter
idBase = idBase[:, :80] # use only first 80 basis
exBase = expPC*np.reshape(expEV, [-1, 79])
exBase = exBase/1e5 # unify the scale to decimeter
exBase = exBase[:, :64] # use only first 64 basis
texBase = texPC*np.reshape(texEV, [-1, 199])
texBase = texBase[:, :80] # use only first 80 basis
# our face model is cropped along face landmarks and contains only 35709 vertex.
# original BFM09 contains 53490 vertex, and expression basis provided by Guo et al. contains 53215 vertex.
# thus we select corresponding vertex to get our face model.
index_exp = loadmat(osp.join(bfm_folder, 'BFM_front_idx.mat'))
index_exp = index_exp['idx'].astype(np.int32) - 1 # starts from 0 (to 53215)
index_shape = loadmat(osp.join(bfm_folder, 'BFM_exp_idx.mat'))
index_shape = index_shape['trimIndex'].astype(
np.int32) - 1 # starts from 0 (to 53490)
index_shape = index_shape[index_exp]
idBase = np.reshape(idBase, [-1, 3, 80])
idBase = idBase[index_shape, :, :]
idBase = np.reshape(idBase, [-1, 80])
texBase = np.reshape(texBase, [-1, 3, 80])
texBase = texBase[index_shape, :, :]
texBase = np.reshape(texBase, [-1, 80])
exBase = np.reshape(exBase, [-1, 3, 64])
exBase = exBase[index_exp, :, :]
exBase = np.reshape(exBase, [-1, 64])
meanshape = np.reshape(shapeMU, [-1, 3])/1e5
meanshape = meanshape[index_shape, :]
meanshape = np.reshape(meanshape, [1, -1])
meantex = np.reshape(texMU, [-1, 3])
meantex = meantex[index_shape, :]
meantex = np.reshape(meantex, [1, -1])
# other info contains triangles, region used for computing photometric loss,
# region used for skin texture regularization, and 68 landmarks index etc.
other_info = loadmat(osp.join(bfm_folder, 'facemodel_info.mat'))
frontmask2_idx = other_info['frontmask2_idx']
skinmask = other_info['skinmask']
keypoints = other_info['keypoints']
point_buf = other_info['point_buf']
tri = other_info['tri']
tri_mask2 = other_info['tri_mask2']
# save our face model
savemat(osp.join(bfm_folder, 'BFM_model_front.mat'), {'meanshape': meanshape, 'meantex': meantex, 'idBase': idBase, 'exBase': exBase, 'texBase': texBase,
'tri': tri, 'point_buf': point_buf, 'tri_mask2': tri_mask2, 'keypoints': keypoints, 'frontmask2_idx': frontmask2_idx, 'skinmask': skinmask})
# load landmarks for standard face, which is used for image preprocessing
def load_lm3d(bfm_folder):
Lm3D = loadmat(osp.join(bfm_folder, 'similarity_Lm3D_all.mat'))
Lm3D = Lm3D['lm']
# calculate 5 facial landmarks using 68 landmarks
lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1
Lm3D = np.stack([Lm3D[lm_idx[0], :], np.mean(Lm3D[lm_idx[[1, 2]], :], 0), np.mean(
Lm3D[lm_idx[[3, 4]], :], 0), Lm3D[lm_idx[5], :], Lm3D[lm_idx[6], :]], axis=0)
Lm3D = Lm3D[[1, 2, 0, 3, 4], :]
return Lm3D
if __name__ == '__main__':
transferBFM09()
\ No newline at end of file
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def calculate_points(heatmaps):
# change heatmaps to landmarks
B, N, H, W = heatmaps.shape
HW = H * W
BN_range = np.arange(B * N)
heatline = heatmaps.reshape(B, N, HW)
indexes = np.argmax(heatline, axis=2)
preds = np.stack((indexes % W, indexes // W), axis=2)
preds = preds.astype(np.float, copy=False)
inr = indexes.ravel()
heatline = heatline.reshape(B * N, HW)
x_up = heatline[BN_range, inr + 1]
x_down = heatline[BN_range, inr - 1]
# y_up = heatline[BN_range, inr + W]
if any((inr + W) >= 4096):
y_up = heatline[BN_range, 4095]
else:
y_up = heatline[BN_range, inr + W]
if any((inr - W) <= 0):
y_down = heatline[BN_range, 0]
else:
y_down = heatline[BN_range, inr - W]
think_diff = np.sign(np.stack((x_up - x_down, y_up - y_down), axis=1))
think_diff *= .25
preds += think_diff.reshape(B, N, 2)
preds += .5
return preds
class AddCoordsTh(nn.Module):
def __init__(self, x_dim=64, y_dim=64, with_r=False, with_boundary=False):
super(AddCoordsTh, self).__init__()
self.x_dim = x_dim
self.y_dim = y_dim
self.with_r = with_r
self.with_boundary = with_boundary
def forward(self, input_tensor, heatmap=None):
"""
input_tensor: (batch, c, x_dim, y_dim)
"""
batch_size_tensor = input_tensor.shape[0]
xx_ones = torch.ones([1, self.y_dim], dtype=torch.int32, device=input_tensor.device)
xx_ones = xx_ones.unsqueeze(-1)
xx_range = torch.arange(self.x_dim, dtype=torch.int32, device=input_tensor.device).unsqueeze(0)
xx_range = xx_range.unsqueeze(1)
xx_channel = torch.matmul(xx_ones.float(), xx_range.float())
xx_channel = xx_channel.unsqueeze(-1)
yy_ones = torch.ones([1, self.x_dim], dtype=torch.int32, device=input_tensor.device)
yy_ones = yy_ones.unsqueeze(1)
yy_range = torch.arange(self.y_dim, dtype=torch.int32, device=input_tensor.device).unsqueeze(0)
yy_range = yy_range.unsqueeze(-1)
yy_channel = torch.matmul(yy_range.float(), yy_ones.float())
yy_channel = yy_channel.unsqueeze(-1)
xx_channel = xx_channel.permute(0, 3, 2, 1)
yy_channel = yy_channel.permute(0, 3, 2, 1)
xx_channel = xx_channel / (self.x_dim - 1)
yy_channel = yy_channel / (self.y_dim - 1)
xx_channel = xx_channel * 2 - 1
yy_channel = yy_channel * 2 - 1
xx_channel = xx_channel.repeat(batch_size_tensor, 1, 1, 1)
yy_channel = yy_channel.repeat(batch_size_tensor, 1, 1, 1)
if self.with_boundary and heatmap is not None:
boundary_channel = torch.clamp(heatmap[:, -1:, :, :], 0.0, 1.0)
zero_tensor = torch.zeros_like(xx_channel)
xx_boundary_channel = torch.where(boundary_channel > 0.05, xx_channel, zero_tensor)
yy_boundary_channel = torch.where(boundary_channel > 0.05, yy_channel, zero_tensor)
if self.with_boundary and heatmap is not None:
xx_boundary_channel = xx_boundary_channel.to(input_tensor.device)
yy_boundary_channel = yy_boundary_channel.to(input_tensor.device)
ret = torch.cat([input_tensor, xx_channel, yy_channel], dim=1)
if self.with_r:
rr = torch.sqrt(torch.pow(xx_channel, 2) + torch.pow(yy_channel, 2))
rr = rr / torch.max(rr)
ret = torch.cat([ret, rr], dim=1)
if self.with_boundary and heatmap is not None:
ret = torch.cat([ret, xx_boundary_channel, yy_boundary_channel], dim=1)
return ret
class CoordConvTh(nn.Module):
"""CoordConv layer as in the paper."""
def __init__(self, x_dim, y_dim, with_r, with_boundary, in_channels, first_one=False, *args, **kwargs):
super(CoordConvTh, self).__init__()
self.addcoords = AddCoordsTh(x_dim=x_dim, y_dim=y_dim, with_r=with_r, with_boundary=with_boundary)
in_channels += 2
if with_r:
in_channels += 1
if with_boundary and not first_one:
in_channels += 2
self.conv = nn.Conv2d(in_channels=in_channels, *args, **kwargs)
def forward(self, input_tensor, heatmap=None):
ret = self.addcoords(input_tensor, heatmap)
last_channel = ret[:, -2:, :, :]
ret = self.conv(ret)
return ret, last_channel
def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False, dilation=1):
'3x3 convolution with padding'
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=strd, padding=padding, bias=bias, dilation=dilation)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
# self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
# self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.relu(out)
out = self.conv2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ConvBlock(nn.Module):
def __init__(self, in_planes, out_planes):
super(ConvBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = conv3x3(in_planes, int(out_planes / 2))
self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4), padding=1, dilation=1)
self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4), padding=1, dilation=1)
if in_planes != out_planes:
self.downsample = nn.Sequential(
nn.BatchNorm2d(in_planes),
nn.ReLU(True),
nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False),
)
else:
self.downsample = None
def forward(self, x):
residual = x
out1 = self.bn1(x)
out1 = F.relu(out1, True)
out1 = self.conv1(out1)
out2 = self.bn2(out1)
out2 = F.relu(out2, True)
out2 = self.conv2(out2)
out3 = self.bn3(out2)
out3 = F.relu(out3, True)
out3 = self.conv3(out3)
out3 = torch.cat((out1, out2, out3), 1)
if self.downsample is not None:
residual = self.downsample(residual)
out3 += residual
return out3
class HourGlass(nn.Module):
def __init__(self, num_modules, depth, num_features, first_one=False):
super(HourGlass, self).__init__()
self.num_modules = num_modules
self.depth = depth
self.features = num_features
self.coordconv = CoordConvTh(
x_dim=64,
y_dim=64,
with_r=True,
with_boundary=True,
in_channels=256,
first_one=first_one,
out_channels=256,
kernel_size=1,
stride=1,
padding=0)
self._generate_network(self.depth)
def _generate_network(self, level):
self.add_module('b1_' + str(level), ConvBlock(256, 256))
self.add_module('b2_' + str(level), ConvBlock(256, 256))
if level > 1:
self._generate_network(level - 1)
else:
self.add_module('b2_plus_' + str(level), ConvBlock(256, 256))
self.add_module('b3_' + str(level), ConvBlock(256, 256))
def _forward(self, level, inp):
# Upper branch
up1 = inp
up1 = self._modules['b1_' + str(level)](up1)
# Lower branch
low1 = F.avg_pool2d(inp, 2, stride=2)
low1 = self._modules['b2_' + str(level)](low1)
if level > 1:
low2 = self._forward(level - 1, low1)
else:
low2 = low1
low2 = self._modules['b2_plus_' + str(level)](low2)
low3 = low2
low3 = self._modules['b3_' + str(level)](low3)
up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
return up1 + up2
def forward(self, x, heatmap):
x, last_channel = self.coordconv(x, heatmap)
return self._forward(self.depth, x), last_channel
class FAN(nn.Module):
def __init__(self, num_modules=1, end_relu=False, gray_scale=False, num_landmarks=68, device='cuda'):
super(FAN, self).__init__()
self.device = device
self.num_modules = num_modules
self.gray_scale = gray_scale
self.end_relu = end_relu
self.num_landmarks = num_landmarks
# Base part
if self.gray_scale:
self.conv1 = CoordConvTh(
x_dim=256,
y_dim=256,
with_r=True,
with_boundary=False,
in_channels=3,
out_channels=64,
kernel_size=7,
stride=2,
padding=3)
else:
self.conv1 = CoordConvTh(
x_dim=256,
y_dim=256,
with_r=True,
with_boundary=False,
in_channels=3,
out_channels=64,
kernel_size=7,
stride=2,
padding=3)
self.bn1 = nn.BatchNorm2d(64)
self.conv2 = ConvBlock(64, 128)
self.conv3 = ConvBlock(128, 128)
self.conv4 = ConvBlock(128, 256)
# Stacking part
for hg_module in range(self.num_modules):
if hg_module == 0:
first_one = True
else:
first_one = False
self.add_module('m' + str(hg_module), HourGlass(1, 4, 256, first_one))
self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
self.add_module('conv_last' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
self.add_module('l' + str(hg_module), nn.Conv2d(256, num_landmarks + 1, kernel_size=1, stride=1, padding=0))
if hg_module < self.num_modules - 1:
self.add_module('bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
self.add_module('al' + str(hg_module),
nn.Conv2d(num_landmarks + 1, 256, kernel_size=1, stride=1, padding=0))
def forward(self, x):
x, _ = self.conv1(x)
x = F.relu(self.bn1(x), True)
# x = F.relu(self.bn1(self.conv1(x)), True)
x = F.avg_pool2d(self.conv2(x), 2, stride=2)
x = self.conv3(x)
x = self.conv4(x)
previous = x
outputs = []
boundary_channels = []
tmp_out = None
for i in range(self.num_modules):
hg, boundary_channel = self._modules['m' + str(i)](previous, tmp_out)
ll = hg
ll = self._modules['top_m_' + str(i)](ll)
ll = F.relu(self._modules['bn_end' + str(i)](self._modules['conv_last' + str(i)](ll)), True)
# Predict heatmaps
tmp_out = self._modules['l' + str(i)](ll)
if self.end_relu:
tmp_out = F.relu(tmp_out) # HACK: Added relu
outputs.append(tmp_out)
boundary_channels.append(boundary_channel)
if i < self.num_modules - 1:
ll = self._modules['bl' + str(i)](ll)
tmp_out_ = self._modules['al' + str(i)](tmp_out)
previous = previous + ll + tmp_out_
return outputs, boundary_channels
def get_landmarks(self, img):
H, W, _ = img.shape
offset = W / 64, H / 64, 0, 0
img = cv2.resize(img, (256, 256))
inp = img[..., ::-1]
inp = torch.from_numpy(np.ascontiguousarray(inp.transpose((2, 0, 1)))).float()
inp = inp.to(self.device)
inp.div_(255.0).unsqueeze_(0)
outputs, _ = self.forward(inp)
out = outputs[-1][:, :-1, :, :]
heatmaps = out.detach().cpu().numpy()
pred = calculate_points(heatmaps).reshape(-1, 2)
pred *= offset[:2]
pred += offset[-2:]
return pred
"""This script is the differentiable renderer for Deep3DFaceRecon_pytorch
Attention, antialiasing step is missing in current version.
"""
import pytorch3d.ops
import torch
import torch.nn.functional as F
import kornia
from kornia.geometry.camera import pixel2cam
import numpy as np
from typing import List
from scipy.io import loadmat
from torch import nn
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
look_at_view_transform,
FoVPerspectiveCameras,
DirectionalLights,
RasterizationSettings,
MeshRenderer,
MeshRasterizer,
SoftPhongShader,
TexturesUV,
)
# def ndc_projection(x=0.1, n=1.0, f=50.0):
# return np.array([[n/x, 0, 0, 0],
# [ 0, n/-x, 0, 0],
# [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],
# [ 0, 0, -1, 0]]).astype(np.float32)
class MeshRenderer(nn.Module):
def __init__(self,
rasterize_fov,
znear=0.1,
zfar=10,
rasterize_size=224):
super(MeshRenderer, self).__init__()
# x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear
# self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul(
# torch.diag(torch.tensor([1., -1, -1, 1])))
self.rasterize_size = rasterize_size
self.fov = rasterize_fov
self.znear = znear
self.zfar = zfar
self.rasterizer = None
def forward(self, vertex, tri, feat=None):
"""
Return:
mask -- torch.tensor, size (B, 1, H, W)
depth -- torch.tensor, size (B, 1, H, W)
features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None
Parameters:
vertex -- torch.tensor, size (B, N, 3)
tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles
feat(optional) -- torch.tensor, size (B, N ,C), features
"""
device = vertex.device
rsize = int(self.rasterize_size)
# ndc_proj = self.ndc_proj.to(device)
# trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v
if vertex.shape[-1] == 3:
vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1)
vertex[..., 0] = -vertex[..., 0]
# vertex_ndc = vertex @ ndc_proj.t()
if self.rasterizer is None:
self.rasterizer = MeshRasterizer()
print("create rasterizer on device cuda:%d"%device.index)
# ranges = None
# if isinstance(tri, List) or len(tri.shape) == 3:
# vum = vertex_ndc.shape[1]
# fnum = torch.tensor([f.shape[0] for f in tri]).unsqueeze(1).to(device)
# fstartidx = torch.cumsum(fnum, dim=0) - fnum
# ranges = torch.cat([fstartidx, fnum], axis=1).type(torch.int32).cpu()
# for i in range(tri.shape[0]):
# tri[i] = tri[i] + i*vum
# vertex_ndc = torch.cat(vertex_ndc, dim=0)
# tri = torch.cat(tri, dim=0)
# for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3]
tri = tri.type(torch.int32).contiguous()
# rasterize
cameras = FoVPerspectiveCameras(
device=device,
fov=self.fov,
znear=self.znear,
zfar=self.zfar,
)
raster_settings = RasterizationSettings(
image_size=rsize
)
# print(vertex.shape, tri.shape)
mesh = Meshes(vertex.contiguous()[...,:3], tri.unsqueeze(0).repeat((vertex.shape[0],1,1)))
fragments = self.rasterizer(mesh, cameras = cameras, raster_settings = raster_settings)
rast_out = fragments.pix_to_face.squeeze(-1)
depth = fragments.zbuf
# render depth
depth = depth.permute(0, 3, 1, 2)
mask = (rast_out > 0).float().unsqueeze(1)
depth = mask * depth
image = None
if feat is not None:
attributes = feat.reshape(-1,3)[mesh.faces_packed()]
image = pytorch3d.ops.interpolate_face_attributes(fragments.pix_to_face,
fragments.bary_coords,
attributes)
# print(image.shape)
image = image.squeeze(-2).permute(0, 3, 1, 2)
image = mask * image
return mask, depth, image
"""This script contains the image preprocessing code for Deep3DFaceRecon_pytorch
"""
import numpy as np
from scipy.io import loadmat
from PIL import Image
import cv2
import os
from skimage import transform as trans
import torch
import warnings
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
# calculating least square problem for image alignment
def POS(xp, x):
npts = xp.shape[1]
A = np.zeros([2*npts, 8])
A[0:2*npts-1:2, 0:3] = x.transpose()
A[0:2*npts-1:2, 3] = 1
A[1:2*npts:2, 4:7] = x.transpose()
A[1:2*npts:2, 7] = 1
b = np.reshape(xp.transpose(), [2*npts, 1])
k, _, _, _ = np.linalg.lstsq(A, b)
R1 = k[0:3]
R2 = k[4:7]
sTx = k[3]
sTy = k[7]
s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2
t = np.stack([sTx, sTy], axis=0)
return t, s
# resize and crop images for face reconstruction
def resize_n_crop_img(img, lm, t, s, target_size=224., mask=None):
w0, h0 = img.size
w = (w0*s).astype(np.int32)
h = (h0*s).astype(np.int32)
left = (w/2 - target_size/2 + float((t[0] - w0/2)*s)).astype(np.int32)
right = left + target_size
up = (h/2 - target_size/2 + float((h0/2 - t[1])*s)).astype(np.int32)
below = up + target_size
img = img.resize((w, h), resample=Image.BICUBIC)
img = img.crop((left, up, right, below))
if mask is not None:
mask = mask.resize((w, h), resample=Image.BICUBIC)
mask = mask.crop((left, up, right, below))
lm = np.stack([lm[:, 0] - t[0] + w0/2, lm[:, 1] -
t[1] + h0/2], axis=1)*s
lm = lm - np.reshape(
np.array([(w/2 - target_size/2), (h/2-target_size/2)]), [1, 2])
return img, lm, mask
# utils for face reconstruction
def extract_5p(lm):
lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1
lm5p = np.stack([lm[lm_idx[0], :], np.mean(lm[lm_idx[[1, 2]], :], 0), np.mean(
lm[lm_idx[[3, 4]], :], 0), lm[lm_idx[5], :], lm[lm_idx[6], :]], axis=0)
lm5p = lm5p[[1, 2, 0, 3, 4], :]
return lm5p
# utils for face reconstruction
def align_img(img, lm, lm3D, mask=None, target_size=224., rescale_factor=102.):
"""
Return:
transparams --numpy.array (raw_W, raw_H, scale, tx, ty)
img_new --PIL.Image (target_size, target_size, 3)
lm_new --numpy.array (68, 2), y direction is opposite to v direction
mask_new --PIL.Image (target_size, target_size)
Parameters:
img --PIL.Image (raw_H, raw_W, 3)
lm --numpy.array (68, 2), y direction is opposite to v direction
lm3D --numpy.array (5, 3)
mask --PIL.Image (raw_H, raw_W, 3)
"""
w0, h0 = img.size
if lm.shape[0] != 5:
lm5p = extract_5p(lm)
else:
lm5p = lm
# calculate translation and scale factors using 5 facial landmarks and standard landmarks of a 3D face
t, s = POS(lm5p.transpose(), lm3D.transpose())
s = rescale_factor/s
# processing the image
img_new, lm_new, mask_new = resize_n_crop_img(img, lm, t, s, target_size=target_size, mask=mask)
trans_params = np.array([w0, h0, s, t[0], t[1]])
return trans_params, img_new, lm_new, mask_new
"""This script is to generate skin attention mask for Deep3DFaceRecon_pytorch
"""
import math
import numpy as np
import os
import cv2
class GMM:
def __init__(self, dim, num, w, mu, cov, cov_det, cov_inv):
self.dim = dim # feature dimension
self.num = num # number of Gaussian components
self.w = w # weights of Gaussian components (a list of scalars)
self.mu= mu # mean of Gaussian components (a list of 1xdim vectors)
self.cov = cov # covariance matrix of Gaussian components (a list of dimxdim matrices)
self.cov_det = cov_det # pre-computed determinet of covariance matrices (a list of scalars)
self.cov_inv = cov_inv # pre-computed inverse covariance matrices (a list of dimxdim matrices)
self.factor = [0]*num
for i in range(self.num):
self.factor[i] = (2*math.pi)**(self.dim/2) * self.cov_det[i]**0.5
def likelihood(self, data):
assert(data.shape[1] == self.dim)
N = data.shape[0]
lh = np.zeros(N)
for i in range(self.num):
data_ = data - self.mu[i]
tmp = np.matmul(data_,self.cov_inv[i]) * data_
tmp = np.sum(tmp,axis=1)
power = -0.5 * tmp
p = np.array([math.exp(power[j]) for j in range(N)])
p = p/self.factor[i]
lh += p*self.w[i]
return lh
def _rgb2ycbcr(rgb):
m = np.array([[65.481, 128.553, 24.966],
[-37.797, -74.203, 112],
[112, -93.786, -18.214]])
shape = rgb.shape
rgb = rgb.reshape((shape[0] * shape[1], 3))
ycbcr = np.dot(rgb, m.transpose() / 255.)
ycbcr[:, 0] += 16.
ycbcr[:, 1:] += 128.
return ycbcr.reshape(shape)
def _bgr2ycbcr(bgr):
rgb = bgr[..., ::-1]
return _rgb2ycbcr(rgb)
gmm_skin_w = [0.24063933, 0.16365987, 0.26034665, 0.33535415]
gmm_skin_mu = [np.array([113.71862, 103.39613, 164.08226]),
np.array([150.19858, 105.18467, 155.51428]),
np.array([183.92976, 107.62468, 152.71820]),
np.array([114.90524, 113.59782, 151.38217])]
gmm_skin_cov_det = [5692842.5, 5851930.5, 2329131., 1585971.]
gmm_skin_cov_inv = [np.array([[0.0019472069, 0.0020450759, -0.00060243998],[0.0020450759, 0.017700525, 0.0051420014],[-0.00060243998, 0.0051420014, 0.0081308950]]),
np.array([[0.0027110141, 0.0011036990, 0.0023122299],[0.0011036990, 0.010707724, 0.010742856],[0.0023122299, 0.010742856, 0.017481629]]),
np.array([[0.0048026871, 0.00022935172, 0.0077668377],[0.00022935172, 0.011729696, 0.0081661865],[0.0077668377, 0.0081661865, 0.025374353]]),
np.array([[0.0011989699, 0.0022453172, -0.0010748957],[0.0022453172, 0.047758564, 0.020332102],[-0.0010748957, 0.020332102, 0.024502251]])]
gmm_skin = GMM(3, 4, gmm_skin_w, gmm_skin_mu, [], gmm_skin_cov_det, gmm_skin_cov_inv)
gmm_nonskin_w = [0.12791070, 0.31130761, 0.34245777, 0.21832393]
gmm_nonskin_mu = [np.array([99.200851, 112.07533, 140.20602]),
np.array([110.91392, 125.52969, 130.19237]),
np.array([129.75864, 129.96107, 126.96808]),
np.array([112.29587, 128.85121, 129.05431])]
gmm_nonskin_cov_det = [458703648., 6466488., 90611376., 133097.63]
gmm_nonskin_cov_inv = [np.array([[0.00085371657, 0.00071197288, 0.00023958916],[0.00071197288, 0.0025935620, 0.00076557708],[0.00023958916, 0.00076557708, 0.0015042332]]),
np.array([[0.00024650150, 0.00045542428, 0.00015019422],[0.00045542428, 0.026412144, 0.018419769],[0.00015019422, 0.018419769, 0.037497383]]),
np.array([[0.00037054974, 0.00038146760, 0.00040408765],[0.00038146760, 0.0085505722, 0.0079136286],[0.00040408765, 0.0079136286, 0.010982352]]),
np.array([[0.00013709733, 0.00051228428, 0.00012777430],[0.00051228428, 0.28237113, 0.10528370],[0.00012777430, 0.10528370, 0.23468947]])]
gmm_nonskin = GMM(3, 4, gmm_nonskin_w, gmm_nonskin_mu, [], gmm_nonskin_cov_det, gmm_nonskin_cov_inv)
prior_skin = 0.8
prior_nonskin = 1 - prior_skin
# calculate skin attention mask
def skinmask(imbgr):
im = _bgr2ycbcr(imbgr)
data = im.reshape((-1,3))
lh_skin = gmm_skin.likelihood(data)
lh_nonskin = gmm_nonskin.likelihood(data)
tmp1 = prior_skin * lh_skin
tmp2 = prior_nonskin * lh_nonskin
post_skin = tmp1 / (tmp1+tmp2) # posterior probability
post_skin = post_skin.reshape((im.shape[0],im.shape[1]))
post_skin = np.round(post_skin*255)
post_skin = post_skin.astype(np.uint8)
post_skin = np.tile(np.expand_dims(post_skin,2),[1,1,3]) # reshape to H*W*3
return post_skin
def get_skin_mask(img_path):
print('generating skin masks......')
names = [i for i in sorted(os.listdir(
img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i]
save_path = os.path.join(img_path, 'mask')
if not os.path.isdir(save_path):
os.makedirs(save_path)
for i in range(0, len(names)):
name = names[i]
print('%05d' % (i), ' ', name)
full_image_name = os.path.join(img_path, name)
img = cv2.imread(full_image_name).astype(np.float32)
skin_img = skinmask(img)
cv2.imwrite(os.path.join(save_path, name), skin_img.astype(np.uint8))
-5.228591537475585938e+01
2.078247070312500000e-01
-5.064269638061523438e+01
-1.315765380859375000e+01
-4.952939224243164062e+01
-2.592591094970703125e+01
-4.793047332763671875e+01
-3.832135772705078125e+01
-4.512159729003906250e+01
-5.059623336791992188e+01
-3.917720794677734375e+01
-6.043736648559570312e+01
-2.929953765869140625e+01
-6.861183166503906250e+01
-1.719801330566406250e+01
-7.572736358642578125e+01
-1.961936950683593750e+00
-7.862001037597656250e+01
1.467941284179687500e+01
-7.607844543457031250e+01
2.744073486328125000e+01
-6.915261840820312500e+01
3.855677795410156250e+01
-5.950350570678710938e+01
4.478240966796875000e+01
-4.867547225952148438e+01
4.714337158203125000e+01
-3.800830078125000000e+01
4.940315246582031250e+01
-2.496297454833984375e+01
5.117234802246093750e+01
-1.241538238525390625e+01
5.190507507324218750e+01
8.244247436523437500e-01
-4.150688934326171875e+01
2.386329650878906250e+01
-3.570307159423828125e+01
3.017010498046875000e+01
-2.790358734130859375e+01
3.212951660156250000e+01
-1.941773223876953125e+01
3.156523132324218750e+01
-1.138106536865234375e+01
2.841992187500000000e+01
5.993263244628906250e+00
2.895182800292968750e+01
1.343590545654296875e+01
3.189880371093750000e+01
2.203153991699218750e+01
3.302221679687500000e+01
2.992478942871093750e+01
3.099150085449218750e+01
3.628388977050781250e+01
2.765748596191406250e+01
-1.933914184570312500e+00
1.405374145507812500e+01
-2.153038024902343750e+00
5.772636413574218750e+00
-2.270050048828125000e+00
-2.121643066406250000e+00
-2.218330383300781250e+00
-1.068978118896484375e+01
-1.187252044677734375e+01
-1.997912597656250000e+01
-6.879402160644531250e+00
-2.143579864501953125e+01
-1.227821350097656250e+00
-2.193494415283203125e+01
4.623237609863281250e+00
-2.152721405029296875e+01
9.721397399902343750e+00
-1.953671264648437500e+01
-3.648714447021484375e+01
9.811126708984375000e+00
-3.130242919921875000e+01
1.422447967529296875e+01
-2.212834930419921875e+01
1.493019866943359375e+01
-1.500880432128906250e+01
1.073588562011718750e+01
-2.095037078857421875e+01
9.054298400878906250e+00
-3.050099182128906250e+01
8.704177856445312500e+00
1.173237609863281250e+01
1.054329681396484375e+01
1.856353759765625000e+01
1.535009765625000000e+01
2.893331909179687500e+01
1.451992797851562500e+01
3.452944946289062500e+01
1.065280151367187500e+01
2.875990295410156250e+01
8.654792785644531250e+00
1.942100524902343750e+01
9.422447204589843750e+00
-2.204488372802734375e+01
-3.983994293212890625e+01
-1.324458312988281250e+01
-3.467377471923828125e+01
-6.749649047851562500e+00
-3.092894744873046875e+01
-9.183349609375000000e-01
-3.196458435058593750e+01
4.220649719238281250e+00
-3.090406036376953125e+01
1.089889526367187500e+01
-3.497008514404296875e+01
1.874589538574218750e+01
-4.065438079833984375e+01
1.124106597900390625e+01
-4.438417816162109375e+01
5.181709289550781250e+00
-4.649170684814453125e+01
-1.158607482910156250e+00
-4.680406951904296875e+01
-7.918922424316406250e+00
-4.671575164794921875e+01
-1.452505493164062500e+01
-4.416526031494140625e+01
-2.005007171630859375e+01
-3.997841644287109375e+01
-1.054919433593750000e+01
-3.849683380126953125e+01
-1.051826477050781250e+00
-3.794863128662109375e+01
6.412681579589843750e+00
-3.804645538330078125e+01
1.627674865722656250e+01
-4.039697265625000000e+01
6.373878479003906250e+00
-4.087213897705078125e+01
-8.551712036132812500e-01
-4.157129669189453125e+01
-1.014953613281250000e+01
-4.128469085693359375e+01
"""This script contains basic utilities for Deep3DFaceRecon_pytorch
"""
from __future__ import print_function
import numpy as np
import torch
from PIL import Image
import os
import importlib
import argparse
from argparse import Namespace
import torchvision
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def copyconf(default_opt, **kwargs):
conf = Namespace(**vars(default_opt))
for key in kwargs:
setattr(conf, key, kwargs[key])
return conf
def genvalconf(train_opt, **kwargs):
conf = Namespace(**vars(train_opt))
attr_dict = train_opt.__dict__
for key, value in attr_dict.items():
if 'val' in key and key.split('_')[0] in attr_dict:
setattr(conf, key.split('_')[0], value)
for key in kwargs:
setattr(conf, key, kwargs[key])
return conf
def find_class_in_module(target_cls_name, module):
target_cls_name = target_cls_name.replace('_', '').lower()
clslib = importlib.import_module(module)
cls = None
for name, clsobj in clslib.__dict__.items():
if name.lower() == target_cls_name:
cls = clsobj
assert cls is not None, "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name)
return cls
def tensor2im(input_image, imtype=np.uint8):
""""Converts a Tensor array into a numpy image array.
Parameters:
input_image (tensor) -- the input image tensor array, range(0, 1)
imtype (type) -- the desired type of the converted numpy array
"""
if not isinstance(input_image, np.ndarray):
if isinstance(input_image, torch.Tensor): # get the data from a variable
image_tensor = input_image.data
else:
return input_image
image_numpy = image_tensor.clamp(0.0, 1.0).cpu().float().numpy() # convert it into a numpy array
if image_numpy.shape[0] == 1: # grayscale to RGB
image_numpy = np.tile(image_numpy, (3, 1, 1))
image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 # post-processing: tranpose and scaling
else: # if it is a numpy array, do nothing
image_numpy = input_image
return image_numpy.astype(imtype)
def diagnose_network(net, name='network'):
"""Calculate and print the mean of average absolute(gradients)
Parameters:
net (torch network) -- Torch network
name (str) -- the name of the network
"""
mean = 0.0
count = 0
for param in net.parameters():
if param.grad is not None:
mean += torch.mean(torch.abs(param.grad.data))
count += 1
if count > 0:
mean = mean / count
print(name)
print(mean)
def save_image(image_numpy, image_path, aspect_ratio=1.0):
"""Save a numpy image to the disk
Parameters:
image_numpy (numpy array) -- input numpy array
image_path (str) -- the path of the image
"""
image_pil = Image.fromarray(image_numpy)
h, w, _ = image_numpy.shape
if aspect_ratio is None:
pass
elif aspect_ratio > 1.0:
image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
elif aspect_ratio < 1.0:
image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
image_pil.save(image_path)
def print_numpy(x, val=True, shp=False):
"""Print the mean, min, max, median, std, and size of a numpy array
Parameters:
val (bool) -- if print the values of the numpy array
shp (bool) -- if print the shape of the numpy array
"""
x = x.astype(np.float64)
if shp:
print('shape,', x.shape)
if val:
x = x.flatten()
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
def mkdirs(paths):
"""create empty directories if they don't exist
Parameters:
paths (str list) -- a list of directory paths
"""
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
mkdir(path)
else:
mkdir(paths)
def mkdir(path):
"""create a single empty directory if it didn't exist
Parameters:
path (str) -- a single directory path
"""
if not os.path.exists(path):
os.makedirs(path)
def correct_resize_label(t, size):
device = t.device
t = t.detach().cpu()
resized = []
for i in range(t.size(0)):
one_t = t[i, :1]
one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0))
one_np = one_np[:, :, 0]
one_image = Image.fromarray(one_np).resize(size, Image.NEAREST)
resized_t = torch.from_numpy(np.array(one_image)).long()
resized.append(resized_t)
return torch.stack(resized, dim=0).to(device)
def correct_resize(t, size, mode=Image.BICUBIC):
device = t.device
t = t.detach().cpu()
resized = []
for i in range(t.size(0)):
one_t = t[i:i + 1]
one_image = Image.fromarray(tensor2im(one_t)).resize(size, Image.BICUBIC)
resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0
resized.append(resized_t)
return torch.stack(resized, dim=0).to(device)
def draw_landmarks(img, landmark, color='r', step=2):
"""
Return:
img -- numpy.array, (B, H, W, 3) img with landmark, RGB order, range (0, 255)
Parameters:
img -- numpy.array, (B, H, W, 3), RGB order, range (0, 255)
landmark -- numpy.array, (B, 68, 2), y direction is opposite to v direction
color -- str, 'r' or 'b' (red or blue)
"""
if color =='r':
c = np.array([255., 0, 0])
else:
c = np.array([0, 0, 255.])
_, H, W, _ = img.shape
img, landmark = img.copy(), landmark.copy()
landmark[..., 1] = H - 1 - landmark[..., 1]
landmark = np.round(landmark).astype(np.int32)
for i in range(landmark.shape[1]):
x, y = landmark[:, i, 0], landmark[:, i, 1]
for j in range(-step, step):
for k in range(-step, step):
u = np.clip(x + j, 0, W - 1)
v = np.clip(y + k, 0, H - 1)
for m in range(landmark.shape[0]):
img[m, v[m], u[m]] = c
return img
"""This script defines the visualizer for Deep3DFaceRecon_pytorch
"""
import numpy as np
import os
import sys
import ntpath
import time
from . import util, html
from subprocess import Popen, PIPE
from torch.utils.tensorboard import SummaryWriter
def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
"""Save images to the disk.
Parameters:
webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
image_path (str) -- the string is used to create image paths
aspect_ratio (float) -- the aspect ratio of saved images
width (int) -- the images will be resized to width x width
This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
"""
image_dir = webpage.get_image_dir()
short_path = ntpath.basename(image_path[0])
name = os.path.splitext(short_path)[0]
webpage.add_header(name)
ims, txts, links = [], [], []
for label, im_data in visuals.items():
im = util.tensor2im(im_data)
image_name = '%s/%s.png' % (label, name)
os.makedirs(os.path.join(image_dir, label), exist_ok=True)
save_path = os.path.join(image_dir, image_name)
util.save_image(im, save_path, aspect_ratio=aspect_ratio)
ims.append(image_name)
txts.append(label)
links.append(image_name)
webpage.add_images(ims, txts, links, width=width)
class Visualizer():
"""This class includes several functions that can display/save images and print/save logging information.
It uses a Python library tensprboardX for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
"""
def __init__(self, opt):
"""Initialize the Visualizer class
Parameters:
opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
Step 1: Cache the training/test options
Step 2: create a tensorboard writer
Step 3: create an HTML object for saveing HTML filters
Step 4: create a logging file to store training losses
"""
self.opt = opt # cache the option
self.use_html = opt.isTrain and not opt.no_html
self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, 'logs', opt.name))
self.win_size = opt.display_winsize
self.name = opt.name
self.saved = False
if self.use_html: # create an HTML object at <checkpoints_dir>/web/; images will be saved under <checkpoints_dir>/web/images/
self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
self.img_dir = os.path.join(self.web_dir, 'images')
print('create web directory %s...' % self.web_dir)
util.mkdirs([self.web_dir, self.img_dir])
# create a logging file to store training losses
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
with open(self.log_name, "a") as log_file:
now = time.strftime("%c")
log_file.write('================ Training Loss (%s) ================\n' % now)
def reset(self):
"""Reset the self.saved status"""
self.saved = False
def display_current_results(self, visuals, total_iters, epoch, save_result):
"""Display current results on tensorboad; save current results to an HTML file.
Parameters:
visuals (OrderedDict) - - dictionary of images to display or save
total_iters (int) -- total iterations
epoch (int) - - the current epoch
save_result (bool) - - if save the current results to an HTML file
"""
for label, image in visuals.items():
self.writer.add_image(label, util.tensor2im(image), total_iters, dataformats='HWC')
if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
self.saved = True
# save images to the disk
for label, image in visuals.items():
image_numpy = util.tensor2im(image)
img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
util.save_image(image_numpy, img_path)
# update website
webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=0)
for n in range(epoch, 0, -1):
webpage.add_header('epoch [%d]' % n)
ims, txts, links = [], [], []
for label, image_numpy in visuals.items():
image_numpy = util.tensor2im(image)
img_path = 'epoch%.3d_%s.png' % (n, label)
ims.append(img_path)
txts.append(label)
links.append(img_path)
webpage.add_images(ims, txts, links, width=self.win_size)
webpage.save()
def plot_current_losses(self, total_iters, losses):
# G_loss_collection = {}
# D_loss_collection = {}
# for name, value in losses.items():
# if 'G' in name or 'NCE' in name or 'idt' in name:
# G_loss_collection[name] = value
# else:
# D_loss_collection[name] = value
# self.writer.add_scalars('G_collec', G_loss_collection, total_iters)
# self.writer.add_scalars('D_collec', D_loss_collection, total_iters)
for name, value in losses.items():
self.writer.add_scalar(name, value, total_iters)
# losses: same format as |losses| of plot_current_losses
def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
"""print current losses on console; also save the losses to the disk
Parameters:
epoch (int) -- current epoch
iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
t_comp (float) -- computational time per data point (normalized by batch_size)
t_data (float) -- data loading time per data point (normalized by batch_size)
"""
message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
for k, v in losses.items():
message += '%s: %.3f ' % (k, v)
print(message) # print the message
with open(self.log_name, "a") as log_file:
log_file.write('%s\n' % message) # save the message
class MyVisualizer:
def __init__(self, opt):
"""Initialize the Visualizer class
Parameters:
opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
Step 1: Cache the training/test options
Step 2: create a tensorboard writer
Step 3: create an HTML object for saveing HTML filters
Step 4: create a logging file to store training losses
"""
self.opt = opt # cache the optio
self.name = opt.name
self.img_dir = os.path.join(opt.checkpoints_dir, opt.name, 'results')
if opt.phase != 'test':
self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, 'logs'))
# create a logging file to store training losses
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
with open(self.log_name, "a") as log_file:
now = time.strftime("%c")
log_file.write('================ Training Loss (%s) ================\n' % now)
def display_current_results(self, visuals, total_iters, epoch, dataset='train', save_results=False, count=0, name=None,
add_image=True):
"""Display current results on tensorboad; save current results to an HTML file.
Parameters:
visuals (OrderedDict) - - dictionary of images to display or save
total_iters (int) -- total iterations
epoch (int) - - the current epoch
dataset (str) - - 'train' or 'val' or 'test'
"""
# if (not add_image) and (not save_results): return
for label, image in visuals.items():
for i in range(image.shape[0]):
image_numpy = util.tensor2im(image[i])
if add_image:
self.writer.add_image(label + '%s_%02d'%(dataset, i + count),
image_numpy, total_iters, dataformats='HWC')
if save_results:
save_path = os.path.join(self.img_dir, dataset, 'epoch_%s_%06d'%(epoch, total_iters))
if not os.path.isdir(save_path):
os.makedirs(save_path)
if name is not None:
img_path = os.path.join(save_path, '%s.png' % name)
else:
img_path = os.path.join(save_path, '%s_%03d.png' % (label, i + count))
util.save_image(image_numpy, img_path)
def plot_current_losses(self, total_iters, losses, dataset='train'):
for name, value in losses.items():
self.writer.add_scalar(name + '/%s'%dataset, value, total_iters)
# losses: same format as |losses| of plot_current_losses
def print_current_losses(self, epoch, iters, losses, t_comp, t_data, dataset='train'):
"""print current losses on console; also save the losses to the disk
Parameters:
epoch (int) -- current epoch
iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
t_comp (float) -- computational time per data point (normalized by batch_size)
t_data (float) -- data loading time per data point (normalized by batch_size)
"""
message = '(dataset: %s, epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (
dataset, epoch, iters, t_comp, t_data)
for k, v in losses.items():
message += '%s: %.3f ' % (k, v)
print(message) # print the message
with open(self.log_name, "a") as log_file:
log_file.write('%s\n' % message) # save the message
# check the sync of 3dmm feature and the audio
import cv2
import numpy as np
from src.face3d.models.bfm import ParametricFaceModel
from src.face3d.models.facerecon_model import FaceReconModel
import torch
import subprocess, platform
import scipy.io as scio
from tqdm import tqdm
# draft
def gen_composed_video(args, device, first_frame_coeff, coeff_path, audio_path, save_path, exp_dim=64):
coeff_first = scio.loadmat(first_frame_coeff)['full_3dmm']
coeff_pred = scio.loadmat(coeff_path)['coeff_3dmm']
coeff_full = np.repeat(coeff_first, coeff_pred.shape[0], axis=0) # 257
coeff_full[:, 80:144] = coeff_pred[:, 0:64]
coeff_full[:, 224:227] = coeff_pred[:, 64:67] # 3 dim translation
coeff_full[:, 254:] = coeff_pred[:, 67:] # 3 dim translation
tmp_video_path = '/tmp/face3dtmp.mp4'
facemodel = FaceReconModel(args)
video = cv2.VideoWriter(tmp_video_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (224, 224))
for k in tqdm(range(coeff_pred.shape[0]), 'face3d rendering:'):
cur_coeff_full = torch.tensor(coeff_full[k:k+1], device=device)
facemodel.forward(cur_coeff_full, device)
predicted_landmark = facemodel.pred_lm # TODO.
predicted_landmark = predicted_landmark.cpu().numpy().squeeze()
rendered_img = facemodel.pred_face
rendered_img = 255. * rendered_img.cpu().numpy().squeeze().transpose(1,2,0)
out_img = rendered_img[:, :, :3].astype(np.uint8)
video.write(np.uint8(out_img[:,:,::-1]))
video.release()
command = 'ffmpeg -v quiet -y -i {} -i {} -strict -2 -q:v 1 {}'.format(audio_path, tmp_video_path, save_path)
subprocess.call(command, shell=platform.system() != 'Windows')
import os
import cv2
import yaml
import numpy as np
import warnings
from skimage import img_as_ubyte
import safetensors
import safetensors.torch
warnings.filterwarnings('ignore')
import imageio
import torch
import torchvision
from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector
from src.facerender.modules.mapping import MappingNet
from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator
from src.facerender.modules.make_animation import make_animation
from pydub import AudioSegment
from src.utils.face_enhancer import enhancer_generator_with_len, enhancer_list
from src.utils.paste_pic import paste_pic
from src.utils.videoio import save_video_with_watermark
try:
import webui # in webui
in_webui = True
except:
in_webui = False
class AnimateFromCoeff():
def __init__(self, sadtalker_path, device):
with open(sadtalker_path['facerender_yaml']) as f:
config = yaml.safe_load(f)
generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'],
**config['model_params']['common_params'])
kp_extractor = KPDetector(**config['model_params']['kp_detector_params'],
**config['model_params']['common_params'])
he_estimator = HEEstimator(**config['model_params']['he_estimator_params'],
**config['model_params']['common_params'])
mapping = MappingNet(**config['model_params']['mapping_params'])
generator.to(device)
kp_extractor.to(device)
he_estimator.to(device)
mapping.to(device)
for param in generator.parameters():
param.requires_grad = False
for param in kp_extractor.parameters():
param.requires_grad = False
for param in he_estimator.parameters():
param.requires_grad = False
for param in mapping.parameters():
param.requires_grad = False
if sadtalker_path is not None:
if 'checkpoint' in sadtalker_path: # use safe tensor
self.load_cpk_facevid2vid_safetensor(sadtalker_path['checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=None)
else:
self.load_cpk_facevid2vid(sadtalker_path['free_view_checkpoint'], kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator)
else:
raise AttributeError("Checkpoint should be specified for video head pose estimator.")
if sadtalker_path['mappingnet_checkpoint'] is not None:
self.load_cpk_mapping(sadtalker_path['mappingnet_checkpoint'], mapping=mapping)
else:
raise AttributeError("Checkpoint should be specified for video head pose estimator.")
self.kp_extractor = kp_extractor
self.generator = generator
self.he_estimator = he_estimator
self.mapping = mapping
self.kp_extractor.eval()
self.generator.eval()
self.he_estimator.eval()
self.mapping.eval()
self.device = device
def load_cpk_facevid2vid_safetensor(self, checkpoint_path, generator=None,
kp_detector=None, he_estimator=None,
device="cpu"):
checkpoint = safetensors.torch.load_file(checkpoint_path)
if generator is not None:
x_generator = {}
for k,v in checkpoint.items():
if 'generator' in k:
x_generator[k.replace('generator.', '')] = v
generator.load_state_dict(x_generator)
if kp_detector is not None:
x_generator = {}
for k,v in checkpoint.items():
if 'kp_extractor' in k:
x_generator[k.replace('kp_extractor.', '')] = v
kp_detector.load_state_dict(x_generator)
if he_estimator is not None:
x_generator = {}
for k,v in checkpoint.items():
if 'he_estimator' in k:
x_generator[k.replace('he_estimator.', '')] = v
he_estimator.load_state_dict(x_generator)
return None
def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discriminator=None,
kp_detector=None, he_estimator=None, optimizer_generator=None,
optimizer_discriminator=None, optimizer_kp_detector=None,
optimizer_he_estimator=None, device="cpu"):
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
if generator is not None:
generator.load_state_dict(checkpoint['generator'])
if kp_detector is not None:
kp_detector.load_state_dict(checkpoint['kp_detector'])
if he_estimator is not None:
he_estimator.load_state_dict(checkpoint['he_estimator'])
if discriminator is not None:
try:
discriminator.load_state_dict(checkpoint['discriminator'])
except:
print ('No discriminator in the state-dict. Dicriminator will be randomly initialized')
if optimizer_generator is not None:
optimizer_generator.load_state_dict(checkpoint['optimizer_generator'])
if optimizer_discriminator is not None:
try:
optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
except RuntimeError as e:
print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized')
if optimizer_kp_detector is not None:
optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector'])
if optimizer_he_estimator is not None:
optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator'])
return checkpoint['epoch']
def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None,
optimizer_mapping=None, optimizer_discriminator=None, device='cpu'):
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
if mapping is not None:
mapping.load_state_dict(checkpoint['mapping'])
if discriminator is not None:
discriminator.load_state_dict(checkpoint['discriminator'])
if optimizer_mapping is not None:
optimizer_mapping.load_state_dict(checkpoint['optimizer_mapping'])
if optimizer_discriminator is not None:
optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
return checkpoint['epoch']
def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop', img_size=256):
source_image=x['source_image'].type(torch.FloatTensor)
source_semantics=x['source_semantics'].type(torch.FloatTensor)
target_semantics=x['target_semantics_list'].type(torch.FloatTensor)
source_image=source_image.to(self.device)
source_semantics=source_semantics.to(self.device)
target_semantics=target_semantics.to(self.device)
if 'yaw_c_seq' in x:
yaw_c_seq = x['yaw_c_seq'].type(torch.FloatTensor)
yaw_c_seq = x['yaw_c_seq'].to(self.device)
else:
yaw_c_seq = None
if 'pitch_c_seq' in x:
pitch_c_seq = x['pitch_c_seq'].type(torch.FloatTensor)
pitch_c_seq = x['pitch_c_seq'].to(self.device)
else:
pitch_c_seq = None
if 'roll_c_seq' in x:
roll_c_seq = x['roll_c_seq'].type(torch.FloatTensor)
roll_c_seq = x['roll_c_seq'].to(self.device)
else:
roll_c_seq = None
frame_num = x['frame_num']
predictions_video = make_animation(source_image, source_semantics, target_semantics,
self.generator, self.kp_extractor, self.he_estimator, self.mapping,
yaw_c_seq, pitch_c_seq, roll_c_seq, use_exp = True)
predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:])
predictions_video = predictions_video[:frame_num]
video = []
for idx in range(predictions_video.shape[0]):
image = predictions_video[idx]
image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32)
video.append(image)
result = img_as_ubyte(video)
### the generated video is 256x256, so we keep the aspect ratio,
original_size = crop_info[0]
if original_size:
result = [ cv2.resize(result_i,(img_size, int(img_size * original_size[1]/original_size[0]) )) for result_i in result ]
video_name = x['video_name'] + '.mp4'
path = os.path.join(video_save_dir, 'temp_'+video_name)
imageio.mimsave(path, result, fps=float(25))
av_path = os.path.join(video_save_dir, video_name)
return_path = av_path
audio_path = x['audio_path']
audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
new_audio_path = os.path.join(video_save_dir, audio_name+'.wav')
start_time = 0
# cog will not keep the .mp3 filename
sound = AudioSegment.from_file(audio_path)
frames = frame_num
end_time = start_time + frames*1/25*1000
word1=sound.set_frame_rate(16000)
word = word1[start_time:end_time]
word.export(new_audio_path, format="wav")
save_video_with_watermark(path, new_audio_path, av_path, watermark= False)
print(f'The generated video is named {video_save_dir}/{video_name}')
if 'full' in preprocess.lower():
# only add watermark to the full image.
video_name_full = x['video_name'] + '_full.mp4'
full_video_path = os.path.join(video_save_dir, video_name_full)
return_path = full_video_path
paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path, extended_crop= True if 'ext' in preprocess.lower() else False)
print(f'The generated video is named {video_save_dir}/{video_name_full}')
else:
full_video_path = av_path
#### paste back then enhancers
if enhancer:
video_name_enhancer = x['video_name'] + '_enhanced.mp4'
enhanced_path = os.path.join(video_save_dir, 'temp_'+video_name_enhancer)
av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer)
return_path = av_path_enhancer
try:
enhanced_images_gen_with_len = enhancer_generator_with_len(full_video_path, method=enhancer, bg_upsampler=background_enhancer)
imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25))
except:
enhanced_images_gen_with_len = enhancer_list(full_video_path, method=enhancer, bg_upsampler=background_enhancer)
imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(25))
save_video_with_watermark(enhanced_path, new_audio_path, av_path_enhancer, watermark= False)
print(f'The generated video is named {video_save_dir}/{video_name_enhancer}')
os.remove(enhanced_path)
os.remove(path)
os.remove(new_audio_path)
return return_path
from torch import nn
import torch.nn.functional as F
import torch
from src.facerender.modules.util import Hourglass, make_coordinate_grid, kp2gaussian
from src.facerender.sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d
class DenseMotionNetwork(nn.Module):
"""
Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving
"""
def __init__(self, block_expansion, num_blocks, max_features, num_kp, feature_channel, reshape_depth, compress,
estimate_occlusion_map=False):
super(DenseMotionNetwork, self).__init__()
# self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(feature_channel+1), max_features=max_features, num_blocks=num_blocks)
self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(compress+1), max_features=max_features, num_blocks=num_blocks)
self.mask = nn.Conv3d(self.hourglass.out_filters, num_kp + 1, kernel_size=7, padding=3)
self.compress = nn.Conv3d(feature_channel, compress, kernel_size=1)
self.norm = BatchNorm3d(compress, affine=True)
if estimate_occlusion_map:
# self.occlusion = nn.Conv2d(reshape_channel*reshape_depth, 1, kernel_size=7, padding=3)
self.occlusion = nn.Conv2d(self.hourglass.out_filters*reshape_depth, 1, kernel_size=7, padding=3)
else:
self.occlusion = None
self.num_kp = num_kp
def create_sparse_motions(self, feature, kp_driving, kp_source):
bs, _, d, h, w = feature.shape
identity_grid = make_coordinate_grid((d, h, w), type=kp_source['value'].type())
identity_grid = identity_grid.view(1, 1, d, h, w, 3)
coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 1, 3)
# if 'jacobian' in kp_driving:
if 'jacobian' in kp_driving and kp_driving['jacobian'] is not None:
jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian']))
jacobian = jacobian.unsqueeze(-3).unsqueeze(-3).unsqueeze(-3)
jacobian = jacobian.repeat(1, 1, d, h, w, 1, 1)
coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1))
coordinate_grid = coordinate_grid.squeeze(-1)
driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 1, 3) # (bs, num_kp, d, h, w, 3)
#adding background feature
identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1, 1)
sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) #bs num_kp+1 d h w 3
# sparse_motions = driving_to_source
return sparse_motions
def create_deformed_feature(self, feature, sparse_motions):
bs, _, d, h, w = feature.shape
feature_repeat = feature.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp+1, 1, 1, 1, 1, 1) # (bs, num_kp+1, 1, c, d, h, w)
feature_repeat = feature_repeat.view(bs * (self.num_kp+1), -1, d, h, w) # (bs*(num_kp+1), c, d, h, w)
sparse_motions = sparse_motions.view((bs * (self.num_kp+1), d, h, w, -1)) # (bs*(num_kp+1), d, h, w, 3) !!!!
sparse_deformed = F.grid_sample(feature_repeat, sparse_motions)
sparse_deformed = sparse_deformed.view((bs, self.num_kp+1, -1, d, h, w)) # (bs, num_kp+1, c, d, h, w)
return sparse_deformed
def create_heatmap_representations(self, feature, kp_driving, kp_source):
spatial_size = feature.shape[3:]
gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=0.01)
gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=0.01)
heatmap = gaussian_driving - gaussian_source
# adding background feature
zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.type())
heatmap = torch.cat([zeros, heatmap], dim=1)
heatmap = heatmap.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w)
return heatmap
def forward(self, feature, kp_driving, kp_source):
bs, _, d, h, w = feature.shape
feature = self.compress(feature)
feature = self.norm(feature)
feature = F.relu(feature)
out_dict = dict()
sparse_motion = self.create_sparse_motions(feature, kp_driving, kp_source)
deformed_feature = self.create_deformed_feature(feature, sparse_motion)
heatmap = self.create_heatmap_representations(deformed_feature, kp_driving, kp_source)
input_ = torch.cat([heatmap, deformed_feature], dim=2)
input_ = input_.view(bs, -1, d, h, w)
# input = deformed_feature.view(bs, -1, d, h, w) # (bs, num_kp+1 * c, d, h, w)
prediction = self.hourglass(input_)
mask = self.mask(prediction)
mask = F.softmax(mask, dim=1)
out_dict['mask'] = mask
mask = mask.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w)
zeros_mask = torch.zeros_like(mask)
mask = torch.where(mask < 1e-3, zeros_mask, mask)
sparse_motion = sparse_motion.permute(0, 1, 5, 2, 3, 4) # (bs, num_kp+1, 3, d, h, w)
deformation = (sparse_motion * mask).sum(dim=1) # (bs, 3, d, h, w)
deformation = deformation.permute(0, 2, 3, 4, 1) # (bs, d, h, w, 3)
out_dict['deformation'] = deformation
if self.occlusion:
bs, c, d, h, w = prediction.shape
prediction = prediction.view(bs, -1, h, w)
occlusion_map = torch.sigmoid(self.occlusion(prediction))
out_dict['occlusion_map'] = occlusion_map
return out_dict
from torch import nn
import torch.nn.functional as F
from facerender.modules.util import kp2gaussian
import torch
class DownBlock2d(nn.Module):
"""
Simple block for processing video (encoder).
"""
def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False):
super(DownBlock2d, self).__init__()
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size)
if sn:
self.conv = nn.utils.spectral_norm(self.conv)
if norm:
self.norm = nn.InstanceNorm2d(out_features, affine=True)
else:
self.norm = None
self.pool = pool
def forward(self, x):
out = x
out = self.conv(out)
if self.norm:
out = self.norm(out)
out = F.leaky_relu(out, 0.2)
if self.pool:
out = F.avg_pool2d(out, (2, 2))
return out
class Discriminator(nn.Module):
"""
Discriminator similar to Pix2Pix
"""
def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512,
sn=False, **kwargs):
super(Discriminator, self).__init__()
down_blocks = []
for i in range(num_blocks):
down_blocks.append(
DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)),
min(max_features, block_expansion * (2 ** (i + 1))),
norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn))
self.down_blocks = nn.ModuleList(down_blocks)
self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1)
if sn:
self.conv = nn.utils.spectral_norm(self.conv)
def forward(self, x):
feature_maps = []
out = x
for down_block in self.down_blocks:
feature_maps.append(down_block(out))
out = feature_maps[-1]
prediction_map = self.conv(out)
return feature_maps, prediction_map
class MultiScaleDiscriminator(nn.Module):
"""
Multi-scale (scale) discriminator
"""
def __init__(self, scales=(), **kwargs):
super(MultiScaleDiscriminator, self).__init__()
self.scales = scales
discs = {}
for scale in scales:
discs[str(scale).replace('.', '-')] = Discriminator(**kwargs)
self.discs = nn.ModuleDict(discs)
def forward(self, x):
out_dict = {}
for scale, disc in self.discs.items():
scale = str(scale).replace('-', '.')
key = 'prediction_' + scale
feature_maps, prediction_map = disc(x[key])
out_dict['feature_maps_' + scale] = feature_maps
out_dict['prediction_map_' + scale] = prediction_map
return out_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment