Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
Fast-ReID_pytorch
Commits
b6c19984
Commit
b6c19984
authored
Nov 18, 2025
by
dengjb
Browse files
update
parents
Changes
435
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2448 additions
and
0 deletions
+2448
-0
fastreid/modeling/backbones/resnext.py
fastreid/modeling/backbones/resnext.py
+335
-0
fastreid/modeling/backbones/shufflenet.py
fastreid/modeling/backbones/shufflenet.py
+203
-0
fastreid/modeling/backbones/vision_transformer.py
fastreid/modeling/backbones/vision_transformer.py
+399
-0
fastreid/modeling/heads/__init__.py
fastreid/modeling/heads/__init__.py
+11
-0
fastreid/modeling/heads/build.py
fastreid/modeling/heads/build.py
+25
-0
fastreid/modeling/heads/clas_head.py
fastreid/modeling/heads/clas_head.py
+36
-0
fastreid/modeling/heads/embedding_head.py
fastreid/modeling/heads/embedding_head.py
+151
-0
fastreid/modeling/losses/__init__.py
fastreid/modeling/losses/__init__.py
+13
-0
fastreid/modeling/losses/circle_loss.py
fastreid/modeling/losses/circle_loss.py
+71
-0
fastreid/modeling/losses/cross_entroy_loss.py
fastreid/modeling/losses/cross_entroy_loss.py
+54
-0
fastreid/modeling/losses/focal_loss.py
fastreid/modeling/losses/focal_loss.py
+92
-0
fastreid/modeling/losses/triplet_loss.py
fastreid/modeling/losses/triplet_loss.py
+113
-0
fastreid/modeling/losses/utils.py
fastreid/modeling/losses/utils.py
+48
-0
fastreid/modeling/meta_arch/__init__.py
fastreid/modeling/meta_arch/__init__.py
+14
-0
fastreid/modeling/meta_arch/baseline.py
fastreid/modeling/meta_arch/baseline.py
+188
-0
fastreid/modeling/meta_arch/build.py
fastreid/modeling/meta_arch/build.py
+26
-0
fastreid/modeling/meta_arch/distiller.py
fastreid/modeling/meta_arch/distiller.py
+140
-0
fastreid/modeling/meta_arch/mgn.py
fastreid/modeling/meta_arch/mgn.py
+394
-0
fastreid/modeling/meta_arch/moco.py
fastreid/modeling/meta_arch/moco.py
+126
-0
fastreid/solver/__init__.py
fastreid/solver/__init__.py
+9
-0
No files found.
fastreid/modeling/backbones/resnext.py
0 → 100644
View file @
b6c19984
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
# based on:
# https://github.com/XingangPan/IBN-Net/blob/master/models/imagenet/resnext_ibn_a.py
import
logging
import
math
import
torch
import
torch.nn
as
nn
from
fastreid.layers
import
*
from
fastreid.utils
import
comm
from
fastreid.utils.checkpoint
import
get_missing_parameters_message
,
get_unexpected_parameters_message
from
.build
import
BACKBONE_REGISTRY
logger
=
logging
.
getLogger
(
__name__
)
model_urls
=
{
'ibn_101x'
:
'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnext101_ibn_a-6ace051d.pth'
,
}
class
Bottleneck
(
nn
.
Module
):
"""
RexNeXt bottleneck type C
"""
expansion
=
4
def
__init__
(
self
,
inplanes
,
planes
,
bn_norm
,
with_ibn
,
baseWidth
,
cardinality
,
stride
=
1
,
downsample
=
None
):
""" Constructor
Args:
inplanes: input channel dimensionality
planes: output channel dimensionality
baseWidth: base width.
cardinality: num of convolution groups.
stride: conv stride. Replaces pooling layer.
"""
super
(
Bottleneck
,
self
).
__init__
()
D
=
int
(
math
.
floor
(
planes
*
(
baseWidth
/
64
)))
C
=
cardinality
self
.
conv1
=
nn
.
Conv2d
(
inplanes
,
D
*
C
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
False
)
if
with_ibn
:
self
.
bn1
=
IBN
(
D
*
C
,
bn_norm
)
else
:
self
.
bn1
=
get_norm
(
bn_norm
,
D
*
C
)
self
.
conv2
=
nn
.
Conv2d
(
D
*
C
,
D
*
C
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
1
,
groups
=
C
,
bias
=
False
)
self
.
bn2
=
get_norm
(
bn_norm
,
D
*
C
)
self
.
conv3
=
nn
.
Conv2d
(
D
*
C
,
planes
*
4
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
False
)
self
.
bn3
=
get_norm
(
bn_norm
,
planes
*
4
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
downsample
=
downsample
def
forward
(
self
,
x
):
residual
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv3
(
out
)
out
=
self
.
bn3
(
out
)
if
self
.
downsample
is
not
None
:
residual
=
self
.
downsample
(
x
)
out
+=
residual
out
=
self
.
relu
(
out
)
return
out
class
ResNeXt
(
nn
.
Module
):
"""
ResNext optimized for the ImageNet dataset, as specified in
https://arxiv.org/pdf/1611.05431.pdf
"""
def
__init__
(
self
,
last_stride
,
bn_norm
,
with_ibn
,
with_nl
,
block
,
layers
,
non_layers
,
baseWidth
=
4
,
cardinality
=
32
):
""" Constructor
Args:
baseWidth: baseWidth for ResNeXt.
cardinality: number of convolution groups.
layers: config of layers, e.g., [3, 4, 6, 3]
"""
super
(
ResNeXt
,
self
).
__init__
()
self
.
cardinality
=
cardinality
self
.
baseWidth
=
baseWidth
self
.
inplanes
=
64
self
.
output_size
=
64
self
.
conv1
=
nn
.
Conv2d
(
3
,
64
,
7
,
2
,
3
,
bias
=
False
)
self
.
bn1
=
get_norm
(
bn_norm
,
64
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
maxpool1
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
layer1
=
self
.
_make_layer
(
block
,
64
,
layers
[
0
],
1
,
bn_norm
,
with_ibn
=
with_ibn
)
self
.
layer2
=
self
.
_make_layer
(
block
,
128
,
layers
[
1
],
2
,
bn_norm
,
with_ibn
=
with_ibn
)
self
.
layer3
=
self
.
_make_layer
(
block
,
256
,
layers
[
2
],
2
,
bn_norm
,
with_ibn
=
with_ibn
)
self
.
layer4
=
self
.
_make_layer
(
block
,
512
,
layers
[
3
],
last_stride
,
bn_norm
,
with_ibn
=
with_ibn
)
self
.
random_init
()
# fmt: off
if
with_nl
:
self
.
_build_nonlocal
(
layers
,
non_layers
,
bn_norm
)
else
:
self
.
NL_1_idx
=
self
.
NL_2_idx
=
self
.
NL_3_idx
=
self
.
NL_4_idx
=
[]
# fmt: on
def
_make_layer
(
self
,
block
,
planes
,
blocks
,
stride
=
1
,
bn_norm
=
'BN'
,
with_ibn
=
False
):
""" Stack n bottleneck modules where n is inferred from the depth of the network.
Args:
block: block type used to construct ResNext
planes: number of output channels (need to multiply by block.expansion)
blocks: number of blocks to be built
stride: factor to reduce the spatial dimensionality in the first bottleneck of the block.
Returns: a Module consisting of n sequential bottlenecks.
"""
downsample
=
None
if
stride
!=
1
or
self
.
inplanes
!=
planes
*
block
.
expansion
:
downsample
=
nn
.
Sequential
(
nn
.
Conv2d
(
self
.
inplanes
,
planes
*
block
.
expansion
,
kernel_size
=
1
,
stride
=
stride
,
bias
=
False
),
get_norm
(
bn_norm
,
planes
*
block
.
expansion
),
)
layers
=
[]
layers
.
append
(
block
(
self
.
inplanes
,
planes
,
bn_norm
,
with_ibn
,
self
.
baseWidth
,
self
.
cardinality
,
stride
,
downsample
))
self
.
inplanes
=
planes
*
block
.
expansion
for
i
in
range
(
1
,
blocks
):
layers
.
append
(
block
(
self
.
inplanes
,
planes
,
bn_norm
,
with_ibn
,
self
.
baseWidth
,
self
.
cardinality
,
1
,
None
))
return
nn
.
Sequential
(
*
layers
)
def
_build_nonlocal
(
self
,
layers
,
non_layers
,
bn_norm
):
self
.
NL_1
=
nn
.
ModuleList
(
[
Non_local
(
256
,
bn_norm
)
for
_
in
range
(
non_layers
[
0
])])
self
.
NL_1_idx
=
sorted
([
layers
[
0
]
-
(
i
+
1
)
for
i
in
range
(
non_layers
[
0
])])
self
.
NL_2
=
nn
.
ModuleList
(
[
Non_local
(
512
,
bn_norm
)
for
_
in
range
(
non_layers
[
1
])])
self
.
NL_2_idx
=
sorted
([
layers
[
1
]
-
(
i
+
1
)
for
i
in
range
(
non_layers
[
1
])])
self
.
NL_3
=
nn
.
ModuleList
(
[
Non_local
(
1024
,
bn_norm
)
for
_
in
range
(
non_layers
[
2
])])
self
.
NL_3_idx
=
sorted
([
layers
[
2
]
-
(
i
+
1
)
for
i
in
range
(
non_layers
[
2
])])
self
.
NL_4
=
nn
.
ModuleList
(
[
Non_local
(
2048
,
bn_norm
)
for
_
in
range
(
non_layers
[
3
])])
self
.
NL_4_idx
=
sorted
([
layers
[
3
]
-
(
i
+
1
)
for
i
in
range
(
non_layers
[
3
])])
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
maxpool1
(
x
)
NL1_counter
=
0
if
len
(
self
.
NL_1_idx
)
==
0
:
self
.
NL_1_idx
=
[
-
1
]
for
i
in
range
(
len
(
self
.
layer1
)):
x
=
self
.
layer1
[
i
](
x
)
if
i
==
self
.
NL_1_idx
[
NL1_counter
]:
_
,
C
,
H
,
W
=
x
.
shape
x
=
self
.
NL_1
[
NL1_counter
](
x
)
NL1_counter
+=
1
# Layer 2
NL2_counter
=
0
if
len
(
self
.
NL_2_idx
)
==
0
:
self
.
NL_2_idx
=
[
-
1
]
for
i
in
range
(
len
(
self
.
layer2
)):
x
=
self
.
layer2
[
i
](
x
)
if
i
==
self
.
NL_2_idx
[
NL2_counter
]:
_
,
C
,
H
,
W
=
x
.
shape
x
=
self
.
NL_2
[
NL2_counter
](
x
)
NL2_counter
+=
1
# Layer 3
NL3_counter
=
0
if
len
(
self
.
NL_3_idx
)
==
0
:
self
.
NL_3_idx
=
[
-
1
]
for
i
in
range
(
len
(
self
.
layer3
)):
x
=
self
.
layer3
[
i
](
x
)
if
i
==
self
.
NL_3_idx
[
NL3_counter
]:
_
,
C
,
H
,
W
=
x
.
shape
x
=
self
.
NL_3
[
NL3_counter
](
x
)
NL3_counter
+=
1
# Layer 4
NL4_counter
=
0
if
len
(
self
.
NL_4_idx
)
==
0
:
self
.
NL_4_idx
=
[
-
1
]
for
i
in
range
(
len
(
self
.
layer4
)):
x
=
self
.
layer4
[
i
](
x
)
if
i
==
self
.
NL_4_idx
[
NL4_counter
]:
_
,
C
,
H
,
W
=
x
.
shape
x
=
self
.
NL_4
[
NL4_counter
](
x
)
NL4_counter
+=
1
return
x
def
random_init
(
self
):
self
.
conv1
.
weight
.
data
.
normal_
(
0
,
math
.
sqrt
(
2.
/
(
7
*
7
*
64
)))
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
n
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
m
.
weight
.
data
.
normal_
(
0
,
math
.
sqrt
(
2.
/
n
))
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
m
.
weight
.
data
.
fill_
(
1
)
m
.
bias
.
data
.
zero_
()
elif
isinstance
(
m
,
nn
.
InstanceNorm2d
):
m
.
weight
.
data
.
fill_
(
1
)
m
.
bias
.
data
.
zero_
()
def
init_pretrained_weights
(
key
):
"""Initializes model with pretrained weights.
Layers that don't match with pretrained layers in name or size are kept unchanged.
"""
import
os
import
errno
import
gdown
def
_get_torch_home
():
ENV_TORCH_HOME
=
'TORCH_HOME'
ENV_XDG_CACHE_HOME
=
'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR
=
'~/.cache'
torch_home
=
os
.
path
.
expanduser
(
os
.
getenv
(
ENV_TORCH_HOME
,
os
.
path
.
join
(
os
.
getenv
(
ENV_XDG_CACHE_HOME
,
DEFAULT_CACHE_DIR
),
'torch'
)
)
)
return
torch_home
torch_home
=
_get_torch_home
()
model_dir
=
os
.
path
.
join
(
torch_home
,
'checkpoints'
)
try
:
os
.
makedirs
(
model_dir
)
except
OSError
as
e
:
if
e
.
errno
==
errno
.
EEXIST
:
# Directory already exists, ignore.
pass
else
:
# Unexpected OSError, re-raise.
raise
filename
=
model_urls
[
key
].
split
(
'/'
)[
-
1
]
cached_file
=
os
.
path
.
join
(
model_dir
,
filename
)
if
not
os
.
path
.
exists
(
cached_file
):
logger
.
info
(
f
"Pretrain model don't exist, downloading from
{
model_urls
[
key
]
}
"
)
if
comm
.
is_main_process
():
gdown
.
download
(
model_urls
[
key
],
cached_file
,
quiet
=
False
)
comm
.
synchronize
()
logger
.
info
(
f
"Loading pretrained model from
{
cached_file
}
"
)
state_dict
=
torch
.
load
(
cached_file
,
map_location
=
torch
.
device
(
'cpu'
))
return
state_dict
@
BACKBONE_REGISTRY
.
register
()
def
build_resnext_backbone
(
cfg
):
"""
Create a ResNeXt instance from config.
Returns:
ResNeXt: a :class:`ResNeXt` instance.
"""
# fmt: off
pretrain
=
cfg
.
MODEL
.
BACKBONE
.
PRETRAIN
pretrain_path
=
cfg
.
MODEL
.
BACKBONE
.
PRETRAIN_PATH
last_stride
=
cfg
.
MODEL
.
BACKBONE
.
LAST_STRIDE
bn_norm
=
cfg
.
MODEL
.
BACKBONE
.
NORM
with_ibn
=
cfg
.
MODEL
.
BACKBONE
.
WITH_IBN
with_nl
=
cfg
.
MODEL
.
BACKBONE
.
WITH_NL
depth
=
cfg
.
MODEL
.
BACKBONE
.
DEPTH
# fmt: on
num_blocks_per_stage
=
{
'50x'
:
[
3
,
4
,
6
,
3
],
'101x'
:
[
3
,
4
,
23
,
3
],
'152x'
:
[
3
,
8
,
36
,
3
],
}[
depth
]
nl_layers_per_stage
=
{
'50x'
:
[
0
,
2
,
3
,
0
],
'101x'
:
[
0
,
2
,
3
,
0
]}[
depth
]
model
=
ResNeXt
(
last_stride
,
bn_norm
,
with_ibn
,
with_nl
,
Bottleneck
,
num_blocks_per_stage
,
nl_layers_per_stage
)
if
pretrain
:
if
pretrain_path
:
try
:
state_dict
=
torch
.
load
(
pretrain_path
,
map_location
=
torch
.
device
(
'cpu'
))[
'model'
]
# Remove module.encoder in name
new_state_dict
=
{}
for
k
in
state_dict
:
new_k
=
'.'
.
join
(
k
.
split
(
'.'
)[
2
:])
if
new_k
in
model
.
state_dict
()
and
(
model
.
state_dict
()[
new_k
].
shape
==
state_dict
[
k
].
shape
):
new_state_dict
[
new_k
]
=
state_dict
[
k
]
state_dict
=
new_state_dict
logger
.
info
(
f
"Loading pretrained model from
{
pretrain_path
}
"
)
except
FileNotFoundError
as
e
:
logger
.
info
(
f
'
{
pretrain_path
}
is not found! Please check this path.'
)
raise
e
except
KeyError
as
e
:
logger
.
info
(
"State dict keys error! Please check the state dict."
)
raise
e
else
:
key
=
depth
if
with_ibn
:
key
=
'ibn_'
+
key
state_dict
=
init_pretrained_weights
(
key
)
incompatible
=
model
.
load_state_dict
(
state_dict
,
strict
=
False
)
if
incompatible
.
missing_keys
:
logger
.
info
(
get_missing_parameters_message
(
incompatible
.
missing_keys
)
)
if
incompatible
.
unexpected_keys
:
logger
.
info
(
get_unexpected_parameters_message
(
incompatible
.
unexpected_keys
)
)
return
model
fastreid/modeling/backbones/shufflenet.py
0 → 100644
View file @
b6c19984
"""
Author: Guan'an Wang
Contact: guan.wang0706@gmail.com
"""
import
torch
from
torch
import
nn
from
collections
import
OrderedDict
import
logging
from
fastreid.utils.checkpoint
import
get_missing_parameters_message
,
get_unexpected_parameters_message
from
fastreid.layers
import
get_norm
from
fastreid.modeling.backbones
import
BACKBONE_REGISTRY
logger
=
logging
.
getLogger
(
__name__
)
class
ShuffleV2Block
(
nn
.
Module
):
"""
Reference:
https://github.com/megvii-model/ShuffleNet-Series/tree/master/ShuffleNetV2
"""
def
__init__
(
self
,
bn_norm
,
inp
,
oup
,
mid_channels
,
*
,
ksize
,
stride
):
super
(
ShuffleV2Block
,
self
).
__init__
()
self
.
stride
=
stride
assert
stride
in
[
1
,
2
]
self
.
mid_channels
=
mid_channels
self
.
ksize
=
ksize
pad
=
ksize
//
2
self
.
pad
=
pad
self
.
inp
=
inp
outputs
=
oup
-
inp
branch_main
=
[
# pw
nn
.
Conv2d
(
inp
,
mid_channels
,
1
,
1
,
0
,
bias
=
False
),
get_norm
(
bn_norm
,
mid_channels
),
nn
.
ReLU
(
inplace
=
True
),
# dw
nn
.
Conv2d
(
mid_channels
,
mid_channels
,
ksize
,
stride
,
pad
,
groups
=
mid_channels
,
bias
=
False
),
get_norm
(
bn_norm
,
mid_channels
),
# pw-linear
nn
.
Conv2d
(
mid_channels
,
outputs
,
1
,
1
,
0
,
bias
=
False
),
get_norm
(
bn_norm
,
outputs
),
nn
.
ReLU
(
inplace
=
True
),
]
self
.
branch_main
=
nn
.
Sequential
(
*
branch_main
)
if
stride
==
2
:
branch_proj
=
[
# dw
nn
.
Conv2d
(
inp
,
inp
,
ksize
,
stride
,
pad
,
groups
=
inp
,
bias
=
False
),
get_norm
(
bn_norm
,
inp
),
# pw-linear
nn
.
Conv2d
(
inp
,
inp
,
1
,
1
,
0
,
bias
=
False
),
get_norm
(
bn_norm
,
inp
),
nn
.
ReLU
(
inplace
=
True
),
]
self
.
branch_proj
=
nn
.
Sequential
(
*
branch_proj
)
else
:
self
.
branch_proj
=
None
def
forward
(
self
,
old_x
):
if
self
.
stride
==
1
:
x_proj
,
x
=
self
.
channel_shuffle
(
old_x
)
return
torch
.
cat
((
x_proj
,
self
.
branch_main
(
x
)),
1
)
elif
self
.
stride
==
2
:
x_proj
=
old_x
x
=
old_x
return
torch
.
cat
((
self
.
branch_proj
(
x_proj
),
self
.
branch_main
(
x
)),
1
)
def
channel_shuffle
(
self
,
x
):
batchsize
,
num_channels
,
height
,
width
=
x
.
data
.
size
()
assert
(
num_channels
%
4
==
0
)
x
=
x
.
reshape
(
batchsize
*
num_channels
//
2
,
2
,
height
*
width
)
x
=
x
.
permute
(
1
,
0
,
2
)
x
=
x
.
reshape
(
2
,
-
1
,
num_channels
//
2
,
height
,
width
)
return
x
[
0
],
x
[
1
]
class
ShuffleNetV2
(
nn
.
Module
):
"""
Reference:
https://github.com/megvii-model/ShuffleNet-Series/tree/master/ShuffleNetV2
"""
def
__init__
(
self
,
bn_norm
,
model_size
=
'1.5x'
):
super
(
ShuffleNetV2
,
self
).
__init__
()
self
.
stage_repeats
=
[
4
,
8
,
4
]
self
.
model_size
=
model_size
if
model_size
==
'0.5x'
:
self
.
stage_out_channels
=
[
-
1
,
24
,
48
,
96
,
192
,
1024
]
elif
model_size
==
'1.0x'
:
self
.
stage_out_channels
=
[
-
1
,
24
,
116
,
232
,
464
,
1024
]
elif
model_size
==
'1.5x'
:
self
.
stage_out_channels
=
[
-
1
,
24
,
176
,
352
,
704
,
1024
]
elif
model_size
==
'2.0x'
:
self
.
stage_out_channels
=
[
-
1
,
24
,
244
,
488
,
976
,
2048
]
else
:
raise
NotImplementedError
# building first layer
input_channel
=
self
.
stage_out_channels
[
1
]
self
.
first_conv
=
nn
.
Sequential
(
nn
.
Conv2d
(
3
,
input_channel
,
3
,
2
,
1
,
bias
=
False
),
get_norm
(
bn_norm
,
input_channel
),
nn
.
ReLU
(
inplace
=
True
),
)
self
.
maxpool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
features
=
[]
for
idxstage
in
range
(
len
(
self
.
stage_repeats
)):
numrepeat
=
self
.
stage_repeats
[
idxstage
]
output_channel
=
self
.
stage_out_channels
[
idxstage
+
2
]
for
i
in
range
(
numrepeat
):
if
i
==
0
:
self
.
features
.
append
(
ShuffleV2Block
(
bn_norm
,
input_channel
,
output_channel
,
mid_channels
=
output_channel
//
2
,
ksize
=
3
,
stride
=
2
))
else
:
self
.
features
.
append
(
ShuffleV2Block
(
bn_norm
,
input_channel
//
2
,
output_channel
,
mid_channels
=
output_channel
//
2
,
ksize
=
3
,
stride
=
1
))
input_channel
=
output_channel
self
.
features
=
nn
.
Sequential
(
*
self
.
features
)
self
.
conv_last
=
nn
.
Sequential
(
nn
.
Conv2d
(
input_channel
,
self
.
stage_out_channels
[
-
1
],
1
,
1
,
0
,
bias
=
False
),
get_norm
(
bn_norm
,
self
.
stage_out_channels
[
-
1
]),
nn
.
ReLU
(
inplace
=
True
)
)
self
.
_initialize_weights
()
def
forward
(
self
,
x
):
x
=
self
.
first_conv
(
x
)
x
=
self
.
maxpool
(
x
)
x
=
self
.
features
(
x
)
x
=
self
.
conv_last
(
x
)
return
x
def
_initialize_weights
(
self
):
for
name
,
m
in
self
.
named_modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
if
'first'
in
name
:
nn
.
init
.
normal_
(
m
.
weight
,
0
,
0.01
)
else
:
nn
.
init
.
normal_
(
m
.
weight
,
0
,
1.0
/
m
.
weight
.
shape
[
1
])
if
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
nn
.
init
.
constant_
(
m
.
weight
,
1
)
if
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0.0001
)
nn
.
init
.
constant_
(
m
.
running_mean
,
0
)
elif
isinstance
(
m
,
nn
.
BatchNorm1d
):
nn
.
init
.
constant_
(
m
.
weight
,
1
)
if
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0.0001
)
nn
.
init
.
constant_
(
m
.
running_mean
,
0
)
elif
isinstance
(
m
,
nn
.
Linear
):
nn
.
init
.
normal_
(
m
.
weight
,
0
,
0.01
)
if
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
@
BACKBONE_REGISTRY
.
register
()
def
build_shufflenetv2_backbone
(
cfg
):
# fmt: off
pretrain
=
cfg
.
MODEL
.
BACKBONE
.
PRETRAIN
pretrain_path
=
cfg
.
MODEL
.
BACKBONE
.
PRETRAIN_PATH
bn_norm
=
cfg
.
MODEL
.
BACKBONE
.
NORM
model_size
=
cfg
.
MODEL
.
BACKBONE
.
DEPTH
# fmt: on
model
=
ShuffleNetV2
(
bn_norm
,
model_size
=
model_size
)
if
pretrain
:
new_state_dict
=
OrderedDict
()
state_dict
=
torch
.
load
(
pretrain_path
)[
"state_dict"
]
for
k
,
v
in
state_dict
.
items
():
if
k
[:
7
]
==
'module.'
:
k
=
k
[
7
:]
new_state_dict
[
k
]
=
v
incompatible
=
model
.
load_state_dict
(
new_state_dict
,
strict
=
False
)
if
incompatible
.
missing_keys
:
logger
.
info
(
get_missing_parameters_message
(
incompatible
.
missing_keys
)
)
if
incompatible
.
unexpected_keys
:
logger
.
info
(
get_unexpected_parameters_message
(
incompatible
.
unexpected_keys
)
)
return
model
fastreid/modeling/backbones/vision_transformer.py
0 → 100644
View file @
b6c19984
""" Vision Transformer (ViT) in PyTorch
A PyTorch implement of Vision Transformers as described in
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
The official jax code is released and available at https://github.com/google-research/vision_transformer
Status/TODO:
* Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.
* Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.
* Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.
* Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.
Acknowledgments:
* The paper authors for releasing code and weights, thanks!
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
for some einops/einsum fun
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
Hacked together by / Copyright 2020 Ross Wightman
"""
import
logging
import
math
from
functools
import
partial
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
fastreid.layers
import
DropPath
,
trunc_normal_
,
to_2tuple
from
fastreid.utils.checkpoint
import
get_missing_parameters_message
,
get_unexpected_parameters_message
from
.build
import
BACKBONE_REGISTRY
logger
=
logging
.
getLogger
(
__name__
)
class
Mlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
drop
=
0.
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
act
=
act_layer
()
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop
=
nn
.
Dropout
(
drop
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop
(
x
)
return
x
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
=
8
,
qkv_bias
=
False
,
qk_scale
=
None
,
attn_drop
=
0.
,
proj_drop
=
0.
):
super
().
__init__
()
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self
.
scale
=
qk_scale
or
head_dim
**
-
0.5
self
.
qkv
=
nn
.
Linear
(
dim
,
dim
*
3
,
bias
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
def
forward
(
self
,
x
):
B
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
).
reshape
(
B
,
N
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
# make torchscript happy (cannot use tensor as tuple)
attn
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
*
self
.
scale
attn
=
attn
.
softmax
(
dim
=-
1
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B
,
N
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
,
mlp_ratio
=
4.
,
qkv_bias
=
False
,
qk_scale
=
None
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
act_layer
=
nn
.
GELU
,
norm_layer
=
nn
.
LayerNorm
):
super
().
__init__
()
self
.
norm1
=
norm_layer
(
dim
)
self
.
attn
=
Attention
(
dim
,
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
else
nn
.
Identity
()
self
.
norm2
=
norm_layer
(
dim
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
mlp
=
Mlp
(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
act_layer
,
drop
=
drop
)
def
forward
(
self
,
x
):
x
=
x
+
self
.
drop_path
(
self
.
attn
(
self
.
norm1
(
x
)))
x
=
x
+
self
.
drop_path
(
self
.
mlp
(
self
.
norm2
(
x
)))
return
x
class
PatchEmbed
(
nn
.
Module
):
""" Image to Patch Embedding
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
embed_dim
=
768
):
super
().
__init__
()
img_size
=
to_2tuple
(
img_size
)
patch_size
=
to_2tuple
(
patch_size
)
num_patches
=
(
img_size
[
1
]
//
patch_size
[
1
])
*
(
img_size
[
0
]
//
patch_size
[
0
])
self
.
img_size
=
img_size
self
.
patch_size
=
patch_size
self
.
num_patches
=
num_patches
self
.
proj
=
nn
.
Conv2d
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
)
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
# FIXME look at relaxing size constraints
assert
H
==
self
.
img_size
[
0
]
and
W
==
self
.
img_size
[
1
],
\
f
"Input image size (
{
H
}
*
{
W
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
*
{
self
.
img_size
[
1
]
}
)."
x
=
self
.
proj
(
x
).
flatten
(
2
).
transpose
(
1
,
2
)
return
x
class
HybridEmbed
(
nn
.
Module
):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
"""
def
__init__
(
self
,
backbone
,
img_size
=
224
,
feature_size
=
None
,
in_chans
=
3
,
embed_dim
=
768
):
super
().
__init__
()
assert
isinstance
(
backbone
,
nn
.
Module
)
img_size
=
to_2tuple
(
img_size
)
self
.
img_size
=
img_size
self
.
backbone
=
backbone
if
feature_size
is
None
:
with
torch
.
no_grad
():
# FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
# map for all networks, the feature metadata has reliable channel and stride info, but using
# stride to calc feature dim requires info about padding of each stage that isn't captured.
training
=
backbone
.
training
if
training
:
backbone
.
eval
()
o
=
self
.
backbone
(
torch
.
zeros
(
1
,
in_chans
,
img_size
[
0
],
img_size
[
1
]))
if
isinstance
(
o
,
(
list
,
tuple
)):
o
=
o
[
-
1
]
# last feature if backbone outputs list/tuple of features
feature_size
=
o
.
shape
[
-
2
:]
feature_dim
=
o
.
shape
[
1
]
backbone
.
train
(
training
)
else
:
feature_size
=
to_2tuple
(
feature_size
)
if
hasattr
(
self
.
backbone
,
'feature_info'
):
feature_dim
=
self
.
backbone
.
feature_info
.
channels
()[
-
1
]
else
:
feature_dim
=
self
.
backbone
.
num_features
self
.
num_patches
=
feature_size
[
0
]
*
feature_size
[
1
]
self
.
proj
=
nn
.
Conv2d
(
feature_dim
,
embed_dim
,
1
)
def
forward
(
self
,
x
):
x
=
self
.
backbone
(
x
)
if
isinstance
(
x
,
(
list
,
tuple
)):
x
=
x
[
-
1
]
# last feature if backbone outputs list/tuple of features
x
=
self
.
proj
(
x
).
flatten
(
2
).
transpose
(
1
,
2
)
return
x
class
PatchEmbed_overlap
(
nn
.
Module
):
""" Image to Patch Embedding with overlapping patches
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
stride_size
=
20
,
in_chans
=
3
,
embed_dim
=
768
):
super
().
__init__
()
img_size
=
to_2tuple
(
img_size
)
patch_size
=
to_2tuple
(
patch_size
)
stride_size_tuple
=
to_2tuple
(
stride_size
)
self
.
num_x
=
(
img_size
[
1
]
-
patch_size
[
1
])
//
stride_size_tuple
[
1
]
+
1
self
.
num_y
=
(
img_size
[
0
]
-
patch_size
[
0
])
//
stride_size_tuple
[
0
]
+
1
num_patches
=
self
.
num_x
*
self
.
num_y
self
.
img_size
=
img_size
self
.
patch_size
=
patch_size
self
.
num_patches
=
num_patches
self
.
proj
=
nn
.
Conv2d
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
stride_size
)
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
n
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
m
.
weight
.
data
.
normal_
(
0
,
math
.
sqrt
(
2.
/
n
))
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
m
.
weight
.
data
.
fill_
(
1
)
m
.
bias
.
data
.
zero_
()
elif
isinstance
(
m
,
nn
.
InstanceNorm2d
):
m
.
weight
.
data
.
fill_
(
1
)
m
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
# FIXME look at relaxing size constraints
assert
H
==
self
.
img_size
[
0
]
and
W
==
self
.
img_size
[
1
],
\
f
"Input image size (
{
H
}
*
{
W
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
*
{
self
.
img_size
[
1
]
}
)."
x
=
self
.
proj
(
x
)
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
# [64, 8, 768]
return
x
class
VisionTransformer
(
nn
.
Module
):
""" Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
- https://arxiv.org/abs/2012.12877
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
stride_size
=
16
,
in_chans
=
3
,
embed_dim
=
768
,
depth
=
12
,
num_heads
=
12
,
mlp_ratio
=
4.
,
qkv_bias
=
False
,
qk_scale
=
None
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
camera
=
0
,
drop_path_rate
=
0.
,
hybrid_backbone
=
None
,
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
),
sie_xishu
=
1.0
):
super
().
__init__
()
self
.
num_features
=
self
.
embed_dim
=
embed_dim
# num_features for consistency with other models
if
hybrid_backbone
is
not
None
:
self
.
patch_embed
=
HybridEmbed
(
hybrid_backbone
,
img_size
=
img_size
,
in_chans
=
in_chans
,
embed_dim
=
embed_dim
)
else
:
self
.
patch_embed
=
PatchEmbed_overlap
(
img_size
=
img_size
,
patch_size
=
patch_size
,
stride_size
=
stride_size
,
in_chans
=
in_chans
,
embed_dim
=
embed_dim
)
num_patches
=
self
.
patch_embed
.
num_patches
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
embed_dim
))
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_patches
+
1
,
embed_dim
))
self
.
cam_num
=
camera
self
.
sie_xishu
=
sie_xishu
# Initialize SIE Embedding
if
camera
>
1
:
self
.
sie_embed
=
nn
.
Parameter
(
torch
.
zeros
(
camera
,
1
,
embed_dim
))
trunc_normal_
(
self
.
sie_embed
,
std
=
.
02
)
self
.
pos_drop
=
nn
.
Dropout
(
p
=
drop_rate
)
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
depth
)]
# stochastic depth decay rule
self
.
blocks
=
nn
.
ModuleList
([
Block
(
dim
=
embed_dim
,
num_heads
=
num_heads
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
i
],
norm_layer
=
norm_layer
)
for
i
in
range
(
depth
)])
self
.
norm
=
norm_layer
(
embed_dim
)
trunc_normal_
(
self
.
cls_token
,
std
=
.
02
)
trunc_normal_
(
self
.
pos_embed
,
std
=
.
02
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
@
torch
.
jit
.
ignore
def
no_weight_decay
(
self
):
return
{
'pos_embed'
,
'cls_token'
}
def
forward
(
self
,
x
,
camera_id
=
None
):
B
=
x
.
shape
[
0
]
x
=
self
.
patch_embed
(
x
)
cls_tokens
=
self
.
cls_token
.
expand
(
B
,
-
1
,
-
1
)
# stole cls_tokens impl from Phil Wang, thanks
x
=
torch
.
cat
((
cls_tokens
,
x
),
dim
=
1
)
if
self
.
cam_num
>
0
:
x
=
x
+
self
.
pos_embed
+
self
.
sie_xishu
*
self
.
sie_embed
[
camera_id
]
else
:
x
=
x
+
self
.
pos_embed
x
=
self
.
pos_drop
(
x
)
for
blk
in
self
.
blocks
:
x
=
blk
(
x
)
x
=
self
.
norm
(
x
)
return
x
[:,
0
].
reshape
(
x
.
shape
[
0
],
-
1
,
1
,
1
)
def
resize_pos_embed
(
posemb
,
posemb_new
,
hight
,
width
):
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
ntok_new
=
posemb_new
.
shape
[
1
]
posemb_token
,
posemb_grid
=
posemb
[:,
:
1
],
posemb
[
0
,
1
:]
ntok_new
-=
1
gs_old
=
int
(
math
.
sqrt
(
len
(
posemb_grid
)))
logger
.
info
(
'Resized position embedding from size:{} to size: {} with height:{} width: {}'
.
format
(
posemb
.
shape
,
posemb_new
.
shape
,
hight
,
width
))
posemb_grid
=
posemb_grid
.
reshape
(
1
,
gs_old
,
gs_old
,
-
1
).
permute
(
0
,
3
,
1
,
2
)
posemb_grid
=
F
.
interpolate
(
posemb_grid
,
size
=
(
hight
,
width
),
mode
=
'bilinear'
)
posemb_grid
=
posemb_grid
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
1
,
hight
*
width
,
-
1
)
posemb
=
torch
.
cat
([
posemb_token
,
posemb_grid
],
dim
=
1
)
return
posemb
@
BACKBONE_REGISTRY
.
register
()
def
build_vit_backbone
(
cfg
):
"""
Create a Vision Transformer instance from config.
Returns:
SwinTransformer: a :class:`SwinTransformer` instance.
"""
# fmt: off
input_size
=
cfg
.
INPUT
.
SIZE_TRAIN
pretrain
=
cfg
.
MODEL
.
BACKBONE
.
PRETRAIN
pretrain_path
=
cfg
.
MODEL
.
BACKBONE
.
PRETRAIN_PATH
depth
=
cfg
.
MODEL
.
BACKBONE
.
DEPTH
sie_xishu
=
cfg
.
MODEL
.
BACKBONE
.
SIE_COE
stride_size
=
cfg
.
MODEL
.
BACKBONE
.
STRIDE_SIZE
drop_ratio
=
cfg
.
MODEL
.
BACKBONE
.
DROP_RATIO
drop_path_ratio
=
cfg
.
MODEL
.
BACKBONE
.
DROP_PATH_RATIO
attn_drop_rate
=
cfg
.
MODEL
.
BACKBONE
.
ATT_DROP_RATE
# fmt: on
num_depth
=
{
'small'
:
8
,
'base'
:
12
,
}[
depth
]
num_heads
=
{
'small'
:
8
,
'base'
:
12
,
}[
depth
]
mlp_ratio
=
{
'small'
:
3.
,
'base'
:
4.
}[
depth
]
qkv_bias
=
{
'small'
:
False
,
'base'
:
True
}[
depth
]
qk_scale
=
{
'small'
:
768
**
-
0.5
,
'base'
:
None
,
}[
depth
]
model
=
VisionTransformer
(
img_size
=
input_size
,
sie_xishu
=
sie_xishu
,
stride_size
=
stride_size
,
depth
=
num_depth
,
num_heads
=
num_heads
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop_path_rate
=
drop_path_ratio
,
drop_rate
=
drop_ratio
,
attn_drop_rate
=
attn_drop_rate
)
if
pretrain
:
try
:
state_dict
=
torch
.
load
(
pretrain_path
,
map_location
=
torch
.
device
(
'cpu'
))
logger
.
info
(
f
"Loading pretrained model from
{
pretrain_path
}
"
)
if
'model'
in
state_dict
:
state_dict
=
state_dict
.
pop
(
'model'
)
if
'state_dict'
in
state_dict
:
state_dict
=
state_dict
.
pop
(
'state_dict'
)
for
k
,
v
in
state_dict
.
items
():
if
'head'
in
k
or
'dist'
in
k
:
continue
if
'patch_embed.proj.weight'
in
k
and
len
(
v
.
shape
)
<
4
:
# For old models that I trained prior to conv based patchification
O
,
I
,
H
,
W
=
model
.
patch_embed
.
proj
.
weight
.
shape
v
=
v
.
reshape
(
O
,
-
1
,
H
,
W
)
elif
k
==
'pos_embed'
and
v
.
shape
!=
model
.
pos_embed
.
shape
:
# To resize pos embedding when using model at different size from pretrained weights
if
'distilled'
in
pretrain_path
:
logger
.
info
(
"distill need to choose right cls token in the pth."
)
v
=
torch
.
cat
([
v
[:,
0
:
1
],
v
[:,
2
:]],
dim
=
1
)
v
=
resize_pos_embed
(
v
,
model
.
pos_embed
.
data
,
model
.
patch_embed
.
num_y
,
model
.
patch_embed
.
num_x
)
state_dict
[
k
]
=
v
except
FileNotFoundError
as
e
:
logger
.
info
(
f
'
{
pretrain_path
}
is not found! Please check this path.'
)
raise
e
except
KeyError
as
e
:
logger
.
info
(
"State dict keys error! Please check the state dict."
)
raise
e
incompatible
=
model
.
load_state_dict
(
state_dict
,
strict
=
False
)
if
incompatible
.
missing_keys
:
logger
.
info
(
get_missing_parameters_message
(
incompatible
.
missing_keys
)
)
if
incompatible
.
unexpected_keys
:
logger
.
info
(
get_unexpected_parameters_message
(
incompatible
.
unexpected_keys
)
)
return
model
fastreid/modeling/heads/__init__.py
0 → 100644
View file @
b6c19984
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from
.build
import
REID_HEADS_REGISTRY
,
build_heads
# import all the meta_arch, so they will be registered
from
.embedding_head
import
EmbeddingHead
from
.clas_head
import
ClasHead
fastreid/modeling/heads/build.py
0 → 100644
View file @
b6c19984
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from
...utils.registry
import
Registry
REID_HEADS_REGISTRY
=
Registry
(
"HEADS"
)
REID_HEADS_REGISTRY
.
__doc__
=
"""
Registry for reid heads in a baseline model.
ROIHeads take feature maps and region proposals, and
perform per-region computation.
The registered object will be called with `obj(cfg, input_shape)`.
The call is expected to return an :class:`ROIHeads`.
"""
def
build_heads
(
cfg
):
"""
Build REIDHeads defined by `cfg.MODEL.REID_HEADS.NAME`.
"""
head
=
cfg
.
MODEL
.
HEADS
.
NAME
return
REID_HEADS_REGISTRY
.
get
(
head
)(
cfg
)
fastreid/modeling/heads/clas_head.py
0 → 100644
View file @
b6c19984
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import
torch.nn.functional
as
F
from
fastreid.modeling.heads
import
REID_HEADS_REGISTRY
,
EmbeddingHead
@
REID_HEADS_REGISTRY
.
register
()
class
ClasHead
(
EmbeddingHead
):
def
forward
(
self
,
features
,
targets
=
None
):
"""
See :class:`ClsHeads.forward`.
"""
pool_feat
=
self
.
pool_layer
(
features
)
neck_feat
=
self
.
bottleneck
(
pool_feat
)
neck_feat
=
neck_feat
.
view
(
neck_feat
.
size
(
0
),
-
1
)
if
self
.
cls_layer
.
__class__
.
__name__
==
'Linear'
:
logits
=
F
.
linear
(
neck_feat
,
self
.
weight
)
else
:
logits
=
F
.
linear
(
F
.
normalize
(
neck_feat
),
F
.
normalize
(
self
.
weight
))
# Evaluation
if
not
self
.
training
:
return
logits
.
mul_
(
self
.
cls_layer
.
s
)
cls_outputs
=
self
.
cls_layer
(
logits
.
clone
(),
targets
)
return
{
"cls_outputs"
:
cls_outputs
,
"pred_class_logits"
:
logits
.
mul_
(
self
.
cls_layer
.
s
),
"features"
:
neck_feat
,
}
fastreid/modeling/heads/embedding_head.py
0 → 100644
View file @
b6c19984
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
fastreid.config
import
configurable
from
fastreid.layers
import
*
from
fastreid.layers
import
pooling
,
any_softmax
from
fastreid.layers.weight_init
import
weights_init_kaiming
from
.build
import
REID_HEADS_REGISTRY
@
REID_HEADS_REGISTRY
.
register
()
class
EmbeddingHead
(
nn
.
Module
):
"""
EmbeddingHead perform all feature aggregation in an embedding task, such as reid, image retrieval
and face recognition
It typically contains logic to
1. feature aggregation via global average pooling and generalized mean pooling
2. (optional) batchnorm, dimension reduction and etc.
2. (in training only) margin-based softmax logits computation
"""
@
configurable
def
__init__
(
self
,
*
,
feat_dim
,
embedding_dim
,
num_classes
,
neck_feat
,
pool_type
,
cls_type
,
scale
,
margin
,
with_bnneck
,
norm_type
):
"""
NOTE: this interface is experimental.
Args:
feat_dim:
embedding_dim:
num_classes:
neck_feat:
pool_type:
cls_type:
scale:
margin:
with_bnneck:
norm_type:
"""
super
().
__init__
()
# Pooling layer
assert
hasattr
(
pooling
,
pool_type
),
"Expected pool types are {}, "
\
"but got {}"
.
format
(
pooling
.
__all__
,
pool_type
)
self
.
pool_layer
=
getattr
(
pooling
,
pool_type
)()
self
.
neck_feat
=
neck_feat
neck
=
[]
if
embedding_dim
>
0
:
neck
.
append
(
nn
.
Conv2d
(
feat_dim
,
embedding_dim
,
1
,
1
,
bias
=
False
))
feat_dim
=
embedding_dim
if
with_bnneck
:
neck
.
append
(
get_norm
(
norm_type
,
feat_dim
,
bias_freeze
=
True
))
self
.
bottleneck
=
nn
.
Sequential
(
*
neck
)
# Classification head
assert
hasattr
(
any_softmax
,
cls_type
),
"Expected cls types are {}, "
\
"but got {}"
.
format
(
any_softmax
.
__all__
,
cls_type
)
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
num_classes
,
feat_dim
))
self
.
cls_layer
=
getattr
(
any_softmax
,
cls_type
)(
num_classes
,
scale
,
margin
)
self
.
reset_parameters
()
def
reset_parameters
(
self
)
->
None
:
self
.
bottleneck
.
apply
(
weights_init_kaiming
)
nn
.
init
.
normal_
(
self
.
weight
,
std
=
0.01
)
@
classmethod
def
from_config
(
cls
,
cfg
):
# fmt: off
feat_dim
=
cfg
.
MODEL
.
BACKBONE
.
FEAT_DIM
embedding_dim
=
cfg
.
MODEL
.
HEADS
.
EMBEDDING_DIM
num_classes
=
cfg
.
MODEL
.
HEADS
.
NUM_CLASSES
neck_feat
=
cfg
.
MODEL
.
HEADS
.
NECK_FEAT
pool_type
=
cfg
.
MODEL
.
HEADS
.
POOL_LAYER
cls_type
=
cfg
.
MODEL
.
HEADS
.
CLS_LAYER
scale
=
cfg
.
MODEL
.
HEADS
.
SCALE
margin
=
cfg
.
MODEL
.
HEADS
.
MARGIN
with_bnneck
=
cfg
.
MODEL
.
HEADS
.
WITH_BNNECK
norm_type
=
cfg
.
MODEL
.
HEADS
.
NORM
# fmt: on
return
{
'feat_dim'
:
feat_dim
,
'embedding_dim'
:
embedding_dim
,
'num_classes'
:
num_classes
,
'neck_feat'
:
neck_feat
,
'pool_type'
:
pool_type
,
'cls_type'
:
cls_type
,
'scale'
:
scale
,
'margin'
:
margin
,
'with_bnneck'
:
with_bnneck
,
'norm_type'
:
norm_type
}
def
forward
(
self
,
features
,
targets
=
None
):
"""
See :class:`ReIDHeads.forward`.
"""
pool_feat
=
self
.
pool_layer
(
features
)
neck_feat
=
self
.
bottleneck
(
pool_feat
)
neck_feat
=
neck_feat
[...,
0
,
0
]
# Evaluation
# fmt: off
if
not
self
.
training
:
return
neck_feat
# fmt: on
# Training
if
self
.
cls_layer
.
__class__
.
__name__
==
'Linear'
:
logits
=
F
.
linear
(
neck_feat
,
self
.
weight
)
else
:
logits
=
F
.
linear
(
F
.
normalize
(
neck_feat
),
F
.
normalize
(
self
.
weight
))
# Pass logits.clone() into cls_layer, because there is in-place operations
cls_outputs
=
self
.
cls_layer
(
logits
.
clone
(),
targets
)
# fmt: off
if
self
.
neck_feat
==
'before'
:
feat
=
pool_feat
[...,
0
,
0
]
elif
self
.
neck_feat
==
'after'
:
feat
=
neck_feat
else
:
raise
KeyError
(
f
"
{
self
.
neck_feat
}
is invalid for MODEL.HEADS.NECK_FEAT"
)
# fmt: on
return
{
"cls_outputs"
:
cls_outputs
,
"pred_class_logits"
:
logits
.
mul
(
self
.
cls_layer
.
s
),
"features"
:
feat
,
}
fastreid/modeling/losses/__init__.py
0 → 100644
View file @
b6c19984
# encoding: utf-8
"""
@author: l1aoxingyu
@contact: sherlockliao01@gmail.com
"""
from
.circle_loss
import
*
from
.cross_entroy_loss
import
cross_entropy_loss
,
log_accuracy
from
.focal_loss
import
focal_loss
from
.triplet_loss
import
triplet_loss
__all__
=
[
k
for
k
in
globals
().
keys
()
if
not
k
.
startswith
(
"_"
)]
\ No newline at end of file
fastreid/modeling/losses/circle_loss.py
0 → 100644
View file @
b6c19984
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import
torch
import
torch.nn.functional
as
F
__all__
=
[
"pairwise_circleloss"
,
"pairwise_cosface"
]
def
pairwise_circleloss
(
embedding
:
torch
.
Tensor
,
targets
:
torch
.
Tensor
,
margin
:
float
,
gamma
:
float
,
)
->
torch
.
Tensor
:
embedding
=
F
.
normalize
(
embedding
,
dim
=
1
)
dist_mat
=
torch
.
matmul
(
embedding
,
embedding
.
t
())
N
=
dist_mat
.
size
(
0
)
is_pos
=
targets
.
view
(
N
,
1
).
expand
(
N
,
N
).
eq
(
targets
.
view
(
N
,
1
).
expand
(
N
,
N
).
t
()).
float
()
is_neg
=
targets
.
view
(
N
,
1
).
expand
(
N
,
N
).
ne
(
targets
.
view
(
N
,
1
).
expand
(
N
,
N
).
t
()).
float
()
# Mask scores related to itself
is_pos
=
is_pos
-
torch
.
eye
(
N
,
N
,
device
=
is_pos
.
device
)
s_p
=
dist_mat
*
is_pos
s_n
=
dist_mat
*
is_neg
alpha_p
=
torch
.
clamp_min
(
-
s_p
.
detach
()
+
1
+
margin
,
min
=
0.
)
alpha_n
=
torch
.
clamp_min
(
s_n
.
detach
()
+
margin
,
min
=
0.
)
delta_p
=
1
-
margin
delta_n
=
margin
logit_p
=
-
gamma
*
alpha_p
*
(
s_p
-
delta_p
)
+
(
-
99999999.
)
*
(
1
-
is_pos
)
logit_n
=
gamma
*
alpha_n
*
(
s_n
-
delta_n
)
+
(
-
99999999.
)
*
(
1
-
is_neg
)
loss
=
F
.
softplus
(
torch
.
logsumexp
(
logit_p
,
dim
=
1
)
+
torch
.
logsumexp
(
logit_n
,
dim
=
1
)).
mean
()
return
loss
def
pairwise_cosface
(
embedding
:
torch
.
Tensor
,
targets
:
torch
.
Tensor
,
margin
:
float
,
gamma
:
float
,
)
->
torch
.
Tensor
:
# Normalize embedding features
embedding
=
F
.
normalize
(
embedding
,
dim
=
1
)
dist_mat
=
torch
.
matmul
(
embedding
,
embedding
.
t
())
N
=
dist_mat
.
size
(
0
)
is_pos
=
targets
.
view
(
N
,
1
).
expand
(
N
,
N
).
eq
(
targets
.
view
(
N
,
1
).
expand
(
N
,
N
).
t
()).
float
()
is_neg
=
targets
.
view
(
N
,
1
).
expand
(
N
,
N
).
ne
(
targets
.
view
(
N
,
1
).
expand
(
N
,
N
).
t
()).
float
()
# Mask scores related to itself
is_pos
=
is_pos
-
torch
.
eye
(
N
,
N
,
device
=
is_pos
.
device
)
s_p
=
dist_mat
*
is_pos
s_n
=
dist_mat
*
is_neg
logit_p
=
-
gamma
*
s_p
+
(
-
99999999.
)
*
(
1
-
is_pos
)
logit_n
=
gamma
*
(
s_n
+
margin
)
+
(
-
99999999.
)
*
(
1
-
is_neg
)
loss
=
F
.
softplus
(
torch
.
logsumexp
(
logit_p
,
dim
=
1
)
+
torch
.
logsumexp
(
logit_n
,
dim
=
1
)).
mean
()
return
loss
fastreid/modeling/losses/cross_entroy_loss.py
0 → 100644
View file @
b6c19984
# encoding: utf-8
"""
@author: l1aoxingyu
@contact: sherlockliao01@gmail.com
"""
import
torch
import
torch.nn.functional
as
F
from
fastreid.utils.events
import
get_event_storage
def
log_accuracy
(
pred_class_logits
,
gt_classes
,
topk
=
(
1
,)):
"""
Log the accuracy metrics to EventStorage.
"""
bsz
=
pred_class_logits
.
size
(
0
)
maxk
=
max
(
topk
)
_
,
pred_class
=
pred_class_logits
.
topk
(
maxk
,
1
,
True
,
True
)
pred_class
=
pred_class
.
t
()
correct
=
pred_class
.
eq
(
gt_classes
.
view
(
1
,
-
1
).
expand_as
(
pred_class
))
ret
=
[]
for
k
in
topk
:
correct_k
=
correct
[:
k
].
view
(
-
1
).
float
().
sum
(
dim
=
0
,
keepdim
=
True
)
ret
.
append
(
correct_k
.
mul_
(
1.
/
bsz
))
storage
=
get_event_storage
()
storage
.
put_scalar
(
"cls_accuracy"
,
ret
[
0
])
def
cross_entropy_loss
(
pred_class_outputs
,
gt_classes
,
eps
,
alpha
=
0.2
):
num_classes
=
pred_class_outputs
.
size
(
1
)
if
eps
>=
0
:
smooth_param
=
eps
else
:
# Adaptive label smooth regularization
soft_label
=
F
.
softmax
(
pred_class_outputs
,
dim
=
1
)
smooth_param
=
alpha
*
soft_label
[
torch
.
arange
(
soft_label
.
size
(
0
)),
gt_classes
].
unsqueeze
(
1
)
log_probs
=
F
.
log_softmax
(
pred_class_outputs
,
dim
=
1
)
with
torch
.
no_grad
():
targets
=
torch
.
ones_like
(
log_probs
)
targets
*=
smooth_param
/
(
num_classes
-
1
)
targets
.
scatter_
(
1
,
gt_classes
.
data
.
unsqueeze
(
1
),
(
1
-
smooth_param
))
loss
=
(
-
targets
*
log_probs
).
sum
(
dim
=
1
)
with
torch
.
no_grad
():
non_zero_cnt
=
max
(
loss
.
nonzero
(
as_tuple
=
False
).
size
(
0
),
1
)
loss
=
loss
.
sum
()
/
non_zero_cnt
return
loss
fastreid/modeling/losses/focal_loss.py
0 → 100644
View file @
b6c19984
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import
torch
import
torch.nn.functional
as
F
# based on:
# https://github.com/kornia/kornia/blob/master/kornia/losses/focal.py
def
focal_loss
(
input
:
torch
.
Tensor
,
target
:
torch
.
Tensor
,
alpha
:
float
,
gamma
:
float
=
2.0
,
reduction
:
str
=
'mean'
)
->
torch
.
Tensor
:
r
"""Criterion that computes Focal loss.
See :class:`fastreid.modeling.losses.FocalLoss` for details.
According to [1], the Focal loss is computed as follows:
.. math::
\text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
where:
- :math:`p_t` is the model's estimated probability for each class.
Arguments:
alpha (float): Weighting factor :math:`\alpha \in [0, 1]`.
gamma (float): Focusing parameter :math:`\gamma >= 0`.
reduction (str, optional): Specifies the reduction to apply to the
output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied,
‘mean’: the sum of the output will be divided by the number of elements
in the output, ‘sum’: the output will be summed. Default: ‘none’.
Shape:
- Input: :math:`(N, C, *)` where C = number of classes.
- Target: :math:`(N, *)` where each value is
:math:`0 ≤ targets[i] ≤ C−1`.
Examples:
>>> N = 5 # num_classes
>>> loss = FocalLoss(cfg)
>>> input = torch.randn(1, N, 3, 5, requires_grad=True)
>>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
>>> output = loss(input, target)
>>> output.backward()
References:
[1] https://arxiv.org/abs/1708.02002
"""
if
not
torch
.
is_tensor
(
input
):
raise
TypeError
(
"Input type is not a torch.Tensor. Got {}"
.
format
(
type
(
input
)))
if
not
len
(
input
.
shape
)
>=
2
:
raise
ValueError
(
"Invalid input shape, we expect BxCx*. Got: {}"
.
format
(
input
.
shape
))
if
input
.
size
(
0
)
!=
target
.
size
(
0
):
raise
ValueError
(
'Expected input batch_size ({}) to match target batch_size ({}).'
.
format
(
input
.
size
(
0
),
target
.
size
(
0
)))
n
=
input
.
size
(
0
)
out_size
=
(
n
,)
+
input
.
size
()[
2
:]
if
target
.
size
()[
1
:]
!=
input
.
size
()[
2
:]:
raise
ValueError
(
'Expected target size {}, got {}'
.
format
(
out_size
,
target
.
size
()))
if
not
input
.
device
==
target
.
device
:
raise
ValueError
(
"input and target must be in the same device. Got: {}"
.
format
(
input
.
device
,
target
.
device
))
# compute softmax over the classes axis
input_soft
=
F
.
softmax
(
input
,
dim
=
1
)
# create the labels one hot tensor
target_one_hot
=
F
.
one_hot
(
target
,
num_classes
=
input
.
shape
[
1
])
# compute the actual focal loss
weight
=
torch
.
pow
(
-
input_soft
+
1.
,
gamma
)
focal
=
-
alpha
*
weight
*
torch
.
log
(
input_soft
)
loss_tmp
=
torch
.
sum
(
target_one_hot
*
focal
,
dim
=
1
)
if
reduction
==
'none'
:
loss
=
loss_tmp
elif
reduction
==
'mean'
:
loss
=
torch
.
mean
(
loss_tmp
)
elif
reduction
==
'sum'
:
loss
=
torch
.
sum
(
loss_tmp
)
else
:
raise
NotImplementedError
(
"Invalid reduction mode: {}"
.
format
(
reduction
))
return
loss
fastreid/modeling/losses/triplet_loss.py
0 → 100644
View file @
b6c19984
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import
torch
import
torch.nn.functional
as
F
from
.utils
import
euclidean_dist
,
cosine_dist
def
softmax_weights
(
dist
,
mask
):
max_v
=
torch
.
max
(
dist
*
mask
,
dim
=
1
,
keepdim
=
True
)[
0
]
diff
=
dist
-
max_v
Z
=
torch
.
sum
(
torch
.
exp
(
diff
)
*
mask
,
dim
=
1
,
keepdim
=
True
)
+
1e-6
# avoid division by zero
W
=
torch
.
exp
(
diff
)
*
mask
/
Z
return
W
def
hard_example_mining
(
dist_mat
,
is_pos
,
is_neg
):
"""For each anchor, find the hardest positive and negative sample.
Args:
dist_mat: pair wise distance between samples, shape [N, M]
is_pos: positive index with shape [N, M]
is_neg: negative index with shape [N, M]
Returns:
dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
dist_an: pytorch Variable, distance(anchor, negative); shape [N]
p_inds: pytorch LongTensor, with shape [N];
indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1
n_inds: pytorch LongTensor, with shape [N];
indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1
NOTE: Only consider the case in which all labels have same num of samples,
thus we can cope with all anchors in parallel.
"""
assert
len
(
dist_mat
.
size
())
==
2
# `dist_ap` means distance(anchor, positive)
# both `dist_ap` and `relative_p_inds` with shape [N]
dist_ap
,
_
=
torch
.
max
(
dist_mat
*
is_pos
,
dim
=
1
)
# `dist_an` means distance(anchor, negative)
# both `dist_an` and `relative_n_inds` with shape [N]
dist_an
,
_
=
torch
.
min
(
dist_mat
*
is_neg
+
is_pos
*
1e9
,
dim
=
1
)
return
dist_ap
,
dist_an
def
weighted_example_mining
(
dist_mat
,
is_pos
,
is_neg
):
"""For each anchor, find the weighted positive and negative sample.
Args:
dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]
is_pos:
is_neg:
Returns:
dist_ap: pytorch Variable, distance(anchor, positive); shape [N]
dist_an: pytorch Variable, distance(anchor, negative); shape [N]
"""
assert
len
(
dist_mat
.
size
())
==
2
is_pos
=
is_pos
is_neg
=
is_neg
dist_ap
=
dist_mat
*
is_pos
dist_an
=
dist_mat
*
is_neg
weights_ap
=
softmax_weights
(
dist_ap
,
is_pos
)
weights_an
=
softmax_weights
(
-
dist_an
,
is_neg
)
dist_ap
=
torch
.
sum
(
dist_ap
*
weights_ap
,
dim
=
1
)
dist_an
=
torch
.
sum
(
dist_an
*
weights_an
,
dim
=
1
)
return
dist_ap
,
dist_an
def
triplet_loss
(
embedding
,
targets
,
margin
,
norm_feat
,
hard_mining
):
r
"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).
Related Triplet Loss theory can be found in paper 'In Defense of the Triplet
Loss for Person Re-Identification'."""
if
norm_feat
:
dist_mat
=
cosine_dist
(
embedding
,
embedding
)
else
:
dist_mat
=
euclidean_dist
(
embedding
,
embedding
)
# For distributed training, gather all features from different process.
# if comm.get_world_size() > 1:
# all_embedding = torch.cat(GatherLayer.apply(embedding), dim=0)
# all_targets = concat_all_gather(targets)
# else:
# all_embedding = embedding
# all_targets = targets
N
=
dist_mat
.
size
(
0
)
is_pos
=
targets
.
view
(
N
,
1
).
expand
(
N
,
N
).
eq
(
targets
.
view
(
N
,
1
).
expand
(
N
,
N
).
t
()).
float
()
is_neg
=
targets
.
view
(
N
,
1
).
expand
(
N
,
N
).
ne
(
targets
.
view
(
N
,
1
).
expand
(
N
,
N
).
t
()).
float
()
if
hard_mining
:
dist_ap
,
dist_an
=
hard_example_mining
(
dist_mat
,
is_pos
,
is_neg
)
else
:
dist_ap
,
dist_an
=
weighted_example_mining
(
dist_mat
,
is_pos
,
is_neg
)
y
=
dist_an
.
new
().
resize_as_
(
dist_an
).
fill_
(
1
)
if
margin
>
0
:
loss
=
F
.
margin_ranking_loss
(
dist_an
,
dist_ap
,
y
,
margin
=
margin
)
else
:
loss
=
F
.
soft_margin_loss
(
dist_an
-
dist_ap
,
y
)
# fmt: off
if
loss
==
float
(
'Inf'
):
loss
=
F
.
margin_ranking_loss
(
dist_an
,
dist_ap
,
y
,
margin
=
0.3
)
# fmt: on
return
loss
fastreid/modeling/losses/utils.py
0 → 100644
View file @
b6c19984
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import
torch
import
torch.nn.functional
as
F
def
concat_all_gather
(
tensor
):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
tensors_gather
=
[
torch
.
ones_like
(
tensor
)
for
_
in
range
(
torch
.
distributed
.
get_world_size
())]
torch
.
distributed
.
all_gather
(
tensors_gather
,
tensor
,
async_op
=
False
)
output
=
torch
.
cat
(
tensors_gather
,
dim
=
0
)
return
output
def
normalize
(
x
,
axis
=-
1
):
"""Normalizing to unit length along the specified dimension.
Args:
x: pytorch Variable
Returns:
x: pytorch Variable, same shape as input
"""
x
=
1.
*
x
/
(
torch
.
norm
(
x
,
2
,
axis
,
keepdim
=
True
).
expand_as
(
x
)
+
1e-12
)
return
x
def
euclidean_dist
(
x
,
y
):
m
,
n
=
x
.
size
(
0
),
y
.
size
(
0
)
xx
=
torch
.
pow
(
x
,
2
).
sum
(
1
,
keepdim
=
True
).
expand
(
m
,
n
)
yy
=
torch
.
pow
(
y
,
2
).
sum
(
1
,
keepdim
=
True
).
expand
(
n
,
m
).
t
()
dist
=
xx
+
yy
-
2
*
torch
.
matmul
(
x
,
y
.
t
())
dist
=
dist
.
clamp
(
min
=
1e-12
).
sqrt
()
# for numerical stability
return
dist
def
cosine_dist
(
x
,
y
):
x
=
F
.
normalize
(
x
,
dim
=
1
)
y
=
F
.
normalize
(
y
,
dim
=
1
)
dist
=
2
-
2
*
torch
.
mm
(
x
,
y
.
t
())
return
dist
fastreid/modeling/meta_arch/__init__.py
0 → 100644
View file @
b6c19984
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from
.build
import
META_ARCH_REGISTRY
,
build_model
# import all the meta_arch, so they will be registered
from
.baseline
import
Baseline
from
.mgn
import
MGN
from
.moco
import
MoCo
from
.distiller
import
Distiller
fastreid/modeling/meta_arch/baseline.py
0 → 100644
View file @
b6c19984
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import
torch
from
torch
import
nn
from
fastreid.config
import
configurable
from
fastreid.modeling.backbones
import
build_backbone
from
fastreid.modeling.heads
import
build_heads
from
fastreid.modeling.losses
import
*
from
.build
import
META_ARCH_REGISTRY
@
META_ARCH_REGISTRY
.
register
()
class
Baseline
(
nn
.
Module
):
"""
Baseline architecture. Any models that contains the following two components:
1. Per-image feature extraction (aka backbone)
2. Per-image feature aggregation and loss computation
"""
@
configurable
def
__init__
(
self
,
*
,
backbone
,
heads
,
pixel_mean
,
pixel_std
,
loss_kwargs
=
None
):
"""
NOTE: this interface is experimental.
Args:
backbone:
heads:
pixel_mean:
pixel_std:
"""
super
().
__init__
()
# backbone
self
.
backbone
=
backbone
# head
self
.
heads
=
heads
self
.
loss_kwargs
=
loss_kwargs
self
.
register_buffer
(
'pixel_mean'
,
torch
.
Tensor
(
pixel_mean
).
view
(
1
,
-
1
,
1
,
1
),
False
)
self
.
register_buffer
(
'pixel_std'
,
torch
.
Tensor
(
pixel_std
).
view
(
1
,
-
1
,
1
,
1
),
False
)
@
classmethod
def
from_config
(
cls
,
cfg
):
backbone
=
build_backbone
(
cfg
)
heads
=
build_heads
(
cfg
)
return
{
'backbone'
:
backbone
,
'heads'
:
heads
,
'pixel_mean'
:
cfg
.
MODEL
.
PIXEL_MEAN
,
'pixel_std'
:
cfg
.
MODEL
.
PIXEL_STD
,
'loss_kwargs'
:
{
# loss name
'loss_names'
:
cfg
.
MODEL
.
LOSSES
.
NAME
,
# loss hyperparameters
'ce'
:
{
'eps'
:
cfg
.
MODEL
.
LOSSES
.
CE
.
EPSILON
,
'alpha'
:
cfg
.
MODEL
.
LOSSES
.
CE
.
ALPHA
,
'scale'
:
cfg
.
MODEL
.
LOSSES
.
CE
.
SCALE
},
'tri'
:
{
'margin'
:
cfg
.
MODEL
.
LOSSES
.
TRI
.
MARGIN
,
'norm_feat'
:
cfg
.
MODEL
.
LOSSES
.
TRI
.
NORM_FEAT
,
'hard_mining'
:
cfg
.
MODEL
.
LOSSES
.
TRI
.
HARD_MINING
,
'scale'
:
cfg
.
MODEL
.
LOSSES
.
TRI
.
SCALE
},
'circle'
:
{
'margin'
:
cfg
.
MODEL
.
LOSSES
.
CIRCLE
.
MARGIN
,
'gamma'
:
cfg
.
MODEL
.
LOSSES
.
CIRCLE
.
GAMMA
,
'scale'
:
cfg
.
MODEL
.
LOSSES
.
CIRCLE
.
SCALE
},
'cosface'
:
{
'margin'
:
cfg
.
MODEL
.
LOSSES
.
COSFACE
.
MARGIN
,
'gamma'
:
cfg
.
MODEL
.
LOSSES
.
COSFACE
.
GAMMA
,
'scale'
:
cfg
.
MODEL
.
LOSSES
.
COSFACE
.
SCALE
}
}
}
@
property
def
device
(
self
):
return
self
.
pixel_mean
.
device
def
forward
(
self
,
batched_inputs
):
images
=
self
.
preprocess_image
(
batched_inputs
)
features
=
self
.
backbone
(
images
)
if
self
.
training
:
assert
"targets"
in
batched_inputs
,
"Person ID annotation are missing in training!"
targets
=
batched_inputs
[
"targets"
]
# PreciseBN flag, When do preciseBN on different dataset, the number of classes in new dataset
# may be larger than that in the original dataset, so the circle/arcface will
# throw an error. We just set all the targets to 0 to avoid this problem.
if
targets
.
sum
()
<
0
:
targets
.
zero_
()
outputs
=
self
.
heads
(
features
,
targets
)
losses
=
self
.
losses
(
outputs
,
targets
)
return
losses
else
:
outputs
=
self
.
heads
(
features
)
return
outputs
def
preprocess_image
(
self
,
batched_inputs
):
"""
Normalize and batch the input images.
"""
if
isinstance
(
batched_inputs
,
dict
):
images
=
batched_inputs
[
'images'
]
elif
isinstance
(
batched_inputs
,
torch
.
Tensor
):
images
=
batched_inputs
else
:
raise
TypeError
(
"batched_inputs must be dict or torch.Tensor, but get {}"
.
format
(
type
(
batched_inputs
)))
images
.
sub_
(
self
.
pixel_mean
).
div_
(
self
.
pixel_std
)
return
images
def
losses
(
self
,
outputs
,
gt_labels
):
"""
Compute loss from modeling's outputs, the loss function input arguments
must be the same as the outputs of the model forwarding.
"""
# model predictions
# fmt: off
pred_class_logits
=
outputs
[
'pred_class_logits'
].
detach
()
cls_outputs
=
outputs
[
'cls_outputs'
]
pred_features
=
outputs
[
'features'
]
# fmt: on
# Log prediction accuracy
log_accuracy
(
pred_class_logits
,
gt_labels
)
loss_dict
=
{}
loss_names
=
self
.
loss_kwargs
[
'loss_names'
]
if
'CrossEntropyLoss'
in
loss_names
:
ce_kwargs
=
self
.
loss_kwargs
.
get
(
'ce'
)
loss_dict
[
'loss_cls'
]
=
cross_entropy_loss
(
cls_outputs
,
gt_labels
,
ce_kwargs
.
get
(
'eps'
),
ce_kwargs
.
get
(
'alpha'
)
)
*
ce_kwargs
.
get
(
'scale'
)
if
'TripletLoss'
in
loss_names
:
tri_kwargs
=
self
.
loss_kwargs
.
get
(
'tri'
)
loss_dict
[
'loss_triplet'
]
=
triplet_loss
(
pred_features
,
gt_labels
,
tri_kwargs
.
get
(
'margin'
),
tri_kwargs
.
get
(
'norm_feat'
),
tri_kwargs
.
get
(
'hard_mining'
)
)
*
tri_kwargs
.
get
(
'scale'
)
if
'CircleLoss'
in
loss_names
:
circle_kwargs
=
self
.
loss_kwargs
.
get
(
'circle'
)
loss_dict
[
'loss_circle'
]
=
pairwise_circleloss
(
pred_features
,
gt_labels
,
circle_kwargs
.
get
(
'margin'
),
circle_kwargs
.
get
(
'gamma'
)
)
*
circle_kwargs
.
get
(
'scale'
)
if
'Cosface'
in
loss_names
:
cosface_kwargs
=
self
.
loss_kwargs
.
get
(
'cosface'
)
loss_dict
[
'loss_cosface'
]
=
pairwise_cosface
(
pred_features
,
gt_labels
,
cosface_kwargs
.
get
(
'margin'
),
cosface_kwargs
.
get
(
'gamma'
),
)
*
cosface_kwargs
.
get
(
'scale'
)
return
loss_dict
fastreid/modeling/meta_arch/build.py
0 → 100644
View file @
b6c19984
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import
torch
from
fastreid.utils.registry
import
Registry
META_ARCH_REGISTRY
=
Registry
(
"META_ARCH"
)
# noqa F401 isort:skip
META_ARCH_REGISTRY
.
__doc__
=
"""
Registry for meta-architectures, i.e. the whole model.
The registered object will be called with `obj(cfg)`
and expected to return a `nn.Module` object.
"""
def
build_model
(
cfg
):
"""
Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
Note that it does not load any weights from ``cfg``.
"""
meta_arch
=
cfg
.
MODEL
.
META_ARCHITECTURE
model
=
META_ARCH_REGISTRY
.
get
(
meta_arch
)(
cfg
)
model
.
to
(
torch
.
device
(
cfg
.
MODEL
.
DEVICE
))
return
model
fastreid/modeling/meta_arch/distiller.py
0 → 100644
View file @
b6c19984
# encoding: utf-8
"""
@author: l1aoxingyu
@contact: sherlockliao01@gmail.com
"""
import
logging
import
torch
import
torch.nn.functional
as
F
from
fastreid.config
import
get_cfg
from
fastreid.modeling.meta_arch
import
META_ARCH_REGISTRY
,
build_model
,
Baseline
from
fastreid.utils.checkpoint
import
Checkpointer
logger
=
logging
.
getLogger
(
__name__
)
@
META_ARCH_REGISTRY
.
register
()
class
Distiller
(
Baseline
):
def
__init__
(
self
,
cfg
):
super
().
__init__
(
cfg
)
# Get teacher model config
model_ts
=
[]
for
i
in
range
(
len
(
cfg
.
KD
.
MODEL_CONFIG
)):
cfg_t
=
get_cfg
()
cfg_t
.
merge_from_file
(
cfg
.
KD
.
MODEL_CONFIG
[
i
])
cfg_t
.
defrost
()
cfg_t
.
MODEL
.
META_ARCHITECTURE
=
"Baseline"
# Change syncBN to BN due to no DDP wrapper
if
cfg_t
.
MODEL
.
BACKBONE
.
NORM
==
"syncBN"
:
cfg_t
.
MODEL
.
BACKBONE
.
NORM
=
"BN"
if
cfg_t
.
MODEL
.
HEADS
.
NORM
==
"syncBN"
:
cfg_t
.
MODEL
.
HEADS
.
NORM
=
"BN"
model_t
=
build_model
(
cfg_t
)
# No gradients for teacher model
for
param
in
model_t
.
parameters
():
param
.
requires_grad_
(
False
)
logger
.
info
(
"Loading teacher model weights ..."
)
Checkpointer
(
model_t
).
load
(
cfg
.
KD
.
MODEL_WEIGHTS
[
i
])
model_ts
.
append
(
model_t
)
self
.
ema_enabled
=
cfg
.
KD
.
EMA
.
ENABLED
self
.
ema_momentum
=
cfg
.
KD
.
EMA
.
MOMENTUM
if
self
.
ema_enabled
:
cfg_self
=
cfg
.
clone
()
cfg_self
.
defrost
()
cfg_self
.
MODEL
.
META_ARCHITECTURE
=
"Baseline"
if
cfg_self
.
MODEL
.
BACKBONE
.
NORM
==
"syncBN"
:
cfg_self
.
MODEL
.
BACKBONE
.
NORM
=
"BN"
if
cfg_self
.
MODEL
.
HEADS
.
NORM
==
"syncBN"
:
cfg_self
.
MODEL
.
HEADS
.
NORM
=
"BN"
model_self
=
build_model
(
cfg_self
)
# No gradients for self model
for
param
in
model_self
.
parameters
():
param
.
requires_grad_
(
False
)
if
cfg_self
.
MODEL
.
WEIGHTS
!=
''
:
logger
.
info
(
"Loading self distillation model weights ..."
)
Checkpointer
(
model_self
).
load
(
cfg_self
.
MODEL
.
WEIGHTS
)
else
:
# Make sure the initial state is same
for
param_q
,
param_k
in
zip
(
self
.
parameters
(),
model_self
.
parameters
()):
param_k
.
data
.
copy_
(
param_q
.
data
)
model_ts
.
insert
(
0
,
model_self
)
# Not register teacher model as `nn.Module`, this is
# make sure teacher model weights not saved
self
.
model_ts
=
model_ts
@
torch
.
no_grad
()
def
_momentum_update_key_encoder
(
self
,
m
=
0.999
):
"""
Momentum update of the key encoder
"""
for
param_q
,
param_k
in
zip
(
self
.
parameters
(),
self
.
model_ts
[
0
].
parameters
()):
param_k
.
data
=
param_k
.
data
*
m
+
param_q
.
data
*
(
1.
-
m
)
def
forward
(
self
,
batched_inputs
):
if
self
.
training
:
images
=
self
.
preprocess_image
(
batched_inputs
)
# student model forward
s_feat
=
self
.
backbone
(
images
)
assert
"targets"
in
batched_inputs
,
"Labels are missing in training!"
targets
=
batched_inputs
[
"targets"
].
to
(
self
.
device
)
if
targets
.
sum
()
<
0
:
targets
.
zero_
()
s_outputs
=
self
.
heads
(
s_feat
,
targets
)
t_outputs
=
[]
# teacher model forward
with
torch
.
no_grad
():
if
self
.
ema_enabled
:
self
.
_momentum_update_key_encoder
(
self
.
ema_momentum
)
# update self distill model
for
model_t
in
self
.
model_ts
:
t_feat
=
model_t
.
backbone
(
images
)
t_output
=
model_t
.
heads
(
t_feat
,
targets
)
t_outputs
.
append
(
t_output
)
losses
=
self
.
losses
(
s_outputs
,
t_outputs
,
targets
)
return
losses
# Eval mode, just conventional reid feature extraction
else
:
return
super
().
forward
(
batched_inputs
)
def
losses
(
self
,
s_outputs
,
t_outputs
,
gt_labels
):
"""
Compute loss from modeling's outputs, the loss function input arguments
must be the same as the outputs of the model forwarding.
"""
loss_dict
=
super
().
losses
(
s_outputs
,
gt_labels
)
s_logits
=
s_outputs
[
'pred_class_logits'
]
loss_jsdiv
=
0.
for
t_output
in
t_outputs
:
t_logits
=
t_output
[
'pred_class_logits'
].
detach
()
loss_jsdiv
+=
self
.
jsdiv_loss
(
s_logits
,
t_logits
)
loss_dict
[
"loss_jsdiv"
]
=
loss_jsdiv
/
len
(
t_outputs
)
return
loss_dict
@
staticmethod
def
_kldiv
(
y_s
,
y_t
,
t
):
p_s
=
F
.
log_softmax
(
y_s
/
t
,
dim
=
1
)
p_t
=
F
.
softmax
(
y_t
/
t
,
dim
=
1
)
loss
=
F
.
kl_div
(
p_s
,
p_t
,
reduction
=
"sum"
)
*
(
t
**
2
)
/
y_s
.
shape
[
0
]
return
loss
def
jsdiv_loss
(
self
,
y_s
,
y_t
,
t
=
16
):
loss
=
(
self
.
_kldiv
(
y_s
,
y_t
,
t
)
+
self
.
_kldiv
(
y_t
,
y_s
,
t
))
/
2
return
loss
fastreid/modeling/meta_arch/mgn.py
0 → 100644
View file @
b6c19984
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import
copy
import
torch
from
torch
import
nn
from
fastreid.config
import
configurable
from
fastreid.layers
import
get_norm
from
fastreid.modeling.backbones
import
build_backbone
from
fastreid.modeling.backbones.resnet
import
Bottleneck
from
fastreid.modeling.heads
import
build_heads
from
fastreid.modeling.losses
import
*
from
.build
import
META_ARCH_REGISTRY
@
META_ARCH_REGISTRY
.
register
()
class
MGN
(
nn
.
Module
):
"""
Multiple Granularities Network architecture, which contains the following two components:
1. Per-image feature extraction (aka backbone)
2. Multi-branch feature aggregation
"""
@
configurable
def
__init__
(
self
,
*
,
backbone
,
neck1
,
neck2
,
neck3
,
b1_head
,
b2_head
,
b21_head
,
b22_head
,
b3_head
,
b31_head
,
b32_head
,
b33_head
,
pixel_mean
,
pixel_std
,
loss_kwargs
=
None
):
"""
NOTE: this interface is experimental.
Args:
backbone:
neck1:
neck2:
neck3:
b1_head:
b2_head:
b21_head:
b22_head:
b3_head:
b31_head:
b32_head:
b33_head:
pixel_mean:
pixel_std:
loss_kwargs:
"""
super
().
__init__
()
self
.
backbone
=
backbone
# branch1
self
.
b1
=
neck1
self
.
b1_head
=
b1_head
# branch2
self
.
b2
=
neck2
self
.
b2_head
=
b2_head
self
.
b21_head
=
b21_head
self
.
b22_head
=
b22_head
# branch3
self
.
b3
=
neck3
self
.
b3_head
=
b3_head
self
.
b31_head
=
b31_head
self
.
b32_head
=
b32_head
self
.
b33_head
=
b33_head
self
.
loss_kwargs
=
loss_kwargs
self
.
register_buffer
(
'pixel_mean'
,
torch
.
Tensor
(
pixel_mean
).
view
(
1
,
-
1
,
1
,
1
),
False
)
self
.
register_buffer
(
'pixel_std'
,
torch
.
Tensor
(
pixel_std
).
view
(
1
,
-
1
,
1
,
1
),
False
)
@
classmethod
def
from_config
(
cls
,
cfg
):
bn_norm
=
cfg
.
MODEL
.
BACKBONE
.
NORM
with_se
=
cfg
.
MODEL
.
BACKBONE
.
WITH_SE
all_blocks
=
build_backbone
(
cfg
)
# backbone
backbone
=
nn
.
Sequential
(
all_blocks
.
conv1
,
all_blocks
.
bn1
,
all_blocks
.
relu
,
all_blocks
.
maxpool
,
all_blocks
.
layer1
,
all_blocks
.
layer2
,
all_blocks
.
layer3
[
0
]
)
res_conv4
=
nn
.
Sequential
(
*
all_blocks
.
layer3
[
1
:])
res_g_conv5
=
all_blocks
.
layer4
res_p_conv5
=
nn
.
Sequential
(
Bottleneck
(
1024
,
512
,
bn_norm
,
False
,
with_se
,
downsample
=
nn
.
Sequential
(
nn
.
Conv2d
(
1024
,
2048
,
1
,
bias
=
False
),
get_norm
(
bn_norm
,
2048
))),
Bottleneck
(
2048
,
512
,
bn_norm
,
False
,
with_se
),
Bottleneck
(
2048
,
512
,
bn_norm
,
False
,
with_se
))
res_p_conv5
.
load_state_dict
(
all_blocks
.
layer4
.
state_dict
())
# branch
neck1
=
nn
.
Sequential
(
copy
.
deepcopy
(
res_conv4
),
copy
.
deepcopy
(
res_g_conv5
)
)
b1_head
=
build_heads
(
cfg
)
# branch2
neck2
=
nn
.
Sequential
(
copy
.
deepcopy
(
res_conv4
),
copy
.
deepcopy
(
res_p_conv5
)
)
b2_head
=
build_heads
(
cfg
)
b21_head
=
build_heads
(
cfg
)
b22_head
=
build_heads
(
cfg
)
# branch3
neck3
=
nn
.
Sequential
(
copy
.
deepcopy
(
res_conv4
),
copy
.
deepcopy
(
res_p_conv5
)
)
b3_head
=
build_heads
(
cfg
)
b31_head
=
build_heads
(
cfg
)
b32_head
=
build_heads
(
cfg
)
b33_head
=
build_heads
(
cfg
)
return
{
'backbone'
:
backbone
,
'neck1'
:
neck1
,
'neck2'
:
neck2
,
'neck3'
:
neck3
,
'b1_head'
:
b1_head
,
'b2_head'
:
b2_head
,
'b21_head'
:
b21_head
,
'b22_head'
:
b22_head
,
'b3_head'
:
b3_head
,
'b31_head'
:
b31_head
,
'b32_head'
:
b32_head
,
'b33_head'
:
b33_head
,
'pixel_mean'
:
cfg
.
MODEL
.
PIXEL_MEAN
,
'pixel_std'
:
cfg
.
MODEL
.
PIXEL_STD
,
'loss_kwargs'
:
{
# loss name
'loss_names'
:
cfg
.
MODEL
.
LOSSES
.
NAME
,
# loss hyperparameters
'ce'
:
{
'eps'
:
cfg
.
MODEL
.
LOSSES
.
CE
.
EPSILON
,
'alpha'
:
cfg
.
MODEL
.
LOSSES
.
CE
.
ALPHA
,
'scale'
:
cfg
.
MODEL
.
LOSSES
.
CE
.
SCALE
},
'tri'
:
{
'margin'
:
cfg
.
MODEL
.
LOSSES
.
TRI
.
MARGIN
,
'norm_feat'
:
cfg
.
MODEL
.
LOSSES
.
TRI
.
NORM_FEAT
,
'hard_mining'
:
cfg
.
MODEL
.
LOSSES
.
TRI
.
HARD_MINING
,
'scale'
:
cfg
.
MODEL
.
LOSSES
.
TRI
.
SCALE
},
'circle'
:
{
'margin'
:
cfg
.
MODEL
.
LOSSES
.
CIRCLE
.
MARGIN
,
'gamma'
:
cfg
.
MODEL
.
LOSSES
.
CIRCLE
.
GAMMA
,
'scale'
:
cfg
.
MODEL
.
LOSSES
.
CIRCLE
.
SCALE
},
'cosface'
:
{
'margin'
:
cfg
.
MODEL
.
LOSSES
.
COSFACE
.
MARGIN
,
'gamma'
:
cfg
.
MODEL
.
LOSSES
.
COSFACE
.
GAMMA
,
'scale'
:
cfg
.
MODEL
.
LOSSES
.
COSFACE
.
SCALE
}
}
}
@
property
def
device
(
self
):
return
self
.
pixel_mean
.
device
def
forward
(
self
,
batched_inputs
):
images
=
self
.
preprocess_image
(
batched_inputs
)
features
=
self
.
backbone
(
images
)
# (bs, 2048, 16, 8)
# branch1
b1_feat
=
self
.
b1
(
features
)
# branch2
b2_feat
=
self
.
b2
(
features
)
b21_feat
,
b22_feat
=
torch
.
chunk
(
b2_feat
,
2
,
dim
=
2
)
# branch3
b3_feat
=
self
.
b3
(
features
)
b31_feat
,
b32_feat
,
b33_feat
=
torch
.
chunk
(
b3_feat
,
3
,
dim
=
2
)
if
self
.
training
:
assert
"targets"
in
batched_inputs
,
"Person ID annotation are missing in training!"
targets
=
batched_inputs
[
"targets"
]
if
targets
.
sum
()
<
0
:
targets
.
zero_
()
b1_outputs
=
self
.
b1_head
(
b1_feat
,
targets
)
b2_outputs
=
self
.
b2_head
(
b2_feat
,
targets
)
b21_outputs
=
self
.
b21_head
(
b21_feat
,
targets
)
b22_outputs
=
self
.
b22_head
(
b22_feat
,
targets
)
b3_outputs
=
self
.
b3_head
(
b3_feat
,
targets
)
b31_outputs
=
self
.
b31_head
(
b31_feat
,
targets
)
b32_outputs
=
self
.
b32_head
(
b32_feat
,
targets
)
b33_outputs
=
self
.
b33_head
(
b33_feat
,
targets
)
losses
=
self
.
losses
(
b1_outputs
,
b2_outputs
,
b21_outputs
,
b22_outputs
,
b3_outputs
,
b31_outputs
,
b32_outputs
,
b33_outputs
,
targets
)
return
losses
else
:
b1_pool_feat
=
self
.
b1_head
(
b1_feat
)
b2_pool_feat
=
self
.
b2_head
(
b2_feat
)
b21_pool_feat
=
self
.
b21_head
(
b21_feat
)
b22_pool_feat
=
self
.
b22_head
(
b22_feat
)
b3_pool_feat
=
self
.
b3_head
(
b3_feat
)
b31_pool_feat
=
self
.
b31_head
(
b31_feat
)
b32_pool_feat
=
self
.
b32_head
(
b32_feat
)
b33_pool_feat
=
self
.
b33_head
(
b33_feat
)
pred_feat
=
torch
.
cat
([
b1_pool_feat
,
b2_pool_feat
,
b3_pool_feat
,
b21_pool_feat
,
b22_pool_feat
,
b31_pool_feat
,
b32_pool_feat
,
b33_pool_feat
],
dim
=
1
)
return
pred_feat
def
preprocess_image
(
self
,
batched_inputs
):
r
"""
Normalize and batch the input images.
"""
if
isinstance
(
batched_inputs
,
dict
):
images
=
batched_inputs
[
"images"
].
to
(
self
.
device
)
elif
isinstance
(
batched_inputs
,
torch
.
Tensor
):
images
=
batched_inputs
.
to
(
self
.
device
)
else
:
raise
TypeError
(
"batched_inputs must be dict or torch.Tensor, but get {}"
.
format
(
type
(
batched_inputs
)))
images
.
sub_
(
self
.
pixel_mean
).
div_
(
self
.
pixel_std
)
return
images
def
losses
(
self
,
b1_outputs
,
b2_outputs
,
b21_outputs
,
b22_outputs
,
b3_outputs
,
b31_outputs
,
b32_outputs
,
b33_outputs
,
gt_labels
):
# model predictions
# fmt: off
pred_class_logits
=
b1_outputs
[
'pred_class_logits'
].
detach
()
b1_logits
=
b1_outputs
[
'cls_outputs'
]
b2_logits
=
b2_outputs
[
'cls_outputs'
]
b21_logits
=
b21_outputs
[
'cls_outputs'
]
b22_logits
=
b22_outputs
[
'cls_outputs'
]
b3_logits
=
b3_outputs
[
'cls_outputs'
]
b31_logits
=
b31_outputs
[
'cls_outputs'
]
b32_logits
=
b32_outputs
[
'cls_outputs'
]
b33_logits
=
b33_outputs
[
'cls_outputs'
]
b1_pool_feat
=
b1_outputs
[
'features'
]
b2_pool_feat
=
b2_outputs
[
'features'
]
b3_pool_feat
=
b3_outputs
[
'features'
]
b21_pool_feat
=
b21_outputs
[
'features'
]
b22_pool_feat
=
b22_outputs
[
'features'
]
b31_pool_feat
=
b31_outputs
[
'features'
]
b32_pool_feat
=
b32_outputs
[
'features'
]
b33_pool_feat
=
b33_outputs
[
'features'
]
# fmt: on
# Log prediction accuracy
log_accuracy
(
pred_class_logits
,
gt_labels
)
b22_pool_feat
=
torch
.
cat
((
b21_pool_feat
,
b22_pool_feat
),
dim
=
1
)
b33_pool_feat
=
torch
.
cat
((
b31_pool_feat
,
b32_pool_feat
,
b33_pool_feat
),
dim
=
1
)
loss_dict
=
{}
loss_names
=
self
.
loss_kwargs
[
'loss_names'
]
if
"CrossEntropyLoss"
in
loss_names
:
ce_kwargs
=
self
.
loss_kwargs
.
get
(
'ce'
)
loss_dict
[
'loss_cls_b1'
]
=
cross_entropy_loss
(
b1_logits
,
gt_labels
,
ce_kwargs
.
get
(
'eps'
),
ce_kwargs
.
get
(
'alpha'
)
)
*
ce_kwargs
.
get
(
'scale'
)
*
0.125
loss_dict
[
'loss_cls_b2'
]
=
cross_entropy_loss
(
b2_logits
,
gt_labels
,
ce_kwargs
.
get
(
'eps'
),
ce_kwargs
.
get
(
'alpha'
)
)
*
ce_kwargs
.
get
(
'scale'
)
*
0.125
loss_dict
[
'loss_cls_b21'
]
=
cross_entropy_loss
(
b21_logits
,
gt_labels
,
ce_kwargs
.
get
(
'eps'
),
ce_kwargs
.
get
(
'alpha'
)
)
*
ce_kwargs
.
get
(
'scale'
)
*
0.125
loss_dict
[
'loss_cls_b22'
]
=
cross_entropy_loss
(
b22_logits
,
gt_labels
,
ce_kwargs
.
get
(
'eps'
),
ce_kwargs
.
get
(
'alpha'
)
)
*
ce_kwargs
.
get
(
'scale'
)
*
0.125
loss_dict
[
'loss_cls_b3'
]
=
cross_entropy_loss
(
b3_logits
,
gt_labels
,
ce_kwargs
.
get
(
'eps'
),
ce_kwargs
.
get
(
'alpha'
)
)
*
ce_kwargs
.
get
(
'scale'
)
*
0.125
loss_dict
[
'loss_cls_b31'
]
=
cross_entropy_loss
(
b31_logits
,
gt_labels
,
ce_kwargs
.
get
(
'eps'
),
ce_kwargs
.
get
(
'alpha'
)
)
*
ce_kwargs
.
get
(
'scale'
)
*
0.125
loss_dict
[
'loss_cls_b32'
]
=
cross_entropy_loss
(
b32_logits
,
gt_labels
,
ce_kwargs
.
get
(
'eps'
),
ce_kwargs
.
get
(
'alpha'
)
)
*
ce_kwargs
.
get
(
'scale'
)
*
0.125
loss_dict
[
'loss_cls_b33'
]
=
cross_entropy_loss
(
b33_logits
,
gt_labels
,
ce_kwargs
.
get
(
'eps'
),
ce_kwargs
.
get
(
'alpha'
)
)
*
ce_kwargs
.
get
(
'scale'
)
*
0.125
if
"TripletLoss"
in
loss_names
:
tri_kwargs
=
self
.
loss_kwargs
.
get
(
'tri'
)
loss_dict
[
'loss_triplet_b1'
]
=
triplet_loss
(
b1_pool_feat
,
gt_labels
,
tri_kwargs
.
get
(
'margin'
),
tri_kwargs
.
get
(
'norm_feat'
),
tri_kwargs
.
get
(
'hard_mining'
)
)
*
tri_kwargs
.
get
(
'scale'
)
*
0.2
loss_dict
[
'loss_triplet_b2'
]
=
triplet_loss
(
b2_pool_feat
,
gt_labels
,
tri_kwargs
.
get
(
'margin'
),
tri_kwargs
.
get
(
'norm_feat'
),
tri_kwargs
.
get
(
'hard_mining'
)
)
*
tri_kwargs
.
get
(
'scale'
)
*
0.2
loss_dict
[
'loss_triplet_b3'
]
=
triplet_loss
(
b3_pool_feat
,
gt_labels
,
tri_kwargs
.
get
(
'margin'
),
tri_kwargs
.
get
(
'norm_feat'
),
tri_kwargs
.
get
(
'hard_mining'
)
)
*
tri_kwargs
.
get
(
'scale'
)
*
0.2
loss_dict
[
'loss_triplet_b22'
]
=
triplet_loss
(
b22_pool_feat
,
gt_labels
,
tri_kwargs
.
get
(
'margin'
),
tri_kwargs
.
get
(
'norm_feat'
),
tri_kwargs
.
get
(
'hard_mining'
)
)
*
tri_kwargs
.
get
(
'scale'
)
*
0.2
loss_dict
[
'loss_triplet_b33'
]
=
triplet_loss
(
b33_pool_feat
,
gt_labels
,
tri_kwargs
.
get
(
'margin'
),
tri_kwargs
.
get
(
'norm_feat'
),
tri_kwargs
.
get
(
'hard_mining'
)
)
*
tri_kwargs
.
get
(
'scale'
)
*
0.2
return
loss_dict
fastreid/modeling/meta_arch/moco.py
0 → 100644
View file @
b6c19984
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
fastreid.modeling.losses.utils
import
concat_all_gather
from
fastreid.utils
import
comm
from
.baseline
import
Baseline
from
.build
import
META_ARCH_REGISTRY
@
META_ARCH_REGISTRY
.
register
()
class
MoCo
(
Baseline
):
def
__init__
(
self
,
cfg
):
super
().
__init__
(
cfg
)
dim
=
cfg
.
MODEL
.
HEADS
.
EMBEDDING_DIM
if
cfg
.
MODEL
.
HEADS
.
EMBEDDING_DIM
\
else
cfg
.
MODEL
.
BACKBONE
.
FEAT_DIM
size
=
cfg
.
MODEL
.
QUEUE_SIZE
self
.
memory
=
Memory
(
dim
,
size
)
def
losses
(
self
,
outputs
,
gt_labels
):
"""
Compute loss from modeling's outputs, the loss function input arguments
must be the same as the outputs of the model forwarding.
"""
# regular reid loss
loss_dict
=
super
().
losses
(
outputs
,
gt_labels
)
# memory loss
pred_features
=
outputs
[
'features'
]
loss_mb
=
self
.
memory
(
pred_features
,
gt_labels
)
loss_dict
[
'loss_mb'
]
=
loss_mb
return
loss_dict
class
Memory
(
nn
.
Module
):
"""
Build a MoCo memory with a queue
https://arxiv.org/abs/1911.05722
"""
def
__init__
(
self
,
dim
=
512
,
K
=
65536
):
"""
dim: feature dimension (default: 128)
K: queue size; number of negative keys (default: 65536)
"""
super
().
__init__
()
self
.
K
=
K
self
.
margin
=
0.25
self
.
gamma
=
32
# create the queue
self
.
register_buffer
(
"queue"
,
torch
.
randn
(
dim
,
K
))
self
.
queue
=
F
.
normalize
(
self
.
queue
,
dim
=
0
)
self
.
register_buffer
(
"queue_label"
,
torch
.
zeros
((
1
,
K
),
dtype
=
torch
.
long
))
self
.
register_buffer
(
"queue_ptr"
,
torch
.
zeros
(
1
,
dtype
=
torch
.
long
))
@
torch
.
no_grad
()
def
_dequeue_and_enqueue
(
self
,
keys
,
targets
):
# gather keys/targets before updating queue
if
comm
.
get_world_size
()
>
1
:
keys
=
concat_all_gather
(
keys
)
targets
=
concat_all_gather
(
targets
)
else
:
keys
=
keys
.
detach
()
targets
=
targets
.
detach
()
batch_size
=
keys
.
shape
[
0
]
ptr
=
int
(
self
.
queue_ptr
)
assert
self
.
K
%
batch_size
==
0
# for simplicity
# replace the keys at ptr (dequeue and enqueue)
self
.
queue
[:,
ptr
:
ptr
+
batch_size
]
=
keys
.
T
self
.
queue_label
[:,
ptr
:
ptr
+
batch_size
]
=
targets
ptr
=
(
ptr
+
batch_size
)
%
self
.
K
# move pointer
self
.
queue_ptr
[
0
]
=
ptr
def
forward
(
self
,
feat_q
,
targets
):
"""
Memory bank enqueue and compute metric loss
Args:
feat_q: model features
targets: gt labels
Returns:
"""
# normalize embedding features
feat_q
=
F
.
normalize
(
feat_q
,
p
=
2
,
dim
=
1
)
# dequeue and enqueue
self
.
_dequeue_and_enqueue
(
feat_q
.
detach
(),
targets
)
# compute loss
loss
=
self
.
_pairwise_cosface
(
feat_q
,
targets
)
return
loss
def
_pairwise_cosface
(
self
,
feat_q
,
targets
):
dist_mat
=
torch
.
matmul
(
feat_q
,
self
.
queue
)
N
,
M
=
dist_mat
.
size
()
# (bsz, memory)
is_pos
=
targets
.
view
(
N
,
1
).
expand
(
N
,
M
).
eq
(
self
.
queue_label
.
expand
(
N
,
M
)).
float
()
is_neg
=
targets
.
view
(
N
,
1
).
expand
(
N
,
M
).
ne
(
self
.
queue_label
.
expand
(
N
,
M
)).
float
()
# Mask scores related to themselves
same_indx
=
torch
.
eye
(
N
,
N
,
device
=
is_pos
.
device
)
other_indx
=
torch
.
zeros
(
N
,
M
-
N
,
device
=
is_pos
.
device
)
same_indx
=
torch
.
cat
((
same_indx
,
other_indx
),
dim
=
1
)
is_pos
=
is_pos
-
same_indx
s_p
=
dist_mat
*
is_pos
s_n
=
dist_mat
*
is_neg
logit_p
=
-
self
.
gamma
*
s_p
+
(
-
99999999.
)
*
(
1
-
is_pos
)
logit_n
=
self
.
gamma
*
(
s_n
+
self
.
margin
)
+
(
-
99999999.
)
*
(
1
-
is_neg
)
loss
=
F
.
softplus
(
torch
.
logsumexp
(
logit_p
,
dim
=
1
)
+
torch
.
logsumexp
(
logit_n
,
dim
=
1
)).
mean
()
return
loss
fastreid/solver/__init__.py
0 → 100644
View file @
b6c19984
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from
.build
import
build_lr_scheduler
,
build_optimizer
\ No newline at end of file
Prev
1
…
7
8
9
10
11
12
13
14
15
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment