Commit b6c19984 authored by dengjb's avatar dengjb
Browse files

update

parents
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
# based on:
# https://github.com/XingangPan/IBN-Net/blob/master/models/imagenet/resnext_ibn_a.py
import logging
import math
import torch
import torch.nn as nn
from fastreid.layers import *
from fastreid.utils import comm
from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
from .build import BACKBONE_REGISTRY
logger = logging.getLogger(__name__)
model_urls = {
'ibn_101x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnext101_ibn_a-6ace051d.pth',
}
class Bottleneck(nn.Module):
"""
RexNeXt bottleneck type C
"""
expansion = 4
def __init__(self, inplanes, planes, bn_norm, with_ibn, baseWidth, cardinality, stride=1,
downsample=None):
""" Constructor
Args:
inplanes: input channel dimensionality
planes: output channel dimensionality
baseWidth: base width.
cardinality: num of convolution groups.
stride: conv stride. Replaces pooling layer.
"""
super(Bottleneck, self).__init__()
D = int(math.floor(planes * (baseWidth / 64)))
C = cardinality
self.conv1 = nn.Conv2d(inplanes, D * C, kernel_size=1, stride=1, padding=0, bias=False)
if with_ibn:
self.bn1 = IBN(D * C, bn_norm)
else:
self.bn1 = get_norm(bn_norm, D * C)
self.conv2 = nn.Conv2d(D * C, D * C, kernel_size=3, stride=stride, padding=1, groups=C, bias=False)
self.bn2 = get_norm(bn_norm, D * C)
self.conv3 = nn.Conv2d(D * C, planes * 4, kernel_size=1, stride=1, padding=0, bias=False)
self.bn3 = get_norm(bn_norm, planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNeXt(nn.Module):
"""
ResNext optimized for the ImageNet dataset, as specified in
https://arxiv.org/pdf/1611.05431.pdf
"""
def __init__(self, last_stride, bn_norm, with_ibn, with_nl, block, layers, non_layers,
baseWidth=4, cardinality=32):
""" Constructor
Args:
baseWidth: baseWidth for ResNeXt.
cardinality: number of convolution groups.
layers: config of layers, e.g., [3, 4, 6, 3]
"""
super(ResNeXt, self).__init__()
self.cardinality = cardinality
self.baseWidth = baseWidth
self.inplanes = 64
self.output_size = 64
self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False)
self.bn1 = get_norm(bn_norm, 64)
self.relu = nn.ReLU(inplace=True)
self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], 1, bn_norm, with_ibn=with_ibn)
self.layer2 = self._make_layer(block, 128, layers[1], 2, bn_norm, with_ibn=with_ibn)
self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, with_ibn=with_ibn)
self.layer4 = self._make_layer(block, 512, layers[3], last_stride, bn_norm, with_ibn=with_ibn)
self.random_init()
# fmt: off
if with_nl: self._build_nonlocal(layers, non_layers, bn_norm)
else: self.NL_1_idx = self.NL_2_idx = self.NL_3_idx = self.NL_4_idx = []
# fmt: on
def _make_layer(self, block, planes, blocks, stride=1, bn_norm='BN', with_ibn=False):
""" Stack n bottleneck modules where n is inferred from the depth of the network.
Args:
block: block type used to construct ResNext
planes: number of output channels (need to multiply by block.expansion)
blocks: number of blocks to be built
stride: factor to reduce the spatial dimensionality in the first bottleneck of the block.
Returns: a Module consisting of n sequential bottlenecks.
"""
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
get_norm(bn_norm, planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, bn_norm, with_ibn,
self.baseWidth, self.cardinality, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
block(self.inplanes, planes, bn_norm, with_ibn, self.baseWidth, self.cardinality, 1, None))
return nn.Sequential(*layers)
def _build_nonlocal(self, layers, non_layers, bn_norm):
self.NL_1 = nn.ModuleList(
[Non_local(256, bn_norm) for _ in range(non_layers[0])])
self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])])
self.NL_2 = nn.ModuleList(
[Non_local(512, bn_norm) for _ in range(non_layers[1])])
self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])])
self.NL_3 = nn.ModuleList(
[Non_local(1024, bn_norm) for _ in range(non_layers[2])])
self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])])
self.NL_4 = nn.ModuleList(
[Non_local(2048, bn_norm) for _ in range(non_layers[3])])
self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])])
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool1(x)
NL1_counter = 0
if len(self.NL_1_idx) == 0:
self.NL_1_idx = [-1]
for i in range(len(self.layer1)):
x = self.layer1[i](x)
if i == self.NL_1_idx[NL1_counter]:
_, C, H, W = x.shape
x = self.NL_1[NL1_counter](x)
NL1_counter += 1
# Layer 2
NL2_counter = 0
if len(self.NL_2_idx) == 0:
self.NL_2_idx = [-1]
for i in range(len(self.layer2)):
x = self.layer2[i](x)
if i == self.NL_2_idx[NL2_counter]:
_, C, H, W = x.shape
x = self.NL_2[NL2_counter](x)
NL2_counter += 1
# Layer 3
NL3_counter = 0
if len(self.NL_3_idx) == 0:
self.NL_3_idx = [-1]
for i in range(len(self.layer3)):
x = self.layer3[i](x)
if i == self.NL_3_idx[NL3_counter]:
_, C, H, W = x.shape
x = self.NL_3[NL3_counter](x)
NL3_counter += 1
# Layer 4
NL4_counter = 0
if len(self.NL_4_idx) == 0:
self.NL_4_idx = [-1]
for i in range(len(self.layer4)):
x = self.layer4[i](x)
if i == self.NL_4_idx[NL4_counter]:
_, C, H, W = x.shape
x = self.NL_4[NL4_counter](x)
NL4_counter += 1
return x
def random_init(self):
self.conv1.weight.data.normal_(0, math.sqrt(2. / (7 * 7 * 64)))
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.InstanceNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def init_pretrained_weights(key):
"""Initializes model with pretrained weights.
Layers that don't match with pretrained layers in name or size are kept unchanged.
"""
import os
import errno
import gdown
def _get_torch_home():
ENV_TORCH_HOME = 'TORCH_HOME'
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache'
torch_home = os.path.expanduser(
os.getenv(
ENV_TORCH_HOME,
os.path.join(
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
)
)
)
return torch_home
torch_home = _get_torch_home()
model_dir = os.path.join(torch_home, 'checkpoints')
try:
os.makedirs(model_dir)
except OSError as e:
if e.errno == errno.EEXIST:
# Directory already exists, ignore.
pass
else:
# Unexpected OSError, re-raise.
raise
filename = model_urls[key].split('/')[-1]
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
logger.info(f"Pretrain model don't exist, downloading from {model_urls[key]}")
if comm.is_main_process():
gdown.download(model_urls[key], cached_file, quiet=False)
comm.synchronize()
logger.info(f"Loading pretrained model from {cached_file}")
state_dict = torch.load(cached_file, map_location=torch.device('cpu'))
return state_dict
@BACKBONE_REGISTRY.register()
def build_resnext_backbone(cfg):
"""
Create a ResNeXt instance from config.
Returns:
ResNeXt: a :class:`ResNeXt` instance.
"""
# fmt: off
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
bn_norm = cfg.MODEL.BACKBONE.NORM
with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
with_nl = cfg.MODEL.BACKBONE.WITH_NL
depth = cfg.MODEL.BACKBONE.DEPTH
# fmt: on
num_blocks_per_stage = {
'50x': [3, 4, 6, 3],
'101x': [3, 4, 23, 3],
'152x': [3, 8, 36, 3], }[depth]
nl_layers_per_stage = {
'50x': [0, 2, 3, 0],
'101x': [0, 2, 3, 0]}[depth]
model = ResNeXt(last_stride, bn_norm, with_ibn, with_nl, Bottleneck,
num_blocks_per_stage, nl_layers_per_stage)
if pretrain:
if pretrain_path:
try:
state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))['model']
# Remove module.encoder in name
new_state_dict = {}
for k in state_dict:
new_k = '.'.join(k.split('.')[2:])
if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
new_state_dict[new_k] = state_dict[k]
state_dict = new_state_dict
logger.info(f"Loading pretrained model from {pretrain_path}")
except FileNotFoundError as e:
logger.info(f'{pretrain_path} is not found! Please check this path.')
raise e
except KeyError as e:
logger.info("State dict keys error! Please check the state dict.")
raise e
else:
key = depth
if with_ibn: key = 'ibn_' + key
state_dict = init_pretrained_weights(key)
incompatible = model.load_state_dict(state_dict, strict=False)
if incompatible.missing_keys:
logger.info(
get_missing_parameters_message(incompatible.missing_keys)
)
if incompatible.unexpected_keys:
logger.info(
get_unexpected_parameters_message(incompatible.unexpected_keys)
)
return model
"""
Author: Guan'an Wang
Contact: guan.wang0706@gmail.com
"""
import torch
from torch import nn
from collections import OrderedDict
import logging
from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
from fastreid.layers import get_norm
from fastreid.modeling.backbones import BACKBONE_REGISTRY
logger = logging.getLogger(__name__)
class ShuffleV2Block(nn.Module):
"""
Reference:
https://github.com/megvii-model/ShuffleNet-Series/tree/master/ShuffleNetV2
"""
def __init__(self, bn_norm, inp, oup, mid_channels, *, ksize, stride):
super(ShuffleV2Block, self).__init__()
self.stride = stride
assert stride in [1, 2]
self.mid_channels = mid_channels
self.ksize = ksize
pad = ksize // 2
self.pad = pad
self.inp = inp
outputs = oup - inp
branch_main = [
# pw
nn.Conv2d(inp, mid_channels, 1, 1, 0, bias=False),
get_norm(bn_norm, mid_channels),
nn.ReLU(inplace=True),
# dw
nn.Conv2d(mid_channels, mid_channels, ksize, stride, pad, groups=mid_channels, bias=False),
get_norm(bn_norm, mid_channels),
# pw-linear
nn.Conv2d(mid_channels, outputs, 1, 1, 0, bias=False),
get_norm(bn_norm, outputs),
nn.ReLU(inplace=True),
]
self.branch_main = nn.Sequential(*branch_main)
if stride == 2:
branch_proj = [
# dw
nn.Conv2d(inp, inp, ksize, stride, pad, groups=inp, bias=False),
get_norm(bn_norm, inp),
# pw-linear
nn.Conv2d(inp, inp, 1, 1, 0, bias=False),
get_norm(bn_norm, inp),
nn.ReLU(inplace=True),
]
self.branch_proj = nn.Sequential(*branch_proj)
else:
self.branch_proj = None
def forward(self, old_x):
if self.stride == 1:
x_proj, x = self.channel_shuffle(old_x)
return torch.cat((x_proj, self.branch_main(x)), 1)
elif self.stride == 2:
x_proj = old_x
x = old_x
return torch.cat((self.branch_proj(x_proj), self.branch_main(x)), 1)
def channel_shuffle(self, x):
batchsize, num_channels, height, width = x.data.size()
assert (num_channels % 4 == 0)
x = x.reshape(batchsize * num_channels // 2, 2, height * width)
x = x.permute(1, 0, 2)
x = x.reshape(2, -1, num_channels // 2, height, width)
return x[0], x[1]
class ShuffleNetV2(nn.Module):
"""
Reference:
https://github.com/megvii-model/ShuffleNet-Series/tree/master/ShuffleNetV2
"""
def __init__(self, bn_norm, model_size='1.5x'):
super(ShuffleNetV2, self).__init__()
self.stage_repeats = [4, 8, 4]
self.model_size = model_size
if model_size == '0.5x':
self.stage_out_channels = [-1, 24, 48, 96, 192, 1024]
elif model_size == '1.0x':
self.stage_out_channels = [-1, 24, 116, 232, 464, 1024]
elif model_size == '1.5x':
self.stage_out_channels = [-1, 24, 176, 352, 704, 1024]
elif model_size == '2.0x':
self.stage_out_channels = [-1, 24, 244, 488, 976, 2048]
else:
raise NotImplementedError
# building first layer
input_channel = self.stage_out_channels[1]
self.first_conv = nn.Sequential(
nn.Conv2d(3, input_channel, 3, 2, 1, bias=False),
get_norm(bn_norm, input_channel),
nn.ReLU(inplace=True),
)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.features = []
for idxstage in range(len(self.stage_repeats)):
numrepeat = self.stage_repeats[idxstage]
output_channel = self.stage_out_channels[idxstage + 2]
for i in range(numrepeat):
if i == 0:
self.features.append(ShuffleV2Block(bn_norm, input_channel, output_channel,
mid_channels=output_channel // 2, ksize=3, stride=2))
else:
self.features.append(ShuffleV2Block(bn_norm, input_channel // 2, output_channel,
mid_channels=output_channel // 2, ksize=3, stride=1))
input_channel = output_channel
self.features = nn.Sequential(*self.features)
self.conv_last = nn.Sequential(
nn.Conv2d(input_channel, self.stage_out_channels[-1], 1, 1, 0, bias=False),
get_norm(bn_norm, self.stage_out_channels[-1]),
nn.ReLU(inplace=True)
)
self._initialize_weights()
def forward(self, x):
x = self.first_conv(x)
x = self.maxpool(x)
x = self.features(x)
x = self.conv_last(x)
return x
def _initialize_weights(self):
for name, m in self.named_modules():
if isinstance(m, nn.Conv2d):
if 'first' in name:
nn.init.normal_(m.weight, 0, 0.01)
else:
nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1])
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0.0001)
nn.init.constant_(m.running_mean, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0.0001)
nn.init.constant_(m.running_mean, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
@BACKBONE_REGISTRY.register()
def build_shufflenetv2_backbone(cfg):
# fmt: off
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
bn_norm = cfg.MODEL.BACKBONE.NORM
model_size = cfg.MODEL.BACKBONE.DEPTH
# fmt: on
model = ShuffleNetV2(bn_norm, model_size=model_size)
if pretrain:
new_state_dict = OrderedDict()
state_dict = torch.load(pretrain_path)["state_dict"]
for k, v in state_dict.items():
if k[:7] == 'module.':
k = k[7:]
new_state_dict[k] = v
incompatible = model.load_state_dict(new_state_dict, strict=False)
if incompatible.missing_keys:
logger.info(
get_missing_parameters_message(incompatible.missing_keys)
)
if incompatible.unexpected_keys:
logger.info(
get_unexpected_parameters_message(incompatible.unexpected_keys)
)
return model
""" Vision Transformer (ViT) in PyTorch
A PyTorch implement of Vision Transformers as described in
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
The official jax code is released and available at https://github.com/google-research/vision_transformer
Status/TODO:
* Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.
* Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.
* Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.
* Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.
Acknowledgments:
* The paper authors for releasing code and weights, thanks!
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
for some einops/einsum fun
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
Hacked together by / Copyright 2020 Ross Wightman
"""
import logging
import math
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from fastreid.layers import DropPath, trunc_normal_, to_2tuple
from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
from .build import BACKBONE_REGISTRY
logger = logging.getLogger(__name__)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class HybridEmbed(nn.Module):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
"""
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
super().__init__()
assert isinstance(backbone, nn.Module)
img_size = to_2tuple(img_size)
self.img_size = img_size
self.backbone = backbone
if feature_size is None:
with torch.no_grad():
# FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
# map for all networks, the feature metadata has reliable channel and stride info, but using
# stride to calc feature dim requires info about padding of each stage that isn't captured.
training = backbone.training
if training:
backbone.eval()
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
if isinstance(o, (list, tuple)):
o = o[-1] # last feature if backbone outputs list/tuple of features
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
else:
feature_size = to_2tuple(feature_size)
if hasattr(self.backbone, 'feature_info'):
feature_dim = self.backbone.feature_info.channels()[-1]
else:
feature_dim = self.backbone.num_features
self.num_patches = feature_size[0] * feature_size[1]
self.proj = nn.Conv2d(feature_dim, embed_dim, 1)
def forward(self, x):
x = self.backbone(x)
if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class PatchEmbed_overlap(nn.Module):
""" Image to Patch Embedding with overlapping patches
"""
def __init__(self, img_size=224, patch_size=16, stride_size=20, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
stride_size_tuple = to_2tuple(stride_size)
self.num_x = (img_size[1] - patch_size[1]) // stride_size_tuple[1] + 1
self.num_y = (img_size[0] - patch_size[0]) // stride_size_tuple[0] + 1
num_patches = self.num_x * self.num_y
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride_size)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.InstanceNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
x = x.flatten(2).transpose(1, 2) # [64, 8, 768]
return x
class VisionTransformer(nn.Module):
""" Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
- https://arxiv.org/abs/2012.12877
"""
def __init__(self, img_size=224, patch_size=16, stride_size=16, in_chans=3, embed_dim=768,
depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None,
drop_rate=0., attn_drop_rate=0., camera=0, drop_path_rate=0., hybrid_backbone=None,
norm_layer=partial(nn.LayerNorm, eps=1e-6), sie_xishu=1.0):
super().__init__()
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
if hybrid_backbone is not None:
self.patch_embed = HybridEmbed(
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
else:
self.patch_embed = PatchEmbed_overlap(
img_size=img_size, patch_size=patch_size, stride_size=stride_size, in_chans=in_chans,
embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.cam_num = camera
self.sie_xishu = sie_xishu
# Initialize SIE Embedding
if camera > 1:
self.sie_embed = nn.Parameter(torch.zeros(camera, 1, embed_dim))
trunc_normal_(self.sie_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
trunc_normal_(self.cls_token, std=.02)
trunc_normal_(self.pos_embed, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def forward(self, x, camera_id=None):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
if self.cam_num > 0:
x = x + self.pos_embed + self.sie_xishu * self.sie_embed[camera_id]
else:
x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x[:, 0].reshape(x.shape[0], -1, 1, 1)
def resize_pos_embed(posemb, posemb_new, hight, width):
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
ntok_new = posemb_new.shape[1]
posemb_token, posemb_grid = posemb[:, :1], posemb[0, 1:]
ntok_new -= 1
gs_old = int(math.sqrt(len(posemb_grid)))
logger.info('Resized position embedding from size:{} to size: {} with height:{} width: {}'.format(posemb.shape,
posemb_new.shape,
hight,
width))
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid, size=(hight, width), mode='bilinear')
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, hight * width, -1)
posemb = torch.cat([posemb_token, posemb_grid], dim=1)
return posemb
@BACKBONE_REGISTRY.register()
def build_vit_backbone(cfg):
"""
Create a Vision Transformer instance from config.
Returns:
SwinTransformer: a :class:`SwinTransformer` instance.
"""
# fmt: off
input_size = cfg.INPUT.SIZE_TRAIN
pretrain = cfg.MODEL.BACKBONE.PRETRAIN
pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
depth = cfg.MODEL.BACKBONE.DEPTH
sie_xishu = cfg.MODEL.BACKBONE.SIE_COE
stride_size = cfg.MODEL.BACKBONE.STRIDE_SIZE
drop_ratio = cfg.MODEL.BACKBONE.DROP_RATIO
drop_path_ratio = cfg.MODEL.BACKBONE.DROP_PATH_RATIO
attn_drop_rate = cfg.MODEL.BACKBONE.ATT_DROP_RATE
# fmt: on
num_depth = {
'small': 8,
'base': 12,
}[depth]
num_heads = {
'small': 8,
'base': 12,
}[depth]
mlp_ratio = {
'small': 3.,
'base': 4.
}[depth]
qkv_bias = {
'small': False,
'base': True
}[depth]
qk_scale = {
'small': 768 ** -0.5,
'base': None,
}[depth]
model = VisionTransformer(img_size=input_size, sie_xishu=sie_xishu, stride_size=stride_size, depth=num_depth,
num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop_path_rate=drop_path_ratio, drop_rate=drop_ratio, attn_drop_rate=attn_drop_rate)
if pretrain:
try:
state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))
logger.info(f"Loading pretrained model from {pretrain_path}")
if 'model' in state_dict:
state_dict = state_dict.pop('model')
if 'state_dict' in state_dict:
state_dict = state_dict.pop('state_dict')
for k, v in state_dict.items():
if 'head' in k or 'dist' in k:
continue
if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
# For old models that I trained prior to conv based patchification
O, I, H, W = model.patch_embed.proj.weight.shape
v = v.reshape(O, -1, H, W)
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
# To resize pos embedding when using model at different size from pretrained weights
if 'distilled' in pretrain_path:
logger.info("distill need to choose right cls token in the pth.")
v = torch.cat([v[:, 0:1], v[:, 2:]], dim=1)
v = resize_pos_embed(v, model.pos_embed.data, model.patch_embed.num_y, model.patch_embed.num_x)
state_dict[k] = v
except FileNotFoundError as e:
logger.info(f'{pretrain_path} is not found! Please check this path.')
raise e
except KeyError as e:
logger.info("State dict keys error! Please check the state dict.")
raise e
incompatible = model.load_state_dict(state_dict, strict=False)
if incompatible.missing_keys:
logger.info(
get_missing_parameters_message(incompatible.missing_keys)
)
if incompatible.unexpected_keys:
logger.info(
get_unexpected_parameters_message(incompatible.unexpected_keys)
)
return model
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from .build import REID_HEADS_REGISTRY, build_heads
# import all the meta_arch, so they will be registered
from .embedding_head import EmbeddingHead
from .clas_head import ClasHead
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from ...utils.registry import Registry
REID_HEADS_REGISTRY = Registry("HEADS")
REID_HEADS_REGISTRY.__doc__ = """
Registry for reid heads in a baseline model.
ROIHeads take feature maps and region proposals, and
perform per-region computation.
The registered object will be called with `obj(cfg, input_shape)`.
The call is expected to return an :class:`ROIHeads`.
"""
def build_heads(cfg):
"""
Build REIDHeads defined by `cfg.MODEL.REID_HEADS.NAME`.
"""
head = cfg.MODEL.HEADS.NAME
return REID_HEADS_REGISTRY.get(head)(cfg)
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import torch.nn.functional as F
from fastreid.modeling.heads import REID_HEADS_REGISTRY, EmbeddingHead
@REID_HEADS_REGISTRY.register()
class ClasHead(EmbeddingHead):
def forward(self, features, targets=None):
"""
See :class:`ClsHeads.forward`.
"""
pool_feat = self.pool_layer(features)
neck_feat = self.bottleneck(pool_feat)
neck_feat = neck_feat.view(neck_feat.size(0), -1)
if self.cls_layer.__class__.__name__ == 'Linear':
logits = F.linear(neck_feat, self.weight)
else:
logits = F.linear(F.normalize(neck_feat), F.normalize(self.weight))
# Evaluation
if not self.training: return logits.mul_(self.cls_layer.s)
cls_outputs = self.cls_layer(logits.clone(), targets)
return {
"cls_outputs": cls_outputs,
"pred_class_logits": logits.mul_(self.cls_layer.s),
"features": neck_feat,
}
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
import torch.nn.functional as F
from torch import nn
from fastreid.config import configurable
from fastreid.layers import *
from fastreid.layers import pooling, any_softmax
from fastreid.layers.weight_init import weights_init_kaiming
from .build import REID_HEADS_REGISTRY
@REID_HEADS_REGISTRY.register()
class EmbeddingHead(nn.Module):
"""
EmbeddingHead perform all feature aggregation in an embedding task, such as reid, image retrieval
and face recognition
It typically contains logic to
1. feature aggregation via global average pooling and generalized mean pooling
2. (optional) batchnorm, dimension reduction and etc.
2. (in training only) margin-based softmax logits computation
"""
@configurable
def __init__(
self,
*,
feat_dim,
embedding_dim,
num_classes,
neck_feat,
pool_type,
cls_type,
scale,
margin,
with_bnneck,
norm_type
):
"""
NOTE: this interface is experimental.
Args:
feat_dim:
embedding_dim:
num_classes:
neck_feat:
pool_type:
cls_type:
scale:
margin:
with_bnneck:
norm_type:
"""
super().__init__()
# Pooling layer
assert hasattr(pooling, pool_type), "Expected pool types are {}, " \
"but got {}".format(pooling.__all__, pool_type)
self.pool_layer = getattr(pooling, pool_type)()
self.neck_feat = neck_feat
neck = []
if embedding_dim > 0:
neck.append(nn.Conv2d(feat_dim, embedding_dim, 1, 1, bias=False))
feat_dim = embedding_dim
if with_bnneck:
neck.append(get_norm(norm_type, feat_dim, bias_freeze=True))
self.bottleneck = nn.Sequential(*neck)
# Classification head
assert hasattr(any_softmax, cls_type), "Expected cls types are {}, " \
"but got {}".format(any_softmax.__all__, cls_type)
self.weight = nn.Parameter(torch.Tensor(num_classes, feat_dim))
self.cls_layer = getattr(any_softmax, cls_type)(num_classes, scale, margin)
self.reset_parameters()
def reset_parameters(self) -> None:
self.bottleneck.apply(weights_init_kaiming)
nn.init.normal_(self.weight, std=0.01)
@classmethod
def from_config(cls, cfg):
# fmt: off
feat_dim = cfg.MODEL.BACKBONE.FEAT_DIM
embedding_dim = cfg.MODEL.HEADS.EMBEDDING_DIM
num_classes = cfg.MODEL.HEADS.NUM_CLASSES
neck_feat = cfg.MODEL.HEADS.NECK_FEAT
pool_type = cfg.MODEL.HEADS.POOL_LAYER
cls_type = cfg.MODEL.HEADS.CLS_LAYER
scale = cfg.MODEL.HEADS.SCALE
margin = cfg.MODEL.HEADS.MARGIN
with_bnneck = cfg.MODEL.HEADS.WITH_BNNECK
norm_type = cfg.MODEL.HEADS.NORM
# fmt: on
return {
'feat_dim': feat_dim,
'embedding_dim': embedding_dim,
'num_classes': num_classes,
'neck_feat': neck_feat,
'pool_type': pool_type,
'cls_type': cls_type,
'scale': scale,
'margin': margin,
'with_bnneck': with_bnneck,
'norm_type': norm_type
}
def forward(self, features, targets=None):
"""
See :class:`ReIDHeads.forward`.
"""
pool_feat = self.pool_layer(features)
neck_feat = self.bottleneck(pool_feat)
neck_feat = neck_feat[..., 0, 0]
# Evaluation
# fmt: off
if not self.training: return neck_feat
# fmt: on
# Training
if self.cls_layer.__class__.__name__ == 'Linear':
logits = F.linear(neck_feat, self.weight)
else:
logits = F.linear(F.normalize(neck_feat), F.normalize(self.weight))
# Pass logits.clone() into cls_layer, because there is in-place operations
cls_outputs = self.cls_layer(logits.clone(), targets)
# fmt: off
if self.neck_feat == 'before': feat = pool_feat[..., 0, 0]
elif self.neck_feat == 'after': feat = neck_feat
else: raise KeyError(f"{self.neck_feat} is invalid for MODEL.HEADS.NECK_FEAT")
# fmt: on
return {
"cls_outputs": cls_outputs,
"pred_class_logits": logits.mul(self.cls_layer.s),
"features": feat,
}
# encoding: utf-8
"""
@author: l1aoxingyu
@contact: sherlockliao01@gmail.com
"""
from .circle_loss import *
from .cross_entroy_loss import cross_entropy_loss, log_accuracy
from .focal_loss import focal_loss
from .triplet_loss import triplet_loss
__all__ = [k for k in globals().keys() if not k.startswith("_")]
\ No newline at end of file
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import torch
import torch.nn.functional as F
__all__ = ["pairwise_circleloss", "pairwise_cosface"]
def pairwise_circleloss(
embedding: torch.Tensor,
targets: torch.Tensor,
margin: float,
gamma: float, ) -> torch.Tensor:
embedding = F.normalize(embedding, dim=1)
dist_mat = torch.matmul(embedding, embedding.t())
N = dist_mat.size(0)
is_pos = targets.view(N, 1).expand(N, N).eq(targets.view(N, 1).expand(N, N).t()).float()
is_neg = targets.view(N, 1).expand(N, N).ne(targets.view(N, 1).expand(N, N).t()).float()
# Mask scores related to itself
is_pos = is_pos - torch.eye(N, N, device=is_pos.device)
s_p = dist_mat * is_pos
s_n = dist_mat * is_neg
alpha_p = torch.clamp_min(-s_p.detach() + 1 + margin, min=0.)
alpha_n = torch.clamp_min(s_n.detach() + margin, min=0.)
delta_p = 1 - margin
delta_n = margin
logit_p = - gamma * alpha_p * (s_p - delta_p) + (-99999999.) * (1 - is_pos)
logit_n = gamma * alpha_n * (s_n - delta_n) + (-99999999.) * (1 - is_neg)
loss = F.softplus(torch.logsumexp(logit_p, dim=1) + torch.logsumexp(logit_n, dim=1)).mean()
return loss
def pairwise_cosface(
embedding: torch.Tensor,
targets: torch.Tensor,
margin: float,
gamma: float, ) -> torch.Tensor:
# Normalize embedding features
embedding = F.normalize(embedding, dim=1)
dist_mat = torch.matmul(embedding, embedding.t())
N = dist_mat.size(0)
is_pos = targets.view(N, 1).expand(N, N).eq(targets.view(N, 1).expand(N, N).t()).float()
is_neg = targets.view(N, 1).expand(N, N).ne(targets.view(N, 1).expand(N, N).t()).float()
# Mask scores related to itself
is_pos = is_pos - torch.eye(N, N, device=is_pos.device)
s_p = dist_mat * is_pos
s_n = dist_mat * is_neg
logit_p = -gamma * s_p + (-99999999.) * (1 - is_pos)
logit_n = gamma * (s_n + margin) + (-99999999.) * (1 - is_neg)
loss = F.softplus(torch.logsumexp(logit_p, dim=1) + torch.logsumexp(logit_n, dim=1)).mean()
return loss
# encoding: utf-8
"""
@author: l1aoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
import torch.nn.functional as F
from fastreid.utils.events import get_event_storage
def log_accuracy(pred_class_logits, gt_classes, topk=(1,)):
"""
Log the accuracy metrics to EventStorage.
"""
bsz = pred_class_logits.size(0)
maxk = max(topk)
_, pred_class = pred_class_logits.topk(maxk, 1, True, True)
pred_class = pred_class.t()
correct = pred_class.eq(gt_classes.view(1, -1).expand_as(pred_class))
ret = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True)
ret.append(correct_k.mul_(1. / bsz))
storage = get_event_storage()
storage.put_scalar("cls_accuracy", ret[0])
def cross_entropy_loss(pred_class_outputs, gt_classes, eps, alpha=0.2):
num_classes = pred_class_outputs.size(1)
if eps >= 0:
smooth_param = eps
else:
# Adaptive label smooth regularization
soft_label = F.softmax(pred_class_outputs, dim=1)
smooth_param = alpha * soft_label[torch.arange(soft_label.size(0)), gt_classes].unsqueeze(1)
log_probs = F.log_softmax(pred_class_outputs, dim=1)
with torch.no_grad():
targets = torch.ones_like(log_probs)
targets *= smooth_param / (num_classes - 1)
targets.scatter_(1, gt_classes.data.unsqueeze(1), (1 - smooth_param))
loss = (-targets * log_probs).sum(dim=1)
with torch.no_grad():
non_zero_cnt = max(loss.nonzero(as_tuple=False).size(0), 1)
loss = loss.sum() / non_zero_cnt
return loss
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import torch
import torch.nn.functional as F
# based on:
# https://github.com/kornia/kornia/blob/master/kornia/losses/focal.py
def focal_loss(
input: torch.Tensor,
target: torch.Tensor,
alpha: float,
gamma: float = 2.0,
reduction: str = 'mean') -> torch.Tensor:
r"""Criterion that computes Focal loss.
See :class:`fastreid.modeling.losses.FocalLoss` for details.
According to [1], the Focal loss is computed as follows:
.. math::
\text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
where:
- :math:`p_t` is the model's estimated probability for each class.
Arguments:
alpha (float): Weighting factor :math:`\alpha \in [0, 1]`.
gamma (float): Focusing parameter :math:`\gamma >= 0`.
reduction (str, optional): Specifies the reduction to apply to the
output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied,
‘mean’: the sum of the output will be divided by the number of elements
in the output, ‘sum’: the output will be summed. Default: ‘none’.
Shape:
- Input: :math:`(N, C, *)` where C = number of classes.
- Target: :math:`(N, *)` where each value is
:math:`0 ≤ targets[i] ≤ C−1`.
Examples:
>>> N = 5 # num_classes
>>> loss = FocalLoss(cfg)
>>> input = torch.randn(1, N, 3, 5, requires_grad=True)
>>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
>>> output = loss(input, target)
>>> output.backward()
References:
[1] https://arxiv.org/abs/1708.02002
"""
if not torch.is_tensor(input):
raise TypeError("Input type is not a torch.Tensor. Got {}"
.format(type(input)))
if not len(input.shape) >= 2:
raise ValueError("Invalid input shape, we expect BxCx*. Got: {}"
.format(input.shape))
if input.size(0) != target.size(0):
raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
.format(input.size(0), target.size(0)))
n = input.size(0)
out_size = (n,) + input.size()[2:]
if target.size()[1:] != input.size()[2:]:
raise ValueError('Expected target size {}, got {}'.format(
out_size, target.size()))
if not input.device == target.device:
raise ValueError(
"input and target must be in the same device. Got: {}".format(
input.device, target.device))
# compute softmax over the classes axis
input_soft = F.softmax(input, dim=1)
# create the labels one hot tensor
target_one_hot = F.one_hot(target, num_classes=input.shape[1])
# compute the actual focal loss
weight = torch.pow(-input_soft + 1., gamma)
focal = -alpha * weight * torch.log(input_soft)
loss_tmp = torch.sum(target_one_hot * focal, dim=1)
if reduction == 'none':
loss = loss_tmp
elif reduction == 'mean':
loss = torch.mean(loss_tmp)
elif reduction == 'sum':
loss = torch.sum(loss_tmp)
else:
raise NotImplementedError("Invalid reduction mode: {}"
.format(reduction))
return loss
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
import torch.nn.functional as F
from .utils import euclidean_dist, cosine_dist
def softmax_weights(dist, mask):
max_v = torch.max(dist * mask, dim=1, keepdim=True)[0]
diff = dist - max_v
Z = torch.sum(torch.exp(diff) * mask, dim=1, keepdim=True) + 1e-6 # avoid division by zero
W = torch.exp(diff) * mask / Z
return W
def hard_example_mining(dist_mat, is_pos, is_neg):
"""For each anchor, find the hardest positive and negative sample.
Args:
dist_mat: pair wise distance between samples, shape [N, M]
is_pos: positive index with shape [N, M]
is_neg: negative index with shape [N, M]
Returns:
dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
dist_an: pytorch Variable, distance(anchor, negative); shape [N]
p_inds: pytorch LongTensor, with shape [N];
indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
n_inds: pytorch LongTensor, with shape [N];
indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
NOTE: Only consider the case in which all labels have same num of samples,
thus we can cope with all anchors in parallel.
"""
assert len(dist_mat.size()) == 2
# `dist_ap` means distance(anchor, positive)
# both `dist_ap` and `relative_p_inds` with shape [N]
dist_ap, _ = torch.max(dist_mat * is_pos, dim=1)
# `dist_an` means distance(anchor, negative)
# both `dist_an` and `relative_n_inds` with shape [N]
dist_an, _ = torch.min(dist_mat * is_neg + is_pos * 1e9, dim=1)
return dist_ap, dist_an
def weighted_example_mining(dist_mat, is_pos, is_neg):
"""For each anchor, find the weighted positive and negative sample.
Args:
dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
is_pos:
is_neg:
Returns:
dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
dist_an: pytorch Variable, distance(anchor, negative); shape [N]
"""
assert len(dist_mat.size()) == 2
is_pos = is_pos
is_neg = is_neg
dist_ap = dist_mat * is_pos
dist_an = dist_mat * is_neg
weights_ap = softmax_weights(dist_ap, is_pos)
weights_an = softmax_weights(-dist_an, is_neg)
dist_ap = torch.sum(dist_ap * weights_ap, dim=1)
dist_an = torch.sum(dist_an * weights_an, dim=1)
return dist_ap, dist_an
def triplet_loss(embedding, targets, margin, norm_feat, hard_mining):
r"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).
Related Triplet Loss theory can be found in paper 'In Defense of the Triplet
Loss for Person Re-Identification'."""
if norm_feat:
dist_mat = cosine_dist(embedding, embedding)
else:
dist_mat = euclidean_dist(embedding, embedding)
# For distributed training, gather all features from different process.
# if comm.get_world_size() > 1:
# all_embedding = torch.cat(GatherLayer.apply(embedding), dim=0)
# all_targets = concat_all_gather(targets)
# else:
# all_embedding = embedding
# all_targets = targets
N = dist_mat.size(0)
is_pos = targets.view(N, 1).expand(N, N).eq(targets.view(N, 1).expand(N, N).t()).float()
is_neg = targets.view(N, 1).expand(N, N).ne(targets.view(N, 1).expand(N, N).t()).float()
if hard_mining:
dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg)
else:
dist_ap, dist_an = weighted_example_mining(dist_mat, is_pos, is_neg)
y = dist_an.new().resize_as_(dist_an).fill_(1)
if margin > 0:
loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=margin)
else:
loss = F.soft_margin_loss(dist_an - dist_ap, y)
# fmt: off
if loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3)
# fmt: on
return loss
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import torch
import torch.nn.functional as F
def concat_all_gather(tensor):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
tensors_gather = [torch.ones_like(tensor)
for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output
def normalize(x, axis=-1):
"""Normalizing to unit length along the specified dimension.
Args:
x: pytorch Variable
Returns:
x: pytorch Variable, same shape as input
"""
x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12)
return x
def euclidean_dist(x, y):
m, n = x.size(0), y.size(0)
xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
dist = xx + yy - 2 * torch.matmul(x, y.t())
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
return dist
def cosine_dist(x, y):
x = F.normalize(x, dim=1)
y = F.normalize(y, dim=1)
dist = 2 - 2 * torch.mm(x, y.t())
return dist
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from .build import META_ARCH_REGISTRY, build_model
# import all the meta_arch, so they will be registered
from .baseline import Baseline
from .mgn import MGN
from .moco import MoCo
from .distiller import Distiller
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
from torch import nn
from fastreid.config import configurable
from fastreid.modeling.backbones import build_backbone
from fastreid.modeling.heads import build_heads
from fastreid.modeling.losses import *
from .build import META_ARCH_REGISTRY
@META_ARCH_REGISTRY.register()
class Baseline(nn.Module):
"""
Baseline architecture. Any models that contains the following two components:
1. Per-image feature extraction (aka backbone)
2. Per-image feature aggregation and loss computation
"""
@configurable
def __init__(
self,
*,
backbone,
heads,
pixel_mean,
pixel_std,
loss_kwargs=None
):
"""
NOTE: this interface is experimental.
Args:
backbone:
heads:
pixel_mean:
pixel_std:
"""
super().__init__()
# backbone
self.backbone = backbone
# head
self.heads = heads
self.loss_kwargs = loss_kwargs
self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(1, -1, 1, 1), False)
self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(1, -1, 1, 1), False)
@classmethod
def from_config(cls, cfg):
backbone = build_backbone(cfg)
heads = build_heads(cfg)
return {
'backbone': backbone,
'heads': heads,
'pixel_mean': cfg.MODEL.PIXEL_MEAN,
'pixel_std': cfg.MODEL.PIXEL_STD,
'loss_kwargs':
{
# loss name
'loss_names': cfg.MODEL.LOSSES.NAME,
# loss hyperparameters
'ce': {
'eps': cfg.MODEL.LOSSES.CE.EPSILON,
'alpha': cfg.MODEL.LOSSES.CE.ALPHA,
'scale': cfg.MODEL.LOSSES.CE.SCALE
},
'tri': {
'margin': cfg.MODEL.LOSSES.TRI.MARGIN,
'norm_feat': cfg.MODEL.LOSSES.TRI.NORM_FEAT,
'hard_mining': cfg.MODEL.LOSSES.TRI.HARD_MINING,
'scale': cfg.MODEL.LOSSES.TRI.SCALE
},
'circle': {
'margin': cfg.MODEL.LOSSES.CIRCLE.MARGIN,
'gamma': cfg.MODEL.LOSSES.CIRCLE.GAMMA,
'scale': cfg.MODEL.LOSSES.CIRCLE.SCALE
},
'cosface': {
'margin': cfg.MODEL.LOSSES.COSFACE.MARGIN,
'gamma': cfg.MODEL.LOSSES.COSFACE.GAMMA,
'scale': cfg.MODEL.LOSSES.COSFACE.SCALE
}
}
}
@property
def device(self):
return self.pixel_mean.device
def forward(self, batched_inputs):
images = self.preprocess_image(batched_inputs)
features = self.backbone(images)
if self.training:
assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
targets = batched_inputs["targets"]
# PreciseBN flag, When do preciseBN on different dataset, the number of classes in new dataset
# may be larger than that in the original dataset, so the circle/arcface will
# throw an error. We just set all the targets to 0 to avoid this problem.
if targets.sum() < 0: targets.zero_()
outputs = self.heads(features, targets)
losses = self.losses(outputs, targets)
return losses
else:
outputs = self.heads(features)
return outputs
def preprocess_image(self, batched_inputs):
"""
Normalize and batch the input images.
"""
if isinstance(batched_inputs, dict):
images = batched_inputs['images']
elif isinstance(batched_inputs, torch.Tensor):
images = batched_inputs
else:
raise TypeError("batched_inputs must be dict or torch.Tensor, but get {}".format(type(batched_inputs)))
images.sub_(self.pixel_mean).div_(self.pixel_std)
return images
def losses(self, outputs, gt_labels):
"""
Compute loss from modeling's outputs, the loss function input arguments
must be the same as the outputs of the model forwarding.
"""
# model predictions
# fmt: off
pred_class_logits = outputs['pred_class_logits'].detach()
cls_outputs = outputs['cls_outputs']
pred_features = outputs['features']
# fmt: on
# Log prediction accuracy
log_accuracy(pred_class_logits, gt_labels)
loss_dict = {}
loss_names = self.loss_kwargs['loss_names']
if 'CrossEntropyLoss' in loss_names:
ce_kwargs = self.loss_kwargs.get('ce')
loss_dict['loss_cls'] = cross_entropy_loss(
cls_outputs,
gt_labels,
ce_kwargs.get('eps'),
ce_kwargs.get('alpha')
) * ce_kwargs.get('scale')
if 'TripletLoss' in loss_names:
tri_kwargs = self.loss_kwargs.get('tri')
loss_dict['loss_triplet'] = triplet_loss(
pred_features,
gt_labels,
tri_kwargs.get('margin'),
tri_kwargs.get('norm_feat'),
tri_kwargs.get('hard_mining')
) * tri_kwargs.get('scale')
if 'CircleLoss' in loss_names:
circle_kwargs = self.loss_kwargs.get('circle')
loss_dict['loss_circle'] = pairwise_circleloss(
pred_features,
gt_labels,
circle_kwargs.get('margin'),
circle_kwargs.get('gamma')
) * circle_kwargs.get('scale')
if 'Cosface' in loss_names:
cosface_kwargs = self.loss_kwargs.get('cosface')
loss_dict['loss_cosface'] = pairwise_cosface(
pred_features,
gt_labels,
cosface_kwargs.get('margin'),
cosface_kwargs.get('gamma'),
) * cosface_kwargs.get('scale')
return loss_dict
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
from fastreid.utils.registry import Registry
META_ARCH_REGISTRY = Registry("META_ARCH") # noqa F401 isort:skip
META_ARCH_REGISTRY.__doc__ = """
Registry for meta-architectures, i.e. the whole model.
The registered object will be called with `obj(cfg)`
and expected to return a `nn.Module` object.
"""
def build_model(cfg):
"""
Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
Note that it does not load any weights from ``cfg``.
"""
meta_arch = cfg.MODEL.META_ARCHITECTURE
model = META_ARCH_REGISTRY.get(meta_arch)(cfg)
model.to(torch.device(cfg.MODEL.DEVICE))
return model
# encoding: utf-8
"""
@author: l1aoxingyu
@contact: sherlockliao01@gmail.com
"""
import logging
import torch
import torch.nn.functional as F
from fastreid.config import get_cfg
from fastreid.modeling.meta_arch import META_ARCH_REGISTRY, build_model, Baseline
from fastreid.utils.checkpoint import Checkpointer
logger = logging.getLogger(__name__)
@META_ARCH_REGISTRY.register()
class Distiller(Baseline):
def __init__(self, cfg):
super().__init__(cfg)
# Get teacher model config
model_ts = []
for i in range(len(cfg.KD.MODEL_CONFIG)):
cfg_t = get_cfg()
cfg_t.merge_from_file(cfg.KD.MODEL_CONFIG[i])
cfg_t.defrost()
cfg_t.MODEL.META_ARCHITECTURE = "Baseline"
# Change syncBN to BN due to no DDP wrapper
if cfg_t.MODEL.BACKBONE.NORM == "syncBN":
cfg_t.MODEL.BACKBONE.NORM = "BN"
if cfg_t.MODEL.HEADS.NORM == "syncBN":
cfg_t.MODEL.HEADS.NORM = "BN"
model_t = build_model(cfg_t)
# No gradients for teacher model
for param in model_t.parameters():
param.requires_grad_(False)
logger.info("Loading teacher model weights ...")
Checkpointer(model_t).load(cfg.KD.MODEL_WEIGHTS[i])
model_ts.append(model_t)
self.ema_enabled = cfg.KD.EMA.ENABLED
self.ema_momentum = cfg.KD.EMA.MOMENTUM
if self.ema_enabled:
cfg_self = cfg.clone()
cfg_self.defrost()
cfg_self.MODEL.META_ARCHITECTURE = "Baseline"
if cfg_self.MODEL.BACKBONE.NORM == "syncBN":
cfg_self.MODEL.BACKBONE.NORM = "BN"
if cfg_self.MODEL.HEADS.NORM == "syncBN":
cfg_self.MODEL.HEADS.NORM = "BN"
model_self = build_model(cfg_self)
# No gradients for self model
for param in model_self.parameters():
param.requires_grad_(False)
if cfg_self.MODEL.WEIGHTS != '':
logger.info("Loading self distillation model weights ...")
Checkpointer(model_self).load(cfg_self.MODEL.WEIGHTS)
else:
# Make sure the initial state is same
for param_q, param_k in zip(self.parameters(), model_self.parameters()):
param_k.data.copy_(param_q.data)
model_ts.insert(0, model_self)
# Not register teacher model as `nn.Module`, this is
# make sure teacher model weights not saved
self.model_ts = model_ts
@torch.no_grad()
def _momentum_update_key_encoder(self, m=0.999):
"""
Momentum update of the key encoder
"""
for param_q, param_k in zip(self.parameters(), self.model_ts[0].parameters()):
param_k.data = param_k.data * m + param_q.data * (1. - m)
def forward(self, batched_inputs):
if self.training:
images = self.preprocess_image(batched_inputs)
# student model forward
s_feat = self.backbone(images)
assert "targets" in batched_inputs, "Labels are missing in training!"
targets = batched_inputs["targets"].to(self.device)
if targets.sum() < 0: targets.zero_()
s_outputs = self.heads(s_feat, targets)
t_outputs = []
# teacher model forward
with torch.no_grad():
if self.ema_enabled:
self._momentum_update_key_encoder(self.ema_momentum) # update self distill model
for model_t in self.model_ts:
t_feat = model_t.backbone(images)
t_output = model_t.heads(t_feat, targets)
t_outputs.append(t_output)
losses = self.losses(s_outputs, t_outputs, targets)
return losses
# Eval mode, just conventional reid feature extraction
else:
return super().forward(batched_inputs)
def losses(self, s_outputs, t_outputs, gt_labels):
"""
Compute loss from modeling's outputs, the loss function input arguments
must be the same as the outputs of the model forwarding.
"""
loss_dict = super().losses(s_outputs, gt_labels)
s_logits = s_outputs['pred_class_logits']
loss_jsdiv = 0.
for t_output in t_outputs:
t_logits = t_output['pred_class_logits'].detach()
loss_jsdiv += self.jsdiv_loss(s_logits, t_logits)
loss_dict["loss_jsdiv"] = loss_jsdiv / len(t_outputs)
return loss_dict
@staticmethod
def _kldiv(y_s, y_t, t):
p_s = F.log_softmax(y_s / t, dim=1)
p_t = F.softmax(y_t / t, dim=1)
loss = F.kl_div(p_s, p_t, reduction="sum") * (t ** 2) / y_s.shape[0]
return loss
def jsdiv_loss(self, y_s, y_t, t=16):
loss = (self._kldiv(y_s, y_t, t) + self._kldiv(y_t, y_s, t)) / 2
return loss
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import copy
import torch
from torch import nn
from fastreid.config import configurable
from fastreid.layers import get_norm
from fastreid.modeling.backbones import build_backbone
from fastreid.modeling.backbones.resnet import Bottleneck
from fastreid.modeling.heads import build_heads
from fastreid.modeling.losses import *
from .build import META_ARCH_REGISTRY
@META_ARCH_REGISTRY.register()
class MGN(nn.Module):
"""
Multiple Granularities Network architecture, which contains the following two components:
1. Per-image feature extraction (aka backbone)
2. Multi-branch feature aggregation
"""
@configurable
def __init__(
self,
*,
backbone,
neck1,
neck2,
neck3,
b1_head,
b2_head,
b21_head,
b22_head,
b3_head,
b31_head,
b32_head,
b33_head,
pixel_mean,
pixel_std,
loss_kwargs=None
):
"""
NOTE: this interface is experimental.
Args:
backbone:
neck1:
neck2:
neck3:
b1_head:
b2_head:
b21_head:
b22_head:
b3_head:
b31_head:
b32_head:
b33_head:
pixel_mean:
pixel_std:
loss_kwargs:
"""
super().__init__()
self.backbone = backbone
# branch1
self.b1 = neck1
self.b1_head = b1_head
# branch2
self.b2 = neck2
self.b2_head = b2_head
self.b21_head = b21_head
self.b22_head = b22_head
# branch3
self.b3 = neck3
self.b3_head = b3_head
self.b31_head = b31_head
self.b32_head = b32_head
self.b33_head = b33_head
self.loss_kwargs = loss_kwargs
self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(1, -1, 1, 1), False)
self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(1, -1, 1, 1), False)
@classmethod
def from_config(cls, cfg):
bn_norm = cfg.MODEL.BACKBONE.NORM
with_se = cfg.MODEL.BACKBONE.WITH_SE
all_blocks = build_backbone(cfg)
# backbone
backbone = nn.Sequential(
all_blocks.conv1,
all_blocks.bn1,
all_blocks.relu,
all_blocks.maxpool,
all_blocks.layer1,
all_blocks.layer2,
all_blocks.layer3[0]
)
res_conv4 = nn.Sequential(*all_blocks.layer3[1:])
res_g_conv5 = all_blocks.layer4
res_p_conv5 = nn.Sequential(
Bottleneck(1024, 512, bn_norm, False, with_se, downsample=nn.Sequential(
nn.Conv2d(1024, 2048, 1, bias=False), get_norm(bn_norm, 2048))),
Bottleneck(2048, 512, bn_norm, False, with_se),
Bottleneck(2048, 512, bn_norm, False, with_se))
res_p_conv5.load_state_dict(all_blocks.layer4.state_dict())
# branch
neck1 = nn.Sequential(
copy.deepcopy(res_conv4),
copy.deepcopy(res_g_conv5)
)
b1_head = build_heads(cfg)
# branch2
neck2 = nn.Sequential(
copy.deepcopy(res_conv4),
copy.deepcopy(res_p_conv5)
)
b2_head = build_heads(cfg)
b21_head = build_heads(cfg)
b22_head = build_heads(cfg)
# branch3
neck3 = nn.Sequential(
copy.deepcopy(res_conv4),
copy.deepcopy(res_p_conv5)
)
b3_head = build_heads(cfg)
b31_head = build_heads(cfg)
b32_head = build_heads(cfg)
b33_head = build_heads(cfg)
return {
'backbone': backbone,
'neck1': neck1,
'neck2': neck2,
'neck3': neck3,
'b1_head': b1_head,
'b2_head': b2_head,
'b21_head': b21_head,
'b22_head': b22_head,
'b3_head': b3_head,
'b31_head': b31_head,
'b32_head': b32_head,
'b33_head': b33_head,
'pixel_mean': cfg.MODEL.PIXEL_MEAN,
'pixel_std': cfg.MODEL.PIXEL_STD,
'loss_kwargs':
{
# loss name
'loss_names': cfg.MODEL.LOSSES.NAME,
# loss hyperparameters
'ce': {
'eps': cfg.MODEL.LOSSES.CE.EPSILON,
'alpha': cfg.MODEL.LOSSES.CE.ALPHA,
'scale': cfg.MODEL.LOSSES.CE.SCALE
},
'tri': {
'margin': cfg.MODEL.LOSSES.TRI.MARGIN,
'norm_feat': cfg.MODEL.LOSSES.TRI.NORM_FEAT,
'hard_mining': cfg.MODEL.LOSSES.TRI.HARD_MINING,
'scale': cfg.MODEL.LOSSES.TRI.SCALE
},
'circle': {
'margin': cfg.MODEL.LOSSES.CIRCLE.MARGIN,
'gamma': cfg.MODEL.LOSSES.CIRCLE.GAMMA,
'scale': cfg.MODEL.LOSSES.CIRCLE.SCALE
},
'cosface': {
'margin': cfg.MODEL.LOSSES.COSFACE.MARGIN,
'gamma': cfg.MODEL.LOSSES.COSFACE.GAMMA,
'scale': cfg.MODEL.LOSSES.COSFACE.SCALE
}
}
}
@property
def device(self):
return self.pixel_mean.device
def forward(self, batched_inputs):
images = self.preprocess_image(batched_inputs)
features = self.backbone(images) # (bs, 2048, 16, 8)
# branch1
b1_feat = self.b1(features)
# branch2
b2_feat = self.b2(features)
b21_feat, b22_feat = torch.chunk(b2_feat, 2, dim=2)
# branch3
b3_feat = self.b3(features)
b31_feat, b32_feat, b33_feat = torch.chunk(b3_feat, 3, dim=2)
if self.training:
assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
targets = batched_inputs["targets"]
if targets.sum() < 0: targets.zero_()
b1_outputs = self.b1_head(b1_feat, targets)
b2_outputs = self.b2_head(b2_feat, targets)
b21_outputs = self.b21_head(b21_feat, targets)
b22_outputs = self.b22_head(b22_feat, targets)
b3_outputs = self.b3_head(b3_feat, targets)
b31_outputs = self.b31_head(b31_feat, targets)
b32_outputs = self.b32_head(b32_feat, targets)
b33_outputs = self.b33_head(b33_feat, targets)
losses = self.losses(b1_outputs,
b2_outputs, b21_outputs, b22_outputs,
b3_outputs, b31_outputs, b32_outputs, b33_outputs,
targets)
return losses
else:
b1_pool_feat = self.b1_head(b1_feat)
b2_pool_feat = self.b2_head(b2_feat)
b21_pool_feat = self.b21_head(b21_feat)
b22_pool_feat = self.b22_head(b22_feat)
b3_pool_feat = self.b3_head(b3_feat)
b31_pool_feat = self.b31_head(b31_feat)
b32_pool_feat = self.b32_head(b32_feat)
b33_pool_feat = self.b33_head(b33_feat)
pred_feat = torch.cat([b1_pool_feat, b2_pool_feat, b3_pool_feat, b21_pool_feat,
b22_pool_feat, b31_pool_feat, b32_pool_feat, b33_pool_feat], dim=1)
return pred_feat
def preprocess_image(self, batched_inputs):
r"""
Normalize and batch the input images.
"""
if isinstance(batched_inputs, dict):
images = batched_inputs["images"].to(self.device)
elif isinstance(batched_inputs, torch.Tensor):
images = batched_inputs.to(self.device)
else:
raise TypeError("batched_inputs must be dict or torch.Tensor, but get {}".format(type(batched_inputs)))
images.sub_(self.pixel_mean).div_(self.pixel_std)
return images
def losses(self,
b1_outputs,
b2_outputs, b21_outputs, b22_outputs,
b3_outputs, b31_outputs, b32_outputs, b33_outputs, gt_labels):
# model predictions
# fmt: off
pred_class_logits = b1_outputs['pred_class_logits'].detach()
b1_logits = b1_outputs['cls_outputs']
b2_logits = b2_outputs['cls_outputs']
b21_logits = b21_outputs['cls_outputs']
b22_logits = b22_outputs['cls_outputs']
b3_logits = b3_outputs['cls_outputs']
b31_logits = b31_outputs['cls_outputs']
b32_logits = b32_outputs['cls_outputs']
b33_logits = b33_outputs['cls_outputs']
b1_pool_feat = b1_outputs['features']
b2_pool_feat = b2_outputs['features']
b3_pool_feat = b3_outputs['features']
b21_pool_feat = b21_outputs['features']
b22_pool_feat = b22_outputs['features']
b31_pool_feat = b31_outputs['features']
b32_pool_feat = b32_outputs['features']
b33_pool_feat = b33_outputs['features']
# fmt: on
# Log prediction accuracy
log_accuracy(pred_class_logits, gt_labels)
b22_pool_feat = torch.cat((b21_pool_feat, b22_pool_feat), dim=1)
b33_pool_feat = torch.cat((b31_pool_feat, b32_pool_feat, b33_pool_feat), dim=1)
loss_dict = {}
loss_names = self.loss_kwargs['loss_names']
if "CrossEntropyLoss" in loss_names:
ce_kwargs = self.loss_kwargs.get('ce')
loss_dict['loss_cls_b1'] = cross_entropy_loss(
b1_logits,
gt_labels,
ce_kwargs.get('eps'),
ce_kwargs.get('alpha')
) * ce_kwargs.get('scale') * 0.125
loss_dict['loss_cls_b2'] = cross_entropy_loss(
b2_logits,
gt_labels,
ce_kwargs.get('eps'),
ce_kwargs.get('alpha')
) * ce_kwargs.get('scale') * 0.125
loss_dict['loss_cls_b21'] = cross_entropy_loss(
b21_logits,
gt_labels,
ce_kwargs.get('eps'),
ce_kwargs.get('alpha')
) * ce_kwargs.get('scale') * 0.125
loss_dict['loss_cls_b22'] = cross_entropy_loss(
b22_logits,
gt_labels,
ce_kwargs.get('eps'),
ce_kwargs.get('alpha')
) * ce_kwargs.get('scale') * 0.125
loss_dict['loss_cls_b3'] = cross_entropy_loss(
b3_logits,
gt_labels,
ce_kwargs.get('eps'),
ce_kwargs.get('alpha')
) * ce_kwargs.get('scale') * 0.125
loss_dict['loss_cls_b31'] = cross_entropy_loss(
b31_logits,
gt_labels,
ce_kwargs.get('eps'),
ce_kwargs.get('alpha')
) * ce_kwargs.get('scale') * 0.125
loss_dict['loss_cls_b32'] = cross_entropy_loss(
b32_logits,
gt_labels,
ce_kwargs.get('eps'),
ce_kwargs.get('alpha')
) * ce_kwargs.get('scale') * 0.125
loss_dict['loss_cls_b33'] = cross_entropy_loss(
b33_logits,
gt_labels,
ce_kwargs.get('eps'),
ce_kwargs.get('alpha')
) * ce_kwargs.get('scale') * 0.125
if "TripletLoss" in loss_names:
tri_kwargs = self.loss_kwargs.get('tri')
loss_dict['loss_triplet_b1'] = triplet_loss(
b1_pool_feat,
gt_labels,
tri_kwargs.get('margin'),
tri_kwargs.get('norm_feat'),
tri_kwargs.get('hard_mining')
) * tri_kwargs.get('scale') * 0.2
loss_dict['loss_triplet_b2'] = triplet_loss(
b2_pool_feat,
gt_labels,
tri_kwargs.get('margin'),
tri_kwargs.get('norm_feat'),
tri_kwargs.get('hard_mining')
) * tri_kwargs.get('scale') * 0.2
loss_dict['loss_triplet_b3'] = triplet_loss(
b3_pool_feat,
gt_labels,
tri_kwargs.get('margin'),
tri_kwargs.get('norm_feat'),
tri_kwargs.get('hard_mining')
) * tri_kwargs.get('scale') * 0.2
loss_dict['loss_triplet_b22'] = triplet_loss(
b22_pool_feat,
gt_labels,
tri_kwargs.get('margin'),
tri_kwargs.get('norm_feat'),
tri_kwargs.get('hard_mining')
) * tri_kwargs.get('scale') * 0.2
loss_dict['loss_triplet_b33'] = triplet_loss(
b33_pool_feat,
gt_labels,
tri_kwargs.get('margin'),
tri_kwargs.get('norm_feat'),
tri_kwargs.get('hard_mining')
) * tri_kwargs.get('scale') * 0.2
return loss_dict
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import torch
import torch.nn.functional as F
from torch import nn
from fastreid.modeling.losses.utils import concat_all_gather
from fastreid.utils import comm
from .baseline import Baseline
from .build import META_ARCH_REGISTRY
@META_ARCH_REGISTRY.register()
class MoCo(Baseline):
def __init__(self, cfg):
super().__init__(cfg)
dim = cfg.MODEL.HEADS.EMBEDDING_DIM if cfg.MODEL.HEADS.EMBEDDING_DIM \
else cfg.MODEL.BACKBONE.FEAT_DIM
size = cfg.MODEL.QUEUE_SIZE
self.memory = Memory(dim, size)
def losses(self, outputs, gt_labels):
"""
Compute loss from modeling's outputs, the loss function input arguments
must be the same as the outputs of the model forwarding.
"""
# regular reid loss
loss_dict = super().losses(outputs, gt_labels)
# memory loss
pred_features = outputs['features']
loss_mb = self.memory(pred_features, gt_labels)
loss_dict['loss_mb'] = loss_mb
return loss_dict
class Memory(nn.Module):
"""
Build a MoCo memory with a queue
https://arxiv.org/abs/1911.05722
"""
def __init__(self, dim=512, K=65536):
"""
dim: feature dimension (default: 128)
K: queue size; number of negative keys (default: 65536)
"""
super().__init__()
self.K = K
self.margin = 0.25
self.gamma = 32
# create the queue
self.register_buffer("queue", torch.randn(dim, K))
self.queue = F.normalize(self.queue, dim=0)
self.register_buffer("queue_label", torch.zeros((1, K), dtype=torch.long))
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
@torch.no_grad()
def _dequeue_and_enqueue(self, keys, targets):
# gather keys/targets before updating queue
if comm.get_world_size() > 1:
keys = concat_all_gather(keys)
targets = concat_all_gather(targets)
else:
keys = keys.detach()
targets = targets.detach()
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
assert self.K % batch_size == 0 # for simplicity
# replace the keys at ptr (dequeue and enqueue)
self.queue[:, ptr:ptr + batch_size] = keys.T
self.queue_label[:, ptr:ptr + batch_size] = targets
ptr = (ptr + batch_size) % self.K # move pointer
self.queue_ptr[0] = ptr
def forward(self, feat_q, targets):
"""
Memory bank enqueue and compute metric loss
Args:
feat_q: model features
targets: gt labels
Returns:
"""
# normalize embedding features
feat_q = F.normalize(feat_q, p=2, dim=1)
# dequeue and enqueue
self._dequeue_and_enqueue(feat_q.detach(), targets)
# compute loss
loss = self._pairwise_cosface(feat_q, targets)
return loss
def _pairwise_cosface(self, feat_q, targets):
dist_mat = torch.matmul(feat_q, self.queue)
N, M = dist_mat.size() # (bsz, memory)
is_pos = targets.view(N, 1).expand(N, M).eq(self.queue_label.expand(N, M)).float()
is_neg = targets.view(N, 1).expand(N, M).ne(self.queue_label.expand(N, M)).float()
# Mask scores related to themselves
same_indx = torch.eye(N, N, device=is_pos.device)
other_indx = torch.zeros(N, M - N, device=is_pos.device)
same_indx = torch.cat((same_indx, other_indx), dim=1)
is_pos = is_pos - same_indx
s_p = dist_mat * is_pos
s_n = dist_mat * is_neg
logit_p = -self.gamma * s_p + (-99999999.) * (1 - is_pos)
logit_n = self.gamma * (s_n + self.margin) + (-99999999.) * (1 - is_neg)
loss = F.softplus(torch.logsumexp(logit_p, dim=1) + torch.logsumexp(logit_n, dim=1)).mean()
return loss
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from .build import build_lr_scheduler, build_optimizer
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment