Commit 5efcc6ff authored by mashun1's avatar mashun1
Browse files

metaportrait

parents
Pipeline #584 canceled with stages
*/__pycache__/
*.pyc
.vscode/
data/
*.mp4
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.13.1-centos7.6-dtk-23.04-py37-latest
COPY . /tmp/
WORKDIR /tmp
RUN pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple
MIT License
Copyright (c) 2023 Meta-Portrait
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# MetaPortrait
## 论文
MetaPortrait: Identity-Preserving Talking Head Generation with Fast Personalized Adaptation
https://browse.arxiv.org/pdf/2212.08062.pdf
## 模型结构
整体流程
![Alt text](imgs/image.png)
(a)$I_s$表示输入的原始图像,$I_d$表示被模仿的图像(视频中的某一帧),$I_s^{ldmk}$和$I_d^{ldmk}$分别表示两者的dense landmark;(b)$x_{in} = Concat(I_s, I_s^{ldmk}, I_d^{ldmk})$也就是在阶段(a)中的输入$I_s$及两个输出,$E_w$表示CNN Encoder;(c)$E_{id}$为已经预训练的人脸识别模型,FILM表示Feature-wise Linear Modulate,AdaIN表示一种风格迁移方法。
warping network
![Alt text](imgs/image-1.png)
$F_r$
![Alt text](imgs/image-2.png)
$F_{3d}$
![Alt text](imgs/image-3.png)
## 算法原理
用途:该算法可以用来生成单镜头说话的头部视频
原理:
1. dense landmarks获取几何感知的变形场估计,自适应融合源身份以更好地保持肖像关键特征
2. meta learning加快模型的微调(学习)速度
![Alt text](imgs/image-4.png)
3. 时域一致的超分辨率网络提高图像分辨率
## 环境配置
### Docker(方法一)
docker build --no-cache -t MetaPortrait:latest .
docker run --rm --shm-size 10g --network=host --name=metaportrait --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -it <Image ID> bash
cd sr_model/Basicsr
pip uninstall basicsr
python setup.py develop
pip install urllib3==1.26.15
# 若遇到Dockerfile启动的方式安装环境需要长时间等待,可注释掉里面的pip安装,启动容器后再安装python库:pip install -r requirements.txt
### Docker(方法二)
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.13.1-centos7.6-dtk-23.04-py37-latest
docker run --rm --shm-size 10g --network=host --name=metaportrait --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址<绝对路径>:/home/ -it <Image ID> bash
pip install -r requirements.txt
cd sr_model/Basicsr
pip uninstall basicsr
python setup.py develop
pip install urllib3==1.26.15
### Anaconda (方法二)
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装:
https://developer.hpccube.com/tool/
DTK驱动:dtk23.04
python:python3.7
torch:1.13.1
torchvision:0.14.1
torchaudio:0.13.1
deepspeed:0.9.2
apex:0.1
2、创建虚拟环境并加载
conda create -f meta_portrait_base python=3.7
conda activate meta_portrait_base
pip install -r requirements.txt
cd sr_model/Basicsr
pip uninstall basicsr
python setup.py develop
## 数据集
下载地址:
https://drive.google.com/file/d/166eNbabM6TeJVy7hxol2gL1kUGKHi3Do/view?usp=share_link
```
base_model
data
├── 0
│ ├── imgs
│ │ ├── 00000000.png
│ │ ├── ...
│ ├── ldmks
│ │ ├── 00000000_ldmk.npy
│ │ ├── ...
│ └── thetas
│ ├── 00000000_theta.npy
│ ├── ...
├── src_0_id.npy # identity_embedding可使用人脸识别模型获取
├── src_0_ldmk.npy # landmarks
├── src_0.png
├── src_0_theta.npy # 将人脸对齐到图像中心的变换矩阵
└── src_map_dict.pkl
```
下载地址:
(模型)https://github.com/Meta-Portrait/MetaPortrait/releases/download/v0.0.1/temporal_gfpgan.pth
(模型)https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth
(数据集)https://hkustconnect-my.sharepoint.com/personal/cqiaa_connect_ust_hk/_layouts/15/onedrive.aspx?id=%2Fpersonal%2Fcqiaa%5Fconnect%5Fust%5Fhk%2FDocuments%2Ftalking%20head%2Frelease%2Fdata%2FHDTF%5Fwarprefine&ga=1
```
sr_model
pretrained_ckpt
├── temporal_gfpgan.pth
├── GFPGANv1.3.pth
...
data
├── HDTF_warprefine
│ ├── gt
│ ├── lq
│ ├── ...
```
## 训练
1.训练warping network
cd base_model
CUDA_VISIBLE_DEVICES=0 python main.py --config config/meta_portrait_256_pretrain_warp.yaml --fp16 --stage Warp --task Pretrain
2.联合训练warping network和refinement network,需要修改config/meta_portrait_256_pretrain_full.yaml中的warp_ckpt
CUDA_VISIBLE_DEVICES=0 python main.py --config config/meta_portrait_256_pretrain_full.yaml --fp16 --stage Full --task Pretrain
3.训练sr model
CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 Experimental_root/train.py -opt options/train/train_sr_hdtf.yml --launcher pytorch
## 推理
1.生成256x256的图片
cd base_model
CUDA_VISIBLE_DEVICES=0 python inference.py --save_dir result --config config/meta_portrait_256_eval.yaml --ckpt checkpoint/ckpt_base.pth.tar
2.提升图片分辨率
cd sr_model
CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 Experimental_root/test.py -opt options/test/same_id_demo.yml --launcher pytorch
## result
1.
![Alt text](imgs/2500.png)
2.
![Alt text](imgs/image-9.png)
## 精度
|psnr|lpips|Ewarp|
|:---:|:---:|:---:|
|26.916|0.1514|0.0244|
## 应用场景
### 算法类别
`计算机视觉`
### 热点应用行业
`人脸识别、反欺诈、美颜特效`
## 源码仓库及问题反馈
https://developer.hpccube.com/codes/modelzoo/metaportrait_pytorch
## 参考
https://github.com/Meta-Portrait/MetaPortrait
https://github.com/Meta-Portrait/MetaPortrait/issues/4
https://github.com/1adrianb/face-alignment
https://datahacker.rs/010-how-to-align-faces-with-opencv-in-python/
# MetaPortrait
![Teaser](./docs/Teaser.png)
This repo is the official implementation of "MetaPortrait: Identity-Preserving Talking Head Generation with Fast Personalized Adaptation" (CVPR 2023).
By [Bowen Zhang](http://home.ustc.edu.cn/~zhangbowen)\*, [Chenyang Qi](https://chenyangqiqi.github.io)\*, [Pan Zhang](https://panzhang0212.github.io), [Bo Zhang](https://bo-zhang.me/), [HsiangTao Wu](https://dl.acm.org/profile/81487650131), [Dong Chen](http://www.dongchen.pro/), [Qifeng Chen](https://cqf.io), [Yong Wang](http://en.auto.ustc.edu.cn/2021/0616/c26828a513186/page.htm) and [Fang Wen](https://www.microsoft.com/en-us/research/people/fangwen/).
[Paper](https://arxiv.org/abs/2212.08062) | [Project Page](https://meta-portrait.github.io/) | [Code](https://github.com/Meta-Portrait/MetaPortrait)
## Abstract
> In this work, we propose an ID-preserving talking head generation framework, which advances previous methods in two aspects. First, as opposed to interpolating from sparse flow, we claim that dense landmarks are crucial to achieving accurate geometry-aware flow fields. Second, inspired by face-swapping methods, we adaptively fuse the source identity during synthesis, so that the network better preserves the key characteristics of the image portrait. Although the proposed model surpasses prior generation fidelity on established benchmarks, to further make the talking head generation qualified for real usage, personalized fine-tuning is usually needed. However, this process is rather computationally demanding that is unaffordable to standard users. To solve this, we propose a fast adaptation model using a meta-learning approach. The learned model can be adapted to a high-quality personalized model as fast as 30 seconds. Last but not the least, a spatial-temporal enhancement module is proposed to improve the fine details while ensuring temporal coherency. Extensive experiments prove the significant superiority of our approach over the state of the arts in both one-shot and personalized settings.
## Todo
- [x] Release the inference code of base model and temporal super-resolution model
- [x] Release the training code of base model
- [x] Release the training code of super-resolution model
## Setup Environment
```bash
git clone https://github.com/Meta-Portrait/MetaPortrait.git
cd MetaPortrait
conda env create -f environment.yml
conda activate meta_portrait_base
# if you use GPU that only support cuda11, you may reinstall the torch build with cu11
# pip uninstall torch
# pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116
```
## Base Model
### Inference Base Model
Download the [checkpoint of base model](https://drive.google.com/file/d/1Kmdv3w6N_we7W7lIt6LBzqRHwwy1dBxD/view?usp=share_link) and put it to `base_model/checkpoint`. We provide [preprocessed example data for inference](https://drive.google.com/file/d/166eNbabM6TeJVy7hxol2gL1kUGKHi3Do/view?usp=share_link), you could download the data, unzip and put it to `data`. The directory structure should like this:
```
data
├── 0
│ ├── imgs
│ │ ├── 00000000.png
│ │ ├── ...
│ ├── ldmks
│ │ ├── 00000000_ldmk.npy
│ │ ├── ...
│ └── thetas
│ ├── 00000000_theta.npy
│ ├── ...
├── src_0_id.npy
├── src_0_ldmk.npy
├── src_0.png
├── src_0_theta.npy
└── src_map_dict.pkl
```
You could generate results of self reconstruction on 256x256 resolution by running:
```bash
cd base_model
python inference.py --save_dir result --config config/meta_portrait_256_eval.yaml --ckpt checkpoint/ckpt_base.pth.tar
```
### Train Base Model from Scratch
Train the warping network first using the following command:
```bash
cd base_model
python main.py --config config/meta_portrait_256_pretrain_warp.yaml --fp16 --stage Warp --task Pretrain
```
Then, modify the path of `warp_ckpt` in `config/meta_portrait_256_pretrain_full.yaml` and joint train the warping network and refinement network using the following command:
```bash
python main.py --config config/meta_portrait_256_pretrain_full.yaml --fp16 --stage Full --task Pretrain
```
### Meta Training for Faster Personalization of Base Model
You could start from the standard pretrained checkpoint and further optimize the personalized adaptation speed of the model by utilizing meta-learning using the following command:
```bash
python main.py --config config/meta_portrait_256_pretrain_meta_train.yaml --fp16 --stage Full --task Meta --remove_sn --ckpt /path/to/standard_pretrain_ckpt
```
## Temporal Super-resolution Model
Set the root path to [sr_model](sr_model)
### Data and checkpoint
Download the [checkpoint](https://github.com/Meta-Portrait/MetaPortrait/releases/download/v0.0.1/temporal_gfpgan.pth) using the bash command
```bash
cd sr_model
bash download_sr.sh
```
Unzip the package and keep the file structure like
```
pretrained_ckpt
├── temporal_gfpgan.pth
├── GFPGANv1.3.pth
...
data
├── HDTF_warprefine
│ ├── gt
│ ├── lq
│ ├── ...
Basicsr
Experimental_root
options
```
### Installation Bash command
```bash
# Install a modified basicsr - https://github.com/xinntao/BasicSR
cd Basicsr
pip install -r requirements.txt
python setup.py develop
# Install facexlib - https://github.com/xinntao/facexlib
# We use face detection and face restoration helper in the facexlib package
pip install facexlib
cd ..
pip install -r requirements.txt
# python setup.py develop
```
### Quick Inference
ckpt for inference: pretrained_ckpt/temporal_gfpgan.pth
<!--
Example code to conduct face temporal super-resolution:
```bash
python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 Experimental_root/test.py -opt options/test/same_id.yml --launcher pytorch
``` -->
Enhance the result from our base model without calculating the metrics:
```bash
python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 Experimental_root/test.py -opt options/test/same_id_demo.yml --launcher pytorch
```
You may adjust the ```nproc_per_node``` to the number of GPUs on your own machine.
Finally, check the result at ```results/temporal_gfpgan_same_id_temporal_super_resolution```.
### Demo training
In the paper result, we train on the training split of hdtf dataset. Here we first provide a demo training code to train on the small demo dataset
```bash
CUDA_VISIBLE_DEVICES=1 python -m torch.distributed.launch --nproc_per_node=1 --master_port=4321 Experimental_root/train.py -opt options/train/train_sr_hdtf.yml --launcher pytorch
```
The intermediate result can be check at `/home/cqiaa/talkinghead/MetaPortrait/sr_model/experiments/train_sr_hdtf/visualization/00000001/hdtf_random`
## Citing MetaPortrait
```
@misc{zhang2022metaportrait,
title={MetaPortrait: Identity-Preserving Talking Head Generation with Fast Personalized Adaptation},
author={Bowen Zhang and Chenyang Qi and Pan Zhang and Bo Zhang and HsiangTao Wu and Dong Chen and Qifeng Chen and Yong Wang and Fang Wen},
year={2022},
eprint={2212.08062},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
## Acknowledgements
This code borrows heavily from [FOMM](https://github.com/AliaksandrSiarohin/first-order-model), thanks the authors for sharing their code and models.
## Maintenance
This is the codebase for our research work. Please open a GitHub issue for any help. If you have any questions regarding the technical details, feel free to contact [zhangbowen@mail.ustc.edu.cn](zhangbowen@mail.ustc.edu.cn) or [cqiaa@connect.ust.hk](cqiaa@connect.ust.hk).
\ No newline at end of file
*/__pycache__/
*.pyc
.vscode/
data/
checkpoint/
output/
result/
\ No newline at end of file
general:
exp_name: meta_portrait_base
random_seed: 365
dataset:
frame_shape: [256, 256, 3]
eye_enhance: True
mouth_enhance: True
ldmkimg: True
ldmk_idx: [521, 505, 338, 398, 347, 35, 191, 30, 32, 207, 630, 629, 319, 4, 541, 61, 637, 660, 638, 587, 273, 590, 269, 432, 118,327, 12, 373, 58, 619, 466, 469, 464, 308, 152, 305, 150, 411, 635, 634, 564, 250, 443, 129, 364, 322, 49, 7, 361, 105, 434, 120, 500, 186, 575, 261, 636, 74]
train_data: [personalized]
train_data_weight: [1]
personalized:
root: ../data/
crop_expand: 1.3
crop_offset_y: 0.2
static_bbox: True
model:
arch: 'SPADEID'
common:
num_channels: 3
kp_detector:
temperature: 0.1
block_expansion: 32
max_features: 1024
scale_factor: 0.25
num_blocks: 5
generator:
block_expansion: 64
max_features: 512
with_gaze_htmap: True
with_mouth_line: True
with_ldmk_line: True
use_IN: True
ladder:
need_feat: False
use_mask: False
label_nc: 0
z_dim: 512
dense_motion_params:
label_nc: 0
ldmkimg: True
occlusion: True
block_expansion: 64
max_features: 1024
num_blocks: 5
dec_lease: 2
Lwarp: True
AdaINc: 512
discriminator:
scales: [1]
block_expansion: 32
max_features: 512
num_blocks: 4
use_kp: False
train:
epochs: 60
batch_size: 48
dataset_repeats: 1
epoch_milestones: [45]
lr_generator: 2.0e-4
lr_discriminator: 2.0e-4
scales: [1, 0.5, 0.25, 0.125]
loss_weights:
generator_gan: 1
discriminator_gan: 1
feature_matching: [10, 10, 10, 10]
perceptual: [10, 10, 10, 10, 10]
id: 20
eye_enhance: 50
mouth_enhance: 50
tensorboard: True
event_save_path: ./results/events/
event_save_freq: 500
ckpt_save_path: ./results/ckpt/
ckpt_save_iter_freq: 5000
ckpt_save_freq: 1
print_freq: 1000
eval_freq: 10000
general:
exp_name: meta_portrait_meta_train
random_seed: 365
dataset:
frame_shape: [256, 256, 3]
eye_enhance: True
mouth_enhance: True
ldmkimg: True
ldmk_idx: [521, 505, 338, 398, 347, 35, 191, 30, 32, 207, 630, 629, 319, 4, 541, 61, 637, 660, 638, 587, 273, 590, 269, 432, 118,327, 12, 373, 58, 619, 466, 469, 464, 308, 152, 305, 150, 411, 635, 634, 564, 250, 443, 129, 364, 322, 49, 7, 361, 105, 434, 120, 500, 186, 575, 261, 636, 74]
train_data: [meta]
train_data_weight: [1]
meta:
root: ../data/
crop_expand: 1.3
crop_offset_y: 0.2
model:
arch: 'SPADEID'
# warp_ckpt: /mnt/blob/projects/IMmeeting/amlt-results/Meeting_exp_25/Orig_RegMotion_Ladder256_VoxCeleb2_Warp_SPADEInit_FeatureNorm_Bs48_Baseline_15eps_256/results/ckpt/spade/ckpt_15_2022-08-27-00-10-08.pth.tar
common:
num_channels: 3
kp_detector:
temperature: 0.1
block_expansion: 32
max_features: 1024
scale_factor: 0.25
num_blocks: 5
generator:
block_expansion: 64
max_features: 512
with_gaze_htmap: True
with_mouth_line: True
with_ldmk_line: True
use_IN: True
ladder:
need_feat: False
use_mask: False
label_nc: 0
z_dim: 512
dense_motion_params:
label_nc: 0
ldmkimg: True
occlusion: True
block_expansion: 64
max_features: 1024
num_blocks: 5
dec_lease: 2
Lwarp: True
AdaINc: 512
discriminator:
scales: [1]
block_expansion: 32
max_features: 512
num_blocks: 4
use_kp: False
train:
epochs: 5
batch_size: 0
dataset_repeats: 1
epoch_milestones: [2]
lr_generator: 2.0e-5
lr_discriminator: 2.0e-5
lr_kp_detector: 2.0e-5
warplr_tune: 0.1
outer_beta_1: 0.5
outer_beta_2: 0.999
inner_lr_generator: 2.0e-4
inner_lr_discriminator: 2.0e-4
inner_warplr_tune: 0.1
inner_beta_1: 0.5
inner_beta_2: 0.999
scales: [1, 0.5, 0.25, 0.125]
loss_weights:
generator_gan: 0
discriminator_gan: 1
feature_matching: [10, 10, 10, 10]
perceptual: [10, 10, 10, 10, 10]
id: 20
eye_enhance: 50
mouth_enhance: 50
tensorboard: True
event_save_path: ./results/events/
event_save_freq: 50
ckpt_save_path: ./results/ckpt/
ckpt_save_iter_freq: 500
ckpt_save_freq: 1
print_freq: 50
general:
exp_name: meta_portrait_base
random_seed: 365
dataset:
frame_shape: [256, 256, 3]
eye_enhance: True
mouth_enhance: True
ldmkimg: True
ldmk_idx: [521, 505, 338, 398, 347, 35, 191, 30, 32, 207, 630, 629, 319, 4, 541, 61, 637, 660, 638, 587, 273, 590, 269, 432, 118,327, 12, 373, 58, 619, 466, 469, 464, 308, 152, 305, 150, 411, 635, 634, 564, 250, 443, 129, 364, 322, 49, 7, 361, 105, 434, 120, 500, 186, 575, 261, 636, 74]
train_data: [personalized]
train_data_weight: [1]
personalized:
root: ../data/
crop_expand: 1.3
crop_offset_y: 0.2
static_bbox: True
model:
arch: 'SPADEID'
warp_ckpt: results/ckpt/meta_portrait_base/ckpt_60_2023-10-09-10-50-59.pth.tar
common:
num_channels: 3
kp_detector:
temperature: 0.1
block_expansion: 32
max_features: 1024
scale_factor: 0.25
num_blocks: 5
generator:
block_expansion: 64
max_features: 512
with_gaze_htmap: True
with_mouth_line: True
with_ldmk_line: True
use_IN: True
ladder:
need_feat: False
use_mask: False
label_nc: 0
z_dim: 512
dense_motion_params:
label_nc: 0
ldmkimg: True
occlusion: True
block_expansion: 64
max_features: 1024
num_blocks: 5
dec_lease: 2
Lwarp: True
AdaINc: 512
discriminator:
scales: [1]
block_expansion: 32
max_features: 512
num_blocks: 4
use_kp: False
train:
epochs: 60
batch_size: 2
dataset_repeats: 1
epoch_milestones: [45]
lr_generator: 2.0e-4
lr_discriminator: 2.0e-4
warplr_tune: 0.1
scales: [1, 0.5, 0.25, 0.125]
loss_weights:
generator_gan: 1
discriminator_gan: 1
feature_matching: [10, 10, 10, 10]
perceptual: [10, 10, 10, 10, 10]
id: 20
eye_enhance: 50
mouth_enhance: 50
tensorboard: True
event_save_path: ./results/events/
event_save_freq: 500
ckpt_save_path: ./results/ckpt/
ckpt_save_iter_freq: 5000
ckpt_save_freq: 10
print_freq: 1000
general:
exp_name: meta_portrait_base
random_seed: 365
dataset:
frame_shape: [256, 256, 3]
eye_enhance: True
mouth_enhance: True
ldmkimg: True
ldmk_idx: [521, 505, 338, 398, 347, 35, 191, 30, 32, 207, 630, 629, 319, 4, 541, 61, 637, 660, 638, 587, 273, 590, 269, 432, 118,327, 12, 373, 58, 619, 466, 469, 464, 308, 152, 305, 150, 411, 635, 634, 564, 250, 443, 129, 364, 322, 49, 7, 361, 105, 434, 120, 500, 186, 575, 261, 636, 74]
train_data: [personalized]
train_data_weight: [1]
personalized:
root: ../data/
crop_expand: 1.3
crop_offset_y: 0.2
static_bbox: True
model:
arch: 'SPADEID'
common:
num_channels: 3
kp_detector:
temperature: 0.1
block_expansion: 32
max_features: 1024
scale_factor: 0.25
num_blocks: 5
generator:
block_expansion: 64
max_features: 512
with_gaze_htmap: True
with_mouth_line: True
with_ldmk_line: True
use_IN: True
ladder:
need_feat: False
use_mask: False
label_nc: 0
z_dim: 512
dense_motion_params:
label_nc: 0
ldmkimg: True
occlusion: True
block_expansion: 64
max_features: 1024
num_blocks: 5
dec_lease: 2
Lwarp: True
AdaINc: 512
discriminator:
scales: [1]
block_expansion: 32
max_features: 512
num_blocks: 4
use_kp: False
train:
epochs: 60
batch_size: 2
dataset_repeats: 1
epoch_milestones: [45]
lr_generator: 2.0e-4
lr_discriminator: 2.0e-4
warplr_tune: 0.1
scales: [1, 0.5, 0.25, 0.125]
loss_weights:
generator_gan: 1
discriminator_gan: 1
feature_matching: [10, 10, 10, 10]
perceptual: [10, 10, 10, 10, 10]
id: 20
eye_enhance: 50
mouth_enhance: 50
tensorboard: True
event_save_path: ./results/events/
event_save_freq: 500
ckpt_save_path: ./results/ckpt/
ckpt_save_iter_freq: 5000
ckpt_save_freq: 10
print_freq: 1000
eval_freq: 10000
This diff is collapsed.
import argparse
import os
import cv2
import numpy as np
import torch
import yaml
from moviepy.editor import ImageSequenceClip
from torch.utils.data import DataLoader
from tqdm import tqdm
import utils
from dataset import PersonalDataset
from modules.discriminator import MultiScaleDiscriminator
from modules.generator import Generator
from modules.model import GeneratorFullModel
def build_model(args, conf):
utils.set_random_seed(conf['general']['random_seed'])
G = Generator(conf['model'].get('arch', None), **conf['model']['generator'], **conf['model']['common'])
utils.load_ckpt(args['ckpt'], {'generator': G}, device=args['device'], strict=True)
G.eval()
D = MultiScaleDiscriminator(**conf['model']['discriminator'], **conf['model']['common'])
G_full = GeneratorFullModel(None, G.cuda(), D.cuda(), conf['train'], conf['model'].get('arch', None), conf=conf)
return G_full
def build_data_loader(conf, name='personal'):
dataset = PersonalDataset(conf, name, is_train=False)
sampler = torch.utils.data.SequentialSampler(dataset)
test_dataloader = DataLoader(dataset, batch_size=1, sampler=sampler, num_workers=2, drop_last=False, pin_memory=True)
return test_dataloader
def model_forward(G_full, data):
for key, value in data.items():
if isinstance(value, list):
if isinstance(value[0], (str, list)):
continue
data[key] = [v.cuda() for v in value]
elif isinstance(value, str):
continue
else:
data[key] = value.cuda()
generated = G_full(data, stage="Full", inference=True)
return generated
def save_images(args, generated, data):
for j in range(len(generated['prediction'])):
final = np.transpose(generated['prediction'][j].data.cpu().numpy(), [1, 2, 0])
final = np.clip(final * 255, 0, 255).astype(np.uint8)[:, :, ::-1]
cv2.imwrite(os.path.join(args["save_dir"], "output_256", str(data['driving_name'][j])), final)
def save_to_video(args, gt_path):
img_list = []
for file in tqdm(gt_path):
img_name = file.split("/")[-1]
final_img_path = os.path.join(args["save_dir"], 'output_256', img_name)
final_img = cv2.resize(cv2.cvtColor(cv2.imread(final_img_path), cv2.COLOR_BGR2RGB), (256, 256))
img_list.append(final_img)
imgseqclip = ImageSequenceClip(img_list, 23.98)
imgseqclip.write_videofile(os.path.join(args["save_dir"], "out_256.mp4"), logger=None)
def evaluation(args, conf):
os.makedirs(os.path.join(args["save_dir"], "output_256"), exist_ok=True)
G_full = build_model(args, conf)
name = conf["dataset"]["train_data"][0]
test_dataloader = build_data_loader(conf, name)
print("Evaluation using {} images.".format(len(test_dataloader.dataset)))
for data in tqdm(test_dataloader):
with torch.inference_mode():
generated = model_forward(G_full, data)
save_images(args, generated, data)
save_to_video(args, test_dataloader.dataset.data['imgs'])
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Image Animation for Immersive Meeting Evaluation Scripts',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--save_dir', type=str, default='../../', help='image save dir')
parser.add_argument('--ckpt', type=str, help='load checkpoint path')
parser.add_argument("--config", type=str, default="config/test.yaml", help="path to config")
args = vars(parser.parse_args())
args['device'] = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
with open(args['config']) as f:
conf = yaml.safe_load(f)
evaluation(args, conf)
import argparse
import os
import warnings
warnings.filterwarnings("ignore")
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import yaml
from torch.utils.data import DataLoader
import utils
from dataset import PersonalDataset, PersonalMetaDataset
from modules.discriminator import MultiScaleDiscriminator
from modules.generator import Generator
from train_ddp import train_ddp
def get_dataset(conf, name, is_train=True):
if name == 'personalized':
return PersonalDataset(conf, name, is_train=is_train)
elif name == 'meta':
return PersonalMetaDataset(conf, name, is_train=is_train)
else:
raise Exception("Unsupported dataset type: {}".format(name))
def get_params():
parser = argparse.ArgumentParser(description='Image Animation for Immersive Meeting Training Scripts',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--ckpt', type=str, help='load checkpoint path')
parser.add_argument("--config", type=str, default="config/test.yaml", help="path to config")
parser.add_argument('--remove_sn', action='store_true')
parser.add_argument("--fp16", action='store_true', help="Whether to use fp16")
parser.add_argument("--stage", type=str, default="Full", help="Full | Warp")
parser.add_argument("--task", type=str, default="Meta", help="Meta | Pretrain | Eval")
parser.add_argument("--port", type=int, default=23456, help="Running port for DDP")
args = parser.parse_args()
return args
def main(rank, args):
args = vars(args)
args['local_rank'] = rank
dist.init_process_group(
backend='nccl',
rank=rank,
world_size=args['ngpus'],
init_method="tcp://localhost:{}".format(args['port']),
)
torch.cuda.set_device(rank)
args['device'] = rank if torch.cuda.is_available() else torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
with open(args['config']) as f:
conf = yaml.safe_load(f)
if rank == 0:
print("Saving checkpoints to {}".format(conf['train']['ckpt_save_path']))
utils.set_random_seed(conf['general']['random_seed'])
conf['dataset']['ngpus'] = 1
G = Generator(conf['model'].get('arch', None), **conf['model']['generator'], **conf['model']['common'])
D = MultiScaleDiscriminator(**conf['model']['discriminator'], **conf['model']['common'])
G = G.to(args['device'])
D = D.to(args['device'])
train_data_list = []
for name in conf['dataset']['train_data']:
data = get_dataset(conf, name)
print("Dataset length: {}".format(len(data)))
train_data_list.append(data)
train_data = data
sampler = torch.utils.data.distributed.DistributedSampler(train_data, num_replicas=args['ngpus'], rank=rank,)
batch_size = int(conf['train']['batch_size'] // args['ngpus']) if conf['train']['batch_size'] > 0 else None
print("Batch Size", batch_size)
dataset_train = DataLoader(train_data, batch_size=batch_size, num_workers=1,
drop_last=False, pin_memory=True, sampler=sampler)
models = {'generator': G, 'discriminator': D}
datasets = {'dataset_train': dataset_train}
train_ddp(args, conf, models, datasets)
if __name__ == '__main__':
params = get_params()
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
params.ngpus = torch.cuda.device_count()
mp.spawn(main, nprocs=params.ngpus, args=(params,))
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import torch.nn.utils.spectral_norm as spectral_norm
def get_nonspade_norm_layer(norm_type="instance"):
# helper function to get # output channels of the previous layer
def get_out_channel(layer):
if hasattr(layer, "out_channels"):
return getattr(layer, "out_channels")
return layer.weight.size(0)
# this function will be returned
def add_norm_layer(layer):
nonlocal norm_type
subnorm_type = norm_type
if norm_type.startswith("spectral"):
layer = spectral_norm(layer)
subnorm_type = norm_type[len("spectral") :]
if subnorm_type == "none" or len(subnorm_type) == 0:
return layer
# remove bias in the previous layer, which is meaningless
# since it has no effect after normalization
if getattr(layer, "bias", None) is not None:
delattr(layer, "bias")
layer.register_parameter("bias", None)
if subnorm_type == "batch":
norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
elif subnorm_type == "sync_batch":
norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
elif subnorm_type == "instance":
norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
else:
raise ValueError("normalization layer %s is not recognized" % subnorm_type)
return nn.Sequential(layer, norm_layer)
return add_norm_layer
class LadderEncoder(nn.Module):
""" Same architecture as the image discriminator """
def __init__(self, need_feat=False, use_mask=False, label_nc=0, z_dim=512, norm_type="spectralinstance"):
super().__init__()
self.need_feat = need_feat
ldmk_img_nc = 3
nif = 3 + label_nc + 2 * ldmk_img_nc
kw = 3
pw = int(np.ceil((kw - 1.0) / 2))
nef = 64
norm_layer = get_nonspade_norm_layer(norm_type)
self.layer1 = norm_layer(nn.Conv2d(nif, nef, kw, stride=2, padding=pw))
self.layer2 = norm_layer(nn.Conv2d(nef * 1, nef * 2, kw, stride=2, padding=pw))
self.layer3 = norm_layer(nn.Conv2d(nef * 2, nef * 4, kw, stride=2, padding=pw))
self.layer4 = norm_layer(nn.Conv2d(nef * 4, nef * 8, kw, stride=2, padding=pw))
self.layer5 = norm_layer(nn.Conv2d(nef * 8, nef * 8, kw, stride=2, padding=pw))
self.layer6 = norm_layer(nn.Conv2d(nef * 8, nef * 8, kw, stride=2, padding=pw))
if need_feat:
self.up_layer2 = norm_layer(
nn.Conv2d(nef * 2, nef * 2, kw, stride=1, padding=pw)
)
self.up_layer3 = nn.Sequential(
norm_layer(nn.Conv2d(nef * 4, nef * 2, kw, stride=1, padding=pw)),
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
)
self.up_layer4 = nn.Sequential(
norm_layer(nn.Conv2d(nef * 8, nef * 2, kw, stride=1, padding=pw)),
nn.Upsample(scale_factor=4, mode="bilinear", align_corners=True),
)
self.up_layer5 = nn.Sequential(
norm_layer(nn.Conv2d(nef * 8, nef * 2, kw, stride=1, padding=pw)),
nn.Upsample(scale_factor=8, mode="bilinear", align_corners=True),
)
self.up_layer6 = nn.Sequential(
norm_layer(nn.Conv2d(nef * 8, nef * 2, kw, stride=1, padding=pw)),
nn.Upsample(scale_factor=16, mode="bilinear", align_corners=True),
)
self.actvn = nn.LeakyReLU(0.2, False)
self.so = s0 = 4
self.fc = nn.Linear(nef * 8 * s0 * s0, z_dim)
def forward(self, x):
features = None
if x.size(2) != 256 or x.size(3) != 256:
x = F.interpolate(x, size=(256, 256), mode="bilinear")
x = self.layer1(x)
x = self.layer2(self.actvn(x))
if self.need_feat:
features = self.up_layer2(x)
x = self.layer3(self.actvn(x))
if self.need_feat:
features = self.up_layer3(x) + features
x = self.layer4(self.actvn(x))
if self.need_feat:
features = self.up_layer4(x) + features
x = self.layer5(self.actvn(x))
if self.need_feat:
features = self.up_layer5(x) + features
x = self.layer6(self.actvn(x))
if self.need_feat:
features = self.up_layer6(x) + features
x = self.actvn(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
x = x / (x.norm(dim=-1, p=2, keepdim=True) + 1e-5)
return x, features
#!/usr/bin/python
# -*- coding: UTF-8 -*-
# Created by: algohunt
# Microsoft Research & Peking University
# lilingzhi@pku.edu.cn
# Copyright (c) 2019
import torch
from torch.nn import init
import torch.nn.functional as F
from torch import nn
from math import sqrt
def init_linear(linear):
init.xavier_normal(linear.weight)
linear.bias.data.zero_()
def init_conv(conv, glu=True):
init.kaiming_normal(conv.weight)
if conv.bias is not None:
conv.bias.data.zero_()
class EqualLR:
def __init__(self, name):
self.name = name
def compute_weight(self, module):
weight = getattr(module, self.name + '_orig')
fan_in = weight.data.size(1) * weight.data[0][0].numel()
return weight * sqrt(2 / fan_in)
@staticmethod
def apply(module, name):
fn = EqualLR(name)
weight = getattr(module, name)
del module._parameters[name]
module.register_parameter(name + '_orig', nn.Parameter(weight.data))
module.register_forward_pre_hook(fn)
return fn
def __call__(self, module, input):
weight = self.compute_weight(module)
setattr(module, self.name, weight)
def equal_lr(module, name='weight'):
EqualLR.apply(module, name)
return module
class Blur(nn.Module):
def __init__(self):
super().__init__()
weight = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32)
weight = weight.view(1, 1, 3, 3)
weight = weight / weight.sum()
self.register_buffer('weight', weight)
def forward(self, input):
return F.conv2d(
input,
self.weight.repeat(input.shape[1], 1, 1, 1),
padding=1,
groups=input.shape[1],
)
class EqualConv2d(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
conv = nn.Conv2d(*args, **kwargs)
conv.weight.data.normal_()
conv.bias.data.zero_()
self.conv = equal_lr(conv)
def forward(self, input):
return self.conv(input)
class EqualLinear(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
linear = nn.Linear(in_dim, out_dim)
linear.weight.data.normal_()
linear.bias.data.zero_()
self.linear = equal_lr(linear)
def forward(self, input):
return self.linear(input)
class AdaptiveInstanceNorm(nn.Module):
def __init__(self, in_channel, style_dim):
super().__init__()
self.norm = nn.InstanceNorm2d(in_channel)
self.style = EqualLinear(style_dim, in_channel * 2)
self.style.linear.bias.data[:in_channel] = 1
self.style.linear.bias.data[in_channel:] = 0
def forward(self, input, style):
style = self.style(style).unsqueeze(2).unsqueeze(3)
gamma, beta = style.chunk(2, 1)
out = self.norm(input)
out = gamma * out + beta
return out
class NoiseInjection(nn.Module):
def __init__(self, channel):
super().__init__()
self.weight = nn.Parameter(torch.zeros(1, channel, 1, 1))
def forward(self, image, noise):
return image + self.weight * noise
class ConstantInput(nn.Module):
def __init__(self, channel, size=4):
super().__init__()
self.input = nn.Parameter(torch.randn(1, channel, size, size))
def forward(self, input):
batch = input.shape[0]
out = self.input.repeat(batch, 1, 1, 1)
return out
\ No newline at end of file
from torch import nn
import torch
import functools
from modules.util import (
Hourglass,
make_coordinate_grid,
LayerNorm2d,
)
class DenseMotionNetworkReg(nn.Module):
def __init__(
self,
block_expansion,
num_blocks,
max_features,
Lwarp=False,
AdaINc=0,
dec_lease=0,
label_nc=0,
ldmkimg=False,
occlusion=False,
):
super(DenseMotionNetworkReg, self).__init__()
in_c = 3 + label_nc + 2 * 3 if ldmkimg else 3 + label_nc
self.hourglass = Hourglass(
block_expansion=block_expansion,
in_features=in_c,
max_features=max_features,
num_blocks=num_blocks,
Lwarp=Lwarp,
AdaINc=AdaINc,
dec_lease=dec_lease,
)
self.occlusion = occlusion
if dec_lease > 0:
norm_layer = functools.partial(LayerNorm2d, affine=True)
self.reger = nn.Sequential(
norm_layer(self.hourglass.out_filters),
nn.LeakyReLU(0.1),
nn.Conv2d(
self.hourglass.out_filters, 2, kernel_size=7, stride=1, padding=3
),
)
if occlusion:
self.occlusion_net = nn.Sequential(
norm_layer(self.hourglass.out_filters),
nn.LeakyReLU(0.1),
nn.Conv2d(
self.hourglass.out_filters,
1,
kernel_size=7,
stride=1,
padding=3,
),
)
else:
self.reger = nn.Conv2d(
self.hourglass.out_filters, 2, kernel_size=(7, 7), padding=(3, 3)
)
def forward(self, source_image, drv_deca):
prediction = self.hourglass(source_image, drv_exp=drv_deca)
out_dict = {}
flow = self.reger(prediction)
bs, _, h, w = flow.shape
flow_norm = 2 * torch.cat(
[flow[:, :1, ...] / (w - 1), flow[:, 1:, ...] / (h - 1)], 1
)
out_dict["flow"] = flow_norm
grid = make_coordinate_grid((h, w), type=torch.FloatTensor).to(flow_norm.device)
deformation = grid + flow_norm.permute(0, 2, 3, 1)
out_dict["deformation"] = deformation
if self.occlusion:
occlusion_map = torch.sigmoid(self.occlusion_net(prediction))
_, _, h_old, w_old = occlusion_map.shape
_, _, h, w = source_image.shape
if h_old != h or w_old != w:
occlusion_map = torch.nn.functional.interpolate(
occlusion_map, size=(h, w), mode="bilinear", align_corners=False
)
out_dict["occlusion_map"] = occlusion_map
return out_dict
from torch import nn
import torch.nn.functional as F
from modules.util import kp2gaussian
import torch
from torch.nn.utils import spectral_norm
class DownBlock2d(nn.Module):
"""
Simple block for processing video (encoder).
"""
def __init__(
self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False
):
super(DownBlock2d, self).__init__()
self.conv = nn.Conv2d(
in_channels=in_features, out_channels=out_features, kernel_size=kernel_size
)
if sn:
self.conv = nn.utils.spectral_norm(self.conv)
if norm:
self.norm = nn.InstanceNorm2d(out_features, affine=True)
else:
self.norm = None
self.pool = pool
def forward(self, x):
out = x
out = self.conv(out)
if self.norm:
out = self.norm(out)
out = F.leaky_relu(out, 0.2)
if self.pool:
out = F.avg_pool2d(out, (2, 2))
return out
class Discriminator(nn.Module):
"""
Discriminator similar to Pix2Pix
"""
def __init__(
self,
num_channels=3,
block_expansion=64,
num_blocks=4,
max_features=512,
sn=False,
use_kp=False,
num_kp=10,
kp_variance=0.01,
AdaINc=0,
**kwargs
):
super(Discriminator, self).__init__()
down_blocks = []
for i in range(num_blocks):
down_blocks.append(
DownBlock2d(
num_channels + num_kp * use_kp
if i == 0
else min(max_features, block_expansion * (2 ** i)),
min(max_features, block_expansion * (2 ** (i + 1))),
norm=(i != 0),
kernel_size=4,
pool=(i != num_blocks - 1),
sn=sn,
)
)
self.down_blocks = nn.ModuleList(down_blocks)
self.conv = nn.Conv2d(
self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1
)
if sn:
self.conv = nn.utils.spectral_norm(self.conv)
self.use_kp = use_kp
self.kp_variance = kp_variance
self.AdaINc = AdaINc
if AdaINc > 0:
self.to_exp = nn.Sequential(
nn.Linear(block_expansion * (2 ** num_blocks), 256),
nn.LeakyReLU(256),
nn.Linear(256, AdaINc),
)
def forward(self, x, kp=None):
feature_maps = []
out = x
if self.use_kp:
heatmap = kp2gaussian(kp, x.shape[2:], self.kp_variance)
out = torch.cat([out, heatmap], dim=1)
for down_block in self.down_blocks:
feature_maps.append(down_block(out))
out = feature_maps[-1]
prediction_map = self.conv(out)
if self.AdaINc > 0:
feat = F.adaptive_avg_pool2d(out, 1)
exp = self.to_exp(feat.squeeze(-1).squeeze(-1))
else:
exp = None
return feature_maps, prediction_map, exp
class MultiScaleDiscriminator(nn.Module):
"""
Multi-scale (scale) discriminator
"""
def __init__(self, scales=(), **kwargs):
super(MultiScaleDiscriminator, self).__init__()
self.scales = scales
discs = {}
self.use_kp = kwargs["use_kp"]
for scale in scales:
discs[str(scale).replace(".", "-")] = Discriminator(**kwargs)
self.discs = nn.ModuleDict(discs)
self.apply(self._init_weights)
def _init_weights(self, m):
gain = 0.02
if isinstance(m, nn.LayerNorm):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if m.weight is not None:
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.BatchNorm2d):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if m.weight is not None:
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
if m.weight is not None:
nn.init.xavier_normal_(m.weight, gain=gain)
if hasattr(m, "bias") and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
if hasattr(m, "weight") and m.weight is not None:
nn.init.xavier_normal_(m.weight, gain=gain)
if hasattr(m, "bias") and m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x, kp=None):
out_dict = {}
for scale, disc in self.discs.items():
scale = str(scale).replace("-", ".")
key = "prediction_" + scale
feature_maps, prediction_map, exp = disc(x[key], kp)
out_dict["feature_maps_" + scale] = feature_maps
out_dict["prediction_map_" + scale] = prediction_map
out_dict["exp_" + scale] = exp
return out_dict
import torch
import torch.nn.functional as F
from torch import nn
from modules.dense_motion import DenseMotionNetworkReg
from modules.LadderEncoder import LadderEncoder
from modules.spade import SPADEGenerator
from modules.util import Hourglass, kp2gaussian
def Generator(arch, **kwarg):
return OcclusionAwareSPADEGenerator(**kwarg, hasid=True)
class OcclusionAwareSPADEGenerator(nn.Module):
"""
Generator that given source image, source ldmk image and driving ldmk image try to transform image according to movement trajectories
according to the ldmks.
"""
def __init__(
self,
num_channels,
block_expansion,
max_features,
dense_motion_params=None,
with_warp_im=False,
hasid=False,
with_gaze_htmap=False,
with_ldmk_line=False,
with_mouth_line=False,
with_ht=False,
ladder=None,
use_IN=False,
use_SN=True
):
super(OcclusionAwareSPADEGenerator, self).__init__()
self.with_warp_im = with_warp_im
self.with_gaze_htmap = with_gaze_htmap
self.with_ldmk_line = with_ldmk_line
self.with_mouth_line = with_mouth_line
self.with_ht = with_ht
self.ladder = ladder
self.use_IN = use_IN
self.use_SN = use_SN
ladder_norm_type = "spectralinstance" if use_SN else "instance"
self.ladder_network = LadderEncoder(**ladder, norm_type=ladder_norm_type)
self.dense_motion_network = DenseMotionNetworkReg(
**dense_motion_params
)
num_blocks = 3
self.feature_encoder = Hourglass(
block_expansion=block_expansion,
in_features=3,
max_features=max_features,
num_blocks=num_blocks,
Lwarp=False,
AdaINc=0,
dec_lease=0,
use_IN=use_IN
)
self.fuse_high_res = nn.Conv2d(block_expansion + 3, block_expansion, kernel_size=(3, 3), padding=(1, 1))
norm = "spectral" if self.use_SN else ""
norm += "spadeinstance3x3" if self.use_IN else "spadebatch3x3"
if hasid:
norm += "id"
class_dim = 256
label_nc_offset = 0 # if with_warp_im else 256
label_nc_offset = label_nc_offset + 8 if with_gaze_htmap else label_nc_offset
label_nc_offset = label_nc_offset + 6 if with_ldmk_line else label_nc_offset
label_nc_offset = label_nc_offset + 3 if with_mouth_line else label_nc_offset
label_nc_offset = label_nc_offset + 59 if with_ht else label_nc_offset
label_nc_offset = label_nc_offset + 1 # For occlusion map
label_nc_list = [512, 512, 512, 512, 256, 128, 64]
label_nc_list = [ln + label_nc_offset for ln in label_nc_list]
self.SPDAE_G = SPADEGenerator(
conv_dim=32,
label_nc=label_nc_list,
norm_G=norm,
class_dim=class_dim,
)
self.num_channels = num_channels
self.apply(self._init_weights)
def _init_weights(self, m):
gain = 0.02
if isinstance(m, nn.LayerNorm):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if m.weight is not None:
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.BatchNorm2d):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if m.weight is not None:
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
if m.weight is not None:
nn.init.xavier_normal_(m.weight, gain=gain)
if hasattr(m, "bias") and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
if hasattr(m, "weight") and m.weight is not None:
nn.init.xavier_normal_(m.weight, gain=gain)
if hasattr(m, "bias") and m.bias is not None:
nn.init.constant_(m.bias, 0)
def deform_input(self, inp, deformation):
_, h_old, w_old, _ = deformation.shape
_, _, h, w = inp.shape
if h_old != h or w_old != w:
deformation = deformation.permute(0, 3, 1, 2)
deformation = F.interpolate(deformation, size=(h, w), mode="bilinear")
deformation = deformation.permute(0, 2, 3, 1)
return F.grid_sample(
inp.to(deformation.dtype), deformation, padding_mode="reflection"
)
def get_gaze_ht(self, source_image, kp_driving):
spatial_size = source_image.shape[2:]
gaussian_driving = kp2gaussian(
kp_driving, spatial_size=spatial_size, kp_variance=0.005
)
return gaussian_driving[:, 29:37]
def forward_warp(
self,
source_image,
ldmk_line=None
):
output_dict = {}
input_t = (
source_image
if ldmk_line is None
else torch.cat((source_image, ldmk_line), dim=1)
)
style_feat, _ = self.ladder_network(input_t)
drv_exp = style_feat
dense_motion = self.dense_motion_network(input_t, drv_exp)
output_dict["deformation"] = dense_motion["deformation"]
output_dict["deformed"] = self.deform_input(
source_image, dense_motion["deformation"]
)
output_dict["occlusion_map"] = dense_motion["occlusion_map"]
output_dict["prediction"] = output_dict["deformed"]
output_dict["flow"] = dense_motion["flow"]
return output_dict
def foward_refine(
self,
source_image,
src_id,
ldmk_line,
mouth_line,
warp_out,
kp_driving=None,
):
_, out_list = self.feature_encoder(source_image, return_all=True)
out_list[-1] = self.fuse_high_res(out_list[-1])
out_list = [self.deform_input(out, warp_out["deformation"]) for out in out_list]
feature_list = []
for out in out_list:
if self.with_gaze_htmap:
gaze_htmap = self.get_gaze_ht(out, kp_driving)
inputs = out
if out.shape[2] != warp_out["occlusion_map"].shape[2] or out.shape[3] != warp_out["occlusion_map"].shape[3]:
occlusion_map = F.interpolate(
warp_out["occlusion_map"], size=out.shape[2:], mode="bilinear"
)
else:
occlusion_map = warp_out["occlusion_map"]
inputs = torch.cat((inputs, occlusion_map), dim=1)
if self.with_gaze_htmap:
inputs = torch.cat((inputs, gaze_htmap), dim=1)
if self.with_ldmk_line:
ldmk_line = F.interpolate(ldmk_line, size=inputs.shape[2:], mode="bilinear")
inputs = torch.cat((inputs, ldmk_line), dim=1)
if self.with_mouth_line:
mouth_line = F.interpolate(
mouth_line, size=inputs.shape[2:], mode="bilinear"
)
inputs = torch.cat((inputs, mouth_line), dim=1)
feature_list.append(inputs)
outs = self.SPDAE_G(feature_list, class_emb=src_id)
warp_out["prediction"] = outs
return warp_out
def forward(
self,
source_image,
kp_driving=None,
src_id=None,
stage=None,
ldmk_line=None,
mouth_line=None,
warp_out=None,
):
if stage == "Warp":
return self.forward_warp(source_image, ldmk_line)
elif stage == "Refine":
return self.foward_refine(
source_image,
src_id,
ldmk_line,
mouth_line,
warp_out,
kp_driving,
)
else:
raise Exception("Unknown stage.")
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torchvision import models
from modules.util import AntiAliasInterpolation2d
class Vgg19(torch.nn.Module):
"""
Vgg19 network for perceptual loss. See Sec 3.3.
"""
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
vgg_pretrained_features = models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
requires_grad=False)
self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
requires_grad=False)
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
X = (X - self.mean) / self.std
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
class ImagePyramide(torch.nn.Module):
"""
Create image pyramide for computing pyramide perceptual loss. See Sec 3.3
"""
def __init__(self, scales, num_channels):
super(ImagePyramide, self).__init__()
downs = {}
for scale in scales:
downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
self.downs = nn.ModuleDict(downs)
def forward(self, x):
out_dict = {}
for scale, down_module in self.downs.items():
out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
return out_dict
def detach_kp(kp):
return {key: value.detach() for key, value in kp.items()}
class GeneratorFullModel(torch.nn.Module):
"""
Merge all generator related updates into single model for better multi-gpu usage
"""
def __init__(self, kp_extractor, generator, discriminator, train_params, arch=None, rank=0, conf=None):
super(GeneratorFullModel, self).__init__()
self.arch = arch
self.conf = conf
self.kp_extractor = kp_extractor
self.generator = generator
self.discriminator = discriminator
self.train_params = train_params
self.scales = train_params['scales']
if conf['model']['discriminator'].get('type', 'MultiPatchGan') == 'MultiPatchGan':
self.disc_scales = self.discriminator.scales
self.pyramid = ImagePyramide(self.scales, generator.num_channels)
if torch.cuda.is_available():
self.pyramid = self.pyramid.to(rank)
self.loss_weights = train_params['loss_weights']
if sum(self.loss_weights['perceptual']) != 0:
self.vgg = Vgg19()
if torch.cuda.is_available():
self.vgg = self.vgg.to(rank)
self.vgg.eval()
if self.loss_weights.get('warp_ce', 0) > 0:
self.ce_loss = nn.CrossEntropyLoss().to(rank)
if self.loss_weights.get('l1', 0) > 0:
self.l1_loss = nn.L1Loss()
def nist_prec(self, x):
x = (x.clone() - 0.5) * 2 # -1 ~ 1
x = x[:, :, 25:256, 25:256]
x = torch.flip(x,[1]) # RGB -> BGR
return x
def forward_warp(self, x, cal_loss=True):
if self.conf['dataset'].get('ldmkimg', False):
ldmk_line = torch.cat((x['source_line'], x['driving_line']), dim=1)
else:
ldmk_line = None
generated = self.generator(x['source'], ldmk_line=ldmk_line, stage='Warp')
loss_values = {}
if cal_loss:
pyramide_real = self.pyramid(x['driving'])
pyramide_generated = self.pyramid(generated['deformed'])
value_total = 0
for scale in self.scales:
x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
for i, weight in enumerate(self.loss_weights['perceptual']):
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
value_total += self.loss_weights['perceptual'][i] * value
loss_values['warp_perceptual'] = value_total
return loss_values, generated
def forward_refine(self, x, warp_out, loss_values, inference=False):
kp_driving = {'value': x['driving_ldmk']}
embed_id = x['source_id']
ldmk_line = torch.cat((x['source_line'], x['driving_line']), dim=1) if self.conf['dataset'].get('ldmkimg', False) else None
if self.loss_weights.get('mouth_enhance', 0) > 0:
mouth_line = x['driving_line'] * x['mouth_mask']
else:
mouth_line = None
generated = self.generator(x['source'], kp_driving=kp_driving, src_id=embed_id, ldmk_line=ldmk_line, mouth_line=mouth_line, warp_out=warp_out, stage='Refine')
if inference:
return generated
pyramide_real = self.pyramid(x['driving'])
pyramide_generated = self.pyramid(generated['prediction'])
if sum(self.loss_weights['perceptual']) != 0:
value_total = 0
eye_total = 0
mouth_total = 0
for scale in self.scales:
x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
for i, weight in enumerate(self.loss_weights['perceptual']):
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
value_total += self.loss_weights['perceptual'][i] * value
if self.loss_weights.get('eye_enhance', 0) > 0:
eye_scale = F.interpolate(x['eye_mask'], size=pyramide_generated['prediction_' + str(scale)].shape[2:], mode='nearest')
eye_total += ((pyramide_generated['prediction_' + str(scale)] - pyramide_real['prediction_' + str(scale)]) ** 2 * eye_scale).sum() / (eye_scale.sum() + 1e-6)
if self.loss_weights.get('mouth_enhance', 0) > 0:
mouth_scale = F.interpolate(x['mouth_mask'], size=pyramide_generated['prediction_' + str(scale)].shape[2:], mode='nearest')
mouth_total += ((pyramide_generated['prediction_' + str(scale)] - pyramide_real['prediction_' + str(scale)]) ** 2 * mouth_scale).sum() / (mouth_scale.sum() + 1e-6)
loss_values['perceptual'] = value_total
if self.loss_weights.get('eye_enhance', 0) > 0:
loss_values['eye'] = eye_total * self.loss_weights['eye_enhance']
if self.loss_weights.get('mouth_enhance', 0) > 0:
loss_values['mouth'] = mouth_total * self.loss_weights['mouth_enhance']
if self.loss_weights.get('l1', 0) > 0:
loss_values['l1'] = self.l1_loss(generated['prediction'], x['driving']) * self.loss_weights['l1']
# if self.loss_weights.get('id', 0) > 0:
# gen_grid = F.affine_grid(x['driving_theta'], [x['driving_theta'].shape[0], 3, 256,256], align_corners=True)
# gen_nist = F.grid_sample(F.interpolate(generated['prediction'], (256, 256), mode='bilinear'), gen_grid, align_corners=True)
# gen_id = self.id_classifier(self.nist_prec(gen_nist))
# gen_id = F.normalize(gen_id, dim=1)
# tgt_id = F.normalize(embed_id, dim=1)
# loss_values['id'] = (1 - (gen_id * tgt_id).sum(1).mean()) * self.loss_weights['id']
if self.loss_weights['generator_gan'] != 0:
if self.conf['model']['discriminator'].get('type', 'MultiPatchGan') == 'MultiPatchGan':
if self.conf['model']['discriminator'].get('use_kp', False):
discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))
discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))
else:
discriminator_maps_generated = self.discriminator(pyramide_generated, kp=x['driving_line'])
discriminator_maps_real = self.discriminator(pyramide_real, kp=x['driving_line'])
value_total = 0
for scale in self.disc_scales:
key = 'prediction_map_%s' % scale
value = ((1 - discriminator_maps_generated[key]) ** 2).mean()
value_total += self.loss_weights['generator_gan'] * value
loss_values['gen_gan'] = value_total
if sum(self.loss_weights['feature_matching']) != 0:
value_total = 0
for scale in self.disc_scales:
key = 'feature_maps_%s' % scale
for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])):
if self.loss_weights['feature_matching'][i] == 0:
continue
value = torch.abs(a - b).mean()
value_total += self.loss_weights['feature_matching'][i] * value
loss_values['feature_matching'] = value_total
else:
discriminator_maps_generated = self.discriminator(pyramide_generated['prediction_1'])
value = ((1 - discriminator_maps_generated) ** 2).mean()
loss_values['gen_gan'] = self.loss_weights['generator_gan'] * value
return loss_values, generated
def forward(self, x, stage=None, inference=False):
if stage == 'Warp':
return self.forward_warp(x, cal_loss=not inference)
elif stage == 'Full':
warp_loss, warp_out = self.forward_warp(x)
return self.forward_refine(x, warp_out, warp_loss, inference=inference)
else:
raise Exception("Unknown stage.")
def get_gaze_loss(self, deformation, gaze):
mask = (gaze != 0).detach().float()
up_deform = F.interpolate(deformation.permute(0,3,1,2), size=gaze.shape[1:3], mode='bilinear').permute(0,2,3,1)
gaze_loss = (torch.abs(up_deform - gaze) * mask).sum() / (mask.sum() + 1e-6)
return gaze_loss
def get_ldmk_loss(self, mask, ldmk_gt):
pred = F.interpolate(mask, size=ldmk_gt.shape[1:], mode='bilinear')
ldmk_loss = F.cross_entropy(pred, ldmk_gt, ignore_index=0)
return ldmk_loss
class DiscriminatorFullModel(torch.nn.Module):
"""
Merge all discriminator related updates into single model for better multi-gpu usage
"""
def __init__(self, kp_extractor, generator, discriminator, train_params):
super(DiscriminatorFullModel, self).__init__()
self.kp_extractor = kp_extractor
self.generator = generator
self.discriminator = discriminator
self.train_params = train_params
self.scales = self.discriminator.scales
self.use_kp = discriminator.use_kp
self.pyramid = ImagePyramide(self.scales, generator.num_channels)
if torch.cuda.is_available():
self.pyramid = self.pyramid.cuda()
self.loss_weights = train_params['loss_weights']
def forward(self, x, generated):
pyramide_real = self.pyramid(x['driving'])
pyramide_generated = self.pyramid(generated['prediction'].detach())
if self.use_kp:
kp_driving = generated['kp_driving']
discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))
discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))
else:
kp_driving = x['driving_line']
discriminator_maps_generated = self.discriminator(pyramide_generated, kp=kp_driving)
discriminator_maps_real = self.discriminator(pyramide_real, kp=kp_driving)
loss_values = {}
value_total = 0
for scale in self.scales:
key = 'prediction_map_%s' % scale
value = (1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2
value_total += self.loss_weights['discriminator_gan'] * value.mean()
loss_values['disc_gan'] = value_total
if self.loss_weights.get('D_exp', 0) > 0:
loss_values['exp'] = F.mse_loss(discriminator_maps_real['exp_1'], x['driving_exp']) * self.loss_weights['D_exp'] + \
F.mse_loss(discriminator_maps_generated['exp_1'], x['driving_exp']) * self.loss_weights['D_exp']
return loss_values
import re
import torch
from torch import nn
import torch.nn.functional as F
import torch.nn.utils.spectral_norm as spectral_norm
from modules.adain import AdaptiveInstanceNorm
class SPADE(nn.Module):
def __init__(self, config_text, norm_nc, label_nc, style_nc):
super().__init__()
assert config_text.startswith("spade")
parsed = re.search("spade(\D+)(\d)x\d(\D*)", config_text)
param_free_norm_type = str(parsed.group(1))
ks = int(parsed.group(2))
self.hasid = parsed.group(3) == "id"
if param_free_norm_type == "instance":
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
elif "batch" in param_free_norm_type:
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
else:
raise ValueError(
"%s is not a recognized param-free norm type in SPADE"
% param_free_norm_type
)
# The dimension of the intermediate embedding space. Yes, hardcoded.
nhidden = 128
pw = ks // 2
self.label_nc = label_nc
self.mlp_shared = nn.Sequential(
nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), nn.ReLU()
)
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
if self.hasid:
self.mlp_attention = nn.Sequential(
nn.Conv2d(norm_nc, 1, kernel_size=ks, padding=pw), nn.Sigmoid(),
)
self.adain = AdaptiveInstanceNorm(norm_nc, style_nc)
def forward(self, x, attr_map, id_emb):
# Part 1. generate parameter-free normalized activations
normalized = self.param_free_norm(x)
# Part 2. produce scaling and bias conditioned on semantic map
# segmap = F.interpolate(segmap, size=x.size()[2:], mode='bilinear')
actv = self.mlp_shared(attr_map)
gamma = self.mlp_gamma(actv)
beta = self.mlp_beta(actv)
# apply scale and bias
spade_out = normalized * (1 + gamma) + beta
if self.hasid:
attention = self.mlp_attention(x)
adain_out = self.adain(x, id_emb)
out = attention * spade_out + (1 - attention) * adain_out
else:
out = spade_out
return out
class SPADEResnetBlock(nn.Module):
def __init__(self, fin, fout, semantic_nc, style_nc, norm_G):
super().__init__()
# Attributes
self.learned_shortcut = fin != fout
fmiddle = min(fin, fout)
# create conv layers
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
if self.learned_shortcut:
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
# apply spectral norm if specified
if "spectral" in norm_G:
self.conv_0 = spectral_norm(self.conv_0)
self.conv_1 = spectral_norm(self.conv_1)
if self.learned_shortcut:
self.conv_s = spectral_norm(self.conv_s)
# define normalization layers
spade_config_str = norm_G.replace("spectral", "")
self.norm_0 = SPADE(spade_config_str, fin, semantic_nc, style_nc)
self.norm_1 = SPADE(spade_config_str, fmiddle, semantic_nc, style_nc)
if self.learned_shortcut:
self.norm_s = SPADE(spade_config_str, fin, semantic_nc, style_nc)
# note the resnet block with SPADE also takes in |seg|,
# the semantic segmentation map as input
def forward(self, x, seg, class_emb):
x_s = self.shortcut(x, seg, class_emb)
dx = self.conv_0(self.actvn(self.norm_0(x, seg, class_emb)))
dx = self.conv_1(self.actvn(self.norm_1(dx, seg, class_emb)))
out = x_s + dx
return out
def shortcut(self, x, seg, class_emb):
if self.learned_shortcut:
x_s = self.conv_s(self.norm_s(x, seg, class_emb))
else:
x_s = x
return x_s
def actvn(self, x):
return F.leaky_relu(x, 2e-1, inplace=True)
class SPADEGenerator(nn.Module):
def __init__(
self,
label_nc=256,
class_dim=256,
conv_dim=64,
norm_G="spectralspadebatch3x3",
):
super().__init__()
nf = conv_dim
self.nf = conv_dim
self.norm_G = norm_G
self.conv1 = spectral_norm(nn.ConvTranspose2d(class_dim, nf * 16, 4)) if "spectral" in norm_G else nn.ConvTranspose2d(class_dim, nf * 16, 4)
self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, label_nc[0], class_dim, norm_G)
self.G_middle_0 = SPADEResnetBlock(
16 * nf, 16 * nf, label_nc[1], class_dim, norm_G
)
self.G_middle_1 = SPADEResnetBlock(
16 * nf, 16 * nf, label_nc[2], class_dim, norm_G
)
self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, label_nc[3], class_dim, norm_G)
self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, label_nc[4], class_dim, norm_G)
self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, label_nc[5], class_dim, norm_G)
final_nc = nf
self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, label_nc[6], class_dim, norm_G)
self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)
def forward(self, attr_pyramid, class_emb=None):
if class_emb is None:
x = torch.randn(
(attr_pyramid[0].size(0), 256, 1, 1), device=attr_pyramid[0].device
)
else:
x = class_emb.view(class_emb.size(0), class_emb.size(1), 1, 1)
x = self.conv1(x)
style4 = F.interpolate(attr_pyramid[0], size=x.shape[2:], mode="bilinear")
x = self.head_0(x, style4, class_emb)
x = F.interpolate(x, scale_factor=2, mode="bilinear")
style8 = F.interpolate(attr_pyramid[0], size=x.shape[2:], mode="bilinear")
x = self.G_middle_0(x, style8, class_emb)
x = F.interpolate(x, scale_factor=2, mode="bilinear")
style16 = F.interpolate(attr_pyramid[0], size=x.shape[2:], mode="bilinear")
x = self.G_middle_1(x, style16, class_emb)
x = F.interpolate(x, scale_factor=2, mode="bilinear")
style32 = F.interpolate(attr_pyramid[0], size=x.shape[2:], mode="bilinear")
x = self.up_0(x, style32, class_emb)
x = F.interpolate(x, scale_factor=2, mode="bilinear")
style64 = F.interpolate(attr_pyramid[1], size=x.shape[2:], mode="bilinear")
x = self.up_1(x, style64, class_emb)
x = F.interpolate(x, scale_factor=2, mode="bilinear")
style128 = F.interpolate(attr_pyramid[2], size=x.shape[2:], mode="bilinear")
x = self.up_2(x, style128, class_emb)
x = F.interpolate(x, scale_factor=2, mode="bilinear")
style256 = F.interpolate(attr_pyramid[3], size=x.shape[2:], mode="bilinear")
x = self.up_3(x, style256, class_emb)
x = F.leaky_relu(x, 2e-1, inplace=True)
x = self.conv_img(x)
x = torch.tanh(x)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment