Commit c36d19db authored by mashun1's avatar mashun1
Browse files

liveportrait

parents
Pipeline #1402 canceled with stages
# coding: utf-8
"""
parameters used for crop faces
"""
from dataclasses import dataclass
from .base_config import PrintableConfig
@dataclass(repr=False) # use repr from PrintableConfig
class CropConfig(PrintableConfig):
insightface_root: str = "../../pretrained_weights/insightface"
landmark_ckpt_path: str = "../../pretrained_weights/liveportrait/landmark.onnx"
device_id: int = 0 # gpu device id
flag_force_cpu: bool = False # force cpu inference, WIP
########## source image or video cropping option ##########
dsize: int = 512 # crop size
scale: float = 2.8 # scale factor
vx_ratio: float = 0 # vx ratio
vy_ratio: float = -0.125 # vy ratio +up, -down
max_face_num: int = 0 # max face number, 0 mean no limit
flag_do_rot: bool = True # whether to conduct the rotation when flag_do_crop is True
########## driving video auto cropping option ##########
scale_crop_driving_video: float = 2.2 # 2.0 # scale factor for cropping driving video
vx_ratio_crop_driving_video: float = 0.0 # adjust y offset
vy_ratio_crop_driving_video: float = -0.1 # adjust x offset
direction: str = "large-small" # direction of cropping
\ No newline at end of file
# coding: utf-8
"""
config dataclass used for inference
"""
import cv2
from numpy import ndarray
from dataclasses import dataclass, field
from typing import Literal, Tuple
from .base_config import PrintableConfig, make_abs_path
@dataclass(repr=False) # use repr from PrintableConfig
class InferenceConfig(PrintableConfig):
# MODEL CONFIG, NOT EXPORTED PARAMS
models_config: str = make_abs_path('./models.yaml') # portrait animation config
checkpoint_F: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/appearance_feature_extractor.pth') # path to checkpoint of F
checkpoint_M: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/motion_extractor.pth') # path to checkpoint pf M
checkpoint_G: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/spade_generator.pth') # path to checkpoint of G
checkpoint_W: str = make_abs_path('../../pretrained_weights/liveportrait/base_models/warping_module.pth') # path to checkpoint of W
checkpoint_S: str = make_abs_path('../../pretrained_weights/liveportrait/retargeting_models/stitching_retargeting_module.pth') # path to checkpoint to S and R_eyes, R_lip
# EXPORTED PARAMS
flag_use_half_precision: bool = True
flag_crop_driving_video: bool = False
device_id: int = 0
flag_normalize_lip: bool = True
flag_source_video_eye_retargeting: bool = False
flag_video_editing_head_rotation: bool = False
flag_eye_retargeting: bool = False
flag_lip_retargeting: bool = False
flag_stitching: bool = True
flag_relative_motion: bool = True
flag_pasteback: bool = True
flag_do_crop: bool = True
flag_do_rot: bool = True
flag_force_cpu: bool = False
flag_do_torch_compile: bool = False
# NOT EXPORTED PARAMS
lip_normalize_threshold: float = 0.03 # threshold for flag_normalize_lip
source_video_eye_retargeting_threshold: float = 0.18 # threshold for eyes retargeting if the input is a source video
driving_smooth_observation_variance: float = 1e-7 # smooth strength scalar for the animated video when the input is a source video, the larger the number, the smoother the animated video; too much smoothness would result in loss of motion accuracy
anchor_frame: int = 0 # TO IMPLEMENT
input_shape: Tuple[int, int] = (256, 256) # input shape
output_format: Literal['mp4', 'gif'] = 'mp4' # output video format
crf: int = 15 # crf for output video
output_fps: int = 25 # default output fps
mask_crop: ndarray = field(default_factory=lambda: cv2.imread(make_abs_path('../utils/resources/mask_template.png'), cv2.IMREAD_COLOR))
size_gif: int = 256 # default gif size, TO IMPLEMENT
source_max_dim: int = 1280 # the max dim of height and width of source image or video
source_division: int = 2 # make sure the height and width of source image or video can be divided by this number
\ No newline at end of file
model_params:
appearance_feature_extractor_params: # the F in the paper
image_channel: 3
block_expansion: 64
num_down_blocks: 2
max_features: 512
reshape_channel: 32
reshape_depth: 16
num_resblocks: 6
motion_extractor_params: # the M in the paper
num_kp: 21
backbone: convnextv2_tiny
warping_module_params: # the W in the paper
num_kp: 21
block_expansion: 64
max_features: 512
num_down_blocks: 2
reshape_channel: 32
estimate_occlusion_map: True
dense_motion_params:
block_expansion: 32
max_features: 1024
num_blocks: 5
reshape_depth: 16
compress: 4
spade_generator_params: # the G in the paper
upscale: 2 # represents upsample factor 256x256 -> 512x512
block_expansion: 64
max_features: 512
num_down_blocks: 2
stitching_retargeting_module_params: # the S in the paper
stitching:
input_size: 126 # (21*3)*2
hidden_sizes: [128, 128, 64]
output_size: 65 # (21*3)+2(tx,ty)
lip:
input_size: 65 # (21*3)+2
hidden_sizes: [128, 128, 64]
output_size: 63 # (21*3)
eye:
input_size: 66 # (21*3)+3
hidden_sizes: [256, 256, 128, 128, 64]
output_size: 63 # (21*3)
# coding: utf-8
"""
Pipeline for gradio
"""
import os.path as osp
import gradio as gr
from .config.argument_config import ArgumentConfig
from .live_portrait_pipeline import LivePortraitPipeline
from .utils.io import load_img_online
from .utils.rprint import rlog as log
from .utils.crop import prepare_paste_back, paste_back
from .utils.camera import get_rotation_matrix
from .utils.helper import is_square_video
def update_args(args, user_args):
"""update the args according to user inputs
"""
for k, v in user_args.items():
if hasattr(args, k):
setattr(args, k, v)
return args
class GradioPipeline(LivePortraitPipeline):
def __init__(self, inference_cfg, crop_cfg, args: ArgumentConfig):
super().__init__(inference_cfg, crop_cfg)
# self.live_portrait_wrapper = self.live_portrait_wrapper
self.args = args
def execute_video(
self,
input_source_image_path=None,
input_source_video_path=None,
input_driving_video_path=None,
flag_relative_input=True,
flag_do_crop_input=True,
flag_remap_input=True,
flag_crop_driving_video_input=True,
flag_video_editing_head_rotation=False,
scale=2.3,
vx_ratio=0.0,
vy_ratio=-0.125,
scale_crop_driving_video=2.2,
vx_ratio_crop_driving_video=0.0,
vy_ratio_crop_driving_video=-0.1,
driving_smooth_observation_variance=1e-7,
tab_selection=None,
):
""" for video-driven potrait animation or video editing
"""
if tab_selection == 'Image':
input_source_path = input_source_image_path
elif tab_selection == 'Video':
input_source_path = input_source_video_path
else:
input_source_path = input_source_image_path
if input_source_path is not None and input_driving_video_path is not None:
if osp.exists(input_driving_video_path) and is_square_video(input_driving_video_path) is False:
flag_crop_driving_video_input = True
log("The source video is not square, the driving video will be cropped to square automatically.")
gr.Info("The source video is not square, the driving video will be cropped to square automatically.", duration=2)
args_user = {
'source': input_source_path,
'driving': input_driving_video_path,
'flag_relative_motion': flag_relative_input,
'flag_do_crop': flag_do_crop_input,
'flag_pasteback': flag_remap_input,
'flag_crop_driving_video': flag_crop_driving_video_input,
'flag_video_editing_head_rotation': flag_video_editing_head_rotation,
'scale': scale,
'vx_ratio': vx_ratio,
'vy_ratio': vy_ratio,
'scale_crop_driving_video': scale_crop_driving_video,
'vx_ratio_crop_driving_video': vx_ratio_crop_driving_video,
'vy_ratio_crop_driving_video': vy_ratio_crop_driving_video,
'driving_smooth_observation_variance': driving_smooth_observation_variance,
}
# update config from user input
self.args = update_args(self.args, args_user)
self.live_portrait_wrapper.update_config(self.args.__dict__)
self.cropper.update_config(self.args.__dict__)
# video driven animation
video_path, video_path_concat = self.execute(self.args)
gr.Info("Run successfully!", duration=2)
return video_path, video_path_concat,
else:
raise gr.Error("Please upload the source portrait or source video, and driving video 🤗🤗🤗", duration=5)
def execute_image(self, input_eye_ratio: float, input_lip_ratio: float, input_image, flag_do_crop=True):
""" for single image retargeting
"""
# disposable feature
f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb = \
self.prepare_retargeting(input_image, flag_do_crop)
if input_eye_ratio is None or input_lip_ratio is None:
raise gr.Error("Invalid ratio input 💥!", duration=5)
else:
inference_cfg = self.live_portrait_wrapper.inference_cfg
x_s_user = x_s_user.to(self.live_portrait_wrapper.device)
f_s_user = f_s_user.to(self.live_portrait_wrapper.device)
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
combined_eye_ratio_tensor = self.live_portrait_wrapper.calc_combined_eye_ratio([[input_eye_ratio]], source_lmk_user)
eyes_delta = self.live_portrait_wrapper.retarget_eye(x_s_user, combined_eye_ratio_tensor)
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
combined_lip_ratio_tensor = self.live_portrait_wrapper.calc_combined_lip_ratio([[input_lip_ratio]], source_lmk_user)
lip_delta = self.live_portrait_wrapper.retarget_lip(x_s_user, combined_lip_ratio_tensor)
# default: use x_s
x_d_new = x_s_user + eyes_delta + lip_delta
# D(W(f_s; x_s, x′_d))
out = self.live_portrait_wrapper.warp_decode(f_s_user, x_s_user, x_d_new)
out = self.live_portrait_wrapper.parse_output(out['out'])[0]
out_to_ori_blend = paste_back(out, crop_M_c2o, img_rgb, mask_ori)
gr.Info("Run successfully!", duration=2)
return out, out_to_ori_blend
def prepare_retargeting(self, input_image, flag_do_crop=True):
""" for single image retargeting
"""
if input_image is not None:
# gr.Info("Upload successfully!", duration=2)
inference_cfg = self.live_portrait_wrapper.inference_cfg
######## process source portrait ########
img_rgb = load_img_online(input_image, mode='rgb', max_dim=1280, n=16)
log(f"Load source image from {input_image}.")
crop_info = self.cropper.crop_source_image(img_rgb, self.cropper.crop_cfg)
if flag_do_crop:
I_s = self.live_portrait_wrapper.prepare_source(crop_info['img_crop_256x256'])
else:
I_s = self.live_portrait_wrapper.prepare_source(img_rgb)
x_s_info = self.live_portrait_wrapper.get_kp_info(I_s)
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
############################################
f_s_user = self.live_portrait_wrapper.extract_feature_3d(I_s)
x_s_user = self.live_portrait_wrapper.transform_keypoint(x_s_info)
source_lmk_user = crop_info['lmk_crop']
crop_M_c2o = crop_info['M_c2o']
mask_ori = prepare_paste_back(inference_cfg.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0]))
return f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb
else:
# when press the clear button, go here
raise gr.Error("Please upload a source portrait as the retargeting input 🤗🤗🤗", duration=5)
\ No newline at end of file
This diff is collapsed.
# coding: utf-8
"""
Wrapper for LivePortrait core functions
"""
import contextlib
import os.path as osp
import numpy as np
import cv2
import torch
import yaml
from .utils.timer import Timer
from .utils.helper import load_model, concat_feat
from .utils.camera import headpose_pred_to_degree, get_rotation_matrix
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
from .config.inference_config import InferenceConfig
from .utils.rprint import rlog as log
class LivePortraitWrapper(object):
def __init__(self, inference_cfg: InferenceConfig):
self.inference_cfg = inference_cfg
self.device_id = inference_cfg.device_id
self.compile = inference_cfg.flag_do_torch_compile
if inference_cfg.flag_force_cpu:
self.device = 'cpu'
else:
if torch.backends.mps.is_available():
self.device = 'mps'
else:
self.device = 'cuda:' + str(self.device_id)
model_config = yaml.load(open(inference_cfg.models_config, 'r'), Loader=yaml.SafeLoader)
# init F
self.appearance_feature_extractor = load_model(inference_cfg.checkpoint_F, model_config, self.device, 'appearance_feature_extractor')
log(f'Load appearance_feature_extractor done.')
# init M
self.motion_extractor = load_model(inference_cfg.checkpoint_M, model_config, self.device, 'motion_extractor')
log(f'Load motion_extractor done.')
# init W
self.warping_module = load_model(inference_cfg.checkpoint_W, model_config, self.device, 'warping_module')
log(f'Load warping_module done.')
# init G
self.spade_generator = load_model(inference_cfg.checkpoint_G, model_config, self.device, 'spade_generator')
log(f'Load spade_generator done.')
# init S and R
if inference_cfg.checkpoint_S is not None and osp.exists(inference_cfg.checkpoint_S):
self.stitching_retargeting_module = load_model(inference_cfg.checkpoint_S, model_config, self.device, 'stitching_retargeting_module')
log(f'Load stitching_retargeting_module done.')
else:
self.stitching_retargeting_module = None
# Optimize for inference
if self.compile:
torch._dynamo.config.suppress_errors = True # Suppress errors and fall back to eager execution
self.warping_module = torch.compile(self.warping_module, mode='max-autotune')
self.spade_generator = torch.compile(self.spade_generator, mode='max-autotune')
self.timer = Timer()
def inference_ctx(self):
if self.device == "mps":
ctx = contextlib.nullcontext()
else:
ctx = torch.autocast(device_type=self.device[:4], dtype=torch.float16,
enabled=self.inference_cfg.flag_use_half_precision)
return ctx
def update_config(self, user_args):
for k, v in user_args.items():
if hasattr(self.inference_cfg, k):
setattr(self.inference_cfg, k, v)
def prepare_source(self, img: np.ndarray) -> torch.Tensor:
""" construct the input as standard
img: HxWx3, uint8, 256x256
"""
h, w = img.shape[:2]
if h != self.inference_cfg.input_shape[0] or w != self.inference_cfg.input_shape[1]:
x = cv2.resize(img, (self.inference_cfg.input_shape[0], self.inference_cfg.input_shape[1]))
else:
x = img.copy()
if x.ndim == 3:
x = x[np.newaxis].astype(np.float32) / 255. # HxWx3 -> 1xHxWx3, normalized to 0~1
elif x.ndim == 4:
x = x.astype(np.float32) / 255. # BxHxWx3, normalized to 0~1
else:
raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
x = np.clip(x, 0, 1) # clip to 0~1
x = torch.from_numpy(x).permute(0, 3, 1, 2) # 1xHxWx3 -> 1x3xHxW
x = x.to(self.device)
return x
def prepare_videos(self, imgs) -> torch.Tensor:
""" construct the input as standard
imgs: NxBxHxWx3, uint8
"""
if isinstance(imgs, list):
_imgs = np.array(imgs)[..., np.newaxis] # TxHxWx3x1
elif isinstance(imgs, np.ndarray):
_imgs = imgs
else:
raise ValueError(f'imgs type error: {type(imgs)}')
y = _imgs.astype(np.float32) / 255.
y = np.clip(y, 0, 1) # clip to 0~1
y = torch.from_numpy(y).permute(0, 4, 3, 1, 2) # TxHxWx3x1 -> Tx1x3xHxW
y = y.to(self.device)
return y
def extract_feature_3d(self, x: torch.Tensor) -> torch.Tensor:
""" get the appearance feature of the image by F
x: Bx3xHxW, normalized to 0~1
"""
with torch.no_grad(), self.inference_ctx():
feature_3d = self.appearance_feature_extractor(x)
return feature_3d.float()
def get_kp_info(self, x: torch.Tensor, **kwargs) -> dict:
""" get the implicit keypoint information
x: Bx3xHxW, normalized to 0~1
flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape
return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
"""
with torch.no_grad(), self.inference_ctx():
kp_info = self.motion_extractor(x)
if self.inference_cfg.flag_use_half_precision:
# float the dict
for k, v in kp_info.items():
if isinstance(v, torch.Tensor):
kp_info[k] = v.float()
flag_refine_info: bool = kwargs.get('flag_refine_info', True)
if flag_refine_info:
bs = kp_info['kp'].shape[0]
kp_info['pitch'] = headpose_pred_to_degree(kp_info['pitch'])[:, None] # Bx1
kp_info['yaw'] = headpose_pred_to_degree(kp_info['yaw'])[:, None] # Bx1
kp_info['roll'] = headpose_pred_to_degree(kp_info['roll'])[:, None] # Bx1
kp_info['kp'] = kp_info['kp'].reshape(bs, -1, 3) # BxNx3
kp_info['exp'] = kp_info['exp'].reshape(bs, -1, 3) # BxNx3
return kp_info
def get_pose_dct(self, kp_info: dict) -> dict:
pose_dct = dict(
pitch=headpose_pred_to_degree(kp_info['pitch']).item(),
yaw=headpose_pred_to_degree(kp_info['yaw']).item(),
roll=headpose_pred_to_degree(kp_info['roll']).item(),
)
return pose_dct
def get_fs_and_kp_info(self, source_prepared, driving_first_frame):
# get the canonical keypoints of source image by M
source_kp_info = self.get_kp_info(source_prepared, flag_refine_info=True)
source_rotation = get_rotation_matrix(source_kp_info['pitch'], source_kp_info['yaw'], source_kp_info['roll'])
# get the canonical keypoints of first driving frame by M
driving_first_frame_kp_info = self.get_kp_info(driving_first_frame, flag_refine_info=True)
driving_first_frame_rotation = get_rotation_matrix(
driving_first_frame_kp_info['pitch'],
driving_first_frame_kp_info['yaw'],
driving_first_frame_kp_info['roll']
)
# get feature volume by F
source_feature_3d = self.extract_feature_3d(source_prepared)
return source_kp_info, source_rotation, source_feature_3d, driving_first_frame_kp_info, driving_first_frame_rotation
def transform_keypoint(self, kp_info: dict):
"""
transform the implicit keypoints with the pose, shift, and expression deformation
kp: BxNx3
"""
kp = kp_info['kp'] # (bs, k, 3)
pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll']
t, exp = kp_info['t'], kp_info['exp']
scale = kp_info['scale']
pitch = headpose_pred_to_degree(pitch)
yaw = headpose_pred_to_degree(yaw)
roll = headpose_pred_to_degree(roll)
bs = kp.shape[0]
if kp.ndim == 2:
num_kp = kp.shape[1] // 3 # Bx(num_kpx3)
else:
num_kp = kp.shape[1] # Bxnum_kpx3
rot_mat = get_rotation_matrix(pitch, yaw, roll) # (bs, 3, 3)
# Eqn.2: s * (R * x_c,s + exp) + t
kp_transformed = kp.view(bs, num_kp, 3) @ rot_mat + exp.view(bs, num_kp, 3)
kp_transformed *= scale[..., None] # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3)
kp_transformed[:, :, 0:2] += t[:, None, 0:2] # remove z, only apply tx ty
return kp_transformed
def retarget_eye(self, kp_source: torch.Tensor, eye_close_ratio: torch.Tensor) -> torch.Tensor:
"""
kp_source: BxNx3
eye_close_ratio: Bx3
Return: Bx(3*num_kp)
"""
feat_eye = concat_feat(kp_source, eye_close_ratio)
with torch.no_grad():
delta = self.stitching_retargeting_module['eye'](feat_eye)
return delta.reshape(-1, kp_source.shape[1], 3)
def retarget_lip(self, kp_source: torch.Tensor, lip_close_ratio: torch.Tensor) -> torch.Tensor:
"""
kp_source: BxNx3
lip_close_ratio: Bx2
Return: Bx(3*num_kp)
"""
feat_lip = concat_feat(kp_source, lip_close_ratio)
with torch.no_grad():
delta = self.stitching_retargeting_module['lip'](feat_lip)
return delta.reshape(-1, kp_source.shape[1], 3)
def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
"""
kp_source: BxNx3
kp_driving: BxNx3
Return: Bx(3*num_kp+2)
"""
feat_stiching = concat_feat(kp_source, kp_driving)
with torch.no_grad():
delta = self.stitching_retargeting_module['stitching'](feat_stiching)
return delta
def stitching(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
""" conduct the stitching
kp_source: Bxnum_kpx3
kp_driving: Bxnum_kpx3
"""
if self.stitching_retargeting_module is not None:
bs, num_kp = kp_source.shape[:2]
kp_driving_new = kp_driving.clone()
delta = self.stitch(kp_source, kp_driving_new)
delta_exp = delta[..., :3*num_kp].reshape(bs, num_kp, 3) # 1x20x3
delta_tx_ty = delta[..., 3*num_kp:3*num_kp+2].reshape(bs, 1, 2) # 1x1x2
kp_driving_new += delta_exp
kp_driving_new[..., :2] += delta_tx_ty
return kp_driving_new
return kp_driving
def warp_decode(self, feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
""" get the image after the warping of the implicit keypoints
feature_3d: Bx32x16x64x64, feature volume
kp_source: BxNx3
kp_driving: BxNx3
"""
# The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i))
with torch.no_grad(), self.inference_ctx():
if self.compile:
# Mark the beginning of a new CUDA Graph step
torch.compiler.cudagraph_mark_step_begin()
# get decoder input
ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
# decode
ret_dct['out'] = self.spade_generator(feature=ret_dct['out'])
# float the dict
if self.inference_cfg.flag_use_half_precision:
for k, v in ret_dct.items():
if isinstance(v, torch.Tensor):
ret_dct[k] = v.float()
return ret_dct
def parse_output(self, out: torch.Tensor) -> np.ndarray:
""" construct the output as standard
return: 1xHxWx3, uint8
"""
out = np.transpose(out.data.cpu().numpy(), [0, 2, 3, 1]) # 1x3xHxW -> 1xHxWx3
out = np.clip(out, 0, 1) # clip to 0~1
out = np.clip(out * 255, 0, 255).astype(np.uint8) # 0~1 -> 0~255
return out
def calc_ratio(self, lmk_lst):
input_eye_ratio_lst = []
input_lip_ratio_lst = []
for lmk in lmk_lst:
# for eyes retargeting
input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None]))
# for lip retargeting
input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None]))
return input_eye_ratio_lst, input_lip_ratio_lst
def calc_combined_eye_ratio(self, c_d_eyes_i, source_lmk):
c_s_eyes = calc_eye_close_ratio(source_lmk[None])
c_s_eyes_tensor = torch.from_numpy(c_s_eyes).float().to(self.device)
c_d_eyes_i_tensor = torch.Tensor([c_d_eyes_i[0][0]]).reshape(1, 1).to(self.device)
# [c_s,eyes, c_d,eyes,i]
combined_eye_ratio_tensor = torch.cat([c_s_eyes_tensor, c_d_eyes_i_tensor], dim=1)
return combined_eye_ratio_tensor
def calc_combined_lip_ratio(self, c_d_lip_i, source_lmk):
c_s_lip = calc_lip_close_ratio(source_lmk[None])
c_s_lip_tensor = torch.from_numpy(c_s_lip).float().to(self.device)
c_d_lip_i_tensor = torch.Tensor([c_d_lip_i[0]]).to(self.device).reshape(1, 1) # 1x1
# [c_s,lip, c_d,lip,i]
combined_lip_ratio_tensor = torch.cat([c_s_lip_tensor, c_d_lip_i_tensor], dim=1) # 1x2
return combined_lip_ratio_tensor
\ No newline at end of file
# coding: utf-8
"""
Appearance extractor(F) defined in paper, which maps the source image s to a 3D appearance feature volume.
"""
import torch
from torch import nn
from .util import SameBlock2d, DownBlock2d, ResBlock3d
class AppearanceFeatureExtractor(nn.Module):
def __init__(self, image_channel, block_expansion, num_down_blocks, max_features, reshape_channel, reshape_depth, num_resblocks):
super(AppearanceFeatureExtractor, self).__init__()
self.image_channel = image_channel
self.block_expansion = block_expansion
self.num_down_blocks = num_down_blocks
self.max_features = max_features
self.reshape_channel = reshape_channel
self.reshape_depth = reshape_depth
self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1))
down_blocks = []
for i in range(num_down_blocks):
in_features = min(max_features, block_expansion * (2 ** i))
out_features = min(max_features, block_expansion * (2 ** (i + 1)))
down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
self.down_blocks = nn.ModuleList(down_blocks)
self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1)
self.resblocks_3d = torch.nn.Sequential()
for i in range(num_resblocks):
self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1))
def forward(self, source_image):
out = self.first(source_image) # Bx3x256x256 -> Bx64x256x256
for i in range(len(self.down_blocks)):
out = self.down_blocks[i](out)
out = self.second(out)
bs, c, h, w = out.shape # ->Bx512x64x64
f_s = out.view(bs, self.reshape_channel, self.reshape_depth, h, w) # ->Bx32x16x64x64
f_s = self.resblocks_3d(f_s) # ->Bx32x16x64x64
return f_s
# coding: utf-8
"""
This moudle is adapted to the ConvNeXtV2 version for the extraction of implicit keypoints, poses, and expression deformation.
"""
import torch
import torch.nn as nn
# from timm.models.layers import trunc_normal_, DropPath
from .util import LayerNorm, DropPath, trunc_normal_, GRN
__all__ = ['convnextv2_tiny']
class Block(nn.Module):
""" ConvNeXtV2 Block.
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
"""
def __init__(self, dim, drop_path=0.):
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
self.norm = LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.grn = GRN(4 * dim)
self.pwconv2 = nn.Linear(4 * dim, dim)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.grn(x)
x = self.pwconv2(x)
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
x = input + self.drop_path(x)
return x
class ConvNeXtV2(nn.Module):
""" ConvNeXt V2
Args:
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
drop_path_rate (float): Stochastic depth rate. Default: 0.
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
"""
def __init__(
self,
in_chans=3,
depths=[3, 3, 9, 3],
dims=[96, 192, 384, 768],
drop_path_rate=0.,
**kwargs
):
super().__init__()
self.depths = depths
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
)
self.downsample_layers.append(stem)
for i in range(3):
downsample_layer = nn.Sequential(
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
)
self.downsample_layers.append(downsample_layer)
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
cur = 0
for i in range(4):
stage = nn.Sequential(
*[Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])]
)
self.stages.append(stage)
cur += depths[i]
self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
# NOTE: the output semantic items
num_bins = kwargs.get('num_bins', 66)
num_kp = kwargs.get('num_kp', 24) # the number of implicit keypoints
self.fc_kp = nn.Linear(dims[-1], 3 * num_kp) # implicit keypoints
# print('dims[-1]: ', dims[-1])
self.fc_scale = nn.Linear(dims[-1], 1) # scale
self.fc_pitch = nn.Linear(dims[-1], num_bins) # pitch bins
self.fc_yaw = nn.Linear(dims[-1], num_bins) # yaw bins
self.fc_roll = nn.Linear(dims[-1], num_bins) # roll bins
self.fc_t = nn.Linear(dims[-1], 3) # translation
self.fc_exp = nn.Linear(dims[-1], 3 * num_kp) # expression / delta
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
nn.init.constant_(m.bias, 0)
def forward_features(self, x):
for i in range(4):
x = self.downsample_layers[i](x)
x = self.stages[i](x)
return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
def forward(self, x):
x = self.forward_features(x)
# implicit keypoints
kp = self.fc_kp(x)
# pose and expression deformation
pitch = self.fc_pitch(x)
yaw = self.fc_yaw(x)
roll = self.fc_roll(x)
t = self.fc_t(x)
exp = self.fc_exp(x)
scale = self.fc_scale(x)
ret_dct = {
'pitch': pitch,
'yaw': yaw,
'roll': roll,
't': t,
'exp': exp,
'scale': scale,
'kp': kp, # canonical keypoint
}
return ret_dct
def convnextv2_tiny(**kwargs):
model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
return model
# coding: utf-8
"""
The module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving
"""
from torch import nn
import torch.nn.functional as F
import torch
from .util import Hourglass, make_coordinate_grid, kp2gaussian
class DenseMotionNetwork(nn.Module):
def __init__(self, block_expansion, num_blocks, max_features, num_kp, feature_channel, reshape_depth, compress, estimate_occlusion_map=True):
super(DenseMotionNetwork, self).__init__()
self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(compress+1), max_features=max_features, num_blocks=num_blocks) # ~60+G
self.mask = nn.Conv3d(self.hourglass.out_filters, num_kp + 1, kernel_size=7, padding=3) # 65G! NOTE: computation cost is large
self.compress = nn.Conv3d(feature_channel, compress, kernel_size=1) # 0.8G
self.norm = nn.BatchNorm3d(compress, affine=True)
self.num_kp = num_kp
self.flag_estimate_occlusion_map = estimate_occlusion_map
if self.flag_estimate_occlusion_map:
self.occlusion = nn.Conv2d(self.hourglass.out_filters*reshape_depth, 1, kernel_size=7, padding=3)
else:
self.occlusion = None
def create_sparse_motions(self, feature, kp_driving, kp_source):
bs, _, d, h, w = feature.shape # (bs, 4, 16, 64, 64)
identity_grid = make_coordinate_grid((d, h, w), ref=kp_source) # (16, 64, 64, 3)
identity_grid = identity_grid.view(1, 1, d, h, w, 3) # (1, 1, d=16, h=64, w=64, 3)
coordinate_grid = identity_grid - kp_driving.view(bs, self.num_kp, 1, 1, 1, 3)
k = coordinate_grid.shape[1]
# NOTE: there lacks an one-order flow
driving_to_source = coordinate_grid + kp_source.view(bs, self.num_kp, 1, 1, 1, 3) # (bs, num_kp, d, h, w, 3)
# adding background feature
identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1, 1)
sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) # (bs, 1+num_kp, d, h, w, 3)
return sparse_motions
def create_deformed_feature(self, feature, sparse_motions):
bs, _, d, h, w = feature.shape
feature_repeat = feature.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp+1, 1, 1, 1, 1, 1) # (bs, num_kp+1, 1, c, d, h, w)
feature_repeat = feature_repeat.view(bs * (self.num_kp+1), -1, d, h, w) # (bs*(num_kp+1), c, d, h, w)
sparse_motions = sparse_motions.view((bs * (self.num_kp+1), d, h, w, -1)) # (bs*(num_kp+1), d, h, w, 3)
sparse_deformed = F.grid_sample(feature_repeat, sparse_motions, align_corners=False)
sparse_deformed = sparse_deformed.view((bs, self.num_kp+1, -1, d, h, w)) # (bs, num_kp+1, c, d, h, w)
return sparse_deformed
def create_heatmap_representations(self, feature, kp_driving, kp_source):
spatial_size = feature.shape[3:] # (d=16, h=64, w=64)
gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=0.01) # (bs, num_kp, d, h, w)
gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=0.01) # (bs, num_kp, d, h, w)
heatmap = gaussian_driving - gaussian_source # (bs, num_kp, d, h, w)
# adding background feature
zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.dtype).to(heatmap.device)
heatmap = torch.cat([zeros, heatmap], dim=1)
heatmap = heatmap.unsqueeze(2) # (bs, 1+num_kp, 1, d, h, w)
return heatmap
def forward(self, feature, kp_driving, kp_source):
bs, _, d, h, w = feature.shape # (bs, 32, 16, 64, 64)
feature = self.compress(feature) # (bs, 4, 16, 64, 64)
feature = self.norm(feature) # (bs, 4, 16, 64, 64)
feature = F.relu(feature) # (bs, 4, 16, 64, 64)
out_dict = dict()
# 1. deform 3d feature
sparse_motion = self.create_sparse_motions(feature, kp_driving, kp_source) # (bs, 1+num_kp, d, h, w, 3)
deformed_feature = self.create_deformed_feature(feature, sparse_motion) # (bs, 1+num_kp, c=4, d=16, h=64, w=64)
# 2. (bs, 1+num_kp, d, h, w)
heatmap = self.create_heatmap_representations(deformed_feature, kp_driving, kp_source) # (bs, 1+num_kp, 1, d, h, w)
input = torch.cat([heatmap, deformed_feature], dim=2) # (bs, 1+num_kp, c=5, d=16, h=64, w=64)
input = input.view(bs, -1, d, h, w) # (bs, (1+num_kp)*c=105, d=16, h=64, w=64)
prediction = self.hourglass(input)
mask = self.mask(prediction)
mask = F.softmax(mask, dim=1) # (bs, 1+num_kp, d=16, h=64, w=64)
out_dict['mask'] = mask
mask = mask.unsqueeze(2) # (bs, num_kp+1, 1, d, h, w)
sparse_motion = sparse_motion.permute(0, 1, 5, 2, 3, 4) # (bs, num_kp+1, 3, d, h, w)
deformation = (sparse_motion * mask).sum(dim=1) # (bs, 3, d, h, w) mask take effect in this place
deformation = deformation.permute(0, 2, 3, 4, 1) # (bs, d, h, w, 3)
out_dict['deformation'] = deformation
if self.flag_estimate_occlusion_map:
bs, _, d, h, w = prediction.shape
prediction_reshape = prediction.view(bs, -1, h, w)
occlusion_map = torch.sigmoid(self.occlusion(prediction_reshape)) # Bx1x64x64
out_dict['occlusion_map'] = occlusion_map
return out_dict
# coding: utf-8
"""
Motion extractor(M), which directly predicts the canonical keypoints, head pose and expression deformation of the input image
"""
from torch import nn
import torch
from .convnextv2 import convnextv2_tiny
from .util import filter_state_dict
model_dict = {
'convnextv2_tiny': convnextv2_tiny,
}
class MotionExtractor(nn.Module):
def __init__(self, **kwargs):
super(MotionExtractor, self).__init__()
# default is convnextv2_base
backbone = kwargs.get('backbone', 'convnextv2_tiny')
self.detector = model_dict.get(backbone)(**kwargs)
def load_pretrained(self, init_path: str):
if init_path not in (None, ''):
state_dict = torch.load(init_path, map_location=lambda storage, loc: storage)['model']
state_dict = filter_state_dict(state_dict, remove_name='head')
ret = self.detector.load_state_dict(state_dict, strict=False)
print(f'Load pretrained model from {init_path}, ret: {ret}')
def forward(self, x):
out = self.detector(x)
return out
# coding: utf-8
"""
Spade decoder(G) defined in the paper, which input the warped feature to generate the animated image.
"""
import torch
from torch import nn
import torch.nn.functional as F
from .util import SPADEResnetBlock
class SPADEDecoder(nn.Module):
def __init__(self, upscale=1, max_features=256, block_expansion=64, out_channels=64, num_down_blocks=2):
for i in range(num_down_blocks):
input_channels = min(max_features, block_expansion * (2 ** (i + 1)))
self.upscale = upscale
super().__init__()
norm_G = 'spadespectralinstance'
label_num_channels = input_channels # 256
self.fc = nn.Conv2d(input_channels, 2 * input_channels, 3, padding=1)
self.G_middle_0 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
self.G_middle_1 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
self.G_middle_2 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
self.G_middle_3 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
self.G_middle_4 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
self.G_middle_5 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
self.up_0 = SPADEResnetBlock(2 * input_channels, input_channels, norm_G, label_num_channels)
self.up_1 = SPADEResnetBlock(input_channels, out_channels, norm_G, label_num_channels)
self.up = nn.Upsample(scale_factor=2)
if self.upscale is None or self.upscale <= 1:
self.conv_img = nn.Conv2d(out_channels, 3, 3, padding=1)
else:
self.conv_img = nn.Sequential(
nn.Conv2d(out_channels, 3 * (2 * 2), kernel_size=3, padding=1),
nn.PixelShuffle(upscale_factor=2)
)
def forward(self, feature):
seg = feature # Bx256x64x64
x = self.fc(feature) # Bx512x64x64
x = self.G_middle_0(x, seg)
x = self.G_middle_1(x, seg)
x = self.G_middle_2(x, seg)
x = self.G_middle_3(x, seg)
x = self.G_middle_4(x, seg)
x = self.G_middle_5(x, seg)
x = self.up(x) # Bx512x64x64 -> Bx512x128x128
x = self.up_0(x, seg) # Bx512x128x128 -> Bx256x128x128
x = self.up(x) # Bx256x128x128 -> Bx256x256x256
x = self.up_1(x, seg) # Bx256x256x256 -> Bx64x256x256
x = self.conv_img(F.leaky_relu(x, 2e-1)) # Bx64x256x256 -> Bx3xHxW
x = torch.sigmoid(x) # Bx3xHxW
return x
\ No newline at end of file
# coding: utf-8
"""
Stitching module(S) and two retargeting modules(R) defined in the paper.
- The stitching module pastes the animated portrait back into the original image space without pixel misalignment, such as in
the stitching region.
- The eyes retargeting module is designed to address the issue of incomplete eye closure during cross-id reenactment, especially
when a person with small eyes drives a person with larger eyes.
- The lip retargeting module is designed similarly to the eye retargeting module, and can also normalize the input by ensuring that
the lips are in a closed state, which facilitates better animation driving.
"""
from torch import nn
class StitchingRetargetingNetwork(nn.Module):
def __init__(self, input_size, hidden_sizes, output_size):
super(StitchingRetargetingNetwork, self).__init__()
layers = []
for i in range(len(hidden_sizes)):
if i == 0:
layers.append(nn.Linear(input_size, hidden_sizes[i]))
else:
layers.append(nn.Linear(hidden_sizes[i - 1], hidden_sizes[i]))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Linear(hidden_sizes[-1], output_size))
self.mlp = nn.Sequential(*layers)
def initialize_weights_to_zero(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.zeros_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
return self.mlp(x)
# coding: utf-8
"""
This file defines various neural network modules and utility functions, including convolutional and residual blocks,
normalizations, and functions for spatial transformation and tensor manipulation.
"""
from torch import nn
import torch.nn.functional as F
import torch
import torch.nn.utils.spectral_norm as spectral_norm
import math
import warnings
def kp2gaussian(kp, spatial_size, kp_variance):
"""
Transform a keypoint into gaussian like representation
"""
mean = kp
coordinate_grid = make_coordinate_grid(spatial_size, mean)
number_of_leading_dimensions = len(mean.shape) - 1
shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
coordinate_grid = coordinate_grid.view(*shape)
repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1)
coordinate_grid = coordinate_grid.repeat(*repeats)
# Preprocess kp shape
shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3)
mean = mean.view(*shape)
mean_sub = (coordinate_grid - mean)
out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
return out
def make_coordinate_grid(spatial_size, ref, **kwargs):
d, h, w = spatial_size
x = torch.arange(w).type(ref.dtype).to(ref.device)
y = torch.arange(h).type(ref.dtype).to(ref.device)
z = torch.arange(d).type(ref.dtype).to(ref.device)
# NOTE: must be right-down-in
x = (2 * (x / (w - 1)) - 1) # the x axis faces to the right
y = (2 * (y / (h - 1)) - 1) # the y axis faces to the bottom
z = (2 * (z / (d - 1)) - 1) # the z axis faces to the inner
yy = y.view(1, -1, 1).repeat(d, 1, w)
xx = x.view(1, 1, -1).repeat(d, h, 1)
zz = z.view(-1, 1, 1).repeat(1, h, w)
meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3)
return meshed
class ConvT2d(nn.Module):
"""
Upsampling block for use in decoder.
"""
def __init__(self, in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1):
super(ConvT2d, self).__init__()
self.convT = nn.ConvTranspose2d(in_features, out_features, kernel_size=kernel_size, stride=stride,
padding=padding, output_padding=output_padding)
self.norm = nn.InstanceNorm2d(out_features)
def forward(self, x):
out = self.convT(x)
out = self.norm(out)
out = F.leaky_relu(out)
return out
class ResBlock3d(nn.Module):
"""
Res block, preserve spatial resolution.
"""
def __init__(self, in_features, kernel_size, padding):
super(ResBlock3d, self).__init__()
self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding)
self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, padding=padding)
self.norm1 = nn.BatchNorm3d(in_features, affine=True)
self.norm2 = nn.BatchNorm3d(in_features, affine=True)
def forward(self, x):
out = self.norm1(x)
out = F.relu(out)
out = self.conv1(out)
out = self.norm2(out)
out = F.relu(out)
out = self.conv2(out)
out += x
return out
class UpBlock3d(nn.Module):
"""
Upsampling block for use in decoder.
"""
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
super(UpBlock3d, self).__init__()
self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
padding=padding, groups=groups)
self.norm = nn.BatchNorm3d(out_features, affine=True)
def forward(self, x):
out = F.interpolate(x, scale_factor=(1, 2, 2))
out = self.conv(out)
out = self.norm(out)
out = F.relu(out)
return out
class DownBlock2d(nn.Module):
"""
Downsampling block for use in encoder.
"""
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
super(DownBlock2d, self).__init__()
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups)
self.norm = nn.BatchNorm2d(out_features, affine=True)
self.pool = nn.AvgPool2d(kernel_size=(2, 2))
def forward(self, x):
out = self.conv(x)
out = self.norm(out)
out = F.relu(out)
out = self.pool(out)
return out
class DownBlock3d(nn.Module):
"""
Downsampling block for use in encoder.
"""
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
super(DownBlock3d, self).__init__()
'''
self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
padding=padding, groups=groups, stride=(1, 2, 2))
'''
self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
padding=padding, groups=groups)
self.norm = nn.BatchNorm3d(out_features, affine=True)
self.pool = nn.AvgPool3d(kernel_size=(1, 2, 2))
def forward(self, x):
out = self.conv(x)
out = self.norm(out)
out = F.relu(out)
out = self.pool(out)
return out
class SameBlock2d(nn.Module):
"""
Simple block, preserve spatial resolution.
"""
def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1, lrelu=False):
super(SameBlock2d, self).__init__()
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, padding=padding, groups=groups)
self.norm = nn.BatchNorm2d(out_features, affine=True)
if lrelu:
self.ac = nn.LeakyReLU()
else:
self.ac = nn.ReLU()
def forward(self, x):
out = self.conv(x)
out = self.norm(out)
out = self.ac(out)
return out
class Encoder(nn.Module):
"""
Hourglass Encoder
"""
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
super(Encoder, self).__init__()
down_blocks = []
for i in range(num_blocks):
down_blocks.append(DownBlock3d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), min(max_features, block_expansion * (2 ** (i + 1))), kernel_size=3, padding=1))
self.down_blocks = nn.ModuleList(down_blocks)
def forward(self, x):
outs = [x]
for down_block in self.down_blocks:
outs.append(down_block(outs[-1]))
return outs
class Decoder(nn.Module):
"""
Hourglass Decoder
"""
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
super(Decoder, self).__init__()
up_blocks = []
for i in range(num_blocks)[::-1]:
in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
out_filters = min(max_features, block_expansion * (2 ** i))
up_blocks.append(UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1))
self.up_blocks = nn.ModuleList(up_blocks)
self.out_filters = block_expansion + in_features
self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1)
self.norm = nn.BatchNorm3d(self.out_filters, affine=True)
def forward(self, x):
out = x.pop()
for up_block in self.up_blocks:
out = up_block(out)
skip = x.pop()
out = torch.cat([out, skip], dim=1)
out = self.conv(out)
out = self.norm(out)
out = F.relu(out)
return out
class Hourglass(nn.Module):
"""
Hourglass architecture.
"""
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
super(Hourglass, self).__init__()
self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
self.out_filters = self.decoder.out_filters
def forward(self, x):
return self.decoder(self.encoder(x))
class SPADE(nn.Module):
def __init__(self, norm_nc, label_nc):
super().__init__()
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
nhidden = 128
self.mlp_shared = nn.Sequential(
nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1),
nn.ReLU())
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
def forward(self, x, segmap):
normalized = self.param_free_norm(x)
segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
actv = self.mlp_shared(segmap)
gamma = self.mlp_gamma(actv)
beta = self.mlp_beta(actv)
out = normalized * (1 + gamma) + beta
return out
class SPADEResnetBlock(nn.Module):
def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1):
super().__init__()
# Attributes
self.learned_shortcut = (fin != fout)
fmiddle = min(fin, fout)
self.use_se = use_se
# create conv layers
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation)
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation)
if self.learned_shortcut:
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
# apply spectral norm if specified
if 'spectral' in norm_G:
self.conv_0 = spectral_norm(self.conv_0)
self.conv_1 = spectral_norm(self.conv_1)
if self.learned_shortcut:
self.conv_s = spectral_norm(self.conv_s)
# define normalization layers
self.norm_0 = SPADE(fin, label_nc)
self.norm_1 = SPADE(fmiddle, label_nc)
if self.learned_shortcut:
self.norm_s = SPADE(fin, label_nc)
def forward(self, x, seg1):
x_s = self.shortcut(x, seg1)
dx = self.conv_0(self.actvn(self.norm_0(x, seg1)))
dx = self.conv_1(self.actvn(self.norm_1(dx, seg1)))
out = x_s + dx
return out
def shortcut(self, x, seg1):
if self.learned_shortcut:
x_s = self.conv_s(self.norm_s(x, seg1))
else:
x_s = x
return x_s
def actvn(self, x):
return F.leaky_relu(x, 2e-1)
def filter_state_dict(state_dict, remove_name='fc'):
new_state_dict = {}
for key in state_dict:
if remove_name in key:
continue
new_state_dict[key] = state_dict[key]
return new_state_dict
class GRN(nn.Module):
""" GRN (Global Response Normalization) layer
"""
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x
class LayerNorm(nn.Module):
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape, )
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def drop_path(x, drop_prob=0., training=False, scale_by_keep=True):
""" Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
class DropPath(nn.Module):
""" Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None, scale_by_keep=True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
# coding: utf-8
"""
Warping field estimator(W) defined in the paper, which generates a warping field using the implicit
keypoint representations x_s and x_d, and employs this flow field to warp the source feature volume f_s.
"""
from torch import nn
import torch.nn.functional as F
from .util import SameBlock2d
from .dense_motion import DenseMotionNetwork
class WarpingNetwork(nn.Module):
def __init__(
self,
num_kp,
block_expansion,
max_features,
num_down_blocks,
reshape_channel,
estimate_occlusion_map=False,
dense_motion_params=None,
**kwargs
):
super(WarpingNetwork, self).__init__()
self.upscale = kwargs.get('upscale', 1)
self.flag_use_occlusion_map = kwargs.get('flag_use_occlusion_map', True)
if dense_motion_params is not None:
self.dense_motion_network = DenseMotionNetwork(
num_kp=num_kp,
feature_channel=reshape_channel,
estimate_occlusion_map=estimate_occlusion_map,
**dense_motion_params
)
else:
self.dense_motion_network = None
self.third = SameBlock2d(max_features, block_expansion * (2 ** num_down_blocks), kernel_size=(3, 3), padding=(1, 1), lrelu=True)
self.fourth = nn.Conv2d(in_channels=block_expansion * (2 ** num_down_blocks), out_channels=block_expansion * (2 ** num_down_blocks), kernel_size=1, stride=1)
self.estimate_occlusion_map = estimate_occlusion_map
def deform_input(self, inp, deformation):
return F.grid_sample(inp, deformation, align_corners=False)
def forward(self, feature_3d, kp_driving, kp_source):
if self.dense_motion_network is not None:
# Feature warper, Transforming feature representation according to deformation and occlusion
dense_motion = self.dense_motion_network(
feature=feature_3d, kp_driving=kp_driving, kp_source=kp_source
)
if 'occlusion_map' in dense_motion:
occlusion_map = dense_motion['occlusion_map'] # Bx1x64x64
else:
occlusion_map = None
deformation = dense_motion['deformation'] # Bx16x64x64x3
out = self.deform_input(feature_3d, deformation) # Bx32x16x64x64
bs, c, d, h, w = out.shape # Bx32x16x64x64
out = out.view(bs, c * d, h, w) # -> Bx512x64x64
out = self.third(out) # -> Bx256x64x64
out = self.fourth(out) # -> Bx256x64x64
if self.flag_use_occlusion_map and (occlusion_map is not None):
out = out * occlusion_map
ret_dct = {
'occlusion_map': occlusion_map,
'deformation': deformation,
'out': out,
}
return ret_dct
# coding: utf-8
"""
functions for processing and transforming 3D facial keypoints
"""
import numpy as np
import torch
import torch.nn.functional as F
PI = np.pi
def headpose_pred_to_degree(pred):
"""
pred: (bs, 66) or (bs, 1) or others
"""
if pred.ndim > 1 and pred.shape[1] == 66:
# NOTE: note that the average is modified to 97.5
device = pred.device
idx_tensor = [idx for idx in range(0, 66)]
idx_tensor = torch.FloatTensor(idx_tensor).to(device)
pred = F.softmax(pred, dim=1)
degree = torch.sum(pred*idx_tensor, axis=1) * 3 - 97.5
return degree
return pred
def get_rotation_matrix(pitch_, yaw_, roll_):
""" the input is in degree
"""
# transform to radian
pitch = pitch_ / 180 * PI
yaw = yaw_ / 180 * PI
roll = roll_ / 180 * PI
device = pitch.device
if pitch.ndim == 1:
pitch = pitch.unsqueeze(1)
if yaw.ndim == 1:
yaw = yaw.unsqueeze(1)
if roll.ndim == 1:
roll = roll.unsqueeze(1)
# calculate the euler matrix
bs = pitch.shape[0]
ones = torch.ones([bs, 1]).to(device)
zeros = torch.zeros([bs, 1]).to(device)
x, y, z = pitch, yaw, roll
rot_x = torch.cat([
ones, zeros, zeros,
zeros, torch.cos(x), -torch.sin(x),
zeros, torch.sin(x), torch.cos(x)
], dim=1).reshape([bs, 3, 3])
rot_y = torch.cat([
torch.cos(y), zeros, torch.sin(y),
zeros, ones, zeros,
-torch.sin(y), zeros, torch.cos(y)
], dim=1).reshape([bs, 3, 3])
rot_z = torch.cat([
torch.cos(z), -torch.sin(z), zeros,
torch.sin(z), torch.cos(z), zeros,
zeros, zeros, ones
], dim=1).reshape([bs, 3, 3])
rot = rot_z @ rot_y @ rot_x
return rot.permute(0, 2, 1) # transpose
# coding: utf-8
"""
cropping function and the related preprocess functions for cropping
"""
import numpy as np
import os.path as osp
from math import sin, cos, acos, degrees
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) # NOTE: enforce single thread
from .rprint import rprint as print
DTYPE = np.float32
CV2_INTERP = cv2.INTER_LINEAR
def make_abs_path(fn):
return osp.join(osp.dirname(osp.realpath(__file__)), fn)
def _transform_img(img, M, dsize, flags=CV2_INTERP, borderMode=None):
""" conduct similarity or affine transformation to the image, do not do border operation!
img:
M: 2x3 matrix or 3x3 matrix
dsize: target shape (width, height)
"""
if isinstance(dsize, tuple) or isinstance(dsize, list):
_dsize = tuple(dsize)
else:
_dsize = (dsize, dsize)
if borderMode is not None:
return cv2.warpAffine(img, M[:2, :], dsize=_dsize, flags=flags, borderMode=borderMode, borderValue=(0, 0, 0))
else:
return cv2.warpAffine(img, M[:2, :], dsize=_dsize, flags=flags)
def _transform_pts(pts, M):
""" conduct similarity or affine transformation to the pts
pts: Nx2 ndarray
M: 2x3 matrix or 3x3 matrix
return: Nx2
"""
return pts @ M[:2, :2].T + M[:2, 2]
def parse_pt2_from_pt101(pt101, use_lip=True):
"""
parsing the 2 points according to the 101 points, which cancels the roll
"""
# the former version use the eye center, but it is not robust, now use interpolation
pt_left_eye = np.mean(pt101[[39, 42, 45, 48]], axis=0) # left eye center
pt_right_eye = np.mean(pt101[[51, 54, 57, 60]], axis=0) # right eye center
if use_lip:
# use lip
pt_center_eye = (pt_left_eye + pt_right_eye) / 2
pt_center_lip = (pt101[75] + pt101[81]) / 2
pt2 = np.stack([pt_center_eye, pt_center_lip], axis=0)
else:
pt2 = np.stack([pt_left_eye, pt_right_eye], axis=0)
return pt2
def parse_pt2_from_pt106(pt106, use_lip=True):
"""
parsing the 2 points according to the 106 points, which cancels the roll
"""
pt_left_eye = np.mean(pt106[[33, 35, 40, 39]], axis=0) # left eye center
pt_right_eye = np.mean(pt106[[87, 89, 94, 93]], axis=0) # right eye center
if use_lip:
# use lip
pt_center_eye = (pt_left_eye + pt_right_eye) / 2
pt_center_lip = (pt106[52] + pt106[61]) / 2
pt2 = np.stack([pt_center_eye, pt_center_lip], axis=0)
else:
pt2 = np.stack([pt_left_eye, pt_right_eye], axis=0)
return pt2
def parse_pt2_from_pt203(pt203, use_lip=True):
"""
parsing the 2 points according to the 203 points, which cancels the roll
"""
pt_left_eye = np.mean(pt203[[0, 6, 12, 18]], axis=0) # left eye center
pt_right_eye = np.mean(pt203[[24, 30, 36, 42]], axis=0) # right eye center
if use_lip:
# use lip
pt_center_eye = (pt_left_eye + pt_right_eye) / 2
pt_center_lip = (pt203[48] + pt203[66]) / 2
pt2 = np.stack([pt_center_eye, pt_center_lip], axis=0)
else:
pt2 = np.stack([pt_left_eye, pt_right_eye], axis=0)
return pt2
def parse_pt2_from_pt68(pt68, use_lip=True):
"""
parsing the 2 points according to the 68 points, which cancels the roll
"""
lm_idx = np.array([31, 37, 40, 43, 46, 49, 55], dtype=np.int32) - 1
if use_lip:
pt5 = np.stack([
np.mean(pt68[lm_idx[[1, 2]], :], 0), # left eye
np.mean(pt68[lm_idx[[3, 4]], :], 0), # right eye
pt68[lm_idx[0], :], # nose
pt68[lm_idx[5], :], # lip
pt68[lm_idx[6], :] # lip
], axis=0)
pt2 = np.stack([
(pt5[0] + pt5[1]) / 2,
(pt5[3] + pt5[4]) / 2
], axis=0)
else:
pt2 = np.stack([
np.mean(pt68[lm_idx[[1, 2]], :], 0), # left eye
np.mean(pt68[lm_idx[[3, 4]], :], 0), # right eye
], axis=0)
return pt2
def parse_pt2_from_pt5(pt5, use_lip=True):
"""
parsing the 2 points according to the 5 points, which cancels the roll
"""
if use_lip:
pt2 = np.stack([
(pt5[0] + pt5[1]) / 2,
(pt5[3] + pt5[4]) / 2
], axis=0)
else:
pt2 = np.stack([
pt5[0],
pt5[1]
], axis=0)
return pt2
def parse_pt2_from_pt_x(pts, use_lip=True):
if pts.shape[0] == 101:
pt2 = parse_pt2_from_pt101(pts, use_lip=use_lip)
elif pts.shape[0] == 106:
pt2 = parse_pt2_from_pt106(pts, use_lip=use_lip)
elif pts.shape[0] == 68:
pt2 = parse_pt2_from_pt68(pts, use_lip=use_lip)
elif pts.shape[0] == 5:
pt2 = parse_pt2_from_pt5(pts, use_lip=use_lip)
elif pts.shape[0] == 203:
pt2 = parse_pt2_from_pt203(pts, use_lip=use_lip)
elif pts.shape[0] > 101:
# take the first 101 points
pt2 = parse_pt2_from_pt101(pts[:101], use_lip=use_lip)
else:
raise Exception(f'Unknow shape: {pts.shape}')
if not use_lip:
# NOTE: to compile with the latter code, need to rotate the pt2 90 degrees clockwise manually
v = pt2[1] - pt2[0]
pt2[1, 0] = pt2[0, 0] - v[1]
pt2[1, 1] = pt2[0, 1] + v[0]
return pt2
def parse_rect_from_landmark(
pts,
scale=1.5,
need_square=True,
vx_ratio=0,
vy_ratio=0,
use_deg_flag=False,
**kwargs
):
"""parsing center, size, angle from 101/68/5/x landmarks
vx_ratio: the offset ratio along the pupil axis x-axis, multiplied by size
vy_ratio: the offset ratio along the pupil axis y-axis, multiplied by size, which is used to contain more forehead area
judge with pts.shape
"""
pt2 = parse_pt2_from_pt_x(pts, use_lip=kwargs.get('use_lip', True))
uy = pt2[1] - pt2[0]
l = np.linalg.norm(uy)
if l <= 1e-3:
uy = np.array([0, 1], dtype=DTYPE)
else:
uy /= l
ux = np.array((uy[1], -uy[0]), dtype=DTYPE)
# the rotation degree of the x-axis, the clockwise is positive, the counterclockwise is negative (image coordinate system)
# print(uy)
# print(ux)
angle = acos(ux[0])
if ux[1] < 0:
angle = -angle
# rotation matrix
M = np.array([ux, uy])
# calculate the size which contains the angle degree of the bbox, and the center
center0 = np.mean(pts, axis=0)
rpts = (pts - center0) @ M.T # (M @ P.T).T = P @ M.T
lt_pt = np.min(rpts, axis=0)
rb_pt = np.max(rpts, axis=0)
center1 = (lt_pt + rb_pt) / 2
size = rb_pt - lt_pt
if need_square:
m = max(size[0], size[1])
size[0] = m
size[1] = m
size *= scale # scale size
center = center0 + ux * center1[0] + uy * center1[1] # counterclockwise rotation, equivalent to M.T @ center1.T
center = center + ux * (vx_ratio * size) + uy * \
(vy_ratio * size) # considering the offset in vx and vy direction
if use_deg_flag:
angle = degrees(angle)
return center, size, angle
def parse_bbox_from_landmark(pts, **kwargs):
center, size, angle = parse_rect_from_landmark(pts, **kwargs)
cx, cy = center
w, h = size
# calculate the vertex positions before rotation
bbox = np.array([
[cx-w/2, cy-h/2], # left, top
[cx+w/2, cy-h/2],
[cx+w/2, cy+h/2], # right, bottom
[cx-w/2, cy+h/2]
], dtype=DTYPE)
# construct rotation matrix
bbox_rot = bbox.copy()
R = np.array([
[np.cos(angle), -np.sin(angle)],
[np.sin(angle), np.cos(angle)]
], dtype=DTYPE)
# calculate the relative position of each vertex from the rotation center, then rotate these positions, and finally add the coordinates of the rotation center
bbox_rot = (bbox_rot - center) @ R.T + center
return {
'center': center, # 2x1
'size': size, # scalar
'angle': angle, # rad, counterclockwise
'bbox': bbox, # 4x2
'bbox_rot': bbox_rot, # 4x2
}
def crop_image_by_bbox(img, bbox, lmk=None, dsize=512, angle=None, flag_rot=False, **kwargs):
left, top, right, bot = bbox
if int(right - left) != int(bot - top):
print(f'right-left {right-left} != bot-top {bot-top}')
size = right - left
src_center = np.array([(left + right) / 2, (top + bot) / 2], dtype=DTYPE)
tgt_center = np.array([dsize / 2, dsize / 2], dtype=DTYPE)
s = dsize / size # scale
if flag_rot and angle is not None:
costheta, sintheta = cos(angle), sin(angle)
cx, cy = src_center[0], src_center[1] # ori center
tcx, tcy = tgt_center[0], tgt_center[1] # target center
# need to infer
M_o2c = np.array(
[[s * costheta, s * sintheta, tcx - s * (costheta * cx + sintheta * cy)],
[-s * sintheta, s * costheta, tcy - s * (-sintheta * cx + costheta * cy)]],
dtype=DTYPE
)
else:
M_o2c = np.array(
[[s, 0, tgt_center[0] - s * src_center[0]],
[0, s, tgt_center[1] - s * src_center[1]]],
dtype=DTYPE
)
# if flag_rot and angle is None:
# print('angle is None, but flag_rotate is True', style="bold yellow")
img_crop = _transform_img(img, M_o2c, dsize=dsize, borderMode=kwargs.get('borderMode', None))
lmk_crop = _transform_pts(lmk, M_o2c) if lmk is not None else None
M_o2c = np.vstack([M_o2c, np.array([0, 0, 1], dtype=DTYPE)])
M_c2o = np.linalg.inv(M_o2c)
# cv2.imwrite('crop.jpg', img_crop)
return {
'img_crop': img_crop,
'lmk_crop': lmk_crop,
'M_o2c': M_o2c,
'M_c2o': M_c2o,
}
def _estimate_similar_transform_from_pts(
pts,
dsize,
scale=1.5,
vx_ratio=0,
vy_ratio=-0.1,
flag_do_rot=True,
**kwargs
):
""" calculate the affine matrix of the cropped image from sparse points, the original image to the cropped image, the inverse is the cropped image to the original image
pts: landmark, 101 or 68 points or other points, Nx2
scale: the larger scale factor, the smaller face ratio
vx_ratio: x shift
vy_ratio: y shift, the smaller the y shift, the lower the face region
rot_flag: if it is true, conduct correction
"""
center, size, angle = parse_rect_from_landmark(
pts, scale=scale, vx_ratio=vx_ratio, vy_ratio=vy_ratio,
use_lip=kwargs.get('use_lip', True)
)
s = dsize / size[0] # scale
tgt_center = np.array([dsize / 2, dsize / 2], dtype=DTYPE) # center of dsize
if flag_do_rot:
costheta, sintheta = cos(angle), sin(angle)
cx, cy = center[0], center[1] # ori center
tcx, tcy = tgt_center[0], tgt_center[1] # target center
# need to infer
M_INV = np.array(
[[s * costheta, s * sintheta, tcx - s * (costheta * cx + sintheta * cy)],
[-s * sintheta, s * costheta, tcy - s * (-sintheta * cx + costheta * cy)]],
dtype=DTYPE
)
else:
M_INV = np.array(
[[s, 0, tgt_center[0] - s * center[0]],
[0, s, tgt_center[1] - s * center[1]]],
dtype=DTYPE
)
M_INV_H = np.vstack([M_INV, np.array([0, 0, 1])])
M = np.linalg.inv(M_INV_H)
# M_INV is from the original image to the cropped image, M is from the cropped image to the original image
return M_INV, M[:2, ...]
def crop_image(img, pts: np.ndarray, **kwargs):
dsize = kwargs.get('dsize', 224)
scale = kwargs.get('scale', 1.5) # 1.5 | 1.6
vy_ratio = kwargs.get('vy_ratio', -0.1) # -0.0625 | -0.1
M_INV, _ = _estimate_similar_transform_from_pts(
pts,
dsize=dsize,
scale=scale,
vy_ratio=vy_ratio,
flag_do_rot=kwargs.get('flag_do_rot', True),
)
img_crop = _transform_img(img, M_INV, dsize) # origin to crop
pt_crop = _transform_pts(pts, M_INV)
M_o2c = np.vstack([M_INV, np.array([0, 0, 1], dtype=DTYPE)])
M_c2o = np.linalg.inv(M_o2c)
ret_dct = {
'M_o2c': M_o2c, # from the original image to the cropped image 3x3
'M_c2o': M_c2o, # from the cropped image to the original image 3x3
'img_crop': img_crop, # the cropped image
'pt_crop': pt_crop, # the landmarks of the cropped image
}
return ret_dct
def average_bbox_lst(bbox_lst):
if len(bbox_lst) == 0:
return None
bbox_arr = np.array(bbox_lst)
return np.mean(bbox_arr, axis=0).tolist()
def prepare_paste_back(mask_crop, crop_M_c2o, dsize):
"""prepare mask for later image paste back
"""
mask_ori = _transform_img(mask_crop, crop_M_c2o, dsize)
mask_ori = mask_ori.astype(np.float32) / 255.
return mask_ori
def paste_back(img_crop, M_c2o, img_ori, mask_ori):
"""paste back the image
"""
dsize = (img_ori.shape[1], img_ori.shape[0])
result = _transform_img(img_crop, M_c2o, dsize=dsize)
result = np.clip(mask_ori * result + (1 - mask_ori) * img_ori, 0, 255).astype(np.uint8)
return result
# coding: utf-8
import os.path as osp
from dataclasses import dataclass, field
from typing import List, Tuple, Union
import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False)
import numpy as np
import torch
from ..config.crop_config import CropConfig
from .crop import (
average_bbox_lst,
crop_image,
crop_image_by_bbox,
parse_bbox_from_landmark,
)
from .io import contiguous
from .rprint import rlog as log
from .face_analysis_diy import FaceAnalysisDIY
from .landmark_runner import LandmarkRunner
def make_abs_path(fn):
return osp.join(osp.dirname(osp.realpath(__file__)), fn)
@dataclass
class Trajectory:
start: int = -1 # start frame
end: int = -1 # end frame
lmk_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list
bbox_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # bbox list
M_c2o_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # M_c2o list
frame_rgb_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame list
lmk_crop_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # lmk list
frame_rgb_crop_lst: Union[Tuple, List, np.ndarray] = field(default_factory=list) # frame crop list
class Cropper(object):
def __init__(self, **kwargs) -> None:
self.crop_cfg: CropConfig = kwargs.get("crop_cfg", None)
device_id = kwargs.get("device_id", 0)
flag_force_cpu = kwargs.get("flag_force_cpu", False)
if flag_force_cpu:
device = "cpu"
face_analysis_wrapper_provider = ["CPUExecutionProvider"]
else:
if torch.backends.mps.is_available():
# Shape inference currently fails with CoreMLExecutionProvider
# for the retinaface model
device = "mps"
face_analysis_wrapper_provider = ["CPUExecutionProvider"]
else:
device = "cuda"
face_analysis_wrapper_provider = ["CUDAExecutionProvider"]
self.landmark_runner = LandmarkRunner(
ckpt_path=make_abs_path(self.crop_cfg.landmark_ckpt_path),
onnx_provider=device,
device_id=device_id,
)
self.landmark_runner.warmup()
self.face_analysis_wrapper = FaceAnalysisDIY(
name="buffalo_l",
root=make_abs_path(self.crop_cfg.insightface_root),
providers=face_analysis_wrapper_provider,
)
self.face_analysis_wrapper.prepare(ctx_id=device_id, det_size=(512, 512))
self.face_analysis_wrapper.warmup()
def update_config(self, user_args):
for k, v in user_args.items():
if hasattr(self.crop_cfg, k):
setattr(self.crop_cfg, k, v)
def crop_source_image(self, img_rgb_: np.ndarray, crop_cfg: CropConfig):
# crop a source image and get neccessary information
img_rgb = img_rgb_.copy() # copy it
img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
src_face = self.face_analysis_wrapper.get(
img_bgr,
flag_do_landmark_2d_106=True,
direction=crop_cfg.direction,
max_face_num=crop_cfg.max_face_num,
)
if len(src_face) == 0:
log("No face detected in the source image.")
return None
elif len(src_face) > 1:
log(f"More than one face detected in the image, only pick one face by rule {crop_cfg.direction}.")
# NOTE: temporarily only pick the first face, to support multiple face in the future
src_face = src_face[0]
lmk = src_face.landmark_2d_106 # this is the 106 landmarks from insightface
# crop the face
ret_dct = crop_image(
img_rgb, # ndarray
lmk, # 106x2 or Nx2
dsize=crop_cfg.dsize,
scale=crop_cfg.scale,
vx_ratio=crop_cfg.vx_ratio,
vy_ratio=crop_cfg.vy_ratio,
flag_do_rot=crop_cfg.flag_do_rot,
)
lmk = self.landmark_runner.run(img_rgb, lmk)
ret_dct["lmk_crop"] = lmk
# update a 256x256 version for network input
ret_dct["img_crop_256x256"] = cv2.resize(ret_dct["img_crop"], (256, 256), interpolation=cv2.INTER_AREA)
ret_dct["lmk_crop_256x256"] = ret_dct["lmk_crop"] * 256 / crop_cfg.dsize
return ret_dct
def crop_source_video(self, source_rgb_lst, crop_cfg: CropConfig, **kwargs):
"""Tracking based landmarks/alignment and cropping"""
trajectory = Trajectory()
direction = kwargs.get("direction", "large-small")
for idx, frame_rgb in enumerate(source_rgb_lst):
if idx == 0 or trajectory.start == -1:
src_face = self.face_analysis_wrapper.get(
contiguous(frame_rgb[..., ::-1]),
flag_do_landmark_2d_106=True,
direction=crop_cfg.direction,
max_face_num=crop_cfg.max_face_num,
)
if len(src_face) == 0:
log(f"No face detected in the frame #{idx}")
continue
elif len(src_face) > 1:
log(f"More than one face detected in the source frame_{idx}, only pick one face by rule {direction}.")
src_face = src_face[0]
lmk = src_face.landmark_2d_106
lmk = self.landmark_runner.run(frame_rgb, lmk)
trajectory.start, trajectory.end = idx, idx
else:
lmk = self.landmark_runner.run(frame_rgb, trajectory.lmk_lst[-1])
trajectory.end = idx
trajectory.lmk_lst.append(lmk)
# crop the face
ret_dct = crop_image(
frame_rgb, # ndarray
lmk, # 106x2 or Nx2
dsize=crop_cfg.dsize,
scale=crop_cfg.scale,
vx_ratio=crop_cfg.vx_ratio,
vy_ratio=crop_cfg.vy_ratio,
flag_do_rot=crop_cfg.flag_do_rot,
)
lmk = self.landmark_runner.run(frame_rgb, lmk)
ret_dct["lmk_crop"] = lmk
# update a 256x256 version for network input
ret_dct["img_crop_256x256"] = cv2.resize(ret_dct["img_crop"], (256, 256), interpolation=cv2.INTER_AREA)
ret_dct["lmk_crop_256x256"] = ret_dct["lmk_crop"] * 256 / crop_cfg.dsize
trajectory.frame_rgb_crop_lst.append(ret_dct["img_crop_256x256"])
trajectory.lmk_crop_lst.append(ret_dct["lmk_crop_256x256"])
trajectory.M_c2o_lst.append(ret_dct['M_c2o'])
return {
"frame_crop_lst": trajectory.frame_rgb_crop_lst,
"lmk_crop_lst": trajectory.lmk_crop_lst,
"M_c2o_lst": trajectory.M_c2o_lst,
}
def crop_driving_video(self, driving_rgb_lst, **kwargs):
"""Tracking based landmarks/alignment and cropping"""
trajectory = Trajectory()
direction = kwargs.get("direction", "large-small")
for idx, frame_rgb in enumerate(driving_rgb_lst):
if idx == 0 or trajectory.start == -1:
src_face = self.face_analysis_wrapper.get(
contiguous(frame_rgb[..., ::-1]),
flag_do_landmark_2d_106=True,
direction=direction,
)
if len(src_face) == 0:
log(f"No face detected in the frame #{idx}")
continue
elif len(src_face) > 1:
log(f"More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}.")
src_face = src_face[0]
lmk = src_face.landmark_2d_106
lmk = self.landmark_runner.run(frame_rgb, lmk)
trajectory.start, trajectory.end = idx, idx
else:
lmk = self.landmark_runner.run(frame_rgb, trajectory.lmk_lst[-1])
trajectory.end = idx
trajectory.lmk_lst.append(lmk)
ret_bbox = parse_bbox_from_landmark(
lmk,
scale=self.crop_cfg.scale_crop_driving_video,
vx_ratio_crop_driving_video=self.crop_cfg.vx_ratio_crop_driving_video,
vy_ratio=self.crop_cfg.vy_ratio_crop_driving_video,
)["bbox"]
bbox = [
ret_bbox[0, 0],
ret_bbox[0, 1],
ret_bbox[2, 0],
ret_bbox[2, 1],
] # 4,
trajectory.bbox_lst.append(bbox) # bbox
trajectory.frame_rgb_lst.append(frame_rgb)
global_bbox = average_bbox_lst(trajectory.bbox_lst)
for idx, (frame_rgb, lmk) in enumerate(zip(trajectory.frame_rgb_lst, trajectory.lmk_lst)):
ret_dct = crop_image_by_bbox(
frame_rgb,
global_bbox,
lmk=lmk,
dsize=kwargs.get("dsize", 512),
flag_rot=False,
borderValue=(0, 0, 0),
)
trajectory.frame_rgb_crop_lst.append(ret_dct["img_crop"])
trajectory.lmk_crop_lst.append(ret_dct["lmk_crop"])
return {
"frame_crop_lst": trajectory.frame_rgb_crop_lst,
"lmk_crop_lst": trajectory.lmk_crop_lst,
}
def calc_lmks_from_cropped_video(self, driving_rgb_crop_lst, **kwargs):
"""Tracking based landmarks/alignment"""
trajectory = Trajectory()
direction = kwargs.get("direction", "large-small")
for idx, frame_rgb_crop in enumerate(driving_rgb_crop_lst):
if idx == 0 or trajectory.start == -1:
src_face = self.face_analysis_wrapper.get(
contiguous(frame_rgb_crop[..., ::-1]), # convert to BGR
flag_do_landmark_2d_106=True,
direction=direction,
)
if len(src_face) == 0:
log(f"No face detected in the frame #{idx}")
raise Exception(f"No face detected in the frame #{idx}")
elif len(src_face) > 1:
log(f"More than one face detected in the driving frame_{idx}, only pick one face by rule {direction}.")
src_face = src_face[0]
lmk = src_face.landmark_2d_106
lmk = self.landmark_runner.run(frame_rgb_crop, lmk)
trajectory.start, trajectory.end = idx, idx
else:
lmk = self.landmark_runner.run(frame_rgb_crop, trajectory.lmk_lst[-1])
trajectory.end = idx
trajectory.lmk_lst.append(lmk)
return trajectory.lmk_lst
\ No newline at end of file
# coding: utf-8
# pylint: disable=wrong-import-position
"""InsightFace: A Face Analysis Toolkit."""
from __future__ import absolute_import
try:
#import mxnet as mx
import onnxruntime
except ImportError:
raise ImportError(
"Unable to import dependency onnxruntime. "
)
__version__ = '0.7.3'
from . import model_zoo
from . import utils
from . import app
from . import data
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment