".github/vscode:/vscode.git/clone" did not exist on "6e4539405167da85107fd46fce1db6d874067fc4"
Commit 57463d8d authored by suily's avatar suily
Browse files

init

parents
Pipeline #1918 canceled with stages
import logging
import os
import torch
import torch.distributed as dist
from torch.nn import Module
from torch.nn.functional import normalize, linear
from torch.nn.parameter import Parameter
class PartialFC(Module):
"""
Author: {Xiang An, Yang Xiao, XuHan Zhu} in DeepGlint,
Partial FC: Training 10 Million Identities on a Single Machine
See the original paper:
https://arxiv.org/abs/2010.05222
"""
@torch.no_grad()
def __init__(self, rank, local_rank, world_size, batch_size, resume,
margin_softmax, num_classes, sample_rate=1.0, embedding_size=512, prefix="./"):
"""
rank: int
Unique process(GPU) ID from 0 to world_size - 1.
local_rank: int
Unique process(GPU) ID within the server from 0 to 7.
world_size: int
Number of GPU.
batch_size: int
Batch size on current rank(GPU).
resume: bool
Select whether to restore the weight of softmax.
margin_softmax: callable
A function of margin softmax, eg: cosface, arcface.
num_classes: int
The number of class center storage in current rank(CPU/GPU), usually is total_classes // world_size,
required.
sample_rate: float
The partial fc sampling rate, when the number of classes increases to more than 2 millions, Sampling
can greatly speed up training, and reduce a lot of GPU memory, default is 1.0.
embedding_size: int
The feature dimension, default is 512.
prefix: str
Path for save checkpoint, default is './'.
"""
super(PartialFC, self).__init__()
#
self.num_classes: int = num_classes
self.rank: int = rank
self.local_rank: int = local_rank
self.device: torch.device = torch.device("cuda:{}".format(self.local_rank))
self.world_size: int = world_size
self.batch_size: int = batch_size
self.margin_softmax: callable = margin_softmax
self.sample_rate: float = sample_rate
self.embedding_size: int = embedding_size
self.prefix: str = prefix
self.num_local: int = num_classes // world_size + int(rank < num_classes % world_size)
self.class_start: int = num_classes // world_size * rank + min(rank, num_classes % world_size)
self.num_sample: int = int(self.sample_rate * self.num_local)
self.weight_name = os.path.join(self.prefix, "rank_{}_softmax_weight.pt".format(self.rank))
self.weight_mom_name = os.path.join(self.prefix, "rank_{}_softmax_weight_mom.pt".format(self.rank))
if resume:
try:
self.weight: torch.Tensor = torch.load(self.weight_name)
self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name)
if self.weight.shape[0] != self.num_local or self.weight_mom.shape[0] != self.num_local:
raise IndexError
logging.info("softmax weight resume successfully!")
logging.info("softmax weight mom resume successfully!")
except (FileNotFoundError, KeyError, IndexError):
self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
logging.info("softmax weight init!")
logging.info("softmax weight mom init!")
else:
self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
logging.info("softmax weight init successfully!")
logging.info("softmax weight mom init successfully!")
self.stream: torch.cuda.Stream = torch.cuda.Stream(local_rank)
self.index = None
if int(self.sample_rate) == 1:
self.update = lambda: 0
self.sub_weight = Parameter(self.weight)
self.sub_weight_mom = self.weight_mom
else:
self.sub_weight = Parameter(torch.empty((0, 0)).cuda(local_rank))
def save_params(self):
""" Save softmax weight for each rank on prefix
"""
torch.save(self.weight.data, self.weight_name)
torch.save(self.weight_mom, self.weight_mom_name)
@torch.no_grad()
def sample(self, total_label):
"""
Sample all positive class centers in each rank, and random select neg class centers to filling a fixed
`num_sample`.
total_label: tensor
Label after all gather, which cross all GPUs.
"""
index_positive = (self.class_start <= total_label) & (total_label < self.class_start + self.num_local)
total_label[~index_positive] = -1
total_label[index_positive] -= self.class_start
if int(self.sample_rate) != 1:
positive = torch.unique(total_label[index_positive], sorted=True)
if self.num_sample - positive.size(0) >= 0:
perm = torch.rand(size=[self.num_local], device=self.device)
perm[positive] = 2.0
index = torch.topk(perm, k=self.num_sample)[1]
index = index.sort()[0]
else:
index = positive
self.index = index
total_label[index_positive] = torch.searchsorted(index, total_label[index_positive])
self.sub_weight = Parameter(self.weight[index])
self.sub_weight_mom = self.weight_mom[index]
def forward(self, total_features, norm_weight):
""" Partial fc forward, `logits = X * sample(W)`
"""
torch.cuda.current_stream().wait_stream(self.stream)
logits = linear(total_features, norm_weight)
return logits
@torch.no_grad()
def update(self):
""" Set updated weight and weight_mom to memory bank.
"""
self.weight_mom[self.index] = self.sub_weight_mom
self.weight[self.index] = self.sub_weight
def prepare(self, label, optimizer):
"""
get sampled class centers for cal softmax.
label: tensor
Label tensor on each rank.
optimizer: opt
Optimizer for partial fc, which need to get weight mom.
"""
with torch.cuda.stream(self.stream):
total_label = torch.zeros(
size=[self.batch_size * self.world_size], device=self.device, dtype=torch.long)
dist.all_gather(list(total_label.chunk(self.world_size, dim=0)), label)
self.sample(total_label)
optimizer.state.pop(optimizer.param_groups[-1]['params'][0], None)
optimizer.param_groups[-1]['params'][0] = self.sub_weight
optimizer.state[self.sub_weight]['momentum_buffer'] = self.sub_weight_mom
norm_weight = normalize(self.sub_weight)
return total_label, norm_weight
def forward_backward(self, label, features, optimizer):
"""
Partial fc forward and backward with model parallel
label: tensor
Label tensor on each rank(GPU)
features: tensor
Features tensor on each rank(GPU)
optimizer: optimizer
Optimizer for partial fc
Returns:
--------
x_grad: tensor
The gradient of features.
loss_v: tensor
Loss value for cross entropy.
"""
total_label, norm_weight = self.prepare(label, optimizer)
total_features = torch.zeros(
size=[self.batch_size * self.world_size, self.embedding_size], device=self.device)
dist.all_gather(list(total_features.chunk(self.world_size, dim=0)), features.data)
total_features.requires_grad = True
logits = self.forward(total_features, norm_weight)
logits = self.margin_softmax(logits, total_label)
with torch.no_grad():
max_fc = torch.max(logits, dim=1, keepdim=True)[0]
dist.all_reduce(max_fc, dist.ReduceOp.MAX)
# calculate exp(logits) and all-reduce
logits_exp = torch.exp(logits - max_fc)
logits_sum_exp = logits_exp.sum(dim=1, keepdims=True)
dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)
# calculate prob
logits_exp.div_(logits_sum_exp)
# get one-hot
grad = logits_exp
index = torch.where(total_label != -1)[0]
one_hot = torch.zeros(size=[index.size()[0], grad.size()[1]], device=grad.device)
one_hot.scatter_(1, total_label[index, None], 1)
# calculate loss
loss = torch.zeros(grad.size()[0], 1, device=grad.device)
loss[index] = grad[index].gather(1, total_label[index, None])
dist.all_reduce(loss, dist.ReduceOp.SUM)
loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)
# calculate grad
grad[index] -= one_hot
grad.div_(self.batch_size * self.world_size)
logits.backward(grad)
if total_features.grad is not None:
total_features.grad.detach_()
x_grad: torch.Tensor = torch.zeros_like(features, requires_grad=True)
# feature gradient all-reduce
dist.reduce_scatter(x_grad, list(total_features.grad.chunk(self.world_size, dim=0)))
x_grad = x_grad * self.world_size
# backward backbone
return x_grad, loss_v
tensorboard
easydict
mxnet
onnx
sklearn
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
ps -ef | grep "train" | grep -v grep | awk '{print "kill -9 "$2}' | sh
import numpy as np
import onnx
import torch
def convert_onnx(net, path_module, output, opset=11, simplify=False):
assert isinstance(net, torch.nn.Module)
img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32)
img = img.astype(np.float)
img = (img / 255. - 0.5) / 0.5 # torch style norm
img = img.transpose((2, 0, 1))
img = torch.from_numpy(img).unsqueeze(0).float()
weight = torch.load(path_module)
net.load_state_dict(weight)
net.eval()
torch.onnx.export(net, img, output, keep_initializers_as_inputs=False, verbose=False, opset_version=opset)
model = onnx.load(output)
graph = model.graph
graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None'
if simplify:
from onnxsim import simplify
model, check = simplify(model)
assert check, "Simplified ONNX model could not be validated"
onnx.save(model, output)
if __name__ == '__main__':
import os
import argparse
from backbones import get_model
parser = argparse.ArgumentParser(description='ArcFace PyTorch to onnx')
parser.add_argument('input', type=str, help='input backbone.pth file or path')
parser.add_argument('--output', type=str, default=None, help='output onnx path')
parser.add_argument('--network', type=str, default=None, help='backbone network')
parser.add_argument('--simplify', type=bool, default=False, help='onnx simplify')
args = parser.parse_args()
input_file = args.input
if os.path.isdir(input_file):
input_file = os.path.join(input_file, "backbone.pth")
assert os.path.exists(input_file)
model_name = os.path.basename(os.path.dirname(input_file)).lower()
params = model_name.split("_")
if len(params) >= 3 and params[1] in ('arcface', 'cosface'):
if args.network is None:
args.network = params[2]
assert args.network is not None
print(args)
backbone_onnx = get_model(args.network, dropout=0)
output_path = args.output
if output_path is None:
output_path = os.path.join(os.path.dirname(__file__), 'onnx')
if not os.path.exists(output_path):
os.makedirs(output_path)
assert os.path.isdir(output_path)
output_file = os.path.join(output_path, "%s.onnx" % model_name)
convert_onnx(backbone_onnx, input_file, output_file, simplify=args.simplify)
import argparse
import logging
import os
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torch.utils.data.distributed
from torch.nn.utils import clip_grad_norm_
import losses
from backbones import get_model
from dataset import MXFaceDataset, SyntheticDataset, DataLoaderX
from partial_fc import PartialFC
from utils.utils_amp import MaxClipGradScaler
from utils.utils_callbacks import CallBackVerification, CallBackLogging, CallBackModelCheckpoint
from utils.utils_config import get_config
from utils.utils_logging import AverageMeter, init_logging
def main(args):
cfg = get_config(args.config)
try:
world_size = int(os.environ['WORLD_SIZE'])
rank = int(os.environ['RANK'])
dist.init_process_group('nccl')
except KeyError:
world_size = 1
rank = 0
dist.init_process_group(backend='nccl', init_method="tcp://127.0.0.1:12584", rank=rank, world_size=world_size)
local_rank = args.local_rank
torch.cuda.set_device(local_rank)
os.makedirs(cfg.output, exist_ok=True)
init_logging(rank, cfg.output)
if cfg.rec == "synthetic":
train_set = SyntheticDataset(local_rank=local_rank)
else:
train_set = MXFaceDataset(root_dir=cfg.rec, local_rank=local_rank)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True)
train_loader = DataLoaderX(
local_rank=local_rank, dataset=train_set, batch_size=cfg.batch_size,
sampler=train_sampler, num_workers=2, pin_memory=True, drop_last=True)
backbone = get_model(cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).to(local_rank)
if cfg.resume:
try:
backbone_pth = os.path.join(cfg.output, "backbone.pth")
backbone.load_state_dict(torch.load(backbone_pth, map_location=torch.device(local_rank)))
if rank == 0:
logging.info("backbone resume successfully!")
except (FileNotFoundError, KeyError, IndexError, RuntimeError):
if rank == 0:
logging.info("resume fail, backbone init successfully!")
backbone = torch.nn.parallel.DistributedDataParallel(
module=backbone, broadcast_buffers=False, device_ids=[local_rank])
backbone.train()
margin_softmax = losses.get_loss(cfg.loss)
module_partial_fc = PartialFC(
rank=rank, local_rank=local_rank, world_size=world_size, resume=cfg.resume,
batch_size=cfg.batch_size, margin_softmax=margin_softmax, num_classes=cfg.num_classes,
sample_rate=cfg.sample_rate, embedding_size=cfg.embedding_size, prefix=cfg.output)
opt_backbone = torch.optim.SGD(
params=[{'params': backbone.parameters()}],
lr=cfg.lr / 512 * cfg.batch_size * world_size,
momentum=0.9, weight_decay=cfg.weight_decay)
opt_pfc = torch.optim.SGD(
params=[{'params': module_partial_fc.parameters()}],
lr=cfg.lr / 512 * cfg.batch_size * world_size,
momentum=0.9, weight_decay=cfg.weight_decay)
num_image = len(train_set)
total_batch_size = cfg.batch_size * world_size
cfg.warmup_step = num_image // total_batch_size * cfg.warmup_epoch
cfg.total_step = num_image // total_batch_size * cfg.num_epoch
def lr_step_func(current_step):
cfg.decay_step = [x * num_image // total_batch_size for x in cfg.decay_epoch]
if current_step < cfg.warmup_step:
return current_step / cfg.warmup_step
else:
return 0.1 ** len([m for m in cfg.decay_step if m <= current_step])
scheduler_backbone = torch.optim.lr_scheduler.LambdaLR(
optimizer=opt_backbone, lr_lambda=lr_step_func)
scheduler_pfc = torch.optim.lr_scheduler.LambdaLR(
optimizer=opt_pfc, lr_lambda=lr_step_func)
for key, value in cfg.items():
num_space = 25 - len(key)
logging.info(": " + key + " " * num_space + str(value))
val_target = cfg.val_targets
callback_verification = CallBackVerification(2000, rank, val_target, cfg.rec)
callback_logging = CallBackLogging(50, rank, cfg.total_step, cfg.batch_size, world_size, None)
callback_checkpoint = CallBackModelCheckpoint(rank, cfg.output)
loss = AverageMeter()
start_epoch = 0
global_step = 0
grad_amp = MaxClipGradScaler(cfg.batch_size, 128 * cfg.batch_size, growth_interval=100) if cfg.fp16 else None
for epoch in range(start_epoch, cfg.num_epoch):
train_sampler.set_epoch(epoch)
for step, (img, label) in enumerate(train_loader):
global_step += 1
features = F.normalize(backbone(img))
x_grad, loss_v = module_partial_fc.forward_backward(label, features, opt_pfc)
if cfg.fp16:
features.backward(grad_amp.scale(x_grad))
grad_amp.unscale_(opt_backbone)
clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)
grad_amp.step(opt_backbone)
grad_amp.update()
else:
features.backward(x_grad)
clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)
opt_backbone.step()
opt_pfc.step()
module_partial_fc.update()
opt_backbone.zero_grad()
opt_pfc.zero_grad()
loss.update(loss_v, 1)
callback_logging(global_step, loss, epoch, cfg.fp16, scheduler_backbone.get_last_lr()[0], grad_amp)
callback_verification(global_step, backbone)
scheduler_backbone.step()
scheduler_pfc.step()
callback_checkpoint(global_step, backbone, module_partial_fc)
dist.destroy_process_group()
if __name__ == "__main__":
torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser(description='PyTorch ArcFace Training')
parser.add_argument('config', type=str, help='py config file')
parser.add_argument('--local_rank', type=int, default=0, help='local_rank')
main(parser.parse_args())
# coding: utf-8
import os
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
from prettytable import PrettyTable
from sklearn.metrics import roc_curve, auc
image_path = "/data/anxiang/IJB_release/IJBC"
files = [
"./ms1mv3_arcface_r100/ms1mv3_arcface_r100/ijbc.npy"
]
def read_template_pair_list(path):
pairs = pd.read_csv(path, sep=' ', header=None).values
t1 = pairs[:, 0].astype(np.int)
t2 = pairs[:, 1].astype(np.int)
label = pairs[:, 2].astype(np.int)
return t1, t2, label
p1, p2, label = read_template_pair_list(
os.path.join('%s/meta' % image_path,
'%s_template_pair_label.txt' % 'ijbc'))
methods = []
scores = []
for file in files:
methods.append(file.split('/')[-2])
scores.append(np.load(file))
methods = np.array(methods)
scores = dict(zip(methods, scores))
colours = dict(
zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2')))
x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]
tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels])
fig = plt.figure()
for method in methods:
fpr, tpr, _ = roc_curve(label, scores[method])
roc_auc = auc(fpr, tpr)
fpr = np.flipud(fpr)
tpr = np.flipud(tpr) # select largest tpr at same fpr
plt.plot(fpr,
tpr,
color=colours[method],
lw=1,
label=('[%s (AUC = %0.4f %%)]' %
(method.split('-')[-1], roc_auc * 100)))
tpr_fpr_row = []
tpr_fpr_row.append("%s-%s" % (method, "IJBC"))
for fpr_iter in np.arange(len(x_labels)):
_, min_index = min(
list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
tpr_fpr_table.add_row(tpr_fpr_row)
plt.xlim([10 ** -6, 0.1])
plt.ylim([0.3, 1.0])
plt.grid(linestyle='--', linewidth=1)
plt.xticks(x_labels)
plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
plt.xscale('log')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC on IJB')
plt.legend(loc="lower right")
print(tpr_fpr_table)
from typing import Dict, List
import torch
if torch.__version__ < '1.9':
Iterable = torch._six.container_abcs.Iterable
else:
import collections
Iterable = collections.abc.Iterable
from torch.cuda.amp import GradScaler
class _MultiDeviceReplicator(object):
"""
Lazily serves copies of a tensor to requested devices. Copies are cached per-device.
"""
def __init__(self, master_tensor: torch.Tensor) -> None:
assert master_tensor.is_cuda
self.master = master_tensor
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
def get(self, device) -> torch.Tensor:
retval = self._per_device_tensors.get(device, None)
if retval is None:
retval = self.master.to(device=device, non_blocking=True, copy=True)
self._per_device_tensors[device] = retval
return retval
class MaxClipGradScaler(GradScaler):
def __init__(self, init_scale, max_scale: float, growth_interval=100):
GradScaler.__init__(self, init_scale=init_scale, growth_interval=growth_interval)
self.max_scale = max_scale
def scale_clip(self):
if self.get_scale() == self.max_scale:
self.set_growth_factor(1)
elif self.get_scale() < self.max_scale:
self.set_growth_factor(2)
elif self.get_scale() > self.max_scale:
self._scale.fill_(self.max_scale)
self.set_growth_factor(1)
def scale(self, outputs):
"""
Multiplies ('scales') a tensor or list of tensors by the scale factor.
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
unmodified.
Arguments:
outputs (Tensor or iterable of Tensors): Outputs to scale.
"""
if not self._enabled:
return outputs
self.scale_clip()
# Short-circuit for the common case.
if isinstance(outputs, torch.Tensor):
assert outputs.is_cuda
if self._scale is None:
self._lazy_init_scale_growth_tracker(outputs.device)
assert self._scale is not None
return outputs * self._scale.to(device=outputs.device, non_blocking=True)
# Invoke the more complex machinery only if we're treating multiple outputs.
stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale
def apply_scale(val):
if isinstance(val, torch.Tensor):
assert val.is_cuda
if len(stash) == 0:
if self._scale is None:
self._lazy_init_scale_growth_tracker(val.device)
assert self._scale is not None
stash.append(_MultiDeviceReplicator(self._scale))
return val * stash[0].get(val.device)
elif isinstance(val, Iterable):
iterable = map(apply_scale, val)
if isinstance(val, list) or isinstance(val, tuple):
return type(val)(iterable)
else:
return iterable
else:
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
return apply_scale(outputs)
import logging
import os
import time
from typing import List
import torch
from eval import verification
from utils.utils_logging import AverageMeter
class CallBackVerification(object):
def __init__(self, frequent, rank, val_targets, rec_prefix, image_size=(112, 112)):
self.frequent: int = frequent
self.rank: int = rank
self.highest_acc: float = 0.0
self.highest_acc_list: List[float] = [0.0] * len(val_targets)
self.ver_list: List[object] = []
self.ver_name_list: List[str] = []
if self.rank is 0:
self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size)
def ver_test(self, backbone: torch.nn.Module, global_step: int):
results = []
for i in range(len(self.ver_list)):
acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
self.ver_list[i], backbone, 10, 10)
logging.info('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm))
logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2))
if acc2 > self.highest_acc_list[i]:
self.highest_acc_list[i] = acc2
logging.info(
'[%s][%d]Accuracy-Highest: %1.5f' % (self.ver_name_list[i], global_step, self.highest_acc_list[i]))
results.append(acc2)
def init_dataset(self, val_targets, data_dir, image_size):
for name in val_targets:
path = os.path.join(data_dir, name + ".bin")
if os.path.exists(path):
data_set = verification.load_bin(path, image_size)
self.ver_list.append(data_set)
self.ver_name_list.append(name)
def __call__(self, num_update, backbone: torch.nn.Module):
if self.rank is 0 and num_update > 0 and num_update % self.frequent == 0:
backbone.eval()
self.ver_test(backbone, num_update)
backbone.train()
class CallBackLogging(object):
def __init__(self, frequent, rank, total_step, batch_size, world_size, writer=None):
self.frequent: int = frequent
self.rank: int = rank
self.time_start = time.time()
self.total_step: int = total_step
self.batch_size: int = batch_size
self.world_size: int = world_size
self.writer = writer
self.init = False
self.tic = 0
def __call__(self,
global_step: int,
loss: AverageMeter,
epoch: int,
fp16: bool,
learning_rate: float,
grad_scaler: torch.cuda.amp.GradScaler):
if self.rank == 0 and global_step > 0 and global_step % self.frequent == 0:
if self.init:
try:
speed: float = self.frequent * self.batch_size / (time.time() - self.tic)
speed_total = speed * self.world_size
except ZeroDivisionError:
speed_total = float('inf')
time_now = (time.time() - self.time_start) / 3600
time_total = time_now / ((global_step + 1) / self.total_step)
time_for_end = time_total - time_now
if self.writer is not None:
self.writer.add_scalar('time_for_end', time_for_end, global_step)
self.writer.add_scalar('learning_rate', learning_rate, global_step)
self.writer.add_scalar('loss', loss.avg, global_step)
if fp16:
msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.4f Epoch: %d Global Step: %d " \
"Fp16 Grad Scale: %2.f Required: %1.f hours" % (
speed_total, loss.avg, learning_rate, epoch, global_step,
grad_scaler.get_scale(), time_for_end
)
else:
msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.4f Epoch: %d Global Step: %d " \
"Required: %1.f hours" % (
speed_total, loss.avg, learning_rate, epoch, global_step, time_for_end
)
logging.info(msg)
loss.reset()
self.tic = time.time()
else:
self.init = True
self.tic = time.time()
class CallBackModelCheckpoint(object):
def __init__(self, rank, output="./"):
self.rank: int = rank
self.output: str = output
def __call__(self, global_step, backbone, partial_fc, ):
if global_step > 100 and self.rank == 0:
path_module = os.path.join(self.output, "backbone.pth")
torch.save(backbone.module.state_dict(), path_module)
logging.info("Pytorch Model Saved in '{}'".format(path_module))
if global_step > 100 and partial_fc is not None:
partial_fc.save_params()
import importlib
import os.path as osp
def get_config(config_file):
assert config_file.startswith('configs/'), 'config file setting must start with configs/'
temp_config_name = osp.basename(config_file)
temp_module_name = osp.splitext(temp_config_name)[0]
config = importlib.import_module("configs.base")
cfg = config.config
config = importlib.import_module("configs.%s" % temp_module_name)
job_cfg = config.config
cfg.update(job_cfg)
if cfg.output is None:
cfg.output = osp.join('work_dirs', temp_module_name)
return cfg
\ No newline at end of file
import logging
import os
import sys
class AverageMeter(object):
"""Computes and stores the average and current value
"""
def __init__(self):
self.val = None
self.avg = None
self.sum = None
self.count = None
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def init_logging(rank, models_root):
if rank == 0:
log_root = logging.getLogger()
log_root.setLevel(logging.INFO)
formatter = logging.Formatter("Training: %(asctime)s-%(message)s")
handler_file = logging.FileHandler(os.path.join(models_root, "training.log"))
handler_stream = logging.StreamHandler(sys.stdout)
handler_file.setFormatter(formatter)
handler_stream.setFormatter(formatter)
log_root.addHandler(handler_file)
log_root.addHandler(handler_stream)
log_root.info('rank_id: %d' % rank)
"""This script defines the base network model for Deep3DFaceRecon_pytorch
"""
import os
import numpy as np
import torch
from collections import OrderedDict
from abc import ABC, abstractmethod
from . import networks
class BaseModel(ABC):
"""This class is an abstract base class (ABC) for models.
To create a subclass, you need to implement the following five functions:
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
-- <set_input>: unpack data from dataset and apply preprocessing.
-- <forward>: produce intermediate results.
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
"""
def __init__(self, opt):
"""Initialize the BaseModel class.
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
When creating your custom class, you need to implement your own initialization.
In this fucntion, you should first call <BaseModel.__init__(self, opt)>
Then, you need to define four lists:
-- self.loss_names (str list): specify the training losses that you want to plot and save.
-- self.model_names (str list): specify the images that you want to display and save.
-- self.visual_names (str list): define networks used in our training.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
"""
self.opt = opt
self.isTrain = False
self.device = torch.device('cpu')
self.save_dir = " " # os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
self.loss_names = []
self.model_names = []
self.visual_names = []
self.parallel_names = []
self.optimizers = []
self.image_paths = []
self.metric = 0 # used for learning rate policy 'plateau'
@staticmethod
def dict_grad_hook_factory(add_func=lambda x: x):
saved_dict = dict()
def hook_gen(name):
def grad_hook(grad):
saved_vals = add_func(grad)
saved_dict[name] = saved_vals
return grad_hook
return hook_gen, saved_dict
@staticmethod
def modify_commandline_options(parser, is_train):
"""Add new model-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
return parser
@abstractmethod
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): includes the data itself and its metadata information.
"""
pass
@abstractmethod
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
pass
@abstractmethod
def optimize_parameters(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
pass
def setup(self, opt):
"""Load and print networks; create schedulers
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
if self.isTrain:
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
if not self.isTrain or opt.continue_train:
load_suffix = opt.epoch
self.load_networks(load_suffix)
# self.print_networks(opt.verbose)
def parallelize(self, convert_sync_batchnorm=True):
if not self.opt.use_ddp:
for name in self.parallel_names:
if isinstance(name, str):
module = getattr(self, name)
setattr(self, name, module.to(self.device))
else:
for name in self.model_names:
if isinstance(name, str):
module = getattr(self, name)
if convert_sync_batchnorm:
module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module)
setattr(self, name, torch.nn.parallel.DistributedDataParallel(module.to(self.device),
device_ids=[self.device.index],
find_unused_parameters=True, broadcast_buffers=True))
# DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient.
for name in self.parallel_names:
if isinstance(name, str) and name not in self.model_names:
module = getattr(self, name)
setattr(self, name, module.to(self.device))
# put state_dict of optimizer to gpu device
if self.opt.phase != 'test':
if self.opt.continue_train:
for optim in self.optimizers:
for state in optim.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(self.device)
def data_dependent_initialize(self, data):
pass
def train(self):
"""Make models train mode"""
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, name)
net.train()
def eval(self):
"""Make models eval mode"""
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, name)
net.eval()
def test(self):
"""Forward function used in test time.
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
It also calls <compute_visuals> to produce additional visualization results
"""
with torch.no_grad():
self.forward()
self.compute_visuals()
def compute_visuals(self):
"""Calculate additional output images for visdom and HTML visualization"""
pass
def get_image_paths(self, name='A'):
""" Return image paths that are used to load current data"""
return self.image_paths if name =='A' else self.image_paths_B
def update_learning_rate(self):
"""Update learning rates for all the networks; called at the end of every epoch"""
for scheduler in self.schedulers:
if self.opt.lr_policy == 'plateau':
scheduler.step(self.metric)
else:
scheduler.step()
lr = self.optimizers[0].param_groups[0]['lr']
print('learning rate = %.7f' % lr)
def get_current_visuals(self):
"""Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
visual_ret = OrderedDict()
for name in self.visual_names:
if isinstance(name, str):
visual_ret[name] = getattr(self, name)[:, :3, ...]
return visual_ret
def get_current_losses(self):
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
errors_ret = OrderedDict()
for name in self.loss_names:
if isinstance(name, str):
errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
return errors_ret
def save_networks(self, epoch):
"""Save all the networks to the disk.
Parameters:
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
"""
if not os.path.isdir(self.save_dir):
os.makedirs(self.save_dir)
save_filename = 'epoch_%s.pth' % (epoch)
save_path = os.path.join(self.save_dir, save_filename)
save_dict = {}
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, name)
if isinstance(net, torch.nn.DataParallel) or isinstance(net,
torch.nn.parallel.DistributedDataParallel):
net = net.module
save_dict[name] = net.state_dict()
for i, optim in enumerate(self.optimizers):
save_dict['opt_%02d'%i] = optim.state_dict()
for i, sched in enumerate(self.schedulers):
save_dict['sched_%02d'%i] = sched.state_dict()
torch.save(save_dict, save_path)
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
key = keys[i]
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'running_mean' or key == 'running_var'):
if getattr(module, key) is None:
state_dict.pop('.'.join(keys))
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'num_batches_tracked'):
state_dict.pop('.'.join(keys))
else:
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
def load_networks(self, epoch):
"""Load all the networks from the disk.
Parameters:
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
"""
if self.opt.isTrain and self.opt.pretrained_name is not None:
load_dir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name)
else:
load_dir = self.save_dir
load_filename = 'epoch_%s.pth' % (epoch)
load_path = os.path.join(load_dir, load_filename)
state_dict = torch.load(load_path, map_location=self.device)
print('loading the model from %s' % load_path)
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, name)
if isinstance(net, torch.nn.DataParallel):
net = net.module
net.load_state_dict(state_dict[name])
if self.opt.phase != 'test':
if self.opt.continue_train:
print('loading the optim from %s' % load_path)
for i, optim in enumerate(self.optimizers):
optim.load_state_dict(state_dict['opt_%02d'%i])
try:
print('loading the sched from %s' % load_path)
for i, sched in enumerate(self.schedulers):
sched.load_state_dict(state_dict['sched_%02d'%i])
except:
print('Failed to load schedulers, set schedulers according to epoch count manually')
for i, sched in enumerate(self.schedulers):
sched.last_epoch = self.opt.epoch_count - 1
def print_networks(self, verbose):
"""Print the total number of parameters in the network and (if verbose) network architecture
Parameters:
verbose (bool) -- if verbose: print the network architecture
"""
print('---------- Networks initialized -------------')
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, name)
num_params = 0
for param in net.parameters():
num_params += param.numel()
if verbose:
print(net)
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
print('-----------------------------------------------')
def set_requires_grad(self, nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Parameters:
nets (network list) -- a list of networks
requires_grad (bool) -- whether the networks require gradients or not
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
def generate_visuals_for_evaluation(self, data, mode):
return {}
"""This script defines the parametric 3d face model for Deep3DFaceRecon_pytorch
"""
import numpy as np
import torch
import torch.nn.functional as F
from scipy.io import loadmat
from src.face3d.util.load_mats import transferBFM09
import os
def perspective_projection(focal, center):
# return p.T (N, 3) @ (3, 3)
return np.array([
focal, 0, center,
0, focal, center,
0, 0, 1
]).reshape([3, 3]).astype(np.float32).transpose()
class SH:
def __init__(self):
self.a = [np.pi, 2 * np.pi / np.sqrt(3.), 2 * np.pi / np.sqrt(8.)]
self.c = [1/np.sqrt(4 * np.pi), np.sqrt(3.) / np.sqrt(4 * np.pi), 3 * np.sqrt(5.) / np.sqrt(12 * np.pi)]
class ParametricFaceModel:
def __init__(self,
bfm_folder='./BFM',
recenter=True,
camera_distance=10.,
init_lit=np.array([
0.8, 0, 0, 0, 0, 0, 0, 0, 0
]),
focal=1015.,
center=112.,
is_train=True,
default_name='BFM_model_front.mat'):
if not os.path.isfile(os.path.join(bfm_folder, default_name)):
transferBFM09(bfm_folder)
model = loadmat(os.path.join(bfm_folder, default_name))
# mean face shape. [3*N,1]
self.mean_shape = model['meanshape'].astype(np.float32)
# identity basis. [3*N,80]
self.id_base = model['idBase'].astype(np.float32)
# expression basis. [3*N,64]
self.exp_base = model['exBase'].astype(np.float32)
# mean face texture. [3*N,1] (0-255)
self.mean_tex = model['meantex'].astype(np.float32)
# texture basis. [3*N,80]
self.tex_base = model['texBase'].astype(np.float32)
# face indices for each vertex that lies in. starts from 0. [N,8]
self.point_buf = model['point_buf'].astype(np.int64) - 1
# vertex indices for each face. starts from 0. [F,3]
self.face_buf = model['tri'].astype(np.int64) - 1
# vertex indices for 68 landmarks. starts from 0. [68,1]
self.keypoints = np.squeeze(model['keypoints']).astype(np.int64) - 1
if is_train:
# vertex indices for small face region to compute photometric error. starts from 0.
self.front_mask = np.squeeze(model['frontmask2_idx']).astype(np.int64) - 1
# vertex indices for each face from small face region. starts from 0. [f,3]
self.front_face_buf = model['tri_mask2'].astype(np.int64) - 1
# vertex indices for pre-defined skin region to compute reflectance loss
self.skin_mask = np.squeeze(model['skinmask'])
if recenter:
mean_shape = self.mean_shape.reshape([-1, 3])
mean_shape = mean_shape - np.mean(mean_shape, axis=0, keepdims=True)
self.mean_shape = mean_shape.reshape([-1, 1])
self.persc_proj = perspective_projection(focal, center)
self.device = 'cpu'
self.camera_distance = camera_distance
self.SH = SH()
self.init_lit = init_lit.reshape([1, 1, -1]).astype(np.float32)
def to(self, device):
self.device = device
for key, value in self.__dict__.items():
if type(value).__module__ == np.__name__:
setattr(self, key, torch.tensor(value).to(device))
def compute_shape(self, id_coeff, exp_coeff):
"""
Return:
face_shape -- torch.tensor, size (B, N, 3)
Parameters:
id_coeff -- torch.tensor, size (B, 80), identity coeffs
exp_coeff -- torch.tensor, size (B, 64), expression coeffs
"""
batch_size = id_coeff.shape[0]
id_part = torch.einsum('ij,aj->ai', self.id_base, id_coeff)
exp_part = torch.einsum('ij,aj->ai', self.exp_base, exp_coeff)
face_shape = id_part + exp_part + self.mean_shape.reshape([1, -1])
return face_shape.reshape([batch_size, -1, 3])
def compute_texture(self, tex_coeff, normalize=True):
"""
Return:
face_texture -- torch.tensor, size (B, N, 3), in RGB order, range (0, 1.)
Parameters:
tex_coeff -- torch.tensor, size (B, 80)
"""
batch_size = tex_coeff.shape[0]
face_texture = torch.einsum('ij,aj->ai', self.tex_base, tex_coeff) + self.mean_tex
if normalize:
face_texture = face_texture / 255.
return face_texture.reshape([batch_size, -1, 3])
def compute_norm(self, face_shape):
"""
Return:
vertex_norm -- torch.tensor, size (B, N, 3)
Parameters:
face_shape -- torch.tensor, size (B, N, 3)
"""
v1 = face_shape[:, self.face_buf[:, 0]]
v2 = face_shape[:, self.face_buf[:, 1]]
v3 = face_shape[:, self.face_buf[:, 2]]
e1 = v1 - v2
e2 = v2 - v3
face_norm = torch.cross(e1, e2, dim=-1)
face_norm = F.normalize(face_norm, dim=-1, p=2)
face_norm = torch.cat([face_norm, torch.zeros(face_norm.shape[0], 1, 3).to(self.device)], dim=1)
vertex_norm = torch.sum(face_norm[:, self.point_buf], dim=2)
vertex_norm = F.normalize(vertex_norm, dim=-1, p=2)
return vertex_norm
def compute_color(self, face_texture, face_norm, gamma):
"""
Return:
face_color -- torch.tensor, size (B, N, 3), range (0, 1.)
Parameters:
face_texture -- torch.tensor, size (B, N, 3), from texture model, range (0, 1.)
face_norm -- torch.tensor, size (B, N, 3), rotated face normal
gamma -- torch.tensor, size (B, 27), SH coeffs
"""
batch_size = gamma.shape[0]
v_num = face_texture.shape[1]
a, c = self.SH.a, self.SH.c
gamma = gamma.reshape([batch_size, 3, 9])
gamma = gamma + self.init_lit
gamma = gamma.permute(0, 2, 1)
Y = torch.cat([
a[0] * c[0] * torch.ones_like(face_norm[..., :1]).to(self.device),
-a[1] * c[1] * face_norm[..., 1:2],
a[1] * c[1] * face_norm[..., 2:],
-a[1] * c[1] * face_norm[..., :1],
a[2] * c[2] * face_norm[..., :1] * face_norm[..., 1:2],
-a[2] * c[2] * face_norm[..., 1:2] * face_norm[..., 2:],
0.5 * a[2] * c[2] / np.sqrt(3.) * (3 * face_norm[..., 2:] ** 2 - 1),
-a[2] * c[2] * face_norm[..., :1] * face_norm[..., 2:],
0.5 * a[2] * c[2] * (face_norm[..., :1] ** 2 - face_norm[..., 1:2] ** 2)
], dim=-1)
r = Y @ gamma[..., :1]
g = Y @ gamma[..., 1:2]
b = Y @ gamma[..., 2:]
face_color = torch.cat([r, g, b], dim=-1) * face_texture
return face_color
def compute_rotation(self, angles):
"""
Return:
rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat
Parameters:
angles -- torch.tensor, size (B, 3), radian
"""
batch_size = angles.shape[0]
ones = torch.ones([batch_size, 1]).to(self.device)
zeros = torch.zeros([batch_size, 1]).to(self.device)
x, y, z = angles[:, :1], angles[:, 1:2], angles[:, 2:],
rot_x = torch.cat([
ones, zeros, zeros,
zeros, torch.cos(x), -torch.sin(x),
zeros, torch.sin(x), torch.cos(x)
], dim=1).reshape([batch_size, 3, 3])
rot_y = torch.cat([
torch.cos(y), zeros, torch.sin(y),
zeros, ones, zeros,
-torch.sin(y), zeros, torch.cos(y)
], dim=1).reshape([batch_size, 3, 3])
rot_z = torch.cat([
torch.cos(z), -torch.sin(z), zeros,
torch.sin(z), torch.cos(z), zeros,
zeros, zeros, ones
], dim=1).reshape([batch_size, 3, 3])
rot = rot_z @ rot_y @ rot_x
return rot.permute(0, 2, 1)
def to_camera(self, face_shape):
face_shape[..., -1] = self.camera_distance - face_shape[..., -1]
return face_shape
def to_image(self, face_shape):
"""
Return:
face_proj -- torch.tensor, size (B, N, 2), y direction is opposite to v direction
Parameters:
face_shape -- torch.tensor, size (B, N, 3)
"""
# to image_plane
face_proj = face_shape @ self.persc_proj
face_proj = face_proj[..., :2] / face_proj[..., 2:]
return face_proj
def transform(self, face_shape, rot, trans):
"""
Return:
face_shape -- torch.tensor, size (B, N, 3) pts @ rot + trans
Parameters:
face_shape -- torch.tensor, size (B, N, 3)
rot -- torch.tensor, size (B, 3, 3)
trans -- torch.tensor, size (B, 3)
"""
return face_shape @ rot + trans.unsqueeze(1)
def get_landmarks(self, face_proj):
"""
Return:
face_lms -- torch.tensor, size (B, 68, 2)
Parameters:
face_proj -- torch.tensor, size (B, N, 2)
"""
return face_proj[:, self.keypoints]
def split_coeff(self, coeffs):
"""
Return:
coeffs_dict -- a dict of torch.tensors
Parameters:
coeffs -- torch.tensor, size (B, 256)
"""
id_coeffs = coeffs[:, :80]
exp_coeffs = coeffs[:, 80: 144]
tex_coeffs = coeffs[:, 144: 224]
angles = coeffs[:, 224: 227]
gammas = coeffs[:, 227: 254]
translations = coeffs[:, 254:]
return {
'id': id_coeffs,
'exp': exp_coeffs,
'tex': tex_coeffs,
'angle': angles,
'gamma': gammas,
'trans': translations
}
def compute_for_render(self, coeffs):
"""
Return:
face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
face_color -- torch.tensor, size (B, N, 3), in RGB order
landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
Parameters:
coeffs -- torch.tensor, size (B, 257)
"""
coef_dict = self.split_coeff(coeffs)
face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp'])
rotation = self.compute_rotation(coef_dict['angle'])
face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans'])
face_vertex = self.to_camera(face_shape_transformed)
face_proj = self.to_image(face_vertex)
landmark = self.get_landmarks(face_proj)
face_texture = self.compute_texture(coef_dict['tex'])
face_norm = self.compute_norm(face_shape)
face_norm_roted = face_norm @ rotation
face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma'])
return face_vertex, face_texture, face_color, landmark
def compute_for_render_woRotation(self, coeffs):
"""
Return:
face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
face_color -- torch.tensor, size (B, N, 3), in RGB order
landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
Parameters:
coeffs -- torch.tensor, size (B, 257)
"""
coef_dict = self.split_coeff(coeffs)
face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp'])
#rotation = self.compute_rotation(coef_dict['angle'])
#face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans'])
face_vertex = self.to_camera(face_shape)
face_proj = self.to_image(face_vertex)
landmark = self.get_landmarks(face_proj)
face_texture = self.compute_texture(coef_dict['tex'])
face_norm = self.compute_norm(face_shape)
face_norm_roted = face_norm # @ rotation
face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma'])
return face_vertex, face_texture, face_color, landmark
if __name__ == '__main__':
transferBFM09()
\ No newline at end of file
"""This script defines the face reconstruction model for Deep3DFaceRecon_pytorch
"""
import numpy as np
import torch
from src.face3d.models.base_model import BaseModel
from src.face3d.models import networks
from src.face3d.models.bfm import ParametricFaceModel
from src.face3d.models.losses import perceptual_loss, photo_loss, reg_loss, reflectance_loss, landmark_loss
from src.face3d.util import util
from src.face3d.util.nvdiffrast import MeshRenderer
# from src.face3d.util.preprocess import estimate_norm_torch
import trimesh
from scipy.io import savemat
class FaceReconModel(BaseModel):
@staticmethod
def modify_commandline_options(parser, is_train=False):
""" Configures options specific for CUT model
"""
# net structure and parameters
parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='network structure')
parser.add_argument('--init_path', type=str, default='./checkpoints/init_model/resnet50-0676ba61.pth')
parser.add_argument('--use_last_fc', type=util.str2bool, nargs='?', const=True, default=False, help='zero initialize the last fc')
parser.add_argument('--bfm_folder', type=str, default='./checkpoints/BFM_Fitting/')
parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model')
# renderer parameters
parser.add_argument('--focal', type=float, default=1015.)
parser.add_argument('--center', type=float, default=112.)
parser.add_argument('--camera_d', type=float, default=10.)
parser.add_argument('--z_near', type=float, default=5.)
parser.add_argument('--z_far', type=float, default=15.)
if is_train:
# training parameters
parser.add_argument('--net_recog', type=str, default='r50', choices=['r18', 'r43', 'r50'], help='face recog network structure')
parser.add_argument('--net_recog_path', type=str, default='checkpoints/recog_model/ms1mv3_arcface_r50_fp16/backbone.pth')
parser.add_argument('--use_crop_face', type=util.str2bool, nargs='?', const=True, default=False, help='use crop mask for photo loss')
parser.add_argument('--use_predef_M', type=util.str2bool, nargs='?', const=True, default=False, help='use predefined M for predicted face')
# augmentation parameters
parser.add_argument('--shift_pixs', type=float, default=10., help='shift pixels')
parser.add_argument('--scale_delta', type=float, default=0.1, help='delta scale factor')
parser.add_argument('--rot_angle', type=float, default=10., help='rot angles, degree')
# loss weights
parser.add_argument('--w_feat', type=float, default=0.2, help='weight for feat loss')
parser.add_argument('--w_color', type=float, default=1.92, help='weight for loss loss')
parser.add_argument('--w_reg', type=float, default=3.0e-4, help='weight for reg loss')
parser.add_argument('--w_id', type=float, default=1.0, help='weight for id_reg loss')
parser.add_argument('--w_exp', type=float, default=0.8, help='weight for exp_reg loss')
parser.add_argument('--w_tex', type=float, default=1.7e-2, help='weight for tex_reg loss')
parser.add_argument('--w_gamma', type=float, default=10.0, help='weight for gamma loss')
parser.add_argument('--w_lm', type=float, default=1.6e-3, help='weight for lm loss')
parser.add_argument('--w_reflc', type=float, default=5.0, help='weight for reflc loss')
opt, _ = parser.parse_known_args()
parser.set_defaults(
focal=1015., center=112., camera_d=10., use_last_fc=False, z_near=5., z_far=15.
)
if is_train:
parser.set_defaults(
use_crop_face=True, use_predef_M=False
)
return parser
def __init__(self, opt):
"""Initialize this model class.
Parameters:
opt -- training/test options
A few things can be done here.
- (required) call the initialization function of BaseModel
- define loss function, visualization images, model names, and optimizers
"""
BaseModel.__init__(self, opt) # call the initialization method of BaseModel
self.visual_names = ['output_vis']
self.model_names = ['net_recon']
self.parallel_names = self.model_names + ['renderer']
self.facemodel = ParametricFaceModel(
bfm_folder=opt.bfm_folder, camera_distance=opt.camera_d, focal=opt.focal, center=opt.center,
is_train=self.isTrain, default_name=opt.bfm_model
)
fov = 2 * np.arctan(opt.center / opt.focal) * 180 / np.pi
self.renderer = MeshRenderer(
rasterize_fov=fov, znear=opt.z_near, zfar=opt.z_far, rasterize_size=int(2 * opt.center)
)
if self.isTrain:
self.loss_names = ['all', 'feat', 'color', 'lm', 'reg', 'gamma', 'reflc']
self.net_recog = networks.define_net_recog(
net_recog=opt.net_recog, pretrained_path=opt.net_recog_path
)
# loss func name: (compute_%s_loss) % loss_name
self.compute_feat_loss = perceptual_loss
self.comupte_color_loss = photo_loss
self.compute_lm_loss = landmark_loss
self.compute_reg_loss = reg_loss
self.compute_reflc_loss = reflectance_loss
self.optimizer = torch.optim.Adam(self.net_recon.parameters(), lr=opt.lr)
self.optimizers = [self.optimizer]
self.parallel_names += ['net_recog']
# Our program will automatically call <model.setup> to define schedulers, load networks, and print networks
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input: a dictionary that contains the data itself and its metadata information.
"""
self.input_img = input['imgs'].to(self.device)
self.atten_mask = input['msks'].to(self.device) if 'msks' in input else None
self.gt_lm = input['lms'].to(self.device) if 'lms' in input else None
self.trans_m = input['M'].to(self.device) if 'M' in input else None
self.image_paths = input['im_paths'] if 'im_paths' in input else None
def forward(self, output_coeff, device):
self.facemodel.to(device)
self.pred_vertex, self.pred_tex, self.pred_color, self.pred_lm = \
self.facemodel.compute_for_render(output_coeff)
self.pred_mask, _, self.pred_face = self.renderer(
self.pred_vertex, self.facemodel.face_buf, feat=self.pred_color)
self.pred_coeffs_dict = self.facemodel.split_coeff(output_coeff)
def compute_losses(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
assert self.net_recog.training == False
trans_m = self.trans_m
if not self.opt.use_predef_M:
trans_m = estimate_norm_torch(self.pred_lm, self.input_img.shape[-2])
pred_feat = self.net_recog(self.pred_face, trans_m)
gt_feat = self.net_recog(self.input_img, self.trans_m)
self.loss_feat = self.opt.w_feat * self.compute_feat_loss(pred_feat, gt_feat)
face_mask = self.pred_mask
if self.opt.use_crop_face:
face_mask, _, _ = self.renderer(self.pred_vertex, self.facemodel.front_face_buf)
face_mask = face_mask.detach()
self.loss_color = self.opt.w_color * self.comupte_color_loss(
self.pred_face, self.input_img, self.atten_mask * face_mask)
loss_reg, loss_gamma = self.compute_reg_loss(self.pred_coeffs_dict, self.opt)
self.loss_reg = self.opt.w_reg * loss_reg
self.loss_gamma = self.opt.w_gamma * loss_gamma
self.loss_lm = self.opt.w_lm * self.compute_lm_loss(self.pred_lm, self.gt_lm)
self.loss_reflc = self.opt.w_reflc * self.compute_reflc_loss(self.pred_tex, self.facemodel.skin_mask)
self.loss_all = self.loss_feat + self.loss_color + self.loss_reg + self.loss_gamma \
+ self.loss_lm + self.loss_reflc
def optimize_parameters(self, isTrain=True):
self.forward()
self.compute_losses()
"""Update network weights; it will be called in every training iteration."""
if isTrain:
self.optimizer.zero_grad()
self.loss_all.backward()
self.optimizer.step()
def compute_visuals(self):
with torch.no_grad():
input_img_numpy = 255. * self.input_img.detach().cpu().permute(0, 2, 3, 1).numpy()
output_vis = self.pred_face * self.pred_mask + (1 - self.pred_mask) * self.input_img
output_vis_numpy_raw = 255. * output_vis.detach().cpu().permute(0, 2, 3, 1).numpy()
if self.gt_lm is not None:
gt_lm_numpy = self.gt_lm.cpu().numpy()
pred_lm_numpy = self.pred_lm.detach().cpu().numpy()
output_vis_numpy = util.draw_landmarks(output_vis_numpy_raw, gt_lm_numpy, 'b')
output_vis_numpy = util.draw_landmarks(output_vis_numpy, pred_lm_numpy, 'r')
output_vis_numpy = np.concatenate((input_img_numpy,
output_vis_numpy_raw, output_vis_numpy), axis=-2)
else:
output_vis_numpy = np.concatenate((input_img_numpy,
output_vis_numpy_raw), axis=-2)
self.output_vis = torch.tensor(
output_vis_numpy / 255., dtype=torch.float32
).permute(0, 3, 1, 2).to(self.device)
def save_mesh(self, name):
recon_shape = self.pred_vertex # get reconstructed shape
recon_shape[..., -1] = 10 - recon_shape[..., -1] # from camera space to world space
recon_shape = recon_shape.cpu().numpy()[0]
recon_color = self.pred_color
recon_color = recon_color.cpu().numpy()[0]
tri = self.facemodel.face_buf.cpu().numpy()
mesh = trimesh.Trimesh(vertices=recon_shape, faces=tri, vertex_colors=np.clip(255. * recon_color, 0, 255).astype(np.uint8))
mesh.export(name)
def save_coeff(self,name):
pred_coeffs = {key:self.pred_coeffs_dict[key].cpu().numpy() for key in self.pred_coeffs_dict}
pred_lm = self.pred_lm.cpu().numpy()
pred_lm = np.stack([pred_lm[:,:,0],self.input_img.shape[2]-1-pred_lm[:,:,1]],axis=2) # transfer to image coordinate
pred_coeffs['lm68'] = pred_lm
savemat(name,pred_coeffs)
import numpy as np
import torch
import torch.nn as nn
from kornia.geometry import warp_affine
import torch.nn.functional as F
def resize_n_crop(image, M, dsize=112):
# image: (b, c, h, w)
# M : (b, 2, 3)
return warp_affine(image, M, dsize=(dsize, dsize), align_corners=True)
### perceptual level loss
class PerceptualLoss(nn.Module):
def __init__(self, recog_net, input_size=112):
super(PerceptualLoss, self).__init__()
self.recog_net = recog_net
self.preprocess = lambda x: 2 * x - 1
self.input_size=input_size
def forward(imageA, imageB, M):
"""
1 - cosine distance
Parameters:
imageA --torch.tensor (B, 3, H, W), range (0, 1) , RGB order
imageB --same as imageA
"""
imageA = self.preprocess(resize_n_crop(imageA, M, self.input_size))
imageB = self.preprocess(resize_n_crop(imageB, M, self.input_size))
# freeze bn
self.recog_net.eval()
id_featureA = F.normalize(self.recog_net(imageA), dim=-1, p=2)
id_featureB = F.normalize(self.recog_net(imageB), dim=-1, p=2)
cosine_d = torch.sum(id_featureA * id_featureB, dim=-1)
# assert torch.sum((cosine_d > 1).float()) == 0
return torch.sum(1 - cosine_d) / cosine_d.shape[0]
def perceptual_loss(id_featureA, id_featureB):
cosine_d = torch.sum(id_featureA * id_featureB, dim=-1)
# assert torch.sum((cosine_d > 1).float()) == 0
return torch.sum(1 - cosine_d) / cosine_d.shape[0]
### image level loss
def photo_loss(imageA, imageB, mask, eps=1e-6):
"""
l2 norm (with sqrt, to ensure backward stabililty, use eps, otherwise Nan may occur)
Parameters:
imageA --torch.tensor (B, 3, H, W), range (0, 1), RGB order
imageB --same as imageA
"""
loss = torch.sqrt(eps + torch.sum((imageA - imageB) ** 2, dim=1, keepdims=True)) * mask
loss = torch.sum(loss) / torch.max(torch.sum(mask), torch.tensor(1.0).to(mask.device))
return loss
def landmark_loss(predict_lm, gt_lm, weight=None):
"""
weighted mse loss
Parameters:
predict_lm --torch.tensor (B, 68, 2)
gt_lm --torch.tensor (B, 68, 2)
weight --numpy.array (1, 68)
"""
if not weight:
weight = np.ones([68])
weight[28:31] = 20
weight[-8:] = 20
weight = np.expand_dims(weight, 0)
weight = torch.tensor(weight).to(predict_lm.device)
loss = torch.sum((predict_lm - gt_lm)**2, dim=-1) * weight
loss = torch.sum(loss) / (predict_lm.shape[0] * predict_lm.shape[1])
return loss
### regulization
def reg_loss(coeffs_dict, opt=None):
"""
l2 norm without the sqrt, from yu's implementation (mse)
tf.nn.l2_loss https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss
Parameters:
coeffs_dict -- a dict of torch.tensors , keys: id, exp, tex, angle, gamma, trans
"""
# coefficient regularization to ensure plausible 3d faces
if opt:
w_id, w_exp, w_tex = opt.w_id, opt.w_exp, opt.w_tex
else:
w_id, w_exp, w_tex = 1, 1, 1, 1
creg_loss = w_id * torch.sum(coeffs_dict['id'] ** 2) + \
w_exp * torch.sum(coeffs_dict['exp'] ** 2) + \
w_tex * torch.sum(coeffs_dict['tex'] ** 2)
creg_loss = creg_loss / coeffs_dict['id'].shape[0]
# gamma regularization to ensure a nearly-monochromatic light
gamma = coeffs_dict['gamma'].reshape([-1, 3, 9])
gamma_mean = torch.mean(gamma, dim=1, keepdims=True)
gamma_loss = torch.mean((gamma - gamma_mean) ** 2)
return creg_loss, gamma_loss
def reflectance_loss(texture, mask):
"""
minimize texture variance (mse), albedo regularization to ensure an uniform skin albedo
Parameters:
texture --torch.tensor, (B, N, 3)
mask --torch.tensor, (N), 1 or 0
"""
mask = mask.reshape([1, mask.shape[0], 1])
texture_mean = torch.sum(mask * texture, dim=1, keepdims=True) / torch.sum(mask)
loss = torch.sum(((texture - texture_mean) * mask)**2) / (texture.shape[0] * torch.sum(mask))
return loss
"""This script defines deep neural networks for Deep3DFaceRecon_pytorch
"""
import os
import numpy as np
import torch.nn.functional as F
from torch.nn import init
import functools
from torch.optim import lr_scheduler
import torch
from torch import Tensor
import torch.nn as nn
try:
from torch.hub import load_state_dict_from_url
except ImportError:
from torch.utils.model_zoo import load_url as load_state_dict_from_url
from typing import Type, Any, Callable, Union, List, Optional
from .arcface_torch.backbones import get_model
from kornia.geometry import warp_affine
def resize_n_crop(image, M, dsize=112):
# image: (b, c, h, w)
# M : (b, 2, 3)
return warp_affine(image, M, dsize=(dsize, dsize), align_corners=True)
def filter_state_dict(state_dict, remove_name='fc'):
new_state_dict = {}
for key in state_dict:
if remove_name in key:
continue
new_state_dict[key] = state_dict[key]
return new_state_dict
def get_scheduler(optimizer, opt):
"""Return a learning rate scheduler
Parameters:
optimizer -- the optimizer of the network
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions. 
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
See https://pytorch.org/docs/stable/optim.html for more details.
"""
if opt.lr_policy == 'linear':
def lambda_rule(epoch):
lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs + 1)
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif opt.lr_policy == 'step':
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_epochs, gamma=0.2)
elif opt.lr_policy == 'plateau':
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
elif opt.lr_policy == 'cosine':
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
else:
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
return scheduler
def define_net_recon(net_recon, use_last_fc=False, init_path=None):
return ReconNetWrapper(net_recon, use_last_fc=use_last_fc, init_path=init_path)
def define_net_recog(net_recog, pretrained_path=None):
net = RecogNetWrapper(net_recog=net_recog, pretrained_path=pretrained_path)
net.eval()
return net
class ReconNetWrapper(nn.Module):
fc_dim=257
def __init__(self, net_recon, use_last_fc=False, init_path=None):
super(ReconNetWrapper, self).__init__()
self.use_last_fc = use_last_fc
if net_recon not in func_dict:
return NotImplementedError('network [%s] is not implemented', net_recon)
func, last_dim = func_dict[net_recon]
backbone = func(use_last_fc=use_last_fc, num_classes=self.fc_dim)
if init_path and os.path.isfile(init_path):
state_dict = filter_state_dict(torch.load(init_path, map_location='cpu'))
backbone.load_state_dict(state_dict)
print("loading init net_recon %s from %s" %(net_recon, init_path))
self.backbone = backbone
if not use_last_fc:
self.final_layers = nn.ModuleList([
conv1x1(last_dim, 80, bias=True), # id layer
conv1x1(last_dim, 64, bias=True), # exp layer
conv1x1(last_dim, 80, bias=True), # tex layer
conv1x1(last_dim, 3, bias=True), # angle layer
conv1x1(last_dim, 27, bias=True), # gamma layer
conv1x1(last_dim, 2, bias=True), # tx, ty
conv1x1(last_dim, 1, bias=True) # tz
])
for m in self.final_layers:
nn.init.constant_(m.weight, 0.)
nn.init.constant_(m.bias, 0.)
def forward(self, x):
x = self.backbone(x)
if not self.use_last_fc:
output = []
for layer in self.final_layers:
output.append(layer(x))
x = torch.flatten(torch.cat(output, dim=1), 1)
return x
class RecogNetWrapper(nn.Module):
def __init__(self, net_recog, pretrained_path=None, input_size=112):
super(RecogNetWrapper, self).__init__()
net = get_model(name=net_recog, fp16=False)
if pretrained_path:
state_dict = torch.load(pretrained_path, map_location='cpu')
net.load_state_dict(state_dict)
print("loading pretrained net_recog %s from %s" %(net_recog, pretrained_path))
for param in net.parameters():
param.requires_grad = False
self.net = net
self.preprocess = lambda x: 2 * x - 1
self.input_size=input_size
def forward(self, image, M):
image = self.preprocess(resize_n_crop(image, M, self.input_size))
id_feature = F.normalize(self.net(image), dim=-1, p=2)
return id_feature
# adapted from https://github.com/pytorch/vision/edit/master/torchvision/models/resnet.py
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
'wide_resnet50_2', 'wide_resnet101_2']
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth',
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes: int, out_planes: int, stride: int = 1, bias: bool = False) -> nn.Conv2d:
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias)
class BasicBlock(nn.Module):
expansion: int = 1
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
expansion: int = 4
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(
self,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
num_classes: int = 1000,
zero_init_residual: bool = False,
use_last_fc: bool = False,
groups: int = 1,
width_per_group: int = 64,
replace_stride_with_dilation: Optional[List[bool]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.use_last_fc = use_last_fc
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
if self.use_last_fc:
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
stride: int = 1, dilate: bool = False) -> nn.Sequential:
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def _forward_impl(self, x: Tensor) -> Tensor:
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
if self.use_last_fc:
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)
def _resnet(
arch: str,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
pretrained: bool,
progress: bool,
**kwargs: Any
) -> ResNet:
model = ResNet(block, layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model
def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
**kwargs)
def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
**kwargs)
def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-152 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
**kwargs)
def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNeXt-50 32x4d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 4
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
pretrained, progress, **kwargs)
def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
pretrained, progress, **kwargs)
def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""Wide ResNet-50-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
pretrained, progress, **kwargs)
def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""Wide ResNet-101-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
pretrained, progress, **kwargs)
func_dict = {
'resnet18': (resnet18, 512),
'resnet50': (resnet50, 2048)
}
"""Model class template
This module provides a template for users to implement custom models.
You can specify '--model template' to use this model.
The class name should be consistent with both the filename and its model option.
The filename should be <model>_dataset.py
The class name should be <Model>Dataset.py
It implements a simple image-to-image translation baseline based on regression loss.
Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss:
min_<netG> ||netG(data_A) - data_B||_1
You need to implement the following functions:
<modify_commandline_options>: Add model-specific options and rewrite default values for existing options.
<__init__>: Initialize this model class.
<set_input>: Unpack input data and perform data pre-processing.
<forward>: Run forward pass. This will be called by both <optimize_parameters> and <test>.
<optimize_parameters>: Update network weights; it will be called in every training iteration.
"""
import numpy as np
import torch
from .base_model import BaseModel
from . import networks
class TemplateModel(BaseModel):
@staticmethod
def modify_commandline_options(parser, is_train=True):
"""Add new model-specific options and rewrite default values for existing options.
Parameters:
parser -- the option parser
is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
parser.set_defaults(dataset_mode='aligned') # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset.
if is_train:
parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model.
return parser
def __init__(self, opt):
"""Initialize this model class.
Parameters:
opt -- training/test options
A few things can be done here.
- (required) call the initialization function of BaseModel
- define loss function, visualization images, model names, and optimizers
"""
BaseModel.__init__(self, opt) # call the initialization method of BaseModel
# specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk.
self.loss_names = ['loss_G']
# specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images.
self.visual_names = ['data_A', 'data_B', 'output']
# specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks.
# you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them.
self.model_names = ['G']
# define networks; you can use opt.isTrain to specify different behaviors for training and test.
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids)
if self.isTrain: # only defined during training time
# define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss.
# We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device)
self.criterionLoss = torch.nn.L1Loss()
# define and initialize optimizers. You can define one optimizer for each network.
# If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizers = [self.optimizer]
# Our program will automatically call <model.setup> to define schedulers, load networks, and print networks
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input: a dictionary that contains the data itself and its metadata information.
"""
AtoB = self.opt.direction == 'AtoB' # use <direction> to swap data_A and data_B
self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A
self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B
self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths
def forward(self):
"""Run forward pass. This will be called by both functions <optimize_parameters> and <test>."""
self.output = self.netG(self.data_A) # generate output image given the input data_A
def backward(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
# caculate the intermediate results if necessary; here self.output has been computed during function <forward>
# calculate loss given the input and intermediate results
self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression
self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G
def optimize_parameters(self):
"""Update network weights; it will be called in every training iteration."""
self.forward() # first call forward to calculate intermediate results
self.optimizer.zero_grad() # clear network G's existing gradients
self.backward() # calculate gradients for network G
self.optimizer.step() # update gradients for network G
"""This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
"""This script contains base options for Deep3DFaceRecon_pytorch
"""
import argparse
import os
from util import util
import numpy as np
import torch
import face3d.models as models
import face3d.data as data
class BaseOptions():
"""This class defines options used during both training and test time.
It also implements several helper functions such as parsing, printing, and saving the options.
It also gathers additional options defined in <modify_commandline_options> functions in both dataset class and model class.
"""
def __init__(self, cmd_line=None):
"""Reset the class; indicates the class hasn't been initailized"""
self.initialized = False
self.cmd_line = None
if cmd_line is not None:
self.cmd_line = cmd_line.split()
def initialize(self, parser):
"""Define the common options that are used in both training and test."""
# basic parameters
parser.add_argument('--name', type=str, default='face_recon', help='name of the experiment. It decides where to store samples and models')
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
parser.add_argument('--vis_batch_nums', type=float, default=1, help='batch nums of images for visulization')
parser.add_argument('--eval_batch_nums', type=float, default=float('inf'), help='batch nums of images for evaluation')
parser.add_argument('--use_ddp', type=util.str2bool, nargs='?', const=True, default=True, help='whether use distributed data parallel')
parser.add_argument('--ddp_port', type=str, default='12355', help='ddp port')
parser.add_argument('--display_per_batch', type=util.str2bool, nargs='?', const=True, default=True, help='whether use batch to show losses')
parser.add_argument('--add_image', type=util.str2bool, nargs='?', const=True, default=True, help='whether add image to tensorboard')
parser.add_argument('--world_size', type=int, default=1, help='batch nums of images for evaluation')
# model parameters
parser.add_argument('--model', type=str, default='facerecon', help='chooses which model to use.')
# additional parameters
parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
self.initialized = True
return parser
def gather_options(self):
"""Initialize our parser with basic options(only once).
Add additional model-specific and dataset-specific options.
These options are defined in the <modify_commandline_options> function
in model and dataset classes.
"""
if not self.initialized: # check if it has been initialized
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = self.initialize(parser)
# get the basic options
if self.cmd_line is None:
opt, _ = parser.parse_known_args()
else:
opt, _ = parser.parse_known_args(self.cmd_line)
# set cuda visible devices
os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids
# modify model-related parser options
model_name = opt.model
model_option_setter = models.get_option_setter(model_name)
parser = model_option_setter(parser, self.isTrain)
if self.cmd_line is None:
opt, _ = parser.parse_known_args() # parse again with new defaults
else:
opt, _ = parser.parse_known_args(self.cmd_line) # parse again with new defaults
# modify dataset-related parser options
if opt.dataset_mode:
dataset_name = opt.dataset_mode
dataset_option_setter = data.get_option_setter(dataset_name)
parser = dataset_option_setter(parser, self.isTrain)
# save and return the parser
self.parser = parser
if self.cmd_line is None:
return parser.parse_args()
else:
return parser.parse_args(self.cmd_line)
def print_options(self, opt):
"""Print and save options
It will print both current options and default values(if different).
It will save options into a text file / [checkpoints_dir] / opt.txt
"""
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(opt).items()):
comment = ''
default = self.parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------'
print(message)
# save to the disk
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
util.mkdirs(expr_dir)
file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
try:
with open(file_name, 'wt') as opt_file:
opt_file.write(message)
opt_file.write('\n')
except PermissionError as error:
print("permission error {}".format(error))
pass
def parse(self):
"""Parse our options, create checkpoints directory suffix, and set up gpu device."""
opt = self.gather_options()
opt.isTrain = self.isTrain # train or test
# process opt.suffix
if opt.suffix:
suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
opt.name = opt.name + suffix
# set gpu ids
str_ids = opt.gpu_ids.split(',')
gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
gpu_ids.append(id)
opt.world_size = len(gpu_ids)
# if len(opt.gpu_ids) > 0:
# torch.cuda.set_device(gpu_ids[0])
if opt.world_size == 1:
opt.use_ddp = False
if opt.phase != 'test':
# set continue_train automatically
if opt.pretrained_name is None:
model_dir = os.path.join(opt.checkpoints_dir, opt.name)
else:
model_dir = os.path.join(opt.checkpoints_dir, opt.pretrained_name)
if os.path.isdir(model_dir):
model_pths = [i for i in os.listdir(model_dir) if i.endswith('pth')]
if os.path.isdir(model_dir) and len(model_pths) != 0:
opt.continue_train= True
# update the latest epoch count
if opt.continue_train:
if opt.epoch == 'latest':
epoch_counts = [int(i.split('.')[0].split('_')[-1]) for i in model_pths if 'latest' not in i]
if len(epoch_counts) != 0:
opt.epoch_count = max(epoch_counts) + 1
else:
opt.epoch_count = int(opt.epoch) + 1
self.print_options(opt)
self.opt = opt
return self.opt
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment