Commit 57463d8d authored by suily's avatar suily
Browse files

init

parents
Pipeline #1918 canceled with stages
import torch
from torch import nn
import torch.nn.functional as F
from src.facerender.modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d, ResBlock3d, SPADEResnetBlock
from src.facerender.modules.dense_motion import DenseMotionNetwork
class OcclusionAwareGenerator(nn.Module):
"""
Generator follows NVIDIA architecture.
"""
def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth,
num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False):
super(OcclusionAwareGenerator, self).__init__()
if dense_motion_params is not None:
self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel,
estimate_occlusion_map=estimate_occlusion_map,
**dense_motion_params)
else:
self.dense_motion_network = None
self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(7, 7), padding=(3, 3))
down_blocks = []
for i in range(num_down_blocks):
in_features = min(max_features, block_expansion * (2 ** i))
out_features = min(max_features, block_expansion * (2 ** (i + 1)))
down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
self.down_blocks = nn.ModuleList(down_blocks)
self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1)
self.reshape_channel = reshape_channel
self.reshape_depth = reshape_depth
self.resblocks_3d = torch.nn.Sequential()
for i in range(num_resblocks):
self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1))
out_features = block_expansion * (2 ** (num_down_blocks))
self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True)
self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1)
self.resblocks_2d = torch.nn.Sequential()
for i in range(num_resblocks):
self.resblocks_2d.add_module('2dr' + str(i), ResBlock2d(out_features, kernel_size=3, padding=1))
up_blocks = []
for i in range(num_down_blocks):
in_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i)))
out_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i - 1)))
up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
self.up_blocks = nn.ModuleList(up_blocks)
self.final = nn.Conv2d(block_expansion, image_channel, kernel_size=(7, 7), padding=(3, 3))
self.estimate_occlusion_map = estimate_occlusion_map
self.image_channel = image_channel
def deform_input(self, inp, deformation):
_, d_old, h_old, w_old, _ = deformation.shape
_, _, d, h, w = inp.shape
if d_old != d or h_old != h or w_old != w:
deformation = deformation.permute(0, 4, 1, 2, 3)
deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear')
deformation = deformation.permute(0, 2, 3, 4, 1)
return F.grid_sample(inp, deformation)
def forward(self, source_image, kp_driving, kp_source):
# Encoding (downsampling) part
out = self.first(source_image)
for i in range(len(self.down_blocks)):
out = self.down_blocks[i](out)
out = self.second(out)
bs, c, h, w = out.shape
# print(out.shape)
feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w)
feature_3d = self.resblocks_3d(feature_3d)
# Transforming feature representation according to deformation and occlusion
output_dict = {}
if self.dense_motion_network is not None:
dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving,
kp_source=kp_source)
output_dict['mask'] = dense_motion['mask']
if 'occlusion_map' in dense_motion:
occlusion_map = dense_motion['occlusion_map']
output_dict['occlusion_map'] = occlusion_map
else:
occlusion_map = None
deformation = dense_motion['deformation']
out = self.deform_input(feature_3d, deformation)
bs, c, d, h, w = out.shape
out = out.view(bs, c*d, h, w)
out = self.third(out)
out = self.fourth(out)
if occlusion_map is not None:
if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:
occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')
out = out * occlusion_map
# output_dict["deformed"] = self.deform_input(source_image, deformation) # 3d deformation cannot deform 2d image
# Decoding part
out = self.resblocks_2d(out)
for i in range(len(self.up_blocks)):
out = self.up_blocks[i](out)
out = self.final(out)
out = F.sigmoid(out)
output_dict["prediction"] = out
return output_dict
class SPADEDecoder(nn.Module):
def __init__(self):
super().__init__()
ic = 256
oc = 64
norm_G = 'spadespectralinstance'
label_nc = 256
self.fc = nn.Conv2d(ic, 2 * ic, 3, padding=1)
self.G_middle_0 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
self.G_middle_1 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
self.G_middle_2 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
self.G_middle_3 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
self.G_middle_4 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
self.G_middle_5 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
self.up_0 = SPADEResnetBlock(2 * ic, ic, norm_G, label_nc)
self.up_1 = SPADEResnetBlock(ic, oc, norm_G, label_nc)
self.conv_img = nn.Conv2d(oc, 3, 3, padding=1)
self.up = nn.Upsample(scale_factor=2)
def forward(self, feature):
seg = feature
x = self.fc(feature)
x = self.G_middle_0(x, seg)
x = self.G_middle_1(x, seg)
x = self.G_middle_2(x, seg)
x = self.G_middle_3(x, seg)
x = self.G_middle_4(x, seg)
x = self.G_middle_5(x, seg)
x = self.up(x)
x = self.up_0(x, seg) # 256, 128, 128
x = self.up(x)
x = self.up_1(x, seg) # 64, 256, 256
x = self.conv_img(F.leaky_relu(x, 2e-1))
# x = torch.tanh(x)
x = F.sigmoid(x)
return x
class OcclusionAwareSPADEGenerator(nn.Module):
def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth,
num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False):
super(OcclusionAwareSPADEGenerator, self).__init__()
if dense_motion_params is not None:
self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel,
estimate_occlusion_map=estimate_occlusion_map,
**dense_motion_params)
else:
self.dense_motion_network = None
self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1))
down_blocks = []
for i in range(num_down_blocks):
in_features = min(max_features, block_expansion * (2 ** i))
out_features = min(max_features, block_expansion * (2 ** (i + 1)))
down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
self.down_blocks = nn.ModuleList(down_blocks)
self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1)
self.reshape_channel = reshape_channel
self.reshape_depth = reshape_depth
self.resblocks_3d = torch.nn.Sequential()
for i in range(num_resblocks):
self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1))
out_features = block_expansion * (2 ** (num_down_blocks))
self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True)
self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1)
self.estimate_occlusion_map = estimate_occlusion_map
self.image_channel = image_channel
self.decoder = SPADEDecoder()
def deform_input(self, inp, deformation):
_, d_old, h_old, w_old, _ = deformation.shape
_, _, d, h, w = inp.shape
if d_old != d or h_old != h or w_old != w:
deformation = deformation.permute(0, 4, 1, 2, 3)
deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear')
deformation = deformation.permute(0, 2, 3, 4, 1)
return F.grid_sample(inp, deformation)
def forward(self, source_image, kp_driving, kp_source):
# Encoding (downsampling) part
out = self.first(source_image)
for i in range(len(self.down_blocks)):
out = self.down_blocks[i](out)
out = self.second(out)
bs, c, h, w = out.shape
# print(out.shape)
feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w)
feature_3d = self.resblocks_3d(feature_3d)
# Transforming feature representation according to deformation and occlusion
output_dict = {}
if self.dense_motion_network is not None:
dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving,
kp_source=kp_source)
output_dict['mask'] = dense_motion['mask']
# import pdb; pdb.set_trace()
if 'occlusion_map' in dense_motion:
occlusion_map = dense_motion['occlusion_map']
output_dict['occlusion_map'] = occlusion_map
else:
occlusion_map = None
deformation = dense_motion['deformation']
out = self.deform_input(feature_3d, deformation)
bs, c, d, h, w = out.shape
out = out.view(bs, c*d, h, w)
out = self.third(out)
out = self.fourth(out)
# occlusion_map = torch.where(occlusion_map < 0.95, 0, occlusion_map)
if occlusion_map is not None:
if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:
occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')
out = out * occlusion_map
# Decoding part
out = self.decoder(out)
output_dict["prediction"] = out
return output_dict
\ No newline at end of file
from torch import nn
import torch
import torch.nn.functional as F
from src.facerender.sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d
from src.facerender.modules.util import KPHourglass, make_coordinate_grid, AntiAliasInterpolation2d, ResBottleneck
class KPDetector(nn.Module):
"""
Detecting canonical keypoints. Return keypoint position and jacobian near each keypoint.
"""
def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, reshape_channel, reshape_depth,
num_blocks, temperature, estimate_jacobian=False, scale_factor=1, single_jacobian_map=False):
super(KPDetector, self).__init__()
self.predictor = KPHourglass(block_expansion, in_features=image_channel,
max_features=max_features, reshape_features=reshape_channel, reshape_depth=reshape_depth, num_blocks=num_blocks)
# self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=7, padding=3)
self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=3, padding=1)
if estimate_jacobian:
self.num_jacobian_maps = 1 if single_jacobian_map else num_kp
# self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=7, padding=3)
self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=3, padding=1)
'''
initial as:
[[1 0 0]
[0 1 0]
[0 0 1]]
'''
self.jacobian.weight.data.zero_()
self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float))
else:
self.jacobian = None
self.temperature = temperature
self.scale_factor = scale_factor
if self.scale_factor != 1:
self.down = AntiAliasInterpolation2d(image_channel, self.scale_factor)
def gaussian2kp(self, heatmap):
"""
Extract the mean from a heatmap
"""
shape = heatmap.shape
heatmap = heatmap.unsqueeze(-1)
grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0)
value = (heatmap * grid).sum(dim=(2, 3, 4))
kp = {'value': value}
return kp
def forward(self, x):
if self.scale_factor != 1:
x = self.down(x)
feature_map = self.predictor(x)
prediction = self.kp(feature_map)
final_shape = prediction.shape
heatmap = prediction.view(final_shape[0], final_shape[1], -1)
heatmap = F.softmax(heatmap / self.temperature, dim=2)
heatmap = heatmap.view(*final_shape)
out = self.gaussian2kp(heatmap)
if self.jacobian is not None:
jacobian_map = self.jacobian(feature_map)
jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 9, final_shape[2],
final_shape[3], final_shape[4])
heatmap = heatmap.unsqueeze(2)
jacobian = heatmap * jacobian_map
jacobian = jacobian.view(final_shape[0], final_shape[1], 9, -1)
jacobian = jacobian.sum(dim=-1)
jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 3, 3)
out['jacobian'] = jacobian
return out
class HEEstimator(nn.Module):
"""
Estimating head pose and expression.
"""
def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, num_bins=66, estimate_jacobian=True):
super(HEEstimator, self).__init__()
self.conv1 = nn.Conv2d(in_channels=image_channel, out_channels=block_expansion, kernel_size=7, padding=3, stride=2)
self.norm1 = BatchNorm2d(block_expansion, affine=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.conv2 = nn.Conv2d(in_channels=block_expansion, out_channels=256, kernel_size=1)
self.norm2 = BatchNorm2d(256, affine=True)
self.block1 = nn.Sequential()
for i in range(3):
self.block1.add_module('b1_'+ str(i), ResBottleneck(in_features=256, stride=1))
self.conv3 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1)
self.norm3 = BatchNorm2d(512, affine=True)
self.block2 = ResBottleneck(in_features=512, stride=2)
self.block3 = nn.Sequential()
for i in range(3):
self.block3.add_module('b3_'+ str(i), ResBottleneck(in_features=512, stride=1))
self.conv4 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1)
self.norm4 = BatchNorm2d(1024, affine=True)
self.block4 = ResBottleneck(in_features=1024, stride=2)
self.block5 = nn.Sequential()
for i in range(5):
self.block5.add_module('b5_'+ str(i), ResBottleneck(in_features=1024, stride=1))
self.conv5 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=1)
self.norm5 = BatchNorm2d(2048, affine=True)
self.block6 = ResBottleneck(in_features=2048, stride=2)
self.block7 = nn.Sequential()
for i in range(2):
self.block7.add_module('b7_'+ str(i), ResBottleneck(in_features=2048, stride=1))
self.fc_roll = nn.Linear(2048, num_bins)
self.fc_pitch = nn.Linear(2048, num_bins)
self.fc_yaw = nn.Linear(2048, num_bins)
self.fc_t = nn.Linear(2048, 3)
self.fc_exp = nn.Linear(2048, 3*num_kp)
def forward(self, x):
out = self.conv1(x)
out = self.norm1(out)
out = F.relu(out)
out = self.maxpool(out)
out = self.conv2(out)
out = self.norm2(out)
out = F.relu(out)
out = self.block1(out)
out = self.conv3(out)
out = self.norm3(out)
out = F.relu(out)
out = self.block2(out)
out = self.block3(out)
out = self.conv4(out)
out = self.norm4(out)
out = F.relu(out)
out = self.block4(out)
out = self.block5(out)
out = self.conv5(out)
out = self.norm5(out)
out = F.relu(out)
out = self.block6(out)
out = self.block7(out)
out = F.adaptive_avg_pool2d(out, 1)
out = out.view(out.shape[0], -1)
yaw = self.fc_roll(out)
pitch = self.fc_pitch(out)
roll = self.fc_yaw(out)
t = self.fc_t(out)
exp = self.fc_exp(out)
return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}
from scipy.spatial import ConvexHull
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False,
use_relative_movement=False, use_relative_jacobian=False):
if adapt_movement_scale:
source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume
driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume
adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
else:
adapt_movement_scale = 1
kp_new = {k: v for k, v in kp_driving.items()}
if use_relative_movement:
kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
kp_value_diff *= adapt_movement_scale
kp_new['value'] = kp_value_diff + kp_source['value']
if use_relative_jacobian:
jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])
return kp_new
def headpose_pred_to_degree(pred):
device = pred.device
idx_tensor = [idx for idx in range(66)]
idx_tensor = torch.FloatTensor(idx_tensor).type_as(pred).to(device)
pred = F.softmax(pred)
degree = torch.sum(pred*idx_tensor, 1) * 3 - 99
return degree
def get_rotation_matrix(yaw, pitch, roll):
yaw = yaw / 180 * 3.14
pitch = pitch / 180 * 3.14
roll = roll / 180 * 3.14
roll = roll.unsqueeze(1)
pitch = pitch.unsqueeze(1)
yaw = yaw.unsqueeze(1)
pitch_mat = torch.cat([torch.ones_like(pitch), torch.zeros_like(pitch), torch.zeros_like(pitch),
torch.zeros_like(pitch), torch.cos(pitch), -torch.sin(pitch),
torch.zeros_like(pitch), torch.sin(pitch), torch.cos(pitch)], dim=1)
pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3)
yaw_mat = torch.cat([torch.cos(yaw), torch.zeros_like(yaw), torch.sin(yaw),
torch.zeros_like(yaw), torch.ones_like(yaw), torch.zeros_like(yaw),
-torch.sin(yaw), torch.zeros_like(yaw), torch.cos(yaw)], dim=1)
yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3)
roll_mat = torch.cat([torch.cos(roll), -torch.sin(roll), torch.zeros_like(roll),
torch.sin(roll), torch.cos(roll), torch.zeros_like(roll),
torch.zeros_like(roll), torch.zeros_like(roll), torch.ones_like(roll)], dim=1)
roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3)
rot_mat = torch.einsum('bij,bjk,bkm->bim', pitch_mat, yaw_mat, roll_mat)
return rot_mat
def keypoint_transformation(kp_canonical, he, wo_exp=False):
kp = kp_canonical['value'] # (bs, k, 3)
yaw, pitch, roll= he['yaw'], he['pitch'], he['roll']
yaw = headpose_pred_to_degree(yaw)
pitch = headpose_pred_to_degree(pitch)
roll = headpose_pred_to_degree(roll)
if 'yaw_in' in he:
yaw = he['yaw_in']
if 'pitch_in' in he:
pitch = he['pitch_in']
if 'roll_in' in he:
roll = he['roll_in']
rot_mat = get_rotation_matrix(yaw, pitch, roll) # (bs, 3, 3)
t, exp = he['t'], he['exp']
if wo_exp:
exp = exp*0
# keypoint rotation
kp_rotated = torch.einsum('bmp,bkp->bkm', rot_mat, kp)
# keypoint translation
t[:, 0] = t[:, 0]*0
t[:, 2] = t[:, 2]*0
t = t.unsqueeze(1).repeat(1, kp.shape[1], 1)
kp_t = kp_rotated + t
# add expression deviation
exp = exp.view(exp.shape[0], -1, 3)
kp_transformed = kp_t + exp
return {'value': kp_transformed}
def make_animation(source_image, source_semantics, target_semantics,
generator, kp_detector, he_estimator, mapping,
yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
use_exp=True, use_half=False):
with torch.no_grad():
predictions = []
kp_canonical = kp_detector(source_image)
he_source = mapping(source_semantics)
kp_source = keypoint_transformation(kp_canonical, he_source)
for frame_idx in tqdm(range(target_semantics.shape[1]), 'Face Renderer:'):
# still check the dimension
# print(target_semantics.shape, source_semantics.shape)
target_semantics_frame = target_semantics[:, frame_idx]
he_driving = mapping(target_semantics_frame)
if yaw_c_seq is not None:
he_driving['yaw_in'] = yaw_c_seq[:, frame_idx]
if pitch_c_seq is not None:
he_driving['pitch_in'] = pitch_c_seq[:, frame_idx]
if roll_c_seq is not None:
he_driving['roll_in'] = roll_c_seq[:, frame_idx]
kp_driving = keypoint_transformation(kp_canonical, he_driving)
kp_norm = kp_driving
out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
'''
source_image_new = out['prediction'].squeeze(1)
kp_canonical_new = kp_detector(source_image_new)
he_source_new = he_estimator(source_image_new)
kp_source_new = keypoint_transformation(kp_canonical_new, he_source_new, wo_exp=True)
kp_driving_new = keypoint_transformation(kp_canonical_new, he_driving, wo_exp=True)
out = generator(source_image_new, kp_source=kp_source_new, kp_driving=kp_driving_new)
'''
predictions.append(out['prediction'])
predictions_ts = torch.stack(predictions, dim=1)
return predictions_ts
class AnimateModel(torch.nn.Module):
"""
Merge all generator related updates into single model for better multi-gpu usage
"""
def __init__(self, generator, kp_extractor, mapping):
super(AnimateModel, self).__init__()
self.kp_extractor = kp_extractor
self.generator = generator
self.mapping = mapping
self.kp_extractor.eval()
self.generator.eval()
self.mapping.eval()
def forward(self, x):
source_image = x['source_image']
source_semantics = x['source_semantics']
target_semantics = x['target_semantics']
yaw_c_seq = x['yaw_c_seq']
pitch_c_seq = x['pitch_c_seq']
roll_c_seq = x['roll_c_seq']
predictions_video = make_animation(source_image, source_semantics, target_semantics,
self.generator, self.kp_extractor,
self.mapping, use_exp = True,
yaw_c_seq=yaw_c_seq, pitch_c_seq=pitch_c_seq, roll_c_seq=roll_c_seq)
return predictions_video
\ No newline at end of file
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class MappingNet(nn.Module):
def __init__(self, coeff_nc, descriptor_nc, layer, num_kp, num_bins):
super( MappingNet, self).__init__()
self.layer = layer
nonlinearity = nn.LeakyReLU(0.1)
self.first = nn.Sequential(
torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True))
for i in range(layer):
net = nn.Sequential(nonlinearity,
torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3))
setattr(self, 'encoder' + str(i), net)
self.pooling = nn.AdaptiveAvgPool1d(1)
self.output_nc = descriptor_nc
self.fc_roll = nn.Linear(descriptor_nc, num_bins)
self.fc_pitch = nn.Linear(descriptor_nc, num_bins)
self.fc_yaw = nn.Linear(descriptor_nc, num_bins)
self.fc_t = nn.Linear(descriptor_nc, 3)
self.fc_exp = nn.Linear(descriptor_nc, 3*num_kp)
def forward(self, input_3dmm):
out = self.first(input_3dmm)
for i in range(self.layer):
model = getattr(self, 'encoder' + str(i))
out = model(out) + out[:,:,3:-3]
out = self.pooling(out)
out = out.view(out.shape[0], -1)
#print('out:', out.shape)
yaw = self.fc_yaw(out)
pitch = self.fc_pitch(out)
roll = self.fc_roll(out)
t = self.fc_t(out)
exp = self.fc_exp(out)
return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}
\ No newline at end of file
from torch import nn
import torch.nn.functional as F
import torch
from src.facerender.sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d
from src.facerender.sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d
import torch.nn.utils.spectral_norm as spectral_norm
def kp2gaussian(kp, spatial_size, kp_variance):
"""
Transform a keypoint into gaussian like representation
"""
mean = kp['value']
coordinate_grid = make_coordinate_grid(spatial_size, mean.type())
number_of_leading_dimensions = len(mean.shape) - 1
shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
coordinate_grid = coordinate_grid.view(*shape)
repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1)
coordinate_grid = coordinate_grid.repeat(*repeats)
# Preprocess kp shape
shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3)
mean = mean.view(*shape)
mean_sub = (coordinate_grid - mean)
out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
return out
def make_coordinate_grid_2d(spatial_size, type):
"""
Create a meshgrid [-1,1] x [-1,1] of given spatial_size.
"""
h, w = spatial_size
x = torch.arange(w).type(type)
y = torch.arange(h).type(type)
x = (2 * (x / (w - 1)) - 1)
y = (2 * (y / (h - 1)) - 1)
yy = y.view(-1, 1).repeat(1, w)
xx = x.view(1, -1).repeat(h, 1)
meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
return meshed
def make_coordinate_grid(spatial_size, type):
d, h, w = spatial_size
x = torch.arange(w).type(type)
y = torch.arange(h).type(type)
z = torch.arange(d).type(type)
x = (2 * (x / (w - 1)) - 1)
y = (2 * (y / (h - 1)) - 1)
z = (2 * (z / (d - 1)) - 1)
yy = y.view(1, -1, 1).repeat(d, 1, w)
xx = x.view(1, 1, -1).repeat(d, h, 1)
zz = z.view(-1, 1, 1).repeat(1, h, w)
meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3)
return meshed
class ResBottleneck(nn.Module):
def __init__(self, in_features, stride):
super(ResBottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features//4, kernel_size=1)
self.conv2 = nn.Conv2d(in_channels=in_features//4, out_channels=in_features//4, kernel_size=3, padding=1, stride=stride)
self.conv3 = nn.Conv2d(in_channels=in_features//4, out_channels=in_features, kernel_size=1)
self.norm1 = BatchNorm2d(in_features//4, affine=True)
self.norm2 = BatchNorm2d(in_features//4, affine=True)
self.norm3 = BatchNorm2d(in_features, affine=True)
self.stride = stride
if self.stride != 1:
self.skip = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=1, stride=stride)
self.norm4 = BatchNorm2d(in_features, affine=True)
def forward(self, x):
out = self.conv1(x)
out = self.norm1(out)
out = F.relu(out)
out = self.conv2(out)
out = self.norm2(out)
out = F.relu(out)
out = self.conv3(out)
out = self.norm3(out)
if self.stride != 1:
x = self.skip(x)
x = self.norm4(x)
out += x
out = F.relu(out)
return out
class ResBlock2d(nn.Module):
"""
Res block, preserve spatial resolution.
"""
def __init__(self, in_features, kernel_size, padding):
super(ResBlock2d, self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
padding=padding)
self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
padding=padding)
self.norm1 = BatchNorm2d(in_features, affine=True)
self.norm2 = BatchNorm2d(in_features, affine=True)
def forward(self, x):
out = self.norm1(x)
out = F.relu(out)
out = self.conv1(out)
out = self.norm2(out)
out = F.relu(out)
out = self.conv2(out)
out += x
return out
class ResBlock3d(nn.Module):
"""
Res block, preserve spatial resolution.
"""
def __init__(self, in_features, kernel_size, padding):
super(ResBlock3d, self).__init__()
self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
padding=padding)
self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
padding=padding)
self.norm1 = BatchNorm3d(in_features, affine=True)
self.norm2 = BatchNorm3d(in_features, affine=True)
def forward(self, x):
out = self.norm1(x)
out = F.relu(out)
out = self.conv1(out)
out = self.norm2(out)
out = F.relu(out)
out = self.conv2(out)
out += x
return out
class UpBlock2d(nn.Module):
"""
Upsampling block for use in decoder.
"""
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
super(UpBlock2d, self).__init__()
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
padding=padding, groups=groups)
self.norm = BatchNorm2d(out_features, affine=True)
def forward(self, x):
out = F.interpolate(x, scale_factor=2)
out = self.conv(out)
out = self.norm(out)
out = F.relu(out)
return out
class UpBlock3d(nn.Module):
"""
Upsampling block for use in decoder.
"""
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
super(UpBlock3d, self).__init__()
self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
padding=padding, groups=groups)
self.norm = BatchNorm3d(out_features, affine=True)
def forward(self, x):
# out = F.interpolate(x, scale_factor=(1, 2, 2), mode='trilinear')
out = F.interpolate(x, scale_factor=(1, 2, 2))
out = self.conv(out)
out = self.norm(out)
out = F.relu(out)
return out
class DownBlock2d(nn.Module):
"""
Downsampling block for use in encoder.
"""
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
super(DownBlock2d, self).__init__()
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
padding=padding, groups=groups)
self.norm = BatchNorm2d(out_features, affine=True)
self.pool = nn.AvgPool2d(kernel_size=(2, 2))
def forward(self, x):
out = self.conv(x)
out = self.norm(out)
out = F.relu(out)
out = self.pool(out)
return out
class DownBlock3d(nn.Module):
"""
Downsampling block for use in encoder.
"""
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
super(DownBlock3d, self).__init__()
'''
self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
padding=padding, groups=groups, stride=(1, 2, 2))
'''
self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
padding=padding, groups=groups)
self.norm = BatchNorm3d(out_features, affine=True)
self.pool = nn.AvgPool3d(kernel_size=(1, 2, 2))
def forward(self, x):
out = self.conv(x)
out = self.norm(out)
out = F.relu(out)
out = self.pool(out)
return out
class SameBlock2d(nn.Module):
"""
Simple block, preserve spatial resolution.
"""
def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1, lrelu=False):
super(SameBlock2d, self).__init__()
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features,
kernel_size=kernel_size, padding=padding, groups=groups)
self.norm = BatchNorm2d(out_features, affine=True)
if lrelu:
self.ac = nn.LeakyReLU()
else:
self.ac = nn.ReLU()
def forward(self, x):
out = self.conv(x)
out = self.norm(out)
out = self.ac(out)
return out
class Encoder(nn.Module):
"""
Hourglass Encoder
"""
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
super(Encoder, self).__init__()
down_blocks = []
for i in range(num_blocks):
down_blocks.append(DownBlock3d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),
min(max_features, block_expansion * (2 ** (i + 1))),
kernel_size=3, padding=1))
self.down_blocks = nn.ModuleList(down_blocks)
def forward(self, x):
outs = [x]
for down_block in self.down_blocks:
outs.append(down_block(outs[-1]))
return outs
class Decoder(nn.Module):
"""
Hourglass Decoder
"""
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
super(Decoder, self).__init__()
up_blocks = []
for i in range(num_blocks)[::-1]:
in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
out_filters = min(max_features, block_expansion * (2 ** i))
up_blocks.append(UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1))
self.up_blocks = nn.ModuleList(up_blocks)
# self.out_filters = block_expansion
self.out_filters = block_expansion + in_features
self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1)
self.norm = BatchNorm3d(self.out_filters, affine=True)
def forward(self, x):
out = x.pop()
# for up_block in self.up_blocks[:-1]:
for up_block in self.up_blocks:
out = up_block(out)
skip = x.pop()
out = torch.cat([out, skip], dim=1)
# out = self.up_blocks[-1](out)
out = self.conv(out)
out = self.norm(out)
out = F.relu(out)
return out
class Hourglass(nn.Module):
"""
Hourglass architecture.
"""
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
super(Hourglass, self).__init__()
self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
self.out_filters = self.decoder.out_filters
def forward(self, x):
return self.decoder(self.encoder(x))
class KPHourglass(nn.Module):
"""
Hourglass architecture.
"""
def __init__(self, block_expansion, in_features, reshape_features, reshape_depth, num_blocks=3, max_features=256):
super(KPHourglass, self).__init__()
self.down_blocks = nn.Sequential()
for i in range(num_blocks):
self.down_blocks.add_module('down'+ str(i), DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),
min(max_features, block_expansion * (2 ** (i + 1))),
kernel_size=3, padding=1))
in_filters = min(max_features, block_expansion * (2 ** num_blocks))
self.conv = nn.Conv2d(in_channels=in_filters, out_channels=reshape_features, kernel_size=1)
self.up_blocks = nn.Sequential()
for i in range(num_blocks):
in_filters = min(max_features, block_expansion * (2 ** (num_blocks - i)))
out_filters = min(max_features, block_expansion * (2 ** (num_blocks - i - 1)))
self.up_blocks.add_module('up'+ str(i), UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1))
self.reshape_depth = reshape_depth
self.out_filters = out_filters
def forward(self, x):
out = self.down_blocks(x)
out = self.conv(out)
bs, c, h, w = out.shape
out = out.view(bs, c//self.reshape_depth, self.reshape_depth, h, w)
out = self.up_blocks(out)
return out
class AntiAliasInterpolation2d(nn.Module):
"""
Band-limited downsampling, for better preservation of the input signal.
"""
def __init__(self, channels, scale):
super(AntiAliasInterpolation2d, self).__init__()
sigma = (1 / scale - 1) / 2
kernel_size = 2 * round(sigma * 4) + 1
self.ka = kernel_size // 2
self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
kernel_size = [kernel_size, kernel_size]
sigma = [sigma, sigma]
# The gaussian kernel is the product of the
# gaussian function of each dimension.
kernel = 1
meshgrids = torch.meshgrid(
[
torch.arange(size, dtype=torch.float32)
for size in kernel_size
]
)
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
mean = (size - 1) / 2
kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
# Make sure sum of values in gaussian kernel equals 1.
kernel = kernel / torch.sum(kernel)
# Reshape to depthwise convolutional weight
kernel = kernel.view(1, 1, *kernel.size())
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
self.register_buffer('weight', kernel)
self.groups = channels
self.scale = scale
inv_scale = 1 / scale
self.int_inv_scale = int(inv_scale)
def forward(self, input):
if self.scale == 1.0:
return input
out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
out = F.conv2d(out, weight=self.weight, groups=self.groups)
out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]
return out
class SPADE(nn.Module):
def __init__(self, norm_nc, label_nc):
super().__init__()
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
nhidden = 128
self.mlp_shared = nn.Sequential(
nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1),
nn.ReLU())
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
def forward(self, x, segmap):
normalized = self.param_free_norm(x)
segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
actv = self.mlp_shared(segmap)
gamma = self.mlp_gamma(actv)
beta = self.mlp_beta(actv)
out = normalized * (1 + gamma) + beta
return out
class SPADEResnetBlock(nn.Module):
def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1):
super().__init__()
# Attributes
self.learned_shortcut = (fin != fout)
fmiddle = min(fin, fout)
self.use_se = use_se
# create conv layers
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation)
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation)
if self.learned_shortcut:
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
# apply spectral norm if specified
if 'spectral' in norm_G:
self.conv_0 = spectral_norm(self.conv_0)
self.conv_1 = spectral_norm(self.conv_1)
if self.learned_shortcut:
self.conv_s = spectral_norm(self.conv_s)
# define normalization layers
self.norm_0 = SPADE(fin, label_nc)
self.norm_1 = SPADE(fmiddle, label_nc)
if self.learned_shortcut:
self.norm_s = SPADE(fin, label_nc)
def forward(self, x, seg1):
x_s = self.shortcut(x, seg1)
dx = self.conv_0(self.actvn(self.norm_0(x, seg1)))
dx = self.conv_1(self.actvn(self.norm_1(dx, seg1)))
out = x_s + dx
return out
def shortcut(self, x, seg1):
if self.learned_shortcut:
x_s = self.conv_s(self.norm_s(x, seg1))
else:
x_s = x
return x_s
def actvn(self, x):
return F.leaky_relu(x, 2e-1)
class audio2image(nn.Module):
def __init__(self, generator, kp_extractor, he_estimator_video, he_estimator_audio, train_params):
super().__init__()
# Attributes
self.generator = generator
self.kp_extractor = kp_extractor
self.he_estimator_video = he_estimator_video
self.he_estimator_audio = he_estimator_audio
self.train_params = train_params
def headpose_pred_to_degree(self, pred):
device = pred.device
idx_tensor = [idx for idx in range(66)]
idx_tensor = torch.FloatTensor(idx_tensor).to(device)
pred = F.softmax(pred)
degree = torch.sum(pred*idx_tensor, 1) * 3 - 99
return degree
def get_rotation_matrix(self, yaw, pitch, roll):
yaw = yaw / 180 * 3.14
pitch = pitch / 180 * 3.14
roll = roll / 180 * 3.14
roll = roll.unsqueeze(1)
pitch = pitch.unsqueeze(1)
yaw = yaw.unsqueeze(1)
roll_mat = torch.cat([torch.ones_like(roll), torch.zeros_like(roll), torch.zeros_like(roll),
torch.zeros_like(roll), torch.cos(roll), -torch.sin(roll),
torch.zeros_like(roll), torch.sin(roll), torch.cos(roll)], dim=1)
roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3)
pitch_mat = torch.cat([torch.cos(pitch), torch.zeros_like(pitch), torch.sin(pitch),
torch.zeros_like(pitch), torch.ones_like(pitch), torch.zeros_like(pitch),
-torch.sin(pitch), torch.zeros_like(pitch), torch.cos(pitch)], dim=1)
pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3)
yaw_mat = torch.cat([torch.cos(yaw), -torch.sin(yaw), torch.zeros_like(yaw),
torch.sin(yaw), torch.cos(yaw), torch.zeros_like(yaw),
torch.zeros_like(yaw), torch.zeros_like(yaw), torch.ones_like(yaw)], dim=1)
yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3)
rot_mat = torch.einsum('bij,bjk,bkm->bim', roll_mat, pitch_mat, yaw_mat)
return rot_mat
def keypoint_transformation(self, kp_canonical, he):
kp = kp_canonical['value'] # (bs, k, 3)
yaw, pitch, roll = he['yaw'], he['pitch'], he['roll']
t, exp = he['t'], he['exp']
yaw = self.headpose_pred_to_degree(yaw)
pitch = self.headpose_pred_to_degree(pitch)
roll = self.headpose_pred_to_degree(roll)
rot_mat = self.get_rotation_matrix(yaw, pitch, roll) # (bs, 3, 3)
# keypoint rotation
kp_rotated = torch.einsum('bmp,bkp->bkm', rot_mat, kp)
# keypoint translation
t = t.unsqueeze_(1).repeat(1, kp.shape[1], 1)
kp_t = kp_rotated + t
# add expression deviation
exp = exp.view(exp.shape[0], -1, 3)
kp_transformed = kp_t + exp
return {'value': kp_transformed}
def forward(self, source_image, target_audio):
pose_source = self.he_estimator_video(source_image)
pose_generated = self.he_estimator_audio(target_audio)
kp_canonical = self.kp_extractor(source_image)
kp_source = self.keypoint_transformation(kp_canonical, pose_source)
kp_transformed_generated = self.keypoint_transformation(kp_canonical, pose_generated)
generated = self.generator(source_image, kp_source=kp_source, kp_driving=kp_transformed_generated)
return generated
\ No newline at end of file
# -*- coding: utf-8 -*-
# File : __init__.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
from .replicate import DataParallelWithCallback, patch_replication_callback
# -*- coding: utf-8 -*-
# File : batchnorm.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import collections
import torch
import torch.nn.functional as F
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
from .comm import SyncMaster
__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
def _sum_ft(tensor):
"""sum over the first and last dimention"""
return tensor.sum(dim=0).sum(dim=-1)
def _unsqueeze_ft(tensor):
"""add new dementions at the front and the tail"""
return tensor.unsqueeze(0).unsqueeze(-1)
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
class _SynchronizedBatchNorm(_BatchNorm):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
self._sync_master = SyncMaster(self._data_parallel_master)
self._is_parallel = False
self._parallel_id = None
self._slave_pipe = None
def forward(self, input):
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
if not (self._is_parallel and self.training):
return F.batch_norm(
input, self.running_mean, self.running_var, self.weight, self.bias,
self.training, self.momentum, self.eps)
# Resize the input to (B, C, -1).
input_shape = input.size()
input = input.view(input.size(0), self.num_features, -1)
# Compute the sum and square-sum.
sum_size = input.size(0) * input.size(2)
input_sum = _sum_ft(input)
input_ssum = _sum_ft(input ** 2)
# Reduce-and-broadcast the statistics.
if self._parallel_id == 0:
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
else:
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
# Compute the output.
if self.affine:
# MJY:: Fuse the multiplication for speed.
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
else:
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
# Reshape it.
return output.view(input_shape)
def __data_parallel_replicate__(self, ctx, copy_id):
self._is_parallel = True
self._parallel_id = copy_id
# parallel_id == 0 means master device.
if self._parallel_id == 0:
ctx.sync_master = self._sync_master
else:
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
def _data_parallel_master(self, intermediates):
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
# Always using same "device order" makes the ReduceAdd operation faster.
# Thanks to:: Tete Xiao (http://tetexiao.com/)
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
to_reduce = [i[1][:2] for i in intermediates]
to_reduce = [j for i in to_reduce for j in i] # flatten
target_gpus = [i[1].sum.get_device() for i in intermediates]
sum_size = sum([i[1].sum_size for i in intermediates])
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
outputs = []
for i, rec in enumerate(intermediates):
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
return outputs
def _compute_mean_std(self, sum_, ssum, size):
"""Compute the mean and standard-deviation with sum and square-sum. This method
also maintains the moving average on the master device."""
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
mean = sum_ / size
sumvar = ssum - sum_ * mean
unbias_var = sumvar / (size - 1)
bias_var = sumvar / size
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
return mean, bias_var.clamp(self.eps) ** -0.5
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
mini-batch.
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
This module differs from the built-in PyTorch BatchNorm1d as the mean and
standard-deviation are reduced across all devices during training.
For example, when one uses `nn.DataParallel` to wrap the network during
training, PyTorch's implementation normalize the tensor on each device using
the statistics only on that device, which accelerated the computation and
is also easy to implement, but the statistics might be inaccurate.
Instead, in this synchronized version, the statistics will be computed
over all training samples distributed on multiple devices.
Note that, for one-GPU or CPU-only case, this module behaves exactly same
as the built-in PyTorch implementation.
The mean and standard-deviation are calculated per-dimension over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
During training, this layer keeps a running estimate of its computed mean
and variance. The running sum is kept with a default momentum of 0.1.
During evaluation, this running mean/variance is used for normalization.
Because the BatchNorm is done over the `C` dimension, computing statistics
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
Args:
num_features: num_features from an expected input of size
`batch_size x num_features [x width]`
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
affine: a boolean value that when set to ``True``, gives the layer learnable
affine parameters. Default: ``True``
Shape:
- Input: :math:`(N, C)` or :math:`(N, C, L)`
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
Examples:
>>> # With Learnable Parameters
>>> m = SynchronizedBatchNorm1d(100)
>>> # Without Learnable Parameters
>>> m = SynchronizedBatchNorm1d(100, affine=False)
>>> input = torch.autograd.Variable(torch.randn(20, 100))
>>> output = m(input)
"""
def _check_input_dim(self, input):
if input.dim() != 2 and input.dim() != 3:
raise ValueError('expected 2D or 3D input (got {}D input)'
.format(input.dim()))
super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
of 3d inputs
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
This module differs from the built-in PyTorch BatchNorm2d as the mean and
standard-deviation are reduced across all devices during training.
For example, when one uses `nn.DataParallel` to wrap the network during
training, PyTorch's implementation normalize the tensor on each device using
the statistics only on that device, which accelerated the computation and
is also easy to implement, but the statistics might be inaccurate.
Instead, in this synchronized version, the statistics will be computed
over all training samples distributed on multiple devices.
Note that, for one-GPU or CPU-only case, this module behaves exactly same
as the built-in PyTorch implementation.
The mean and standard-deviation are calculated per-dimension over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
During training, this layer keeps a running estimate of its computed mean
and variance. The running sum is kept with a default momentum of 0.1.
During evaluation, this running mean/variance is used for normalization.
Because the BatchNorm is done over the `C` dimension, computing statistics
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
Args:
num_features: num_features from an expected input of
size batch_size x num_features x height x width
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
affine: a boolean value that when set to ``True``, gives the layer learnable
affine parameters. Default: ``True``
Shape:
- Input: :math:`(N, C, H, W)`
- Output: :math:`(N, C, H, W)` (same shape as input)
Examples:
>>> # With Learnable Parameters
>>> m = SynchronizedBatchNorm2d(100)
>>> # Without Learnable Parameters
>>> m = SynchronizedBatchNorm2d(100, affine=False)
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
>>> output = m(input)
"""
def _check_input_dim(self, input):
if input.dim() != 4:
raise ValueError('expected 4D input (got {}D input)'
.format(input.dim()))
super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
of 4d inputs
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
This module differs from the built-in PyTorch BatchNorm3d as the mean and
standard-deviation are reduced across all devices during training.
For example, when one uses `nn.DataParallel` to wrap the network during
training, PyTorch's implementation normalize the tensor on each device using
the statistics only on that device, which accelerated the computation and
is also easy to implement, but the statistics might be inaccurate.
Instead, in this synchronized version, the statistics will be computed
over all training samples distributed on multiple devices.
Note that, for one-GPU or CPU-only case, this module behaves exactly same
as the built-in PyTorch implementation.
The mean and standard-deviation are calculated per-dimension over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
During training, this layer keeps a running estimate of its computed mean
and variance. The running sum is kept with a default momentum of 0.1.
During evaluation, this running mean/variance is used for normalization.
Because the BatchNorm is done over the `C` dimension, computing statistics
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
or Spatio-temporal BatchNorm
Args:
num_features: num_features from an expected input of
size batch_size x num_features x depth x height x width
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
affine: a boolean value that when set to ``True``, gives the layer learnable
affine parameters. Default: ``True``
Shape:
- Input: :math:`(N, C, D, H, W)`
- Output: :math:`(N, C, D, H, W)` (same shape as input)
Examples:
>>> # With Learnable Parameters
>>> m = SynchronizedBatchNorm3d(100)
>>> # Without Learnable Parameters
>>> m = SynchronizedBatchNorm3d(100, affine=False)
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
>>> output = m(input)
"""
def _check_input_dim(self, input):
if input.dim() != 5:
raise ValueError('expected 5D input (got {}D input)'
.format(input.dim()))
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
# -*- coding: utf-8 -*-
# File : comm.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import queue
import collections
import threading
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
class FutureResult(object):
"""A thread-safe future implementation. Used only as one-to-one pipe."""
def __init__(self):
self._result = None
self._lock = threading.Lock()
self._cond = threading.Condition(self._lock)
def put(self, result):
with self._lock:
assert self._result is None, 'Previous result has\'t been fetched.'
self._result = result
self._cond.notify()
def get(self):
with self._lock:
if self._result is None:
self._cond.wait()
res = self._result
self._result = None
return res
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
class SlavePipe(_SlavePipeBase):
"""Pipe for master-slave communication."""
def run_slave(self, msg):
self.queue.put((self.identifier, msg))
ret = self.result.get()
self.queue.put(True)
return ret
class SyncMaster(object):
"""An abstract `SyncMaster` object.
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
and passed to a registered callback.
- After receiving the messages, the master device should gather the information and determine to message passed
back to each slave devices.
"""
def __init__(self, master_callback):
"""
Args:
master_callback: a callback to be invoked after having collected messages from slave devices.
"""
self._master_callback = master_callback
self._queue = queue.Queue()
self._registry = collections.OrderedDict()
self._activated = False
def __getstate__(self):
return {'master_callback': self._master_callback}
def __setstate__(self, state):
self.__init__(state['master_callback'])
def register_slave(self, identifier):
"""
Register an slave device.
Args:
identifier: an identifier, usually is the device id.
Returns: a `SlavePipe` object which can be used to communicate with the master device.
"""
if self._activated:
assert self._queue.empty(), 'Queue is not clean before next initialization.'
self._activated = False
self._registry.clear()
future = FutureResult()
self._registry[identifier] = _MasterRegistry(future)
return SlavePipe(identifier, self._queue, future)
def run_master(self, master_msg):
"""
Main entry for the master device in each forward pass.
The messages were first collected from each devices (including the master device), and then
an callback will be invoked to compute the message to be sent back to each devices
(including the master device).
Args:
master_msg: the message that the master want to send to itself. This will be placed as the first
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
Returns: the message to be sent back to the master device.
"""
self._activated = True
intermediates = [(0, master_msg)]
for i in range(self.nr_slaves):
intermediates.append(self._queue.get())
results = self._master_callback(intermediates)
assert results[0][0] == 0, 'The first result should belongs to the master.'
for i, res in results:
if i == 0:
continue
self._registry[i].result.put(res)
for i in range(self.nr_slaves):
assert self._queue.get() is True
return results[0][1]
@property
def nr_slaves(self):
return len(self._registry)
# -*- coding: utf-8 -*-
# File : replicate.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import functools
from torch.nn.parallel.data_parallel import DataParallel
__all__ = [
'CallbackContext',
'execute_replication_callbacks',
'DataParallelWithCallback',
'patch_replication_callback'
]
class CallbackContext(object):
pass
def execute_replication_callbacks(modules):
"""
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
Note that, as all modules are isomorphism, we assign each sub-module with a context
(shared among multiple copies of this module on different devices).
Through this context, different copies can share some information.
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
of any slave copies.
"""
master_copy = modules[0]
nr_modules = len(list(master_copy.modules()))
ctxs = [CallbackContext() for _ in range(nr_modules)]
for i, module in enumerate(modules):
for j, m in enumerate(module.modules()):
if hasattr(m, '__data_parallel_replicate__'):
m.__data_parallel_replicate__(ctxs[j], i)
class DataParallelWithCallback(DataParallel):
"""
Data Parallel with a replication callback.
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
original `replicate` function.
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
Examples:
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
# sync_bn.__data_parallel_replicate__ will be invoked.
"""
def replicate(self, module, device_ids):
modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
execute_replication_callbacks(modules)
return modules
def patch_replication_callback(data_parallel):
"""
Monkey-patch an existing `DataParallel` object. Add the replication callback.
Useful when you have customized `DataParallel` implementation.
Examples:
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
> patch_replication_callback(sync_bn)
# this is equivalent to
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
"""
assert isinstance(data_parallel, DataParallel)
old_replicate = data_parallel.replicate
@functools.wraps(old_replicate)
def new_replicate(module, device_ids):
modules = old_replicate(module, device_ids)
execute_replication_callbacks(modules)
return modules
data_parallel.replicate = new_replicate
# -*- coding: utf-8 -*-
# File : unittest.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import unittest
import numpy as np
from torch.autograd import Variable
def as_numpy(v):
if isinstance(v, Variable):
v = v.data
return v.cpu().numpy()
class TorchTestCase(unittest.TestCase):
def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
npa, npb = as_numpy(a), as_numpy(b)
self.assertTrue(
np.allclose(npa, npb, atol=atol),
'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
)
import os
from tqdm import tqdm
import torch
import numpy as np
import random
import scipy.io as scio
import src.utils.audio as audio
def crop_pad_audio(wav, audio_length):
if len(wav) > audio_length:
wav = wav[:audio_length]
elif len(wav) < audio_length:
wav = np.pad(wav, [0, audio_length - len(wav)], mode='constant', constant_values=0)
return wav
def parse_audio_length(audio_length, sr, fps):
bit_per_frames = sr / fps
num_frames = int(audio_length / bit_per_frames)
audio_length = int(num_frames * bit_per_frames)
return audio_length, num_frames
def generate_blink_seq(num_frames):
ratio = np.zeros((num_frames,1))
frame_id = 0
while frame_id in range(num_frames):
start = 80
if frame_id+start+9<=num_frames - 1:
ratio[frame_id+start:frame_id+start+9, 0] = [0.5,0.6,0.7,0.9,1, 0.9, 0.7,0.6,0.5]
frame_id = frame_id+start+9
else:
break
return ratio
def generate_blink_seq_randomly(num_frames):
ratio = np.zeros((num_frames,1))
if num_frames<=20:
return ratio
frame_id = 0
while frame_id in range(num_frames):
start = random.choice(range(min(10,num_frames), min(int(num_frames/2), 70)))
if frame_id+start+5<=num_frames - 1:
ratio[frame_id+start:frame_id+start+5, 0] = [0.5, 0.9, 1.0, 0.9, 0.5]
frame_id = frame_id+start+5
else:
break
return ratio
def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=False, idlemode=False, length_of_audio=False, use_blink=True):
syncnet_mel_step_size = 16
fps = 25
pic_name = os.path.splitext(os.path.split(first_coeff_path)[-1])[0]
audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
if idlemode:
num_frames = int(length_of_audio * 25)
indiv_mels = np.zeros((num_frames, 80, 16))
else:
wav = audio.load_wav(audio_path, 16000)
wav_length, num_frames = parse_audio_length(len(wav), 16000, 25)
wav = crop_pad_audio(wav, wav_length)
orig_mel = audio.melspectrogram(wav).T
spec = orig_mel.copy() # nframes 80
indiv_mels = []
for i in tqdm(range(num_frames), 'mel:'):
start_frame_num = i-2
start_idx = int(80. * (start_frame_num / float(fps)))
end_idx = start_idx + syncnet_mel_step_size
seq = list(range(start_idx, end_idx))
seq = [ min(max(item, 0), orig_mel.shape[0]-1) for item in seq ]
m = spec[seq, :]
indiv_mels.append(m.T)
indiv_mels = np.asarray(indiv_mels) # T 80 16
ratio = generate_blink_seq_randomly(num_frames) # T
source_semantics_path = first_coeff_path
source_semantics_dict = scio.loadmat(source_semantics_path)
ref_coeff = source_semantics_dict['coeff_3dmm'][:1,:70] #1 70
ref_coeff = np.repeat(ref_coeff, num_frames, axis=0)
if ref_eyeblink_coeff_path is not None:
ratio[:num_frames] = 0
refeyeblink_coeff_dict = scio.loadmat(ref_eyeblink_coeff_path)
refeyeblink_coeff = refeyeblink_coeff_dict['coeff_3dmm'][:,:64]
refeyeblink_num_frames = refeyeblink_coeff.shape[0]
if refeyeblink_num_frames<num_frames:
div = num_frames//refeyeblink_num_frames
re = num_frames%refeyeblink_num_frames
refeyeblink_coeff_list = [refeyeblink_coeff for i in range(div)]
refeyeblink_coeff_list.append(refeyeblink_coeff[:re, :64])
refeyeblink_coeff = np.concatenate(refeyeblink_coeff_list, axis=0)
print(refeyeblink_coeff.shape[0])
ref_coeff[:, :64] = refeyeblink_coeff[:num_frames, :64]
indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1).unsqueeze(0) # bs T 1 80 16
if use_blink:
ratio = torch.FloatTensor(ratio).unsqueeze(0) # bs T
else:
ratio = torch.FloatTensor(ratio).unsqueeze(0).fill_(0.)
# bs T
ref_coeff = torch.FloatTensor(ref_coeff).unsqueeze(0) # bs 1 70
indiv_mels = indiv_mels.to(device)
ratio = ratio.to(device)
ref_coeff = ref_coeff.to(device)
return {'indiv_mels': indiv_mels,
'ref': ref_coeff,
'num_frames': num_frames,
'ratio_gt': ratio,
'audio_name': audio_name, 'pic_name': pic_name}
import os
import numpy as np
from PIL import Image
from skimage import io, img_as_float32, transform
import torch
import scipy.io as scio
def get_facerender_data(coeff_path, pic_path, first_coeff_path, audio_path,
batch_size, input_yaw_list=None, input_pitch_list=None, input_roll_list=None,
expression_scale=1.0, still_mode = False, preprocess='crop', size = 256):
semantic_radius = 13
video_name = os.path.splitext(os.path.split(coeff_path)[-1])[0]
txt_path = os.path.splitext(coeff_path)[0]
data={}
img1 = Image.open(pic_path)
source_image = np.array(img1)
source_image = img_as_float32(source_image)
source_image = transform.resize(source_image, (size, size, 3))
source_image = source_image.transpose((2, 0, 1))
source_image_ts = torch.FloatTensor(source_image).unsqueeze(0)
source_image_ts = source_image_ts.repeat(batch_size, 1, 1, 1)
data['source_image'] = source_image_ts
source_semantics_dict = scio.loadmat(first_coeff_path)
generated_dict = scio.loadmat(coeff_path)
if 'full' not in preprocess.lower():
source_semantics = source_semantics_dict['coeff_3dmm'][:1,:70] #1 70
generated_3dmm = generated_dict['coeff_3dmm'][:,:70]
else:
source_semantics = source_semantics_dict['coeff_3dmm'][:1,:73] #1 70
generated_3dmm = generated_dict['coeff_3dmm'][:,:70]
source_semantics_new = transform_semantic_1(source_semantics, semantic_radius)
source_semantics_ts = torch.FloatTensor(source_semantics_new).unsqueeze(0)
source_semantics_ts = source_semantics_ts.repeat(batch_size, 1, 1)
data['source_semantics'] = source_semantics_ts
# target
generated_3dmm[:, :64] = generated_3dmm[:, :64] * expression_scale
if 'full' in preprocess.lower():
generated_3dmm = np.concatenate([generated_3dmm, np.repeat(source_semantics[:,70:], generated_3dmm.shape[0], axis=0)], axis=1)
if still_mode:
generated_3dmm[:, 64:] = np.repeat(source_semantics[:, 64:], generated_3dmm.shape[0], axis=0)
with open(txt_path+'.txt', 'w') as f:
for coeff in generated_3dmm:
for i in coeff:
f.write(str(i)[:7] + ' '+'\t')
f.write('\n')
target_semantics_list = []
frame_num = generated_3dmm.shape[0]
data['frame_num'] = frame_num
for frame_idx in range(frame_num):
target_semantics = transform_semantic_target(generated_3dmm, frame_idx, semantic_radius)
target_semantics_list.append(target_semantics)
remainder = frame_num%batch_size
if remainder!=0:
for _ in range(batch_size-remainder):
target_semantics_list.append(target_semantics)
target_semantics_np = np.array(target_semantics_list) #frame_num 70 semantic_radius*2+1
target_semantics_np = target_semantics_np.reshape(batch_size, -1, target_semantics_np.shape[-2], target_semantics_np.shape[-1])
data['target_semantics_list'] = torch.FloatTensor(target_semantics_np)
data['video_name'] = video_name
data['audio_path'] = audio_path
if input_yaw_list is not None:
yaw_c_seq = gen_camera_pose(input_yaw_list, frame_num, batch_size)
data['yaw_c_seq'] = torch.FloatTensor(yaw_c_seq)
if input_pitch_list is not None:
pitch_c_seq = gen_camera_pose(input_pitch_list, frame_num, batch_size)
data['pitch_c_seq'] = torch.FloatTensor(pitch_c_seq)
if input_roll_list is not None:
roll_c_seq = gen_camera_pose(input_roll_list, frame_num, batch_size)
data['roll_c_seq'] = torch.FloatTensor(roll_c_seq)
return data
def transform_semantic_1(semantic, semantic_radius):
semantic_list = [semantic for i in range(0, semantic_radius*2+1)]
coeff_3dmm = np.concatenate(semantic_list, 0)
return coeff_3dmm.transpose(1,0)
def transform_semantic_target(coeff_3dmm, frame_index, semantic_radius):
num_frames = coeff_3dmm.shape[0]
seq = list(range(frame_index- semantic_radius, frame_index + semantic_radius+1))
index = [ min(max(item, 0), num_frames-1) for item in seq ]
coeff_3dmm_g = coeff_3dmm[index, :]
return coeff_3dmm_g.transpose(1,0)
def gen_camera_pose(camera_degree_list, frame_num, batch_size):
new_degree_list = []
if len(camera_degree_list) == 1:
for _ in range(frame_num):
new_degree_list.append(camera_degree_list[0])
remainder = frame_num%batch_size
if remainder!=0:
for _ in range(batch_size-remainder):
new_degree_list.append(new_degree_list[-1])
new_degree_np = np.array(new_degree_list).reshape(batch_size, -1)
return new_degree_np
degree_sum = 0.
for i, degree in enumerate(camera_degree_list[1:]):
degree_sum += abs(degree-camera_degree_list[i])
degree_per_frame = degree_sum/(frame_num-1)
for i, degree in enumerate(camera_degree_list[1:]):
degree_last = camera_degree_list[i]
degree_step = degree_per_frame * abs(degree-degree_last)/(degree-degree_last)
new_degree_list = new_degree_list + list(np.arange(degree_last, degree, degree_step))
if len(new_degree_list) > frame_num:
new_degree_list = new_degree_list[:frame_num]
elif len(new_degree_list) < frame_num:
for _ in range(frame_num-len(new_degree_list)):
new_degree_list.append(new_degree_list[-1])
print(len(new_degree_list))
print(frame_num)
remainder = frame_num%batch_size
if remainder!=0:
for _ in range(batch_size-remainder):
new_degree_list.append(new_degree_list[-1])
new_degree_np = np.array(new_degree_list).reshape(batch_size, -1)
return new_degree_np
import torch, uuid
import os, sys, shutil
from src.utils.preprocess import CropAndExtract
from src.test_audio2coeff import Audio2Coeff
from src.facerender.animate import AnimateFromCoeff
from src.generate_batch import get_data
from src.generate_facerender_batch import get_facerender_data
from src.utils.init_path import init_path
from pydub import AudioSegment
def mp3_to_wav(mp3_filename,wav_filename,frame_rate):
mp3_file = AudioSegment.from_file(file=mp3_filename)
mp3_file.set_frame_rate(frame_rate).export(wav_filename,format="wav")
class SadTalker():
def __init__(self, checkpoint_path='checkpoints', config_path='src/config', lazy_load=False):
if torch.cuda.is_available() :
device = "cuda"
else:
device = "cpu"
self.device = device
os.environ['TORCH_HOME']= checkpoint_path
self.checkpoint_path = checkpoint_path
self.config_path = config_path
def test(self, source_image, driven_audio, preprocess='crop',
still_mode=False, use_enhancer=False, batch_size=1, size=256,
pose_style = 0, exp_scale=1.0,
use_ref_video = False,
ref_video = None,
ref_info = None,
use_idle_mode = False,
length_of_audio = 0, use_blink=True,
result_dir='./results/'):
self.sadtalker_paths = init_path(self.checkpoint_path, self.config_path, size, False, preprocess)
print(self.sadtalker_paths)
self.audio_to_coeff = Audio2Coeff(self.sadtalker_paths, self.device)
self.preprocess_model = CropAndExtract(self.sadtalker_paths, self.device)
self.animate_from_coeff = AnimateFromCoeff(self.sadtalker_paths, self.device)
time_tag = str(uuid.uuid4())
save_dir = os.path.join(result_dir, time_tag)
os.makedirs(save_dir, exist_ok=True)
input_dir = os.path.join(save_dir, 'input')
os.makedirs(input_dir, exist_ok=True)
print(source_image)
pic_path = os.path.join(input_dir, os.path.basename(source_image))
shutil.move(source_image, input_dir)
if driven_audio is not None and os.path.isfile(driven_audio):
audio_path = os.path.join(input_dir, os.path.basename(driven_audio))
#### mp3 to wav
if '.mp3' in audio_path:
mp3_to_wav(driven_audio, audio_path.replace('.mp3', '.wav'), 16000)
audio_path = audio_path.replace('.mp3', '.wav')
else:
shutil.move(driven_audio, input_dir)
elif use_idle_mode:
audio_path = os.path.join(input_dir, 'idlemode_'+str(length_of_audio)+'.wav') ## generate audio from this new audio_path
from pydub import AudioSegment
one_sec_segment = AudioSegment.silent(duration=1000*length_of_audio) #duration in milliseconds
one_sec_segment.export(audio_path, format="wav")
else:
print(use_ref_video, ref_info)
assert use_ref_video == True and ref_info == 'all'
if use_ref_video and ref_info == 'all': # full ref mode
ref_video_videoname = os.path.basename(ref_video)
audio_path = os.path.join(save_dir, ref_video_videoname+'.wav')
print('new audiopath:',audio_path)
# if ref_video contains audio, set the audio from ref_video.
cmd = r"ffmpeg -y -hide_banner -loglevel error -i %s %s"%(ref_video, audio_path)
os.system(cmd)
os.makedirs(save_dir, exist_ok=True)
#crop image and extract 3dmm from image
first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
os.makedirs(first_frame_dir, exist_ok=True)
first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(pic_path, first_frame_dir, preprocess, True, size)
if first_coeff_path is None:
raise AttributeError("No face is detected")
if use_ref_video:
print('using ref video for genreation')
ref_video_videoname = os.path.splitext(os.path.split(ref_video)[-1])[0]
ref_video_frame_dir = os.path.join(save_dir, ref_video_videoname)
os.makedirs(ref_video_frame_dir, exist_ok=True)
print('3DMM Extraction for the reference video providing pose')
ref_video_coeff_path, _, _ = self.preprocess_model.generate(ref_video, ref_video_frame_dir, preprocess, source_image_flag=False)
else:
ref_video_coeff_path = None
if use_ref_video:
if ref_info == 'pose':
ref_pose_coeff_path = ref_video_coeff_path
ref_eyeblink_coeff_path = None
elif ref_info == 'blink':
ref_pose_coeff_path = None
ref_eyeblink_coeff_path = ref_video_coeff_path
elif ref_info == 'pose+blink':
ref_pose_coeff_path = ref_video_coeff_path
ref_eyeblink_coeff_path = ref_video_coeff_path
elif ref_info == 'all':
ref_pose_coeff_path = None
ref_eyeblink_coeff_path = None
else:
raise('error in refinfo')
else:
ref_pose_coeff_path = None
ref_eyeblink_coeff_path = None
#audio2ceoff
if use_ref_video and ref_info == 'all':
coeff_path = ref_video_coeff_path # self.audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
else:
batch = get_data(first_coeff_path, audio_path, self.device, ref_eyeblink_coeff_path=ref_eyeblink_coeff_path, still=still_mode, idlemode=use_idle_mode, length_of_audio=length_of_audio, use_blink=use_blink) # longer audio?
coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
#coeff2video
data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode, preprocess=preprocess, size=size, expression_scale = exp_scale)
return_path = self.animate_from_coeff.generate(data, save_dir, pic_path, crop_info, enhancer='gfpgan' if use_enhancer else None, preprocess=preprocess, img_size=size)
video_name = data['video_name']
print(f'The generated video is named {video_name} in {save_dir}')
del self.preprocess_model
del self.audio_to_coeff
del self.animate_from_coeff
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
import gc; gc.collect()
return return_path
\ No newline at end of file
import os
import torch
import numpy as np
from scipy.io import savemat, loadmat
from yacs.config import CfgNode as CN
from scipy.signal import savgol_filter
import safetensors
import safetensors.torch
from src.audio2pose_models.audio2pose import Audio2Pose
from src.audio2exp_models.networks import SimpleWrapperV2
from src.audio2exp_models.audio2exp import Audio2Exp
from src.utils.safetensor_helper import load_x_from_safetensor
def load_cpk(checkpoint_path, model=None, optimizer=None, device="cpu"):
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
if model is not None:
model.load_state_dict(checkpoint['model'])
if optimizer is not None:
optimizer.load_state_dict(checkpoint['optimizer'])
return checkpoint['epoch']
class Audio2Coeff():
def __init__(self, sadtalker_path, device):
#load config
fcfg_pose = open(sadtalker_path['audio2pose_yaml_path'])
cfg_pose = CN.load_cfg(fcfg_pose)
cfg_pose.freeze()
fcfg_exp = open(sadtalker_path['audio2exp_yaml_path'])
cfg_exp = CN.load_cfg(fcfg_exp)
cfg_exp.freeze()
# load audio2pose_model
self.audio2pose_model = Audio2Pose(cfg_pose, None, device=device)
self.audio2pose_model = self.audio2pose_model.to(device)
self.audio2pose_model.eval()
for param in self.audio2pose_model.parameters():
param.requires_grad = False
try:
if sadtalker_path['use_safetensor']:
checkpoints = safetensors.torch.load_file(sadtalker_path['checkpoint'])
self.audio2pose_model.load_state_dict(load_x_from_safetensor(checkpoints, 'audio2pose'))
else:
load_cpk(sadtalker_path['audio2pose_checkpoint'], model=self.audio2pose_model, device=device)
except:
raise Exception("Failed in loading audio2pose_checkpoint")
# load audio2exp_model
netG = SimpleWrapperV2()
netG = netG.to(device)
for param in netG.parameters():
netG.requires_grad = False
netG.eval()
try:
if sadtalker_path['use_safetensor']:
checkpoints = safetensors.torch.load_file(sadtalker_path['checkpoint'])
netG.load_state_dict(load_x_from_safetensor(checkpoints, 'audio2exp'))
else:
load_cpk(sadtalker_path['audio2exp_checkpoint'], model=netG, device=device)
except:
raise Exception("Failed in loading audio2exp_checkpoint")
self.audio2exp_model = Audio2Exp(netG, cfg_exp, device=device, prepare_training_loss=False)
self.audio2exp_model = self.audio2exp_model.to(device)
for param in self.audio2exp_model.parameters():
param.requires_grad = False
self.audio2exp_model.eval()
self.device = device
def generate(self, batch, coeff_save_dir, pose_style, ref_pose_coeff_path=None):
with torch.no_grad():
#test
results_dict_exp= self.audio2exp_model.test(batch)
exp_pred = results_dict_exp['exp_coeff_pred'] #bs T 64
#for class_id in range(1):
#class_id = 0#(i+10)%45
#class_id = random.randint(0,46) #46 styles can be selected
batch['class'] = torch.LongTensor([pose_style]).to(self.device)
results_dict_pose = self.audio2pose_model.test(batch)
pose_pred = results_dict_pose['pose_pred'] #bs T 6
pose_len = pose_pred.shape[1]
if pose_len<13:
pose_len = int((pose_len-1)/2)*2+1
pose_pred = torch.Tensor(savgol_filter(np.array(pose_pred.cpu()), pose_len, 2, axis=1)).to(self.device)
else:
pose_pred = torch.Tensor(savgol_filter(np.array(pose_pred.cpu()), 13, 2, axis=1)).to(self.device)
coeffs_pred = torch.cat((exp_pred, pose_pred), dim=-1) #bs T 70
coeffs_pred_numpy = coeffs_pred[0].clone().detach().cpu().numpy()
if ref_pose_coeff_path is not None:
coeffs_pred_numpy = self.using_refpose(coeffs_pred_numpy, ref_pose_coeff_path)
savemat(os.path.join(coeff_save_dir, '%s##%s.mat'%(batch['pic_name'], batch['audio_name'])),
{'coeff_3dmm': coeffs_pred_numpy})
return os.path.join(coeff_save_dir, '%s##%s.mat'%(batch['pic_name'], batch['audio_name']))
def using_refpose(self, coeffs_pred_numpy, ref_pose_coeff_path):
num_frames = coeffs_pred_numpy.shape[0]
refpose_coeff_dict = loadmat(ref_pose_coeff_path)
refpose_coeff = refpose_coeff_dict['coeff_3dmm'][:,64:70]
refpose_num_frames = refpose_coeff.shape[0]
if refpose_num_frames<num_frames:
div = num_frames//refpose_num_frames
re = num_frames%refpose_num_frames
refpose_coeff_list = [refpose_coeff for i in range(div)]
refpose_coeff_list.append(refpose_coeff[:re, :])
refpose_coeff = np.concatenate(refpose_coeff_list, axis=0)
#### relative head pose
coeffs_pred_numpy[:, 64:70] = coeffs_pred_numpy[:, 64:70] + ( refpose_coeff[:num_frames, :] - refpose_coeff[0:1, :] )
return coeffs_pred_numpy
import librosa
import librosa.filters
import numpy as np
# import tensorflow as tf
from scipy import signal
from scipy.io import wavfile
from src.utils.hparams import hparams as hp
def load_wav(path, sr):
return librosa.core.load(path, sr=sr)[0]
def save_wav(wav, path, sr):
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
#proposed by @dsmiller
wavfile.write(path, sr, wav.astype(np.int16))
def save_wavenet_wav(wav, path, sr):
librosa.output.write_wav(path, wav, sr=sr)
def preemphasis(wav, k, preemphasize=True):
if preemphasize:
return signal.lfilter([1, -k], [1], wav)
return wav
def inv_preemphasis(wav, k, inv_preemphasize=True):
if inv_preemphasize:
return signal.lfilter([1], [1, -k], wav)
return wav
def get_hop_size():
hop_size = hp.hop_size
if hop_size is None:
assert hp.frame_shift_ms is not None
hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
return hop_size
def linearspectrogram(wav):
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
S = _amp_to_db(np.abs(D)) - hp.ref_level_db
if hp.signal_normalization:
return _normalize(S)
return S
def melspectrogram(wav):
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
if hp.signal_normalization:
return _normalize(S)
return S
def _lws_processor():
import lws
return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
def _stft(y):
if hp.use_lws:
return _lws_processor(hp).stft(y).T
else:
return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
##########################################################
#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
def num_frames(length, fsize, fshift):
"""Compute number of time frames of spectrogram
"""
pad = (fsize - fshift)
if length % fshift == 0:
M = (length + pad * 2 - fsize) // fshift + 1
else:
M = (length + pad * 2 - fsize) // fshift + 2
return M
def pad_lr(x, fsize, fshift):
"""Compute left and right padding
"""
M = num_frames(len(x), fsize, fshift)
pad = (fsize - fshift)
T = len(x) + 2 * pad
r = (M - 1) * fshift + fsize - T
return pad, pad + r
##########################################################
#Librosa correct padding
def librosa_pad_lr(x, fsize, fshift):
return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
# Conversions
_mel_basis = None
def _linear_to_mel(spectogram):
global _mel_basis
if _mel_basis is None:
_mel_basis = _build_mel_basis()
return np.dot(_mel_basis, spectogram)
def _build_mel_basis():
assert hp.fmax <= hp.sample_rate // 2
return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels,
fmin=hp.fmin, fmax=hp.fmax)
def _amp_to_db(x):
min_level = np.exp(hp.min_level_db / 20 * np.log(10))
return 20 * np.log10(np.maximum(min_level, x))
def _db_to_amp(x):
return np.power(10.0, (x) * 0.05)
def _normalize(S):
if hp.allow_clipping_in_normalization:
if hp.symmetric_mels:
return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value,
-hp.max_abs_value, hp.max_abs_value)
else:
return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)
assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
if hp.symmetric_mels:
return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
else:
return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
def _denormalize(D):
if hp.allow_clipping_in_normalization:
if hp.symmetric_mels:
return (((np.clip(D, -hp.max_abs_value,
hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
+ hp.min_level_db)
else:
return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
if hp.symmetric_mels:
return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
else:
return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
import os
import cv2
import time
import glob
import argparse
import scipy
import numpy as np
from PIL import Image
import torch
from tqdm import tqdm
from itertools import cycle
from src.face3d.extract_kp_videos_safe import KeypointExtractor
from facexlib.alignment import landmark_98_to_68
import numpy as np
from PIL import Image
class Preprocesser:
def __init__(self, device='cuda'):
self.predictor = KeypointExtractor(device)
def get_landmark(self, img_np):
"""get landmark with dlib
:return: np.array shape=(68, 2)
"""
with torch.no_grad():
dets = self.predictor.det_net.detect_faces(img_np, 0.97)
if len(dets) == 0:
return None
det = dets[0]
img = img_np[int(det[1]):int(det[3]), int(det[0]):int(det[2]), :]
lm = landmark_98_to_68(self.predictor.detector.get_landmarks(img)) # [0]
#### keypoints to the original location
lm[:,0] += int(det[0])
lm[:,1] += int(det[1])
return lm
def align_face(self, img, lm, output_size=1024):
"""
:param filepath: str
:return: PIL Image
"""
lm_chin = lm[0: 17] # left-right
lm_eyebrow_left = lm[17: 22] # left-right
lm_eyebrow_right = lm[22: 27] # left-right
lm_nose = lm[27: 31] # top-down
lm_nostrils = lm[31: 36] # top-down
lm_eye_left = lm[36: 42] # left-clockwise
lm_eye_right = lm[42: 48] # left-clockwise
lm_mouth_outer = lm[48: 60] # left-clockwise
lm_mouth_inner = lm[60: 68] # left-clockwise
# Calculate auxiliary vectors.
eye_left = np.mean(lm_eye_left, axis=0)
eye_right = np.mean(lm_eye_right, axis=0)
eye_avg = (eye_left + eye_right) * 0.5
eye_to_eye = eye_right - eye_left
mouth_left = lm_mouth_outer[0]
mouth_right = lm_mouth_outer[6]
mouth_avg = (mouth_left + mouth_right) * 0.5
eye_to_mouth = mouth_avg - eye_avg
# Choose oriented crop rectangle.
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] # Addition of binocular difference and double mouth difference
x /= np.hypot(*x) # hypot函数计算直角三角形的斜边长,用斜边长对三角形两条直边做归一化
x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) # 双眼差和眼嘴差,选较大的作为基准尺度
y = np.flipud(x) * [-1, 1]
c = eye_avg + eye_to_mouth * 0.1
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) # 定义四边形,以面部基准位置为中心上下左右平移得到四个顶点
qsize = np.hypot(*x) * 2 # 定义四边形的大小(边长),为基准尺度的2倍
# Shrink.
# 如果计算出的四边形太大了,就按比例缩小它
shrink = int(np.floor(qsize / output_size * 0.5))
if shrink > 1:
rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
img = img.resize(rsize, Image.ANTIALIAS)
quad /= shrink
qsize /= shrink
else:
rsize = (int(np.rint(float(img.size[0]))), int(np.rint(float(img.size[1]))))
# Crop.
border = max(int(np.rint(qsize * 0.1)), 3)
crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
int(np.ceil(max(quad[:, 1]))))
crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
min(crop[3] + border, img.size[1]))
if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
# img = img.crop(crop)
quad -= crop[0:2]
# Pad.
pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
int(np.ceil(max(quad[:, 1]))))
pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
max(pad[3] - img.size[1] + border, 0))
# if enable_padding and max(pad) > border - 4:
# pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
# img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
# h, w, _ = img.shape
# y, x, _ = np.ogrid[:h, :w, :1]
# mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
# 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]))
# blur = qsize * 0.02
# img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
# img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
# img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
# quad += pad[:2]
# Transform.
quad = (quad + 0.5).flatten()
lx = max(min(quad[0], quad[2]), 0)
ly = max(min(quad[1], quad[7]), 0)
rx = min(max(quad[4], quad[6]), img.size[0])
ry = min(max(quad[3], quad[5]), img.size[0])
# Save aligned image.
return rsize, crop, [lx, ly, rx, ry]
def crop(self, img_np_list, still=False, xsize=512): # first frame for all video
img_np = img_np_list[0]
lm = self.get_landmark(img_np)
if lm is None:
raise 'can not detect the landmark from source image'
rsize, crop, quad = self.align_face(img=Image.fromarray(img_np), lm=lm, output_size=xsize)
clx, cly, crx, cry = crop
lx, ly, rx, ry = quad
lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
for _i in range(len(img_np_list)):
_inp = img_np_list[_i]
_inp = cv2.resize(_inp, (rsize[0], rsize[1]))
_inp = _inp[cly:cry, clx:crx]
if not still:
_inp = _inp[ly:ry, lx:rx]
img_np_list[_i] = _inp
return img_np_list, crop, quad
import os
import torch
from gfpgan import GFPGANer
from tqdm import tqdm
from src.utils.videoio import load_video_to_cv2
import cv2
class GeneratorWithLen(object):
""" From https://stackoverflow.com/a/7460929 """
def __init__(self, gen, length):
self.gen = gen
self.length = length
def __len__(self):
return self.length
def __iter__(self):
return self.gen
def enhancer_list(images, method='gfpgan', bg_upsampler='realesrgan'):
gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler)
return list(gen)
def enhancer_generator_with_len(images, method='gfpgan', bg_upsampler='realesrgan'):
""" Provide a generator with a __len__ method so that it can passed to functions that
call len()"""
if os.path.isfile(images): # handle video to images
# TODO: Create a generator version of load_video_to_cv2
images = load_video_to_cv2(images)
gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler)
gen_with_len = GeneratorWithLen(gen, len(images))
return gen_with_len
def enhancer_generator_no_len(images, method='gfpgan', bg_upsampler='realesrgan'):
""" Provide a generator function so that all of the enhanced images don't need
to be stored in memory at the same time. This can save tons of RAM compared to
the enhancer function. """
print('face enhancer....')
if not isinstance(images, list) and os.path.isfile(images): # handle video to images
images = load_video_to_cv2(images)
# ------------------------ set up GFPGAN restorer ------------------------
if method == 'gfpgan':
arch = 'clean'
channel_multiplier = 2
model_name = 'GFPGANv1.4'
url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth'
elif method == 'RestoreFormer':
arch = 'RestoreFormer'
channel_multiplier = 2
model_name = 'RestoreFormer'
url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth'
elif method == 'codeformer': # TODO:
arch = 'CodeFormer'
channel_multiplier = 2
model_name = 'CodeFormer'
url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
else:
raise ValueError(f'Wrong model version {method}.')
# ------------------------ set up background upsampler ------------------------
if bg_upsampler == 'realesrgan':
if not torch.cuda.is_available(): # CPU
import warnings
warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
'If you really want to use it, please modify the corresponding codes.')
bg_upsampler = None
else:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
bg_upsampler = RealESRGANer(
scale=2,
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
model=model,
tile=400,
tile_pad=10,
pre_pad=0,
half=True) # need to set False in CPU mode
else:
bg_upsampler = None
# determine model paths
model_path = os.path.join('gfpgan/weights', model_name + '.pth')
if not os.path.isfile(model_path):
model_path = os.path.join('checkpoints', model_name + '.pth')
if not os.path.isfile(model_path):
# download pre-trained models from url
model_path = url
restorer = GFPGANer(
model_path=model_path,
upscale=2,
arch=arch,
channel_multiplier=channel_multiplier,
bg_upsampler=bg_upsampler)
# ------------------------ restore ------------------------
for idx in tqdm(range(len(images)), 'Face Enhancer:'):
img = cv2.cvtColor(images[idx], cv2.COLOR_RGB2BGR)
# restore faces and background if necessary
cropped_faces, restored_faces, r_img = restorer.enhance(
img,
has_aligned=False,
only_center_face=False,
paste_back=True)
r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB)
yield r_img
from glob import glob
import os
class HParams:
def __init__(self, **kwargs):
self.data = {}
for key, value in kwargs.items():
self.data[key] = value
def __getattr__(self, key):
if key not in self.data:
raise AttributeError("'HParams' object has no attribute %s" % key)
return self.data[key]
def set_hparam(self, key, value):
self.data[key] = value
# Default hyperparameters
hparams = HParams(
num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality
# network
rescale=True, # Whether to rescale audio prior to preprocessing
rescaling_max=0.9, # Rescaling value
# Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
# It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
# Does not work if n_ffit is not multiple of hop_size!!
use_lws=False,
n_fft=800, # Extra window size is filled with 0 paddings to match this parameter
hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i <filename>)
frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5)
# Mel and Linear spectrograms normalization/scaling and clipping
signal_normalization=True,
# Whether to normalize mel spectrograms to some predefined range (following below parameters)
allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True
symmetric_mels=True,
# Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2,
# faster and cleaner convergence)
max_abs_value=4.,
# max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not
# be too big to avoid gradient explosion,
# not too small for fast convergence)
# Contribution by @begeekmyfriend
# Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude
# levels. Also allows for better G&L phase reconstruction)
preemphasize=True, # whether to apply filter
preemphasis=0.97, # filter coefficient.
# Limits
min_level_db=-100,
ref_level_db=20,
fmin=55,
# Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To
# test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
fmax=7600, # To be increased/reduced depending on data.
###################### Our training parameters #################################
img_size=96,
fps=25,
batch_size=16,
initial_learning_rate=1e-4,
nepochs=300000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs
num_workers=20,
checkpoint_interval=3000,
eval_interval=3000,
writer_interval=300,
save_optimizer_state=True,
syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence.
syncnet_batch_size=64,
syncnet_lr=1e-4,
syncnet_eval_interval=1000,
syncnet_checkpoint_interval=10000,
disc_wt=0.07,
disc_initial_learning_rate=1e-4,
)
# Default hyperparameters
hparamsdebug = HParams(
num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality
# network
rescale=True, # Whether to rescale audio prior to preprocessing
rescaling_max=0.9, # Rescaling value
# Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
# It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
# Does not work if n_ffit is not multiple of hop_size!!
use_lws=False,
n_fft=800, # Extra window size is filled with 0 paddings to match this parameter
hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i <filename>)
frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5)
# Mel and Linear spectrograms normalization/scaling and clipping
signal_normalization=True,
# Whether to normalize mel spectrograms to some predefined range (following below parameters)
allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True
symmetric_mels=True,
# Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2,
# faster and cleaner convergence)
max_abs_value=4.,
# max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not
# be too big to avoid gradient explosion,
# not too small for fast convergence)
# Contribution by @begeekmyfriend
# Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude
# levels. Also allows for better G&L phase reconstruction)
preemphasize=True, # whether to apply filter
preemphasis=0.97, # filter coefficient.
# Limits
min_level_db=-100,
ref_level_db=20,
fmin=55,
# Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To
# test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
fmax=7600, # To be increased/reduced depending on data.
###################### Our training parameters #################################
img_size=96,
fps=25,
batch_size=2,
initial_learning_rate=1e-3,
nepochs=100000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs
num_workers=0,
checkpoint_interval=10000,
eval_interval=10,
writer_interval=5,
save_optimizer_state=True,
syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence.
syncnet_batch_size=64,
syncnet_lr=1e-4,
syncnet_eval_interval=10000,
syncnet_checkpoint_interval=10000,
disc_wt=0.07,
disc_initial_learning_rate=1e-4,
)
def hparams_debug_string():
values = hparams.values()
hp = [" %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"]
return "Hyperparameters:\n" + "\n".join(hp)
import os
import glob
def init_path(checkpoint_dir, config_dir, size=512, old_version=False, preprocess='crop'):
if old_version:
#### load all the checkpoint of `pth`
sadtalker_paths = {
'wav2lip_checkpoint' : os.path.join(checkpoint_dir, 'wav2lip.pth'),
'audio2pose_checkpoint' : os.path.join(checkpoint_dir, 'auido2pose_00140-model.pth'),
'audio2exp_checkpoint' : os.path.join(checkpoint_dir, 'auido2exp_00300-model.pth'),
'free_view_checkpoint' : os.path.join(checkpoint_dir, 'facevid2vid_00189-model.pth.tar'),
'path_of_net_recon_model' : os.path.join(checkpoint_dir, 'epoch_20.pth')
}
use_safetensor = False
elif len(glob.glob(os.path.join(checkpoint_dir, '*.safetensors'))):
print('using safetensor as default')
sadtalker_paths = {
"checkpoint":os.path.join(checkpoint_dir, 'SadTalker_V0.0.2_'+str(size)+'.safetensors'),
}
use_safetensor = True
else:
print("WARNING: The new version of the model will be updated by safetensor, you may need to download it mannully. We run the old version of the checkpoint this time!")
use_safetensor = False
sadtalker_paths = {
'wav2lip_checkpoint' : os.path.join(checkpoint_dir, 'wav2lip.pth'),
'audio2pose_checkpoint' : os.path.join(checkpoint_dir, 'auido2pose_00140-model.pth'),
'audio2exp_checkpoint' : os.path.join(checkpoint_dir, 'auido2exp_00300-model.pth'),
'free_view_checkpoint' : os.path.join(checkpoint_dir, 'facevid2vid_00189-model.pth.tar'),
'path_of_net_recon_model' : os.path.join(checkpoint_dir, 'epoch_20.pth')
}
sadtalker_paths['dir_of_BFM_fitting'] = os.path.join(config_dir) # , 'BFM_Fitting'
sadtalker_paths['audio2pose_yaml_path'] = os.path.join(config_dir, 'auido2pose.yaml')
sadtalker_paths['audio2exp_yaml_path'] = os.path.join(config_dir, 'auido2exp.yaml')
sadtalker_paths['use_safetensor'] = use_safetensor # os.path.join(config_dir, 'auido2exp.yaml')
if 'full' in preprocess:
sadtalker_paths['mappingnet_checkpoint'] = os.path.join(checkpoint_dir, 'mapping_00109-model.pth.tar')
sadtalker_paths['facerender_yaml'] = os.path.join(config_dir, 'facerender_still.yaml')
else:
sadtalker_paths['mappingnet_checkpoint'] = os.path.join(checkpoint_dir, 'mapping_00229-model.pth.tar')
sadtalker_paths['facerender_yaml'] = os.path.join(config_dir, 'facerender.yaml')
return sadtalker_paths
\ No newline at end of file
import torch
import yaml
import os
import safetensors
from safetensors.torch import save_file
from yacs.config import CfgNode as CN
import sys
sys.path.append('/apdcephfs/private_shadowcun/SadTalker')
from src.face3d.models import networks
from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector
from src.facerender.modules.mapping import MappingNet
from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator
from src.audio2pose_models.audio2pose import Audio2Pose
from src.audio2exp_models.networks import SimpleWrapperV2
from src.test_audio2coeff import load_cpk
size = 256
############ face vid2vid
config_path = os.path.join('src', 'config', 'facerender.yaml')
current_root_path = '.'
path_of_net_recon_model = os.path.join(current_root_path, 'checkpoints', 'epoch_20.pth')
net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='')
checkpoint = torch.load(path_of_net_recon_model, map_location='cpu')
net_recon.load_state_dict(checkpoint['net_recon'])
with open(config_path) as f:
config = yaml.safe_load(f)
generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'],
**config['model_params']['common_params'])
kp_extractor = KPDetector(**config['model_params']['kp_detector_params'],
**config['model_params']['common_params'])
he_estimator = HEEstimator(**config['model_params']['he_estimator_params'],
**config['model_params']['common_params'])
mapping = MappingNet(**config['model_params']['mapping_params'])
def load_cpk_facevid2vid(checkpoint_path, generator=None, discriminator=None,
kp_detector=None, he_estimator=None, optimizer_generator=None,
optimizer_discriminator=None, optimizer_kp_detector=None,
optimizer_he_estimator=None, device="cpu"):
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
if generator is not None:
generator.load_state_dict(checkpoint['generator'])
if kp_detector is not None:
kp_detector.load_state_dict(checkpoint['kp_detector'])
if he_estimator is not None:
he_estimator.load_state_dict(checkpoint['he_estimator'])
if discriminator is not None:
try:
discriminator.load_state_dict(checkpoint['discriminator'])
except:
print ('No discriminator in the state-dict. Dicriminator will be randomly initialized')
if optimizer_generator is not None:
optimizer_generator.load_state_dict(checkpoint['optimizer_generator'])
if optimizer_discriminator is not None:
try:
optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
except RuntimeError as e:
print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized')
if optimizer_kp_detector is not None:
optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector'])
if optimizer_he_estimator is not None:
optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator'])
return checkpoint['epoch']
def load_cpk_facevid2vid_safetensor(checkpoint_path, generator=None,
kp_detector=None, he_estimator=None,
device="cpu"):
checkpoint = safetensors.torch.load_file(checkpoint_path)
if generator is not None:
x_generator = {}
for k,v in checkpoint.items():
if 'generator' in k:
x_generator[k.replace('generator.', '')] = v
generator.load_state_dict(x_generator)
if kp_detector is not None:
x_generator = {}
for k,v in checkpoint.items():
if 'kp_extractor' in k:
x_generator[k.replace('kp_extractor.', '')] = v
kp_detector.load_state_dict(x_generator)
if he_estimator is not None:
x_generator = {}
for k,v in checkpoint.items():
if 'he_estimator' in k:
x_generator[k.replace('he_estimator.', '')] = v
he_estimator.load_state_dict(x_generator)
return None
free_view_checkpoint = '/apdcephfs/private_shadowcun/SadTalker/checkpoints/facevid2vid_'+str(size)+'-model.pth.tar'
load_cpk_facevid2vid(free_view_checkpoint, kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator)
wav2lip_checkpoint = os.path.join(current_root_path, 'checkpoints', 'wav2lip.pth')
audio2pose_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2pose_00140-model.pth')
audio2pose_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2pose.yaml')
audio2exp_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2exp_00300-model.pth')
audio2exp_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2exp.yaml')
fcfg_pose = open(audio2pose_yaml_path)
cfg_pose = CN.load_cfg(fcfg_pose)
cfg_pose.freeze()
audio2pose_model = Audio2Pose(cfg_pose, wav2lip_checkpoint)
audio2pose_model.eval()
load_cpk(audio2pose_checkpoint, model=audio2pose_model, device='cpu')
# load audio2exp_model
netG = SimpleWrapperV2()
netG.eval()
load_cpk(audio2exp_checkpoint, model=netG, device='cpu')
class SadTalker(torch.nn.Module):
def __init__(self, kp_extractor, generator, netG, audio2pose, face_3drecon):
super(SadTalker, self).__init__()
self.kp_extractor = kp_extractor
self.generator = generator
self.audio2exp = netG
self.audio2pose = audio2pose
self.face_3drecon = face_3drecon
model = SadTalker(kp_extractor, generator, netG, audio2pose_model, net_recon)
# here, we want to convert it to safetensor
save_file(model.state_dict(), "checkpoints/SadTalker_V0.0.2_"+str(size)+".safetensors")
### test
load_cpk_facevid2vid_safetensor('checkpoints/SadTalker_V0.0.2_'+str(size)+'.safetensors', kp_detector=kp_extractor, generator=generator, he_estimator=None)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment