Commit b6c19984 authored by dengjb's avatar dengjb
Browse files

update

parents
#!/usr/bin/env python
# encoding: utf-8
"""
@author: L1aoXingyu, guan'an wang
@contact: sherlockliao01@gmail.com, guan.wang0706@gmail.com
"""
import sys
sys.path.append('.')
from fastreid.config import get_cfg
from fastreid.engine import default_argument_parser, default_setup, DefaultTrainer, launch
from fastreid.utils.checkpoint import Checkpointer
from fastdistill import *
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
default_setup(cfg, args)
return cfg
def main(args):
cfg = setup(args)
if args.eval_only:
model = DefaultTrainer.build_model(cfg)
Checkpointer(model, save_dir=cfg.OUTPUT_DIR).load(cfg.MODEL.WEIGHTS)
res = DefaultTrainer.test(cfg, model)
return res
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=args.resume)
return trainer.train()
if __name__ == "__main__":
parser = default_argument_parser()
args = parser.parse_args()
print("Command Line Args:", args)
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args,),
)
# FastFace in FastReID
This project provides a baseline for face recognition.
## Datasets Preparation
| Function | Dataset |
| --- | --- |
| Train | MS-Celeb-1M |
| Test-1 | LFW |
| Test-2 | CPLFW |
| Test-3 | CALFW |
| Test-4 | VGG2_FP |
| Test-5 | AgeDB-30 |
| Test-6 | CFP_FF |
| Test-7 | CFP-FP |
We do data wrangling following [InsightFace_Pytorch](https://github.com/TreB1eN/InsightFace_Pytorch) instruction.
## Dependencies
- bcolz
- mxnet (optional) if you want to read `.rec` directly
## Experiment Results
We refer to [insightface_pytorch](https://github.com/TreB1eN/InsightFace_Pytorch) as our baseline methods, and on top of it, we use circle loss and cosine lr scheduler.
| Method | LFW(%) | CFP-FF(%) | CFP-FP(%)| AgeDB-30(%) | calfw(%) | cplfw(%) | vgg2_fp(%) |
| :---: | :---: | :---: |:---: | :---: | :---: | :---: | :---: |
| [insightface_pytorch](https://github.com/TreB1eN/InsightFace_Pytorch) | 99.52 | 99.62 | 95.04 | 96.22 | 95.57 | 91.07 | 93.86 |
| ir50_se | 99.70 | 99.60 | 96.43 | 97.87 | 95.95 | 91.10 | 94.32 |
| ir100_se | 99.65 | 99.69 | 97.10 | 97.98 | 96.00 | 91.53 | 94.62 |
| ir50_se_0.1 | | | | | | | |
| ir100_se_0.1 | | | | | | | |
MODEL:
META_ARCHITECTURE: FaceBaseline
PIXEL_MEAN: [127.5, 127.5, 127.5]
PIXEL_STD: [127.5, 127.5, 127.5]
BACKBONE:
NAME: build_iresnet_backbone
HEADS:
NAME: FaceHead
WITH_BNNECK: True
NORM: BN
NECK_FEAT: after
EMBEDDING_DIM: 512
POOL_LAYER: Flatten
CLS_LAYER: CosSoftmax
SCALE: 64
MARGIN: 0.4
NUM_CLASSES: 360232
PFC:
ENABLED: False
SAMPLE_RATE: 0.1
LOSSES:
NAME: ("CrossEntropyLoss",)
CE:
EPSILON: 0.
SCALE: 1.
DATASETS:
REC_PATH: /export/home/DATA/Glint360k/train.rec
NAMES: ("MS1MV2",)
TESTS: ("CFP_FP", "AgeDB_30", "LFW")
INPUT:
SIZE_TRAIN: [0,] # No need of resize
SIZE_TEST: [0,]
FLIP:
ENABLED: True
PROB: 0.5
DATALOADER:
SAMPLER_TRAIN: TrainingSampler
NUM_WORKERS: 8
SOLVER:
MAX_EPOCH: 20
AMP:
ENABLED: True
OPT: SGD
BASE_LR: 0.05
MOMENTUM: 0.9
SCHED: MultiStepLR
STEPS: [8, 12, 15, 18]
BIAS_LR_FACTOR: 1.
WEIGHT_DECAY: 0.0005
WEIGHT_DECAY_BIAS: 0.0005
IMS_PER_BATCH: 256
WARMUP_FACTOR: 0.1
WARMUP_ITERS: 0
CHECKPOINT_PERIOD: 1
TEST:
EVAL_PERIOD: 1
IMS_PER_BATCH: 1024
CUDNN_BENCHMARK: True
\ No newline at end of file
_BASE_: face_base.yml
MODEL:
BACKBONE:
NAME: build_resnetIR_backbone
DEPTH: 100x
FEAT_DIM: 25088 # 512x7x7
WITH_SE: True
HEADS:
PFC:
ENABLED: True
OUTPUT_DIR: projects/FastFace/logs/ir_se101-ms1mv2-circle
_BASE_: face_base.yml
MODEL:
BACKBONE:
DEPTH: 50x
FEAT_DIM: 25088 # 512x7x7
DROPOUT: 0.
HEADS:
PFC:
ENABLED: True
OUTPUT_DIR: projects/FastFace/logs/pfc0.1_insightface
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
from .modeling import *
from .config import add_face_cfg
from .trainer import FaceTrainer
from .datasets import *
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
from fastreid.config import CfgNode as CN
def add_face_cfg(cfg):
_C = cfg
_C.DATASETS.REC_PATH = ""
_C.MODEL.BACKBONE.DROPOUT = 0.
_C.MODEL.HEADS.PFC = CN({"ENABLED": False})
_C.MODEL.HEADS.PFC.SAMPLE_RATE = 0.1
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
from .ms1mv2 import MS1MV2
from .test_dataset import *
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import glob
import os
from fastreid.data.datasets import DATASET_REGISTRY
from fastreid.data.datasets.bases import ImageDataset
@DATASET_REGISTRY.register()
class MS1MV2(ImageDataset):
dataset_dir = "MS_Celeb_1M"
dataset_name = "ms1mv2"
def __init__(self, root="datasets", **kwargs):
self.root = root
self.dataset_dir = os.path.join(self.root, self.dataset_dir)
required_files = [self.dataset_dir]
self.check_before_run(required_files)
train = self.process_dirs()[:10000]
super().__init__(train, [], [], **kwargs)
def process_dirs(self):
train_list = []
fid_list = os.listdir(self.dataset_dir)
for fid in fid_list:
all_imgs = glob.glob(os.path.join(self.dataset_dir, fid, "*.jpg"))
for img_path in all_imgs:
train_list.append([img_path, self.dataset_name + '_' + fid, '0'])
return train_list
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import os
import bcolz
import numpy as np
from fastreid.data.datasets import DATASET_REGISTRY
from fastreid.data.datasets.bases import ImageDataset
__all__ = ["CPLFW", "VGG2_FP", "AgeDB_30", "CALFW", "CFP_FF", "CFP_FP", "LFW"]
@DATASET_REGISTRY.register()
class CPLFW(ImageDataset):
dataset_dir = "faces_emore_val"
dataset_name = "cplfw"
def __init__(self, root='datasets', **kwargs):
self.root = root
self.dataset_dir = os.path.join(self.root, self.dataset_dir)
required_files = [self.dataset_dir]
self.check_before_run(required_files)
carray = bcolz.carray(rootdir=os.path.join(self.dataset_dir, self.dataset_name), mode='r')
is_same = np.load(os.path.join(self.dataset_dir, "{}_list.npy".format(self.dataset_name)))
self.carray = carray
self.is_same = is_same
super().__init__([], [], [], **kwargs)
@DATASET_REGISTRY.register()
class VGG2_FP(CPLFW):
dataset_name = "vgg2_fp"
@DATASET_REGISTRY.register()
class AgeDB_30(CPLFW):
dataset_name = "agedb_30"
@DATASET_REGISTRY.register()
class CALFW(CPLFW):
dataset_name = "calfw"
@DATASET_REGISTRY.register()
class CFP_FF(CPLFW):
dataset_name = "cfp_ff"
@DATASET_REGISTRY.register()
class CFP_FP(CPLFW):
dataset_name = "cfp_fp"
@DATASET_REGISTRY.register()
class LFW(CPLFW):
dataset_name = "lfw"
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
from PIL import Image
import io
import logging
import numbers
import torch
from torch.utils.data import Dataset
from fastreid.data.common import CommDataset
logger = logging.getLogger("fastreid.face_data")
try:
import mxnet as mx
except ImportError:
logger.info("Please install mxnet if you want to use .rec file")
class MXFaceDataset(Dataset):
def __init__(self, path_imgrec, transforms):
super().__init__()
self.transforms = transforms
logger.info(f"loading recordio {path_imgrec}...")
path_imgidx = path_imgrec[0:-4] + ".idx"
self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')
s = self.imgrec.read_idx(0)
header, _ = mx.recordio.unpack(s)
if header.flag > 0:
# logger.debug(f"header0 label: {header.label}")
self.header0 = (int(header.label[0]), int(header.label[1]))
self.imgidx = list(range(1, int(header.label[0])))
# logger.debug(self.imgidx)
else:
self.imgidx = list(self.imgrec.keys)
logger.info(f"Number of Samples: {len(self.imgidx)}, "
f"Number of Classes: {int(self.header0[1] - self.header0[0])}")
def __getitem__(self, index):
idx = self.imgidx[index]
s = self.imgrec.read_idx(idx)
header, img = mx.recordio.unpack(s)
label = header.label
if not isinstance(label, numbers.Number):
label = label[0]
label = torch.tensor(label, dtype=torch.long)
sample = Image.open(io.BytesIO(img)) # RGB
if self.transforms is not None: sample = self.transforms(sample)
return {
"images": sample,
"targets": label,
"camids": 0,
}
def __len__(self):
# logger.debug(f"mxface dataset length is {len(self.imgidx)}")
return len(self.imgidx)
@property
def num_classes(self):
return int(self.header0[1] - self.header0[0])
class TestFaceDataset(CommDataset):
def __init__(self, img_items, labels):
self.img_items = img_items
self.labels = labels
def __getitem__(self, index):
img = torch.tensor(self.img_items[index]) * 127.5 + 127.5
return {
"images": img,
}
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import copy
import io
import logging
import os
from collections import OrderedDict
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from PIL import Image
from fastreid.evaluation import DatasetEvaluator
from fastreid.utils import comm
from fastreid.utils.file_io import PathManager
from .verification import evaluate
logger = logging.getLogger("fastreid.face_evaluator")
def gen_plot(fpr, tpr):
"""Create a pyplot plot and save to buffer."""
plt.figure()
plt.xlabel("FPR", fontsize=14)
plt.ylabel("TPR", fontsize=14)
plt.title("ROC Curve", fontsize=14)
plt.plot(fpr, tpr, linewidth=2)
buf = io.BytesIO()
plt.savefig(buf, format='jpeg')
buf.seek(0)
plt.close()
return buf
class FaceEvaluator(DatasetEvaluator):
def __init__(self, cfg, labels, dataset_name, output_dir=None):
self.cfg = cfg
self.labels = labels
self.dataset_name = dataset_name
self._output_dir = output_dir
self.features = []
def reset(self):
self.features = []
def process(self, inputs, outputs):
self.features.append(outputs.cpu())
def evaluate(self):
if comm.get_world_size() > 1:
comm.synchronize()
features = comm.gather(self.features)
features = sum(features, [])
# fmt: off
if not comm.is_main_process(): return {}
# fmt: on
else:
features = self.features
features = torch.cat(features, dim=0)
features = F.normalize(features, p=2, dim=1).numpy()
self._results = OrderedDict()
tpr, fpr, accuracy, best_thresholds = evaluate(features, self.labels)
self._results["Accuracy"] = accuracy.mean() * 100
self._results["Threshold"] = best_thresholds.mean()
self._results["metric"] = accuracy.mean() * 100
buf = gen_plot(fpr, tpr)
roc_curve = Image.open(buf)
PathManager.mkdirs(self._output_dir)
roc_curve.save(os.path.join(self._output_dir, self.dataset_name + "_roc.png"))
return copy.deepcopy(self._results)
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
from .partial_fc import PartialFC
from .face_baseline import FaceBaseline
from .face_head import FaceHead
from .iresnet import build_iresnet_backbone
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import torch
from fastreid.modeling.meta_arch import Baseline
from fastreid.modeling.meta_arch import META_ARCH_REGISTRY
@META_ARCH_REGISTRY.register()
class FaceBaseline(Baseline):
def __init__(self, cfg):
super().__init__(cfg)
self.pfc_enabled = cfg.MODEL.HEADS.PFC.ENABLED
self.amp_enabled = cfg.SOLVER.AMP.ENABLED
def forward(self, batched_inputs):
if not self.pfc_enabled:
return super().forward(batched_inputs)
images = self.preprocess_image(batched_inputs)
with torch.cuda.amp.autocast(self.amp_enabled):
features = self.backbone(images)
features = features.float() if self.amp_enabled else features
if self.training:
assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
targets = batched_inputs["targets"]
# PreciseBN flag, When do preciseBN on different dataset, the number of classes in new dataset
# may be larger than that in the original dataset, so the circle/arcface will
# throw an error. We just set all the targets to 0 to avoid this problem.
if targets.sum() < 0: targets.zero_()
outputs = self.heads(features, targets)
return outputs, targets
else:
outputs = self.heads(features)
return outputs
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
from fastreid.config import configurable
from fastreid.modeling.heads import EmbeddingHead
from fastreid.modeling.heads.build import REID_HEADS_REGISTRY
@REID_HEADS_REGISTRY.register()
class FaceHead(EmbeddingHead):
def __init__(self, cfg):
super().__init__(cfg)
self.pfc_enabled = False
if cfg.MODEL.HEADS.PFC.ENABLED:
# Delete pre-defined linear weights for partial fc sample
del self.weight
self.pfc_enabled = True
def forward(self, features, targets=None):
"""
Partial FC forward, which will sample positive weights and part of negative weights,
then compute logits and get the grad of features.
"""
if not self.pfc_enabled:
return super().forward(features, targets)
else:
pool_feat = self.pool_layer(features)
neck_feat = self.bottleneck(pool_feat)
neck_feat = neck_feat[..., 0, 0]
return neck_feat
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import torch
from torch import nn
from fastreid.layers import get_norm
from fastreid.modeling.backbones import BACKBONE_REGISTRY
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes,
out_planes,
kernel_size=1,
stride=stride,
bias=False)
class IBasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, bn_norm, stride=1, downsample=None,
groups=1, base_width=64, dilation=1):
super().__init__()
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
self.bn1 = get_norm(bn_norm, inplanes)
self.conv1 = conv3x3(inplanes, planes)
self.bn2 = get_norm(bn_norm, planes)
self.prelu = nn.PReLU(planes)
self.conv2 = conv3x3(planes, planes, stride)
self.bn3 = get_norm(bn_norm, planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.bn1(x)
out = self.conv1(out)
out = self.bn2(out)
out = self.prelu(out)
out = self.conv2(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
class IResNet(nn.Module):
fc_scale = 7 * 7
def __init__(self, block, layers, bn_norm, dropout=0, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
super().__init__()
self.inplanes = 64
self.dilation = 1
self.fp16 = fp16
if replace_stride_with_dilation is None:
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = get_norm(bn_norm, self.inplanes)
self.prelu = nn.PReLU(self.inplanes)
self.layer1 = self._make_layer(block, 64, layers[0], bn_norm, stride=2)
self.layer2 = self._make_layer(block,
128,
layers[1],
bn_norm,
stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block,
256,
layers[2],
bn_norm,
stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block,
512,
layers[3],
bn_norm,
stride=2,
dilate=replace_stride_with_dilation[2])
self.bn2 = get_norm(bn_norm, 512 * block.expansion)
self.dropout = nn.Dropout(p=dropout, inplace=True)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, 0, 0.1)
elif m.__class__.__name__.find('Norm') != -1:
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
if zero_init_residual:
for m in self.modules():
if isinstance(m, IBasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, bn_norm, stride=1, dilate=False):
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
get_norm(bn_norm, planes * block.expansion),
)
layers = []
layers.append(
block(self.inplanes, planes, bn_norm, stride, downsample, self.groups,
self.base_width, previous_dilation))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(self.inplanes,
planes,
bn_norm,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.prelu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.bn2(x)
x = self.dropout(x)
return x
@BACKBONE_REGISTRY.register()
def build_iresnet_backbone(cfg):
"""
Create a IResNet instance from config.
Returns:
ResNet: a :class:`ResNet` instance.
"""
# fmt: off
bn_norm = cfg.MODEL.BACKBONE.NORM
depth = cfg.MODEL.BACKBONE.DEPTH
dropout = cfg.MODEL.BACKBONE.DROPOUT
fp16 = cfg.SOLVER.AMP.ENABLED
# fmt: on
num_blocks_per_stage = {
'18x': [2, 2, 2, 2],
'34x': [3, 4, 6, 3],
'50x': [3, 4, 14, 3],
'100x': [3, 13, 30, 3],
'200x': [6, 26, 60, 6],
}[depth]
model = IResNet(IBasicBlock, num_blocks_per_stage, bn_norm, dropout, fp16=fp16)
return model
# encoding: utf-8
# code based on:
# https://github.com/deepinsight/insightface/blob/master/recognition/arcface_torch/partial_fc.py
import logging
import math
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn
from fastreid.layers import any_softmax
from fastreid.modeling.losses.utils import concat_all_gather
from fastreid.utils import comm
logger = logging.getLogger('fastreid.partial_fc')
class PartialFC(nn.Module):
"""
Author: {Xiang An, Yang Xiao, XuHan Zhu} in DeepGlint,
Partial FC: Training 10 Million Identities on a Single Machine
See the original paper:
https://arxiv.org/abs/2010.05222
"""
def __init__(
self,
embedding_size,
num_classes,
sample_rate,
cls_type,
scale,
margin
):
super().__init__()
self.embedding_size = embedding_size
self.num_classes = num_classes
self.sample_rate = sample_rate
self.world_size = comm.get_world_size()
self.rank = comm.get_rank()
self.local_rank = comm.get_local_rank()
self.device = torch.device(f'cuda:{self.local_rank}')
self.num_local: int = self.num_classes // self.world_size + int(self.rank < self.num_classes % self.world_size)
self.class_start: int = self.num_classes // self.world_size * self.rank + \
min(self.rank, self.num_classes % self.world_size)
self.num_sample: int = int(self.sample_rate * self.num_local)
self.cls_layer = getattr(any_softmax, cls_type)(num_classes, scale, margin)
self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
logger.info("softmax weight init successfully!")
logger.info("softmax weight mom init successfully!")
self.stream: torch.cuda.Stream = torch.cuda.Stream(self.local_rank)
self.index = None
if int(self.sample_rate) == 1:
self.update = lambda: 0
self.sub_weight = nn.Parameter(self.weight)
self.sub_weight_mom = self.weight_mom
else:
self.sub_weight = nn.Parameter(torch.empty((0, 0), device=self.device))
def forward(self, total_features):
torch.cuda.current_stream().wait_stream(self.stream)
if self.cls_layer.__class__.__name__ == 'Linear':
logits = F.linear(total_features, self.sub_weight)
else:
logits = F.linear(F.normalize(total_features), F.normalize(self.sub_weight))
return logits
def forward_backward(self, features, targets, optimizer):
"""
Partial FC forward, which will sample positive weights and part of negative weights,
then compute logits and get the grad of features.
"""
total_targets = self.prepare(targets, optimizer)
if self.world_size > 1:
total_features = concat_all_gather(features)
else:
total_features = features.detach()
total_features.requires_grad_(True)
logits = self.forward(total_features)
logits = self.cls_layer(logits, total_targets)
# from ipdb import set_trace; set_trace()
with torch.no_grad():
max_fc = torch.max(logits, dim=1, keepdim=True)[0]
if self.world_size > 1:
dist.all_reduce(max_fc, dist.ReduceOp.MAX)
# calculate exp(logits) and all-reduce
logits_exp = torch.exp(logits - max_fc)
logits_sum_exp = logits_exp.sum(dim=1, keepdim=True)
if self.world_size > 1:
dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)
# calculate prob
logits_exp.div_(logits_sum_exp)
# get one-hot
grad = logits_exp
index = torch.where(total_targets != -1)[0]
one_hot = torch.zeros(size=[index.size()[0], grad.size()[1]], device=grad.device)
one_hot.scatter_(1, total_targets[index, None], 1)
# calculate loss
loss = torch.zeros(grad.size()[0], 1, device=grad.device)
loss[index] = grad[index].gather(1, total_targets[index, None])
if self.world_size > 1:
dist.all_reduce(loss, dist.ReduceOp.SUM)
loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)
# calculate grad
grad[index] -= one_hot
grad.div_(logits.size(0))
logits.backward(grad)
if total_features.grad is not None:
total_features.grad.detach_()
x_grad: torch.Tensor = torch.zeros_like(features)
# feature gradient all-reduce
if self.world_size > 1:
dist.reduce_scatter(x_grad, list(total_features.grad.chunk(self.world_size, dim=0)))
else:
x_grad = total_features.grad
x_grad = x_grad * self.world_size
# backward backbone
return x_grad, loss_v
@torch.no_grad()
def sample(self, total_targets):
"""
Get sub_weights according to total targets gathered from all GPUs, due to each weights in different
GPU contains different class centers.
"""
index_positive = (self.class_start <= total_targets) & (total_targets < self.class_start + self.num_local)
total_targets[~index_positive] = -1
total_targets[index_positive] -= self.class_start
if int(self.sample_rate) != 1:
positive = torch.unique(total_targets[index_positive], sorted=True)
if self.num_sample - positive.size(0) >= 0:
perm = torch.rand(size=[self.num_local], device=self.weight.device)
perm[positive] = 2.0
index = torch.topk(perm, k=self.num_sample)[1]
index = index.sort()[0]
else:
index = positive
self.index = index
total_targets[index_positive] = torch.searchsorted(index, total_targets[index_positive])
self.sub_weight = nn.Parameter(self.weight[index])
self.sub_weight_mom = self.weight_mom[index]
@torch.no_grad()
def update(self):
self.weight_mom[self.index] = self.sub_weight_mom
self.weight[self.index] = self.sub_weight
def prepare(self, targets, optimizer):
with torch.cuda.stream(self.stream):
if self.world_size > 1:
total_targets = concat_all_gather(targets)
else:
total_targets = targets
# update sub_weight
self.sample(total_targets)
optimizer.state.pop(optimizer.param_groups[-1]['params'][0], None)
optimizer.param_groups[-1]['params'][0] = self.sub_weight
optimizer.state[self.sub_weight]["momentum_buffer"] = self.sub_weight_mom
return total_targets
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import os
from typing import Any, Dict
import torch
from fastreid.engine.hooks import PeriodicCheckpointer
from fastreid.utils import comm
from fastreid.utils.checkpoint import Checkpointer
from fastreid.utils.file_io import PathManager
class PfcPeriodicCheckpointer(PeriodicCheckpointer):
def step(self, epoch: int, **kwargs: Any):
rank = comm.get_rank()
if (epoch + 1) % self.period == 0 and epoch < self.max_epoch - 1:
self.checkpointer.save(
f"softmax_weight_{epoch:04d}_rank_{rank:02d}"
)
if epoch >= self.max_epoch - 1:
self.checkpointer.save(f"softmax_weight_{rank:02d}", )
class PfcCheckpointer(Checkpointer):
def __init__(self, model, save_dir, *, save_to_disk=True, **checkpointables):
super().__init__(model, save_dir, save_to_disk=save_to_disk, **checkpointables)
self.rank = comm.get_rank()
def save(self, name: str, **kwargs: Dict[str, str]):
if not self.save_dir or not self.save_to_disk:
return
data = {}
data["model"] = {
"weight": self.model.weight.data,
"momentum": self.model.weight_mom,
}
for key, obj in self.checkpointables.items():
data[key] = obj.state_dict()
data.update(kwargs)
basename = f"{name}.pth"
save_file = os.path.join(self.save_dir, basename)
assert os.path.basename(save_file) == basename, basename
self.logger.info("Saving partial fc weights")
with PathManager.open(save_file, "wb") as f:
torch.save(data, f)
self.tag_last_checkpoint(basename)
def _load_model(self, checkpoint: Any):
checkpoint_state_dict = checkpoint.pop("model")
self._convert_ndarray_to_tensor(checkpoint_state_dict)
self.model.weight.data.copy_(checkpoint_state_dict.pop("weight"))
self.model.weight_mom.data.copy_(checkpoint_state_dict.pop("momentum"))
def has_checkpoint(self):
save_file = os.path.join(self.save_dir, f"last_weight_{self.rank:02d}")
return PathManager.exists(save_file)
def get_checkpoint_file(self):
"""
Returns:
str: The latest checkpoint file in target directory.
"""
save_file = os.path.join(self.save_dir, f"last_weight_{self.rank:02d}")
try:
with PathManager.open(save_file, "r") as f:
last_saved = f.read().strip()
except IOError:
# if file doesn't exist, maybe because it has just been
# deleted by a separate process
return ""
return os.path.join(self.save_dir, last_saved)
def tag_last_checkpoint(self, last_filename_basename: str):
save_file = os.path.join(self.save_dir, f"last_weight_{self.rank:02d}")
with PathManager.open(save_file, "w") as f:
f.write(last_filename_basename)
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import logging
import os
import time
from torch.nn.parallel import DistributedDataParallel
from torch.nn.utils import clip_grad_norm_
from fastreid.data.build import _root, build_reid_test_loader, build_reid_train_loader
from fastreid.data.datasets import DATASET_REGISTRY
from fastreid.data.transforms import build_transforms
from fastreid.engine import hooks
from fastreid.engine.defaults import DefaultTrainer, TrainerBase
from fastreid.engine.train_loop import SimpleTrainer, AMPTrainer
from fastreid.solver import build_optimizer
from fastreid.utils import comm
from fastreid.utils.checkpoint import Checkpointer
from fastreid.utils.logger import setup_logger
from fastreid.utils.params import ContiguousParams
from .face_data import MXFaceDataset
from .face_data import TestFaceDataset
from .face_evaluator import FaceEvaluator
from .modeling import PartialFC
from .pfc_checkpointer import PfcPeriodicCheckpointer, PfcCheckpointer
from .utils_amp import MaxClipGradScaler
class FaceTrainer(DefaultTrainer):
def __init__(self, cfg):
TrainerBase.__init__(self)
logger = logging.getLogger('fastreid.partial-fc.trainer')
if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for fastreid
setup_logger()
# Assume these objects must be constructed in this order.
data_loader = self.build_train_loader(cfg)
cfg = self.auto_scale_hyperparams(cfg, data_loader.dataset.num_classes)
model = self.build_model(cfg)
optimizer, param_wrapper = self.build_optimizer(cfg, model)
if cfg.MODEL.HEADS.PFC.ENABLED:
# fmt: off
feat_dim = cfg.MODEL.BACKBONE.FEAT_DIM
embedding_dim = cfg.MODEL.HEADS.EMBEDDING_DIM
num_classes = cfg.MODEL.HEADS.NUM_CLASSES
sample_rate = cfg.MODEL.HEADS.PFC.SAMPLE_RATE
cls_type = cfg.MODEL.HEADS.CLS_LAYER
scale = cfg.MODEL.HEADS.SCALE
margin = cfg.MODEL.HEADS.MARGIN
# fmt: on
# Partial-FC module
embedding_size = embedding_dim if embedding_dim > 0 else feat_dim
self.pfc_module = PartialFC(embedding_size, num_classes, sample_rate, cls_type, scale, margin)
self.pfc_optimizer, _ = build_optimizer(cfg, self.pfc_module, False)
# For training, wrap with DDP. But don't need this for inference.
if comm.get_world_size() > 1:
# ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
# for part of the parameters is not updated.
model = DistributedDataParallel(
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False,
)
if cfg.MODEL.HEADS.PFC.ENABLED:
mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size()
grad_scaler = MaxClipGradScaler(mini_batch_size, 128 * mini_batch_size, growth_interval=100)
self._trainer = PFCTrainer(model, data_loader, optimizer, param_wrapper,
self.pfc_module, self.pfc_optimizer, cfg.SOLVER.AMP.ENABLED, grad_scaler)
else:
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
model, data_loader, optimizer, param_wrapper
)
self.iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
self.scheduler = self.build_lr_scheduler(cfg, optimizer, self.iters_per_epoch)
if cfg.MODEL.HEADS.PFC.ENABLED:
self.pfc_scheduler = self.build_lr_scheduler(cfg, self.pfc_optimizer, self.iters_per_epoch)
self.checkpointer = Checkpointer(
# Assume you want to save checkpoints together with logs/statistics
model,
cfg.OUTPUT_DIR,
save_to_disk=comm.is_main_process(),
optimizer=optimizer,
**self.scheduler,
)
if cfg.MODEL.HEADS.PFC.ENABLED:
self.pfc_checkpointer = PfcCheckpointer(
self.pfc_module,
cfg.OUTPUT_DIR,
optimizer=self.pfc_optimizer,
**self.pfc_scheduler,
)
self.start_epoch = 0
self.max_epoch = cfg.SOLVER.MAX_EPOCH
self.max_iter = self.max_epoch * self.iters_per_epoch
self.warmup_iters = cfg.SOLVER.WARMUP_ITERS
self.delay_epochs = cfg.SOLVER.DELAY_EPOCHS
self.cfg = cfg
self.register_hooks(self.build_hooks())
def build_hooks(self):
ret = super().build_hooks()
if self.cfg.MODEL.HEADS.PFC.ENABLED:
# Make sure checkpointer is after writer
ret.insert(
len(ret) - 1,
PfcPeriodicCheckpointer(self.pfc_checkpointer, self.cfg.SOLVER.CHECKPOINT_PERIOD)
)
# partial fc scheduler hook
ret.append(
hooks.LRScheduler(self.pfc_optimizer, self.pfc_scheduler)
)
return ret
def resume_or_load(self, resume=True):
# Backbone loading state_dict
super().resume_or_load(resume)
# Partial-FC loading state_dict
if self.cfg.MODEL.HEADS.PFC.ENABLED:
self.pfc_checkpointer.resume_or_load('', resume=resume)
@classmethod
def build_train_loader(cls, cfg):
path_imgrec = cfg.DATASETS.REC_PATH
if path_imgrec != "":
transforms = build_transforms(cfg, is_train=True)
train_set = MXFaceDataset(path_imgrec, transforms)
return build_reid_train_loader(cfg, train_set=train_set)
else:
return DefaultTrainer.build_train_loader(cfg)
@classmethod
def build_test_loader(cls, cfg, dataset_name):
dataset = DATASET_REGISTRY.get(dataset_name)(root=_root)
test_set = TestFaceDataset(dataset.carray, dataset.is_same)
data_loader, _ = build_reid_test_loader(cfg, test_set=test_set)
return data_loader, test_set.labels
@classmethod
def build_evaluator(cls, cfg, dataset_name, output_dir=None):
if output_dir is None:
output_dir = os.path.join(cfg.OUTPUT_DIR, "visualization")
data_loader, labels = cls.build_test_loader(cfg, dataset_name)
return data_loader, FaceEvaluator(cfg, labels, dataset_name, output_dir)
class PFCTrainer(SimpleTrainer):
"""
Author: {Xiang An, Yang Xiao, XuHan Zhu} in DeepGlint,
Partial FC: Training 10 Million Identities on a Single Machine
See the original paper:
https://arxiv.org/abs/2010.05222
code based on:
https://github.com/deepinsight/insightface/blob/master/recognition/arcface_torch/partial_fc.py
"""
def __init__(self, model, data_loader, optimizer, param_wrapper, pfc_module, pfc_optimizer, amp_enabled,
grad_scaler):
super().__init__(model, data_loader, optimizer, param_wrapper)
self.pfc_module = pfc_module
self.pfc_optimizer = pfc_optimizer
self.amp_enabled = amp_enabled
self.grad_scaler = grad_scaler
def run_step(self):
assert self.model.training, "[PFCTrainer] model was changed to eval mode!"
start = time.perf_counter()
data = next(self._data_loader_iter)
data_time = time.perf_counter() - start
features, targets = self.model(data)
# Partial-fc backward
f_grad, loss_v = self.pfc_module.forward_backward(features, targets, self.pfc_optimizer)
if self.amp_enabled:
features.backward(self.grad_scaler.scale(f_grad))
self.grad_scaler.unscale_(self.optimizer)
clip_grad_norm_(self.model.parameters(), max_norm=5, norm_type=2)
self.grad_scaler.step(self.optimizer)
self.grad_scaler.update()
else:
features.backward(f_grad)
clip_grad_norm_(self.model.parameters(), max_norm=5, norm_type=2)
self.optimizer.step()
loss_dict = {"loss_cls": loss_v}
self._write_metrics(loss_dict, data_time)
self.pfc_optimizer.step()
self.pfc_module.update()
self.optimizer.zero_grad()
self.pfc_optimizer.zero_grad()
if isinstance(self.param_wrapper, ContiguousParams):
self.param_wrapper.assert_buffer_is_valid()
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
from typing import Dict, List
import torch
from torch._six import container_abcs
from torch.cuda.amp import GradScaler
class _MultiDeviceReplicator(object):
"""
Lazily serves copies of a tensor to requested devices. Copies are cached per-device.
"""
def __init__(self, master_tensor: torch.Tensor) -> None:
assert master_tensor.is_cuda
self.master = master_tensor
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
def get(self, device) -> torch.Tensor:
retval = self._per_device_tensors.get(device, None)
if retval is None:
retval = self.master.to(device=device, non_blocking=True, copy=True)
self._per_device_tensors[device] = retval
return retval
class MaxClipGradScaler(GradScaler):
def __init__(self, init_scale, max_scale: float, growth_interval=100):
super().__init__(init_scale=init_scale, growth_interval=growth_interval)
self.max_scale = max_scale
def scale_clip(self):
if self.get_scale() == self.max_scale:
self.set_growth_factor(1)
elif self.get_scale() < self.max_scale:
self.set_growth_factor(2)
elif self.get_scale() > self.max_scale:
self._scale.fill_(self.max_scale)
self.set_growth_factor(1)
def scale(self, outputs):
"""
Multiplies ('scales') a tensor or list of tensors by the scale factor.
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
unmodified.
Arguments:
outputs (Tensor or iterable of Tensors): Outputs to scale.
"""
if not self._enabled:
return outputs
self.scale_clip()
# Short-circuit for the common case.
if isinstance(outputs, torch.Tensor):
assert outputs.is_cuda
if self._scale is None:
self._lazy_init_scale_growth_tracker(outputs.device)
assert self._scale is not None
return outputs * self._scale.to(device=outputs.device, non_blocking=True)
# Invoke the more complex machinery only if we're treating multiple outputs.
stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale
def apply_scale(val):
if isinstance(val, torch.Tensor):
assert val.is_cuda
if len(stash) == 0:
if self._scale is None:
self._lazy_init_scale_growth_tracker(val.device)
assert self._scale is not None
stash.append(_MultiDeviceReplicator(self._scale))
return val * stash[0].get(val.device)
elif isinstance(val, container_abcs.Iterable):
iterable = map(apply_scale, val)
if isinstance(val, list) or isinstance(val, tuple):
return type(val)(iterable)
else:
return iterable
else:
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
return apply_scale(outputs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment