".github/vscode:/vscode.git/clone" did not exist on "8546dd3d72867be55c7439ecc58bc0790a0f556b"
Commit 57463d8d authored by suily's avatar suily
Browse files

init

parents
Pipeline #1918 canceled with stages
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "M74Gs_TjYl_B"
},
"source": [
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "view-in-github"
},
"source": [
"### SadTalker:Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation \n",
"\n",
"[arxiv](https://arxiv.org/abs/2211.12194) | [project](https://sadtalker.github.io) | [Github](https://github.com/Winfredy/SadTalker)\n",
"\n",
"Wenxuan Zhang, Xiaodong Cun, Xuan Wang, Yong Zhang, Xi Shen, Yu Guo, Ying Shan, Fei Wang.\n",
"\n",
"Xi'an Jiaotong University, Tencent AI Lab, Ant Group\n",
"\n",
"CVPR 2023\n",
"\n",
"TL;DR: A realistic and stylized talking head video generation method from a single image and audio\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "kA89DV-sKS4i"
},
"source": [
"Installation (around 5 mins)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qJ4CplXsYl_E"
},
"outputs": [],
"source": [
"### make sure that CUDA is available in Edit -> Nootbook settings -> GPU\n",
"!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Mdq6j4E5KQAR"
},
"outputs": [],
"source": [
"!update-alternatives --install /usr/local/bin/python3 python3 /usr/bin/python3.8 2\n",
"!update-alternatives --install /usr/local/bin/python3 python3 /usr/bin/python3.9 1\n",
"!sudo apt install python3.8\n",
"\n",
"!sudo apt-get install python3.8-distutils\n",
"\n",
"!python --version\n",
"\n",
"!apt-get update\n",
"\n",
"!apt install software-properties-common\n",
"\n",
"!sudo dpkg --remove --force-remove-reinstreq python3-pip python3-setuptools python3-wheel\n",
"\n",
"!apt-get install python3-pip\n",
"\n",
"print('Git clone project and install requirements...')\n",
"!git clone https://github.com/Winfredy/SadTalker &> /dev/null\n",
"%cd SadTalker\n",
"!export PYTHONPATH=/content/SadTalker:$PYTHONPATH\n",
"!python3.8 -m pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113\n",
"!apt update\n",
"!apt install ffmpeg &> /dev/null\n",
"!python3.8 -m pip install -r requirements.txt"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "DddcKB_nKsnk"
},
"source": [
"Download models (1 mins)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eDw3_UN8K2xa"
},
"outputs": [],
"source": [
"print('Download pre-trained models...')\n",
"!rm -rf checkpoints\n",
"!bash scripts/download_models.sh"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kK7DYeo7Yl_H"
},
"outputs": [],
"source": [
"# borrow from makeittalk\n",
"import ipywidgets as widgets\n",
"import glob\n",
"import matplotlib.pyplot as plt\n",
"print(\"Choose the image name to animate: (saved in folder 'examples/')\")\n",
"img_list = glob.glob1('examples/source_image', '*.png')\n",
"img_list.sort()\n",
"img_list = [item.split('.')[0] for item in img_list]\n",
"default_head_name = widgets.Dropdown(options=img_list, value='full3')\n",
"def on_change(change):\n",
" if change['type'] == 'change' and change['name'] == 'value':\n",
" plt.imshow(plt.imread('examples/source_image/{}.png'.format(default_head_name.value)))\n",
" plt.axis('off')\n",
" plt.show()\n",
"default_head_name.observe(on_change)\n",
"display(default_head_name)\n",
"plt.imshow(plt.imread('examples/source_image/{}.png'.format(default_head_name.value)))\n",
"plt.axis('off')\n",
"plt.show()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "-khNZcnGK4UK"
},
"source": [
"Animation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ToBlDusjK5sS"
},
"outputs": [],
"source": [
"# selected audio from exmaple/driven_audio\n",
"img = 'examples/source_image/{}.png'.format(default_head_name.value)\n",
"print(img)\n",
"!python3.8 inference.py --driven_audio ./examples/driven_audio/RD_Radio31_000.wav \\\n",
" --source_image {img} \\\n",
" --result_dir ./results --still --preprocess full --enhancer gfpgan"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fAjwGmKKYl_I"
},
"outputs": [],
"source": [
"# visualize code from makeittalk\n",
"from IPython.display import HTML\n",
"from base64 import b64encode\n",
"import os, sys\n",
"\n",
"# get the last from results\n",
"\n",
"results = sorted(os.listdir('./results/'))\n",
"\n",
"mp4_name = glob.glob('./results/*.mp4')[0]\n",
"\n",
"mp4 = open('{}'.format(mp4_name),'rb').read()\n",
"data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
"\n",
"print('Display animation: {}'.format(mp4_name), file=sys.stderr)\n",
"display(HTML(\"\"\"\n",
" <video width=256 controls>\n",
" <source src=\"%s\" type=\"video/mp4\">\n",
" </video>\n",
" \"\"\" % data_url))\n"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"provenance": []
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.9.7"
},
"vscode": {
"interpreter": {
"hash": "db5031b3636a3f037ea48eb287fd3d023feb9033aefc2a9652a92e470fb0851b"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
llvmlite==0.38.1
numpy==1.21.6
face_alignment==1.3.5
imageio==2.19.3
imageio-ffmpeg==0.4.7
librosa==0.10.0.post2
numba==0.55.1
resampy==0.3.1
pydub==0.25.1
scipy==1.10.1
kornia==0.6.8
tqdm
yacs==0.1.8
pyyaml
joblib==1.1.0
scikit-image==0.19.3
basicsr==1.4.2
facexlib==0.3.0
gradio
gfpgan
av
safetensors
numpy==1.23.4
face_alignment==1.3.5
imageio==2.19.3
imageio-ffmpeg==0.4.7
librosa==0.9.2 #
numba
resampy==0.3.1
pydub==0.25.1
scipy==1.10.1
kornia==0.6.8
tqdm
yacs==0.1.8
pyyaml
joblib==1.1.0
scikit-image==0.19.3
basicsr==1.4.2
facexlib==0.3.0
gradio
gfpgan
av
safetensors
numpy==1.23.4
face_alignment==1.3.5
imageio==2.19.3
imageio-ffmpeg==0.4.7
librosa==0.9.2 #
numba
resampy==0.3.1
pydub==0.25.1
scipy==1.5.3
kornia==0.6.8
tqdm
yacs==0.1.8
pyyaml
joblib==1.1.0
scikit-image==0.19.3
basicsr==1.4.2
facexlib==0.3.0
trimesh==3.9.20
gradio
gfpgan
safetensors
\ No newline at end of file
mkdir ./checkpoints
# lagency download link
# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2exp_00300-model.pth -O ./checkpoints/auido2exp_00300-model.pth
# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2pose_00140-model.pth -O ./checkpoints/auido2pose_00140-model.pth
# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/epoch_20.pth -O ./checkpoints/epoch_20.pth
# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/facevid2vid_00189-model.pth.tar -O ./checkpoints/facevid2vid_00189-model.pth.tar
# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/shape_predictor_68_face_landmarks.dat -O ./checkpoints/shape_predictor_68_face_landmarks.dat
# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/wav2lip.pth -O ./checkpoints/wav2lip.pth
# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00229-model.pth.tar -O ./checkpoints/mapping_00229-model.pth.tar
# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00109-model.pth.tar -O ./checkpoints/mapping_00109-model.pth.tar
# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/hub.zip -O ./checkpoints/hub.zip
# unzip -n ./checkpoints/hub.zip -d ./checkpoints/
#### download the new links.
wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/mapping_00109-model.pth.tar -O ./checkpoints/mapping_00109-model.pth.tar
wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/mapping_00229-model.pth.tar -O ./checkpoints/mapping_00229-model.pth.tar
wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/SadTalker_V0.0.2_256.safetensors -O ./checkpoints/SadTalker_V0.0.2_256.safetensors
wget -nc https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/SadTalker_V0.0.2_512.safetensors -O ./checkpoints/SadTalker_V0.0.2_512.safetensors
# wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/BFM_Fitting.zip -O ./checkpoints/BFM_Fitting.zip
# unzip -n ./checkpoints/BFM_Fitting.zip -d ./checkpoints/
### enhancer
mkdir -p ./gfpgan/weights
wget -nc https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth -O ./gfpgan/weights/alignment_WFLW_4HG.pth
wget -nc https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth -O ./gfpgan/weights/detection_Resnet50_Final.pth
wget -nc https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -O ./gfpgan/weights/GFPGANv1.4.pth
wget -nc https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth -O ./gfpgan/weights/parsing_parsenet.pth
import os, sys
from pathlib import Path
import tempfile
import gradio as gr
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call
from modules.shared import opts, OptionInfo
from modules import shared, paths, script_callbacks
import launch
import glob
from huggingface_hub import snapshot_download
def check_all_files_safetensor(current_dir):
kv = {
"SadTalker_V0.0.2_256.safetensors": "sadtalker-256",
"SadTalker_V0.0.2_512.safetensors": "sadtalker-512",
"mapping_00109-model.pth.tar" : "mapping-109" ,
"mapping_00229-model.pth.tar" : "mapping-229" ,
}
if not os.path.isdir(current_dir):
return False
dirs = os.listdir(current_dir)
for f in dirs:
if f in kv.keys():
del kv[f]
return len(kv.keys()) == 0
def check_all_files(current_dir):
kv = {
"auido2exp_00300-model.pth": "audio2exp",
"auido2pose_00140-model.pth": "audio2pose",
"epoch_20.pth": "face_recon",
"facevid2vid_00189-model.pth.tar": "face-render",
"mapping_00109-model.pth.tar" : "mapping-109" ,
"mapping_00229-model.pth.tar" : "mapping-229" ,
"wav2lip.pth": "wav2lip",
"shape_predictor_68_face_landmarks.dat": "dlib",
}
if not os.path.isdir(current_dir):
return False
dirs = os.listdir(current_dir)
for f in dirs:
if f in kv.keys():
del kv[f]
return len(kv.keys()) == 0
def download_model(local_dir='./checkpoints'):
REPO_ID = 'vinthony/SadTalker'
snapshot_download(repo_id=REPO_ID, local_dir=local_dir, local_dir_use_symlinks=False)
def get_source_image(image):
return image
def get_img_from_txt2img(x):
talker_path = Path(paths.script_path) / "outputs"
imgs_from_txt_dir = str(talker_path / "txt2img-images/")
imgs = glob.glob(imgs_from_txt_dir+'/*/*.png')
imgs.sort(key=lambda x:os.path.getmtime(os.path.join(imgs_from_txt_dir, x)))
img_from_txt_path = os.path.join(imgs_from_txt_dir, imgs[-1])
return img_from_txt_path, img_from_txt_path
def get_img_from_img2img(x):
talker_path = Path(paths.script_path) / "outputs"
imgs_from_img_dir = str(talker_path / "img2img-images/")
imgs = glob.glob(imgs_from_img_dir+'/*/*.png')
imgs.sort(key=lambda x:os.path.getmtime(os.path.join(imgs_from_img_dir, x)))
img_from_img_path = os.path.join(imgs_from_img_dir, imgs[-1])
return img_from_img_path, img_from_img_path
def get_default_checkpoint_path():
# check the path of models/checkpoints and extensions/
checkpoint_path = Path(paths.script_path) / "models"/ "SadTalker"
extension_checkpoint_path = Path(paths.script_path) / "extensions"/ "SadTalker" / "checkpoints"
if check_all_files_safetensor(checkpoint_path):
# print('founding sadtalker checkpoint in ' + str(checkpoint_path))
return checkpoint_path
if check_all_files_safetensor(extension_checkpoint_path):
# print('founding sadtalker checkpoint in ' + str(extension_checkpoint_path))
return extension_checkpoint_path
if check_all_files(checkpoint_path):
# print('founding sadtalker checkpoint in ' + str(checkpoint_path))
return checkpoint_path
if check_all_files(extension_checkpoint_path):
# print('founding sadtalker checkpoint in ' + str(extension_checkpoint_path))
return extension_checkpoint_path
return None
def install():
kv = {
"face_alignment": "face-alignment==1.3.5",
"imageio": "imageio==2.19.3",
"imageio_ffmpeg": "imageio-ffmpeg==0.4.7",
"librosa":"librosa==0.8.0",
"pydub":"pydub==0.25.1",
"scipy":"scipy==1.8.1",
"tqdm": "tqdm",
"yacs":"yacs==0.1.8",
"yaml": "pyyaml",
"av":"av",
"gfpgan": "gfpgan",
}
# # dlib is not necessary currently
# if 'darwin' in sys.platform:
# kv['dlib'] = "dlib"
# else:
# kv['dlib'] = 'dlib-bin'
# #### we need to have a newer version of imageio for our method.
# launch.run_pip("install imageio==2.19.3", "requirements for SadTalker")
for k,v in kv.items():
if not launch.is_installed(k):
print(k, launch.is_installed(k))
launch.run_pip("install "+ v, "requirements for SadTalker")
if os.getenv('SADTALKER_CHECKPOINTS'):
print('load Sadtalker Checkpoints from '+ os.getenv('SADTALKER_CHECKPOINTS'))
elif get_default_checkpoint_path() is not None:
os.environ['SADTALKER_CHECKPOINTS'] = str(get_default_checkpoint_path())
else:
print(
""""
SadTalker will not support download all the files from hugging face, which will take a long time.
please manually set the SADTALKER_CHECKPOINTS in `webui_user.bat`(windows) or `webui_user.sh`(linux)
"""
)
# python = sys.executable
# launch.run(f'"{python}" -m pip uninstall -y huggingface_hub', live=True)
# launch.run(f'"{python}" -m pip install --upgrade git+https://github.com/huggingface/huggingface_hub@main', live=True)
# ### run the scripts to downlod models to correct localtion.
# # print('download models for SadTalker')
# # launch.run("cd " + paths.script_path+"/extensions/SadTalker && bash ./scripts/download_models.sh", live=True)
# # print('SadTalker is successfully installed!')
# download_model(paths.script_path+'/extensions/SadTalker/checkpoints')
def on_ui_tabs():
install()
sys.path.extend([paths.script_path+'/extensions/SadTalker'])
repo_dir = paths.script_path+'/extensions/SadTalker/'
result_dir = opts.sadtalker_result_dir
os.makedirs(result_dir, exist_ok=True)
from app_sadtalker import sadtalker_demo
if os.getenv('SADTALKER_CHECKPOINTS'):
checkpoint_path = os.getenv('SADTALKER_CHECKPOINTS')
else:
checkpoint_path = repo_dir+'checkpoints/'
audio_to_video = sadtalker_demo(checkpoint_path=checkpoint_path, config_path=repo_dir+'src/config', warpfn = wrap_queued_call)
return [(audio_to_video, "SadTalker", "extension")]
def on_ui_settings():
talker_path = Path(paths.script_path) / "outputs"
section = ('extension', "SadTalker")
opts.add_option("sadtalker_result_dir", OptionInfo(str(talker_path / "SadTalker/"), "Path to save results of sadtalker", section=section))
script_callbacks.on_ui_settings(on_ui_settings)
script_callbacks.on_ui_tabs(on_ui_tabs)
# ### some test command before commit.
# python inference.py --preprocess crop --size 256
# python inference.py --preprocess crop --size 512
# python inference.py --preprocess extcrop --size 256
# python inference.py --preprocess extcrop --size 512
# python inference.py --preprocess resize --size 256
# python inference.py --preprocess resize --size 512
# python inference.py --preprocess full --size 256
# python inference.py --preprocess full --size 512
# python inference.py --preprocess extfull --size 256
# python inference.py --preprocess extfull --size 512
python inference.py --preprocess full --size 256 --enhancer gfpgan
python inference.py --preprocess full --size 512 --enhancer gfpgan
python inference.py --preprocess full --size 256 --enhancer gfpgan --still
python inference.py --preprocess full --size 512 --enhancer gfpgan --still
from tqdm import tqdm
import torch
from torch import nn
class Audio2Exp(nn.Module):
def __init__(self, netG, cfg, device, prepare_training_loss=False):
super(Audio2Exp, self).__init__()
self.cfg = cfg
self.device = device
self.netG = netG.to(device)
def test(self, batch):
mel_input = batch['indiv_mels'] # bs T 1 80 16
bs = mel_input.shape[0]
T = mel_input.shape[1]
exp_coeff_pred = []
for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames
current_mel_input = mel_input[:,i:i+10]
#ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1)) #bs T 64
ref = batch['ref'][:, :, :64][:, i:i+10]
ratio = batch['ratio_gt'][:, i:i+10] #bs T
audiox = current_mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16
curr_exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64
exp_coeff_pred += [curr_exp_coeff_pred]
# BS x T x 64
results_dict = {
'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1)
}
return results_dict
import torch
import torch.nn.functional as F
from torch import nn
class Conv2d(nn.Module):
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(
nn.Conv2d(cin, cout, kernel_size, stride, padding),
nn.BatchNorm2d(cout)
)
self.act = nn.ReLU()
self.residual = residual
self.use_act = use_act
def forward(self, x):
out = self.conv_block(x)
if self.residual:
out += x
if self.use_act:
return self.act(out)
else:
return out
class SimpleWrapperV2(nn.Module):
def __init__(self) -> None:
super().__init__()
self.audio_encoder = nn.Sequential(
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
)
#### load the pre-trained audio_encoder
#self.audio_encoder = self.audio_encoder.to(device)
'''
wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict']
state_dict = self.audio_encoder.state_dict()
for k,v in wav2lip_state_dict.items():
if 'audio_encoder' in k:
print('init:', k)
state_dict[k.replace('module.audio_encoder.', '')] = v
self.audio_encoder.load_state_dict(state_dict)
'''
self.mapping1 = nn.Linear(512+64+1, 64)
#self.mapping2 = nn.Linear(30, 64)
#nn.init.constant_(self.mapping1.weight, 0.)
nn.init.constant_(self.mapping1.bias, 0.)
def forward(self, x, ref, ratio):
x = self.audio_encoder(x).view(x.size(0), -1)
ref_reshape = ref.reshape(x.size(0), -1)
ratio = ratio.reshape(x.size(0), -1)
y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1))
out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial
return out
import torch
from torch import nn
from src.audio2pose_models.cvae import CVAE
from src.audio2pose_models.discriminator import PoseSequenceDiscriminator
from src.audio2pose_models.audio_encoder import AudioEncoder
class Audio2Pose(nn.Module):
def __init__(self, cfg, wav2lip_checkpoint, device='cuda'):
super().__init__()
self.cfg = cfg
self.seq_len = cfg.MODEL.CVAE.SEQ_LEN
self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE
self.device = device
self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device)
self.audio_encoder.eval()
for param in self.audio_encoder.parameters():
param.requires_grad = False
self.netG = CVAE(cfg)
self.netD_motion = PoseSequenceDiscriminator(cfg)
def forward(self, x):
batch = {}
coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73
batch['pose_motion_gt'] = coeff_gt[:, 1:, 64:70] - coeff_gt[:, :1, 64:70] #bs frame_len 6
batch['ref'] = coeff_gt[:, 0, 64:70] #bs 6
batch['class'] = x['class'].squeeze(0).cuda() # bs
indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16
# forward
audio_emb_list = []
audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512
batch['audio_emb'] = audio_emb
batch = self.netG(batch)
pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6
pose_gt = coeff_gt[:, 1:, 64:70].clone() # bs frame_len 6
pose_pred = coeff_gt[:, :1, 64:70] + pose_motion_pred # bs frame_len 6
batch['pose_pred'] = pose_pred
batch['pose_gt'] = pose_gt
return batch
def test(self, x):
batch = {}
ref = x['ref'] #bs 1 70
batch['ref'] = x['ref'][:,0,-6:]
batch['class'] = x['class']
bs = ref.shape[0]
indiv_mels= x['indiv_mels'] # bs T 1 80 16
indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame
num_frames = x['num_frames']
num_frames = int(num_frames) - 1
#
div = num_frames//self.seq_len
re = num_frames%self.seq_len
audio_emb_list = []
pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype,
device=batch['ref'].device)]
for i in range(div):
z = torch.randn(bs, self.latent_dim).to(ref.device)
batch['z'] = z
audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512
batch['audio_emb'] = audio_emb
batch = self.netG.test(batch)
pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6
if re != 0:
z = torch.randn(bs, self.latent_dim).to(ref.device)
batch['z'] = z
audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512
if audio_emb.shape[1] != self.seq_len:
pad_dim = self.seq_len-audio_emb.shape[1]
pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1)
audio_emb = torch.cat([pad_audio_emb, audio_emb], 1)
batch['audio_emb'] = audio_emb
batch = self.netG.test(batch)
pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:])
pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1)
batch['pose_motion_pred'] = pose_motion_pred
pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6
batch['pose_pred'] = pose_pred
return batch
import torch
from torch import nn
from torch.nn import functional as F
class Conv2d(nn.Module):
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(
nn.Conv2d(cin, cout, kernel_size, stride, padding),
nn.BatchNorm2d(cout)
)
self.act = nn.ReLU()
self.residual = residual
def forward(self, x):
out = self.conv_block(x)
if self.residual:
out += x
return self.act(out)
class AudioEncoder(nn.Module):
def __init__(self, wav2lip_checkpoint, device):
super(AudioEncoder, self).__init__()
self.audio_encoder = nn.Sequential(
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
#### load the pre-trained audio_encoder, we do not need to load wav2lip model here.
# wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
# state_dict = self.audio_encoder.state_dict()
# for k,v in wav2lip_state_dict.items():
# if 'audio_encoder' in k:
# state_dict[k.replace('module.audio_encoder.', '')] = v
# self.audio_encoder.load_state_dict(state_dict)
def forward(self, audio_sequences):
# audio_sequences = (B, T, 1, 80, 16)
B = audio_sequences.size(0)
audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
dim = audio_embedding.shape[1]
audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1))
return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512
import torch
import torch.nn.functional as F
from torch import nn
from src.audio2pose_models.res_unet import ResUnet
def class2onehot(idx, class_num):
assert torch.max(idx).item() < class_num
onehot = torch.zeros(idx.size(0), class_num).to(idx.device)
onehot.scatter_(1, idx, 1)
return onehot
class CVAE(nn.Module):
def __init__(self, cfg):
super().__init__()
encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES
decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES
latent_size = cfg.MODEL.CVAE.LATENT_SIZE
num_classes = cfg.DATASET.NUM_CLASSES
audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE
audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE
seq_len = cfg.MODEL.CVAE.SEQ_LEN
self.latent_size = latent_size
self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes,
audio_emb_in_size, audio_emb_out_size, seq_len)
self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes,
audio_emb_in_size, audio_emb_out_size, seq_len)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, batch):
batch = self.encoder(batch)
mu = batch['mu']
logvar = batch['logvar']
z = self.reparameterize(mu, logvar)
batch['z'] = z
return self.decoder(batch)
def test(self, batch):
'''
class_id = batch['class']
z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device)
batch['z'] = z
'''
return self.decoder(batch)
class ENCODER(nn.Module):
def __init__(self, layer_sizes, latent_size, num_classes,
audio_emb_in_size, audio_emb_out_size, seq_len):
super().__init__()
self.resunet = ResUnet()
self.num_classes = num_classes
self.seq_len = seq_len
self.MLP = nn.Sequential()
layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6
for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
self.MLP.add_module(
name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
self.linear_means = nn.Linear(layer_sizes[-1], latent_size)
self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size)
self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
def forward(self, batch):
class_id = batch['class']
pose_motion_gt = batch['pose_motion_gt'] #bs seq_len 6
ref = batch['ref'] #bs 6
bs = pose_motion_gt.shape[0]
audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
#pose encode
pose_emb = self.resunet(pose_motion_gt.unsqueeze(1)) #bs 1 seq_len 6
pose_emb = pose_emb.reshape(bs, -1) #bs seq_len*6
#audio mapping
print(audio_in.shape)
audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
audio_out = audio_out.reshape(bs, -1)
class_bias = self.classbias[class_id] #bs latent_size
x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size
x_out = self.MLP(x_in)
mu = self.linear_means(x_out)
logvar = self.linear_means(x_out) #bs latent_size
batch.update({'mu':mu, 'logvar':logvar})
return batch
class DECODER(nn.Module):
def __init__(self, layer_sizes, latent_size, num_classes,
audio_emb_in_size, audio_emb_out_size, seq_len):
super().__init__()
self.resunet = ResUnet()
self.num_classes = num_classes
self.seq_len = seq_len
self.MLP = nn.Sequential()
input_size = latent_size + seq_len*audio_emb_out_size + 6
for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):
self.MLP.add_module(
name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
if i+1 < len(layer_sizes):
self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
else:
self.MLP.add_module(name="sigmoid", module=nn.Sigmoid())
self.pose_linear = nn.Linear(6, 6)
self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
def forward(self, batch):
z = batch['z'] #bs latent_size
bs = z.shape[0]
class_id = batch['class']
ref = batch['ref'] #bs 6
audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
#print('audio_in: ', audio_in[:, :, :10])
audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
#print('audio_out: ', audio_out[:, :, :10])
audio_out = audio_out.reshape([bs, -1]) # bs seq_len*audio_emb_out_size
class_bias = self.classbias[class_id] #bs latent_size
z = z + class_bias
x_in = torch.cat([ref, z, audio_out], dim=-1)
x_out = self.MLP(x_in) # bs layer_sizes[-1]
x_out = x_out.reshape((bs, self.seq_len, -1))
#print('x_out: ', x_out)
pose_emb = self.resunet(x_out.unsqueeze(1)) #bs 1 seq_len 6
pose_motion_pred = self.pose_linear(pose_emb.squeeze(1)) #bs seq_len 6
batch.update({'pose_motion_pred':pose_motion_pred})
return batch
import torch
import torch.nn.functional as F
from torch import nn
class ConvNormRelu(nn.Module):
def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False,
kernel_size=None, stride=None, padding=None, norm='BN', leaky=False):
super().__init__()
if kernel_size is None:
if downsample:
kernel_size, stride, padding = 4, 2, 1
else:
kernel_size, stride, padding = 3, 1, 1
if conv_type == '2d':
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
bias=False,
)
if norm == 'BN':
self.norm = nn.BatchNorm2d(out_channels)
elif norm == 'IN':
self.norm = nn.InstanceNorm2d(out_channels)
else:
raise NotImplementedError
elif conv_type == '1d':
self.conv = nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
bias=False,
)
if norm == 'BN':
self.norm = nn.BatchNorm1d(out_channels)
elif norm == 'IN':
self.norm = nn.InstanceNorm1d(out_channels)
else:
raise NotImplementedError
nn.init.kaiming_normal_(self.conv.weight)
self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
if isinstance(self.norm, nn.InstanceNorm1d):
x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1)) # normalize on [C]
else:
x = self.norm(x)
x = self.act(x)
return x
class PoseSequenceDiscriminator(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU
self.seq = nn.Sequential(
ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky), # B, 256, 64
ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky), # B, 512, 32
ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky), # B, 1024, 16
nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True) # B, 1, 16
)
def forward(self, x):
x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2)
x = self.seq(x)
x = x.squeeze(1)
return x
\ No newline at end of file
import torch.nn as nn
import torch
class ResidualConv(nn.Module):
def __init__(self, input_dim, output_dim, stride, padding):
super(ResidualConv, self).__init__()
self.conv_block = nn.Sequential(
nn.BatchNorm2d(input_dim),
nn.ReLU(),
nn.Conv2d(
input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
),
nn.BatchNorm2d(output_dim),
nn.ReLU(),
nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
)
self.conv_skip = nn.Sequential(
nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
nn.BatchNorm2d(output_dim),
)
def forward(self, x):
return self.conv_block(x) + self.conv_skip(x)
class Upsample(nn.Module):
def __init__(self, input_dim, output_dim, kernel, stride):
super(Upsample, self).__init__()
self.upsample = nn.ConvTranspose2d(
input_dim, output_dim, kernel_size=kernel, stride=stride
)
def forward(self, x):
return self.upsample(x)
class Squeeze_Excite_Block(nn.Module):
def __init__(self, channel, reduction=16):
super(Squeeze_Excite_Block, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid(),
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
class ASPP(nn.Module):
def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
super(ASPP, self).__init__()
self.aspp_block1 = nn.Sequential(
nn.Conv2d(
in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
),
nn.ReLU(inplace=True),
nn.BatchNorm2d(out_dims),
)
self.aspp_block2 = nn.Sequential(
nn.Conv2d(
in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
),
nn.ReLU(inplace=True),
nn.BatchNorm2d(out_dims),
)
self.aspp_block3 = nn.Sequential(
nn.Conv2d(
in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
),
nn.ReLU(inplace=True),
nn.BatchNorm2d(out_dims),
)
self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
self._init_weights()
def forward(self, x):
x1 = self.aspp_block1(x)
x2 = self.aspp_block2(x)
x3 = self.aspp_block3(x)
out = torch.cat([x1, x2, x3], dim=1)
return self.output(out)
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
class Upsample_(nn.Module):
def __init__(self, scale=2):
super(Upsample_, self).__init__()
self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale)
def forward(self, x):
return self.upsample(x)
class AttentionBlock(nn.Module):
def __init__(self, input_encoder, input_decoder, output_dim):
super(AttentionBlock, self).__init__()
self.conv_encoder = nn.Sequential(
nn.BatchNorm2d(input_encoder),
nn.ReLU(),
nn.Conv2d(input_encoder, output_dim, 3, padding=1),
nn.MaxPool2d(2, 2),
)
self.conv_decoder = nn.Sequential(
nn.BatchNorm2d(input_decoder),
nn.ReLU(),
nn.Conv2d(input_decoder, output_dim, 3, padding=1),
)
self.conv_attn = nn.Sequential(
nn.BatchNorm2d(output_dim),
nn.ReLU(),
nn.Conv2d(output_dim, 1, 1),
)
def forward(self, x1, x2):
out = self.conv_encoder(x1) + self.conv_decoder(x2)
out = self.conv_attn(out)
return out * x2
\ No newline at end of file
import torch
import torch.nn as nn
from src.audio2pose_models.networks import ResidualConv, Upsample
class ResUnet(nn.Module):
def __init__(self, channel=1, filters=[32, 64, 128, 256]):
super(ResUnet, self).__init__()
self.input_layer = nn.Sequential(
nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
nn.BatchNorm2d(filters[0]),
nn.ReLU(),
nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
)
self.input_skip = nn.Sequential(
nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
)
self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1)
self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1)
self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1)
self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1))
self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1)
self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1))
self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1)
self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1))
self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1)
self.output_layer = nn.Sequential(
nn.Conv2d(filters[0], 1, 1, 1),
nn.Sigmoid(),
)
def forward(self, x):
# Encode
x1 = self.input_layer(x) + self.input_skip(x)
x2 = self.residual_conv_1(x1)
x3 = self.residual_conv_2(x2)
# Bridge
x4 = self.bridge(x3)
# Decode
x4 = self.upsample_1(x4)
x5 = torch.cat([x4, x3], dim=1)
x6 = self.up_residual_conv1(x5)
x6 = self.upsample_2(x6)
x7 = torch.cat([x6, x2], dim=1)
x8 = self.up_residual_conv2(x7)
x8 = self.upsample_3(x8)
x9 = torch.cat([x8, x1], dim=1)
x10 = self.up_residual_conv3(x9)
output = self.output_layer(x10)
return output
\ No newline at end of file
DATASET:
TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt
EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt
TRAIN_BATCH_SIZE: 32
EVAL_BATCH_SIZE: 32
EXP: True
EXP_DIM: 64
FRAME_LEN: 32
COEFF_LEN: 73
NUM_CLASSES: 46
AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm
LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
DEBUG: True
NUM_REPEATS: 2
T: 40
MODEL:
FRAMEWORK: V2
AUDIOENCODER:
LEAKY_RELU: True
NORM: 'IN'
DISCRIMINATOR:
LEAKY_RELU: False
INPUT_CHANNELS: 6
CVAE:
AUDIO_EMB_IN_SIZE: 512
AUDIO_EMB_OUT_SIZE: 128
SEQ_LEN: 32
LATENT_SIZE: 256
ENCODER_LAYER_SIZES: [192, 1024]
DECODER_LAYER_SIZES: [1024, 192]
TRAIN:
MAX_EPOCH: 300
GENERATOR:
LR: 2.0e-5
DISCRIMINATOR:
LR: 1.0e-5
LOSS:
W_FEAT: 0
W_COEFF_EXP: 2
W_LM: 1.0e-2
W_LM_MOUTH: 0
W_REG: 0
W_SYNC: 0
W_COLOR: 0
W_EXPRESSION: 0
W_LIPREADING: 0.01
W_LIPREADING_VV: 0
W_EYE_BLINK: 4
TAG:
NAME: small_dataset
DATASET:
TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt
EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt
TRAIN_BATCH_SIZE: 64
EVAL_BATCH_SIZE: 1
EXP: True
EXP_DIM: 64
FRAME_LEN: 32
COEFF_LEN: 73
NUM_CLASSES: 46
AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
DEBUG: True
MODEL:
AUDIOENCODER:
LEAKY_RELU: True
NORM: 'IN'
DISCRIMINATOR:
LEAKY_RELU: False
INPUT_CHANNELS: 6
CVAE:
AUDIO_EMB_IN_SIZE: 512
AUDIO_EMB_OUT_SIZE: 6
SEQ_LEN: 32
LATENT_SIZE: 64
ENCODER_LAYER_SIZES: [192, 128]
DECODER_LAYER_SIZES: [128, 192]
TRAIN:
MAX_EPOCH: 150
GENERATOR:
LR: 1.0e-4
DISCRIMINATOR:
LR: 1.0e-4
LOSS:
LAMBDA_REG: 1
LAMBDA_LANDMARKS: 0
LAMBDA_VERTICES: 0
LAMBDA_GAN_MOTION: 0.7
LAMBDA_GAN_COEFF: 0
LAMBDA_KL: 1
TAG:
NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder
model_params:
common_params:
num_kp: 15
image_channel: 3
feature_channel: 32
estimate_jacobian: False # True
kp_detector_params:
temperature: 0.1
block_expansion: 32
max_features: 1024
scale_factor: 0.25 # 0.25
num_blocks: 5
reshape_channel: 16384 # 16384 = 1024 * 16
reshape_depth: 16
he_estimator_params:
block_expansion: 64
max_features: 2048
num_bins: 66
generator_params:
block_expansion: 64
max_features: 512
num_down_blocks: 2
reshape_channel: 32
reshape_depth: 16 # 512 = 32 * 16
num_resblocks: 6
estimate_occlusion_map: True
dense_motion_params:
block_expansion: 32
max_features: 1024
num_blocks: 5
reshape_depth: 16
compress: 4
discriminator_params:
scales: [1]
block_expansion: 32
max_features: 512
num_blocks: 4
sn: True
mapping_params:
coeff_nc: 70
descriptor_nc: 1024
layer: 3
num_kp: 15
num_bins: 66
model_params:
common_params:
num_kp: 15
image_channel: 3
feature_channel: 32
estimate_jacobian: False # True
kp_detector_params:
temperature: 0.1
block_expansion: 32
max_features: 1024
scale_factor: 0.25 # 0.25
num_blocks: 5
reshape_channel: 16384 # 16384 = 1024 * 16
reshape_depth: 16
he_estimator_params:
block_expansion: 64
max_features: 2048
num_bins: 66
generator_params:
block_expansion: 64
max_features: 512
num_down_blocks: 2
reshape_channel: 32
reshape_depth: 16 # 512 = 32 * 16
num_resblocks: 6
estimate_occlusion_map: True
dense_motion_params:
block_expansion: 32
max_features: 1024
num_blocks: 5
reshape_depth: 16
compress: 4
discriminator_params:
scales: [1]
block_expansion: 32
max_features: 512
num_blocks: 4
sn: True
mapping_params:
coeff_nc: 73
descriptor_nc: 1024
layer: 3
num_kp: 15
num_bins: 66
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment