Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
Fast-ReID_pytorch
Commits
b6c19984
Commit
b6c19984
authored
Nov 18, 2025
by
dengjb
Browse files
update
parents
Changes
435
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1219 additions
and
0 deletions
+1219
-0
projects/FastAttr/fastattr/modeling/__init__.py
projects/FastAttr/fastattr/modeling/__init__.py
+9
-0
projects/FastAttr/fastattr/modeling/attr_baseline.py
projects/FastAttr/fastattr/modeling/attr_baseline.py
+44
-0
projects/FastAttr/fastattr/modeling/attr_head.py
projects/FastAttr/fastattr/modeling/attr_head.py
+43
-0
projects/FastAttr/fastattr/modeling/bce_loss.py
projects/FastAttr/fastattr/modeling/bce_loss.py
+33
-0
projects/FastAttr/train_net.py
projects/FastAttr/train_net.py
+125
-0
projects/FastClas/README.md
projects/FastClas/README.md
+16
-0
projects/FastClas/configs/base-clas.yaml
projects/FastClas/configs/base-clas.yaml
+77
-0
projects/FastClas/fastclas/__init__.py
projects/FastClas/fastclas/__init__.py
+10
-0
projects/FastClas/fastclas/bee_ant.py
projects/FastClas/fastclas/bee_ant.py
+50
-0
projects/FastClas/fastclas/dataset.py
projects/FastClas/fastclas/dataset.py
+50
-0
projects/FastClas/fastclas/trainer.py
projects/FastClas/fastclas/trainer.py
+82
-0
projects/FastClas/train_net.py
projects/FastClas/train_net.py
+73
-0
projects/FastDistill/README.md
projects/FastDistill/README.md
+52
-0
projects/FastDistill/configs/Base-kd.yml
projects/FastDistill/configs/Base-kd.yml
+28
-0
projects/FastDistill/configs/kd-sbs_r101ibn-sbs_r34.yml
projects/FastDistill/configs/kd-sbs_r101ibn-sbs_r34.yml
+19
-0
projects/FastDistill/configs/sbs_r101ibn.yml
projects/FastDistill/configs/sbs_r101ibn.yml
+13
-0
projects/FastDistill/configs/sbs_r34.yml
projects/FastDistill/configs/sbs_r34.yml
+14
-0
projects/FastDistill/fastdistill/__init__.py
projects/FastDistill/fastdistill/__init__.py
+8
-0
projects/FastDistill/fastdistill/overhaul.py
projects/FastDistill/fastdistill/overhaul.py
+126
-0
projects/FastDistill/fastdistill/resnet_distill.py
projects/FastDistill/fastdistill/resnet_distill.py
+347
-0
No files found.
projects/FastAttr/fastattr/modeling/__init__.py
0 → 100644
View file @
b6c19984
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
from
.attr_baseline
import
AttrBaseline
from
.attr_head
import
AttrHead
from
.bce_loss
import
cross_entropy_sigmoid_loss
projects/FastAttr/fastattr/modeling/attr_baseline.py
0 → 100644
View file @
b6c19984
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from
fastreid.modeling.meta_arch.baseline
import
Baseline
from
fastreid.modeling.meta_arch.build
import
META_ARCH_REGISTRY
from
.bce_loss
import
cross_entropy_sigmoid_loss
@
META_ARCH_REGISTRY
.
register
()
class
AttrBaseline
(
Baseline
):
@
classmethod
def
from_config
(
cls
,
cfg
):
base_res
=
Baseline
.
from_config
(
cfg
)
base_res
[
"loss_kwargs"
].
update
({
'bce'
:
{
'scale'
:
cfg
.
MODEL
.
LOSSES
.
BCE
.
SCALE
}
})
return
base_res
def
losses
(
self
,
outputs
,
gt_labels
):
r
"""
Compute loss from modeling's outputs, the loss function input arguments
must be the same as the outputs of the model forwarding.
"""
# model predictions
cls_outputs
=
outputs
[
"cls_outputs"
]
loss_dict
=
{}
loss_names
=
self
.
loss_kwargs
[
"loss_names"
]
if
"BinaryCrossEntropyLoss"
in
loss_names
:
bce_kwargs
=
self
.
loss_kwargs
.
get
(
'bce'
)
loss_dict
[
"loss_bce"
]
=
cross_entropy_sigmoid_loss
(
cls_outputs
,
gt_labels
,
self
.
sample_weights
,
)
*
bce_kwargs
.
get
(
'scale'
)
return
loss_dict
projects/FastAttr/fastattr/modeling/attr_head.py
0 → 100644
View file @
b6c19984
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
fastreid.modeling.heads
import
EmbeddingHead
from
fastreid.modeling.heads.build
import
REID_HEADS_REGISTRY
from
fastreid.layers.weight_init
import
weights_init_kaiming
@
REID_HEADS_REGISTRY
.
register
()
class
AttrHead
(
EmbeddingHead
):
def
__init__
(
self
,
cfg
):
super
().
__init__
(
cfg
)
num_classes
=
cfg
.
MODEL
.
HEADS
.
NUM_CLASSES
self
.
bnneck
=
nn
.
BatchNorm1d
(
num_classes
)
self
.
bnneck
.
apply
(
weights_init_kaiming
)
def
forward
(
self
,
features
,
targets
=
None
):
"""
See :class:`ReIDHeads.forward`.
"""
pool_feat
=
self
.
pool_layer
(
features
)
neck_feat
=
self
.
bottleneck
(
pool_feat
)
neck_feat
=
neck_feat
.
view
(
neck_feat
.
size
(
0
),
-
1
)
logits
=
F
.
linear
(
neck_feat
,
self
.
weight
)
logits
=
self
.
bnneck
(
logits
)
# Evaluation
if
not
self
.
training
:
cls_outptus
=
torch
.
sigmoid
(
logits
)
return
cls_outptus
return
{
"cls_outputs"
:
logits
,
}
projects/FastAttr/fastattr/modeling/bce_loss.py
0 → 100644
View file @
b6c19984
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import
torch
import
torch.nn.functional
as
F
def
ratio2weight
(
targets
,
ratio
):
pos_weights
=
targets
*
(
1
-
ratio
)
neg_weights
=
(
1
-
targets
)
*
ratio
weights
=
torch
.
exp
(
neg_weights
+
pos_weights
)
weights
[
targets
>
1
]
=
0.0
return
weights
def
cross_entropy_sigmoid_loss
(
pred_class_logits
,
gt_classes
,
sample_weight
=
None
):
loss
=
F
.
binary_cross_entropy_with_logits
(
pred_class_logits
,
gt_classes
,
reduction
=
'none'
)
if
sample_weight
is
not
None
:
targets_mask
=
torch
.
where
(
gt_classes
.
detach
()
>
0.5
,
torch
.
ones
(
1
,
device
=
"cuda"
),
torch
.
zeros
(
1
,
device
=
"cuda"
))
# dtype float32
weight
=
ratio2weight
(
targets_mask
,
sample_weight
)
loss
=
loss
*
weight
with
torch
.
no_grad
():
non_zero_cnt
=
max
(
loss
.
nonzero
(
as_tuple
=
False
).
size
(
0
),
1
)
loss
=
loss
.
sum
()
/
non_zero_cnt
return
loss
projects/FastAttr/train_net.py
0 → 100644
View file @
b6c19984
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import
logging
import
sys
sys
.
path
.
append
(
'.'
)
from
fastreid.config
import
get_cfg
from
fastreid.engine
import
DefaultTrainer
from
fastreid.engine
import
default_argument_parser
,
default_setup
,
launch
from
fastreid.utils.checkpoint
import
Checkpointer
from
fastreid.data.datasets
import
DATASET_REGISTRY
from
fastreid.data.build
import
_root
,
build_reid_train_loader
,
build_reid_test_loader
from
fastreid.data.transforms
import
build_transforms
from
fastreid.utils
import
comm
from
fastattr
import
*
class
AttrTrainer
(
DefaultTrainer
):
sample_weights
=
None
@
classmethod
def
build_model
(
cls
,
cfg
):
"""
Returns:
torch.nn.Module:
It now calls :func:`fastreid.modeling.build_model`.
Overwrite it if you'd like a different model.
"""
model
=
DefaultTrainer
.
build_model
(
cfg
)
if
cfg
.
MODEL
.
LOSSES
.
BCE
.
WEIGHT_ENABLED
and
\
AttrTrainer
.
sample_weights
is
not
None
:
setattr
(
model
,
"sample_weights"
,
AttrTrainer
.
sample_weights
.
to
(
model
.
device
))
else
:
setattr
(
model
,
"sample_weights"
,
None
)
return
model
@
classmethod
def
build_train_loader
(
cls
,
cfg
):
logger
=
logging
.
getLogger
(
"fastreid.attr_dataset"
)
train_items
=
list
()
attr_dict
=
None
for
d
in
cfg
.
DATASETS
.
NAMES
:
dataset
=
DATASET_REGISTRY
.
get
(
d
)(
root
=
_root
,
combineall
=
cfg
.
DATASETS
.
COMBINEALL
)
if
comm
.
is_main_process
():
dataset
.
show_train
()
if
attr_dict
is
not
None
:
assert
attr_dict
==
dataset
.
attr_dict
,
f
"attr_dict in
{
d
}
does not match with previous ones"
else
:
attr_dict
=
dataset
.
attr_dict
train_items
.
extend
(
dataset
.
train
)
train_transforms
=
build_transforms
(
cfg
,
is_train
=
True
)
train_set
=
AttrDataset
(
train_items
,
train_transforms
,
attr_dict
)
data_loader
=
build_reid_train_loader
(
cfg
,
train_set
=
train_set
)
AttrTrainer
.
sample_weights
=
data_loader
.
dataset
.
sample_weights
return
data_loader
@
classmethod
def
build_test_loader
(
cls
,
cfg
,
dataset_name
):
dataset
=
DATASET_REGISTRY
.
get
(
dataset_name
)(
root
=
_root
)
attr_dict
=
dataset
.
attr_dict
if
comm
.
is_main_process
():
dataset
.
show_test
()
test_items
=
dataset
.
test
test_transforms
=
build_transforms
(
cfg
,
is_train
=
False
)
test_set
=
AttrDataset
(
test_items
,
test_transforms
,
attr_dict
)
data_loader
,
_
=
build_reid_test_loader
(
cfg
,
test_set
=
test_set
)
return
data_loader
@
classmethod
def
build_evaluator
(
cls
,
cfg
,
dataset_name
,
output_folder
=
None
):
data_loader
=
cls
.
build_test_loader
(
cfg
,
dataset_name
)
return
data_loader
,
AttrEvaluator
(
cfg
,
output_folder
)
def
setup
(
args
):
"""
Create configs and perform basic setups.
"""
cfg
=
get_cfg
()
add_attr_config
(
cfg
)
cfg
.
merge_from_file
(
args
.
config_file
)
cfg
.
merge_from_list
(
args
.
opts
)
cfg
.
freeze
()
default_setup
(
cfg
,
args
)
return
cfg
def
main
(
args
):
cfg
=
setup
(
args
)
if
args
.
eval_only
:
cfg
.
defrost
()
cfg
.
MODEL
.
BACKBONE
.
PRETRAIN
=
False
model
=
AttrTrainer
.
build_model
(
cfg
)
Checkpointer
(
model
).
load
(
cfg
.
MODEL
.
WEIGHTS
)
# load trained model
res
=
AttrTrainer
.
test
(
cfg
,
model
)
return
res
trainer
=
AttrTrainer
(
cfg
)
trainer
.
resume_or_load
(
resume
=
args
.
resume
)
return
trainer
.
train
()
if
__name__
==
"__main__"
:
args
=
default_argument_parser
().
parse_args
()
print
(
"Command Line Args:"
,
args
)
launch
(
main
,
args
.
num_gpus
,
num_machines
=
args
.
num_machines
,
machine_rank
=
args
.
machine_rank
,
dist_url
=
args
.
dist_url
,
args
=
(
args
,),
)
projects/FastClas/README.md
0 → 100644
View file @
b6c19984
# FastClas in FastReID
This project provides a baseline and example for image classification based on fastreid.
## Datasets Preparation
We refer to
[
pytorch tutorial
](
https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
)
for dataset
preparation. This is just an example for building a classification task based on fastreid. You can customize
your own datasets and model.
## Usage
If you want to train models with 4 gpus, you can run
```
bash
python3 projects/FastClas/train_net.py
--config-file
projects/FastClas/config/base-clas.yml
--num-gpus
4
```
projects/FastClas/configs/base-clas.yaml
0 → 100644
View file @
b6c19984
MODEL
:
META_ARCHITECTURE
:
Baseline
BACKBONE
:
NAME
:
build_resnet_backbone
DEPTH
:
18x
NORM
:
BN
LAST_STRIDE
:
2
FEAT_DIM
:
512
PRETRAIN
:
True
HEADS
:
NAME
:
ClasHead
WITH_BNNECK
:
False
EMBEDDING_DIM
:
0
POOL_LAYER
:
FastGlobalAvgPool
CLS_LAYER
:
Linear
NUM_CLASSES
:
2
LOSSES
:
NAME
:
("CrossEntropyLoss",)
CE
:
EPSILON
:
0.1
SCALE
:
1.
INPUT
:
SIZE_TRAIN
:
[
0
,]
# no need for resize when training
SIZE_TEST
:
[
256
,]
CROP
:
ENABLED
:
True
SIZE
:
[
224
,]
SCALE
:
[
0.08
,
1
]
RATIO
:
[
0.75
,
1.333333333
]
FLIP
:
ENABLED
:
True
DATALOADER
:
SAMPLER_TRAIN
:
TrainingSampler
NUM_WORKERS
:
8
SOLVER
:
MAX_EPOCH
:
100
AMP
:
ENABLED
:
True
OPT
:
SGD
SCHED
:
CosineAnnealingLR
BASE_LR
:
0.001
MOMENTUM
:
0.9
NESTEROV
:
False
BIAS_LR_FACTOR
:
1.
WEIGHT_DECAY
:
0.0005
WEIGHT_DECAY_BIAS
:
0.
IMS_PER_BATCH
:
16
ETA_MIN_LR
:
0.00003
WARMUP_FACTOR
:
0.1
WARMUP_ITERS
:
100
CHECKPOINT_PERIOD
:
10
TEST
:
EVAL_PERIOD
:
10
IMS_PER_BATCH
:
256
DATASETS
:
NAMES
:
("Hymenoptera",)
TESTS
:
("Hymenoptera",)
OUTPUT_DIR
:
projects/FastClas/logs/r18_demo
\ No newline at end of file
projects/FastClas/fastclas/__init__.py
0 → 100644
View file @
b6c19984
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
from
.bee_ant
import
*
from
.distracted_driver
import
*
from
.dataset
import
ClasDataset
from
.trainer
import
ClasTrainer
projects/FastClas/fastclas/bee_ant.py
0 → 100644
View file @
b6c19984
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import
glob
import
os
from
fastreid.data.datasets
import
DATASET_REGISTRY
from
fastreid.data.datasets.bases
import
ImageDataset
__all__
=
[
"Hymenoptera"
]
@
DATASET_REGISTRY
.
register
()
class
Hymenoptera
(
ImageDataset
):
"""This is a demo dataset for smoke test, you can refer to
https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
"""
dataset_dir
=
'hymenoptera_data'
dataset_name
=
"hyt"
def
__init__
(
self
,
root
=
'datasets'
,
**
kwargs
):
self
.
root
=
root
self
.
dataset_dir
=
os
.
path
.
join
(
self
.
root
,
self
.
dataset_dir
)
train_dir
=
os
.
path
.
join
(
self
.
dataset_dir
,
"train"
)
val_dir
=
os
.
path
.
join
(
self
.
dataset_dir
,
"val"
)
required_files
=
[
self
.
dataset_dir
,
train_dir
,
val_dir
,
]
self
.
check_before_run
(
required_files
)
train
=
self
.
process_dir
(
train_dir
)
val
=
self
.
process_dir
(
val_dir
)
super
().
__init__
(
train
,
val
,
[],
**
kwargs
)
def
process_dir
(
self
,
data_dir
):
data
=
[]
all_dirs
=
[
d
.
name
for
d
in
os
.
scandir
(
data_dir
)
if
d
.
is_dir
()]
for
dir_name
in
all_dirs
:
all_imgs
=
glob
.
glob
(
os
.
path
.
join
(
data_dir
,
dir_name
,
"*.jpg"
))
for
img_name
in
all_imgs
:
data
.
append
([
img_name
,
dir_name
,
'0'
])
return
data
projects/FastClas/fastclas/dataset.py
0 → 100644
View file @
b6c19984
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from
torch.utils.data
import
Dataset
from
fastreid.data.data_utils
import
read_image
class
ClasDataset
(
Dataset
):
"""Image Person ReID Dataset"""
def
__init__
(
self
,
img_items
,
transform
=
None
,
idx_to_class
=
None
):
self
.
img_items
=
img_items
self
.
transform
=
transform
if
idx_to_class
is
not
None
:
self
.
idx_to_class
=
idx_to_class
self
.
class_to_idx
=
{
clas_name
:
int
(
i
)
for
i
,
clas_name
in
self
.
idx_to_class
.
items
()}
self
.
classes
=
sorted
(
list
(
self
.
idx_to_class
.
values
()))
else
:
classes
=
set
()
for
i
in
img_items
:
classes
.
add
(
i
[
1
])
self
.
classes
=
sorted
(
list
(
classes
))
self
.
class_to_idx
=
{
cls_name
:
i
for
i
,
cls_name
in
enumerate
(
self
.
classes
)}
self
.
idx_to_class
=
{
idx
:
clas
for
clas
,
idx
in
self
.
class_to_idx
.
items
()}
def
__len__
(
self
):
return
len
(
self
.
img_items
)
def
__getitem__
(
self
,
index
):
img_item
=
self
.
img_items
[
index
]
img_path
=
img_item
[
0
]
label
=
self
.
class_to_idx
[
img_item
[
1
]]
img
=
read_image
(
img_path
)
if
self
.
transform
is
not
None
:
img
=
self
.
transform
(
img
)
return
{
"images"
:
img
,
"targets"
:
label
,
"img_paths"
:
img_path
,
}
@
property
def
num_classes
(
self
):
return
len
(
self
.
classes
)
projects/FastClas/fastclas/trainer.py
0 → 100644
View file @
b6c19984
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import
json
import
logging
import
os
from
fastreid.data.build
import
_root
from
fastreid.data.build
import
build_reid_train_loader
,
build_reid_test_loader
from
fastreid.data.datasets
import
DATASET_REGISTRY
from
fastreid.data.transforms
import
build_transforms
from
fastreid.engine
import
DefaultTrainer
from
fastreid.evaluation.clas_evaluator
import
ClasEvaluator
from
fastreid.utils
import
comm
from
fastreid.utils.checkpoint
import
PathManager
from
.dataset
import
ClasDataset
class
ClasTrainer
(
DefaultTrainer
):
idx2class
=
None
@
classmethod
def
build_train_loader
(
cls
,
cfg
):
"""
Returns:
iterable
It now calls :func:`fastreid.data.build_reid_train_loader`.
Overwrite it if you'd like a different data loader.
"""
logger
=
logging
.
getLogger
(
"fastreid.clas_dataset"
)
logger
.
info
(
"Prepare training set"
)
train_items
=
list
()
for
d
in
cfg
.
DATASETS
.
NAMES
:
data
=
DATASET_REGISTRY
.
get
(
d
)(
root
=
_root
)
if
comm
.
is_main_process
():
data
.
show_train
()
train_items
.
extend
(
data
.
train
)
transforms
=
build_transforms
(
cfg
,
is_train
=
True
)
train_set
=
ClasDataset
(
train_items
,
transforms
)
cls
.
idx2class
=
train_set
.
idx_to_class
data_loader
=
build_reid_train_loader
(
cfg
,
train_set
=
train_set
)
return
data_loader
@
classmethod
def
build_test_loader
(
cls
,
cfg
,
dataset_name
):
"""
Returns:
iterable
It now calls :func:`fastreid.data.build_reid_test_loader`.
Overwrite it if you'd like a different data loader.
"""
data
=
DATASET_REGISTRY
.
get
(
dataset_name
)(
root
=
_root
)
if
comm
.
is_main_process
():
data
.
show_test
()
transforms
=
build_transforms
(
cfg
,
is_train
=
False
)
test_set
=
ClasDataset
(
data
.
query
,
transforms
,
cls
.
idx2class
)
data_loader
,
_
=
build_reid_test_loader
(
cfg
,
test_set
=
test_set
)
return
data_loader
@
classmethod
def
build_evaluator
(
cls
,
cfg
,
dataset_name
,
output_dir
=
None
):
data_loader
=
cls
.
build_test_loader
(
cfg
,
dataset_name
)
return
data_loader
,
ClasEvaluator
(
cfg
,
output_dir
)
@
staticmethod
def
auto_scale_hyperparams
(
cfg
,
num_classes
):
cfg
=
DefaultTrainer
.
auto_scale_hyperparams
(
cfg
,
num_classes
)
# Save index to class dictionary
output_dir
=
cfg
.
OUTPUT_DIR
if
comm
.
is_main_process
()
and
output_dir
:
path
=
os
.
path
.
join
(
output_dir
,
"idx2class.json"
)
with
PathManager
.
open
(
path
,
"w"
)
as
f
:
json
.
dump
(
ClasTrainer
.
idx2class
,
f
)
return
cfg
projects/FastClas/train_net.py
0 → 100644
View file @
b6c19984
#!/usr/bin/env python
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
import
json
import
logging
import
os
import
sys
sys
.
path
.
append
(
'.'
)
from
fastreid.config
import
get_cfg
from
fastreid.engine
import
default_argument_parser
,
default_setup
,
launch
from
fastreid.utils.checkpoint
import
Checkpointer
,
PathManager
from
fastclas
import
*
def
setup
(
args
):
"""
Create configs and perform basic setups.
"""
cfg
=
get_cfg
()
cfg
.
merge_from_file
(
args
.
config_file
)
cfg
.
merge_from_list
(
args
.
opts
)
cfg
.
freeze
()
default_setup
(
cfg
,
args
)
return
cfg
def
main
(
args
):
cfg
=
setup
(
args
)
if
args
.
eval_only
:
cfg
.
defrost
()
cfg
.
MODEL
.
BACKBONE
.
PRETRAIN
=
False
model
=
ClasTrainer
.
build_model
(
cfg
)
Checkpointer
(
model
).
load
(
cfg
.
MODEL
.
WEIGHTS
)
# load trained model
try
:
output_dir
=
os
.
path
.
dirname
(
cfg
.
MODEL
.
WEIGHTS
)
path
=
os
.
path
.
join
(
output_dir
,
"idx2class.json"
)
with
PathManager
.
open
(
path
,
'r'
)
as
f
:
idx2class
=
json
.
load
(
f
)
ClasTrainer
.
idx2class
=
idx2class
except
:
logger
=
logging
.
getLogger
(
"fastreid.fastclas"
)
logger
.
info
(
f
"Cannot find idx2class dict in
{
os
.
path
.
dirname
(
cfg
.
MODEL
.
WEIGHTS
)
}
"
)
res
=
ClasTrainer
.
test
(
cfg
,
model
)
return
res
trainer
=
ClasTrainer
(
cfg
)
trainer
.
resume_or_load
(
resume
=
args
.
resume
)
return
trainer
.
train
()
if
__name__
==
"__main__"
:
args
=
default_argument_parser
().
parse_args
()
print
(
"Command Line Args:"
,
args
)
launch
(
main
,
args
.
num_gpus
,
num_machines
=
args
.
num_machines
,
machine_rank
=
args
.
machine_rank
,
dist_url
=
args
.
dist_url
,
args
=
(
args
,),
)
projects/FastDistill/README.md
0 → 100644
View file @
b6c19984
# FastDistill in FastReID
This project provides a strong distillation method for both embedding and classification training.
The feature distillation comes from
[
overhaul-distillation
](
https://github.com/clovaai/overhaul-distillation/tree/master/ImageNet
)
.
## Datasets Prepration
-
DukeMTMC-reID
## Train and Evaluation
```
shell
# teacher model training
python3 projects/FastDistill/train_net.py
\
--config-file
projects/FastDistill/configs/sbs_r101ibn.yml
\
--num-gpus
4
# loss distillation
python3 projects/FastDistill/train_net.py
\
--config-file
projects/FastDistill/configs/kd-sbs_r101ibn-sbs_r34.yaml
\
--num-gpus
4
\
MODEL.META_ARCHITECTURE Distiller
KD.MODEL_CONFIG
'("projects/FastDistill/logs/dukemtmc/r101_ibn/config.yaml",)'
\
KD.MODEL_WEIGHTS
'("projects/FastDistill/logs/dukemtmc/r101_ibn/model_best.pth",)'
# loss+overhaul distillation
python3 projects/FastDistill/train_net.py
\
--config-file
projects/FastDistill/configs/kd-sbs_r101ibn-sbs_r34.yaml
\
--num-gpus
4
\
MODEL.META_ARCHITECTURE DistillerOverhaul
KD.MODEL_CONFIG
'("projects/FastDistill/logs/dukemtmc/r101_ibn/config.yaml",)'
\
KD.MODEL_WEIGHTS
'("projects/FastDistill/logs/dukemtmc/r101_ibn/model_best.pth",)'
```
## Experimental Results
### Settings
All the experiments are conducted with 4 V100 GPUs.
### DukeMTMC-reID
| Model | Rank@1 | mAP |
| --- | --- | --- |
| R101_ibn (teacher) | 90.66 | 81.14 |
| R34 (student) | 86.31 | 73.28 |
| JS Div | 88.60 | 77.80 |
| JS Div + Overhaul | 88.73 | 78.25 |
## Contact
This project is conducted by
[
Xingyu Liao
](
https://github.com/L1aoXingyu
)
and
[
Guan'an Wang
](
https://wangguanan.github.io/
)
(
guan.wang0706@gmail
)
.
projects/FastDistill/configs/Base-kd.yml
0 → 100644
View file @
b6c19984
_BASE_
:
../../../configs/Base-SBS.yml
MODEL
:
BACKBONE
:
NAME
:
build_resnet_backbone_distill
WITH_IBN
:
False
WITH_NL
:
False
PRETRAIN
:
True
INPUT
:
SIZE_TRAIN
:
[
256
,
128
]
SIZE_TEST
:
[
256
,
128
]
SOLVER
:
MAX_EPOCH
:
60
BASE_LR
:
0.0007
IMS_PER_BATCH
:
256
DELAY_EPOCHS
:
30
FREEZE_ITERS
:
500
CHECKPOINT_PERIOD
:
20
TEST
:
EVAL_PERIOD
:
20
IMS_PER_BATCH
:
128
CUDNN_BENCHMARK
:
True
projects/FastDistill/configs/kd-sbs_r101ibn-sbs_r34.yml
0 → 100644
View file @
b6c19984
_BASE_
:
Base-kd.yml
MODEL
:
META_ARCHITECTURE
:
Distiller
BACKBONE
:
DEPTH
:
34x
FEAT_DIM
:
512
WITH_IBN
:
False
KD
:
MODEL_CONFIG
:
("projects/FastDistill/logs/dukemtmc/r101_ibn/config.yaml",)
MODEL_WEIGHTS
:
("projects/FastDistill/logs/dukemtmc/r101_ibn/model_best.pth",)
DATASETS
:
NAMES
:
("DukeMTMC",)
TESTS
:
("DukeMTMC",)
OUTPUT_DIR
:
projects/FastDistill/logs/dukemtmc/kd-r34-r101_ibn
\ No newline at end of file
projects/FastDistill/configs/sbs_r101ibn.yml
0 → 100644
View file @
b6c19984
_BASE_
:
Base-kd.yml
MODEL
:
BACKBONE
:
WITH_IBN
:
True
DEPTH
:
101x
DATASETS
:
NAMES
:
("DukeMTMC",)
TESTS
:
("DukeMTMC",)
OUTPUT_DIR
:
projects/FastDistill/logs/dukemtmc/r101_ibn
\ No newline at end of file
projects/FastDistill/configs/sbs_r34.yml
0 → 100644
View file @
b6c19984
_BASE_
:
Base-kd.yml
MODEL
:
BACKBONE
:
DEPTH
:
34x
FEAT_DIM
:
512
WITH_IBN
:
False
DATASETS
:
NAMES
:
("DukeMTMC",)
TESTS
:
("DukeMTMC",)
OUTPUT_DIR
:
projects/FastDistill/logs/dukemtmc/r34
\ No newline at end of file
projects/FastDistill/fastdistill/__init__.py
0 → 100644
View file @
b6c19984
# encoding: utf-8
"""
@author: l1aoxingyu
@contact: sherlockliao01@gmail.com
"""
from
.overhaul
import
DistillerOverhaul
from
.resnet_distill
import
build_resnet_backbone_distill
projects/FastDistill/fastdistill/overhaul.py
0 → 100644
View file @
b6c19984
# encoding: utf-8
"""
@author: l1aoxingyu
@contact: sherlockliao01@gmail.com
"""
import
logging
import
math
import
torch
import
torch.nn.functional
as
F
from
scipy.stats
import
norm
from
torch
import
nn
from
fastreid.modeling.meta_arch
import
META_ARCH_REGISTRY
,
Distiller
logger
=
logging
.
getLogger
(
"fastreid.meta_arch.overhaul_distiller"
)
def
distillation_loss
(
source
,
target
,
margin
):
target
=
torch
.
max
(
target
,
margin
)
loss
=
F
.
mse_loss
(
source
,
target
,
reduction
=
"none"
)
loss
=
loss
*
((
source
>
target
)
|
(
target
>
0
)).
float
()
return
loss
.
sum
()
def
build_feature_connector
(
t_channel
,
s_channel
):
C
=
[
nn
.
Conv2d
(
s_channel
,
t_channel
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
False
),
nn
.
BatchNorm2d
(
t_channel
)]
for
m
in
C
:
if
isinstance
(
m
,
nn
.
Conv2d
):
n
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
m
.
weight
.
data
.
normal_
(
0
,
math
.
sqrt
(
2.
/
n
))
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
m
.
weight
.
data
.
fill_
(
1
)
m
.
bias
.
data
.
zero_
()
return
nn
.
Sequential
(
*
C
)
def
get_margin_from_BN
(
bn
):
margin
=
[]
std
=
bn
.
weight
.
data
mean
=
bn
.
bias
.
data
for
(
s
,
m
)
in
zip
(
std
,
mean
):
s
=
abs
(
s
.
item
())
m
=
m
.
item
()
if
norm
.
cdf
(
-
m
/
s
)
>
0.001
:
margin
.
append
(
-
s
*
math
.
exp
(
-
(
m
/
s
)
**
2
/
2
)
/
\
math
.
sqrt
(
2
*
math
.
pi
)
/
norm
.
cdf
(
-
m
/
s
)
+
m
)
else
:
margin
.
append
(
-
3
*
s
)
return
torch
.
tensor
(
margin
,
dtype
=
torch
.
float32
,
device
=
mean
.
device
)
@
META_ARCH_REGISTRY
.
register
()
class
DistillerOverhaul
(
Distiller
):
def
__init__
(
self
,
cfg
):
super
().
__init__
(
cfg
)
s_channels
=
self
.
backbone
.
get_channel_nums
()
for
i
in
range
(
len
(
self
.
model_ts
)):
t_channels
=
self
.
model_ts
[
i
].
backbone
.
get_channel_nums
()
setattr
(
self
,
"connectors_{}"
.
format
(
i
),
nn
.
ModuleList
(
[
build_feature_connector
(
t
,
s
)
for
t
,
s
in
zip
(
t_channels
,
s_channels
)]))
teacher_bns
=
self
.
model_ts
[
i
].
backbone
.
get_bn_before_relu
()
margins
=
[
get_margin_from_BN
(
bn
)
for
bn
in
teacher_bns
]
for
j
,
margin
in
enumerate
(
margins
):
self
.
register_buffer
(
"margin{}_{}"
.
format
(
i
,
j
+
1
),
margin
.
unsqueeze
(
1
).
unsqueeze
(
2
).
unsqueeze
(
0
).
detach
())
def
forward
(
self
,
batched_inputs
):
if
self
.
training
:
images
=
self
.
preprocess_image
(
batched_inputs
)
# student model forward
s_feats
,
s_feat
=
self
.
backbone
.
extract_feature
(
images
,
preReLU
=
True
)
assert
"targets"
in
batched_inputs
,
"Labels are missing in training!"
targets
=
batched_inputs
[
"targets"
].
to
(
self
.
device
)
if
targets
.
sum
()
<
0
:
targets
.
zero_
()
s_outputs
=
self
.
heads
(
s_feat
,
targets
)
t_feats_list
=
[]
t_outputs
=
[]
# teacher model forward
with
torch
.
no_grad
():
if
self
.
ema_enabled
:
self
.
_momentum_update_key_encoder
(
self
.
ema_momentum
)
for
model_t
in
self
.
model_ts
:
t_feats
,
t_feat
=
model_t
.
backbone
.
extract_feature
(
images
,
preReLU
=
True
)
t_output
=
model_t
.
heads
(
t_feat
,
targets
)
t_feats_list
.
append
(
t_feats
)
t_outputs
.
append
(
t_output
)
losses
=
self
.
losses
(
s_outputs
,
s_feats
,
t_outputs
,
t_feats_list
,
targets
)
return
losses
else
:
outputs
=
super
(
DistillerOverhaul
,
self
).
forward
(
batched_inputs
)
return
outputs
def
losses
(
self
,
s_outputs
,
s_feats
,
t_outputs
,
t_feats_list
,
gt_labels
):
"""
Compute loss from modeling's outputs, the loss function input arguments
must be the same as the outputs of the model forwarding.
"""
loss_dict
=
super
().
losses
(
s_outputs
,
t_outputs
,
gt_labels
)
# Overhaul distillation loss
feat_num
=
len
(
s_feats
)
loss_distill
=
0
for
i
in
range
(
len
(
t_feats_list
)):
for
j
in
range
(
feat_num
):
s_feats_connect
=
getattr
(
self
,
"connectors_{}"
.
format
(
i
))[
j
](
s_feats
[
j
])
loss_distill
+=
distillation_loss
(
s_feats_connect
,
t_feats_list
[
i
][
j
].
detach
(),
getattr
(
self
,
"margin{}_{}"
.
format
(
i
,
j
+
1
)).
to
(
s_feats_connect
.
dtype
))
/
2
**
(
feat_num
-
j
-
1
)
loss_dict
[
"loss_overhaul"
]
=
loss_distill
/
len
(
t_feats_list
)
/
len
(
gt_labels
)
/
10000
return
loss_dict
projects/FastDistill/fastdistill/resnet_distill.py
0 → 100644
View file @
b6c19984
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import
logging
import
math
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
fastreid.layers
import
(
IBN
,
SELayer
,
get_norm
,
)
from
fastreid.modeling.backbones
import
BACKBONE_REGISTRY
from
fastreid.utils
import
comm
from
fastreid.utils.checkpoint
import
get_missing_parameters_message
,
get_unexpected_parameters_message
logger
=
logging
.
getLogger
(
"fastreid.overhaul.backbone"
)
model_urls
=
{
'18x'
:
'https://download.pytorch.org/models/resnet18-5c106cde.pth'
,
'34x'
:
'https://download.pytorch.org/models/resnet34-333f7ec4.pth'
,
'50x'
:
'https://download.pytorch.org/models/resnet50-19c8e357.pth'
,
'101x'
:
'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'
,
'ibn_18x'
:
'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet18_ibn_a-2f571257.pth'
,
'ibn_34x'
:
'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet34_ibn_a-94bc1577.pth'
,
'ibn_50x'
:
'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet50_ibn_a-d9d0bb7b.pth'
,
'ibn_101x'
:
'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet101_ibn_a-59ea0ac6.pth'
,
'se_ibn_101x'
:
'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/se_resnet101_ibn_a-fabed4e2.pth'
,
}
class
BasicBlock
(
nn
.
Module
):
expansion
=
1
def
__init__
(
self
,
inplanes
,
planes
,
bn_norm
,
with_ibn
=
False
,
with_se
=
False
,
stride
=
1
,
downsample
=
None
,
reduction
=
16
):
super
(
BasicBlock
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
inplanes
,
planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
1
,
bias
=
False
)
if
with_ibn
:
self
.
bn1
=
IBN
(
planes
,
bn_norm
)
else
:
self
.
bn1
=
get_norm
(
bn_norm
,
planes
)
self
.
conv2
=
nn
.
Conv2d
(
planes
,
planes
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
)
self
.
bn2
=
get_norm
(
bn_norm
,
planes
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
if
with_se
:
self
.
se
=
SELayer
(
planes
,
reduction
)
else
:
self
.
se
=
nn
.
Identity
()
self
.
downsample
=
downsample
self
.
stride
=
stride
def
forward
(
self
,
x
):
x
=
self
.
relu
(
x
)
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
out
=
self
.
se
(
out
)
if
self
.
downsample
is
not
None
:
identity
=
self
.
downsample
(
x
)
out
+=
identity
# out = self.relu(out)
return
out
class
Bottleneck
(
nn
.
Module
):
expansion
=
4
def
__init__
(
self
,
inplanes
,
planes
,
bn_norm
,
with_ibn
=
False
,
with_se
=
False
,
stride
=
1
,
downsample
=
None
,
reduction
=
16
):
super
(
Bottleneck
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
inplanes
,
planes
,
kernel_size
=
1
,
bias
=
False
)
if
with_ibn
:
self
.
bn1
=
IBN
(
planes
,
bn_norm
)
else
:
self
.
bn1
=
get_norm
(
bn_norm
,
planes
)
self
.
conv2
=
nn
.
Conv2d
(
planes
,
planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
1
,
bias
=
False
)
self
.
bn2
=
get_norm
(
bn_norm
,
planes
)
self
.
conv3
=
nn
.
Conv2d
(
planes
,
planes
*
self
.
expansion
,
kernel_size
=
1
,
bias
=
False
)
self
.
bn3
=
get_norm
(
bn_norm
,
planes
*
self
.
expansion
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
if
with_se
:
self
.
se
=
SELayer
(
planes
*
self
.
expansion
,
reduction
)
else
:
self
.
se
=
nn
.
Identity
()
self
.
downsample
=
downsample
self
.
stride
=
stride
def
forward
(
self
,
x
):
x
=
self
.
relu
(
x
)
residual
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv3
(
out
)
out
=
self
.
bn3
(
out
)
out
=
self
.
se
(
out
)
if
self
.
downsample
is
not
None
:
residual
=
self
.
downsample
(
x
)
out
+=
residual
# out = self.relu(out)
return
out
class
ResNet
(
nn
.
Module
):
def
__init__
(
self
,
last_stride
,
bn_norm
,
with_ibn
,
with_se
,
with_nl
,
block
,
layers
,
non_layers
):
self
.
channel_nums
=
[]
self
.
inplanes
=
64
super
().
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
3
,
64
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
bias
=
False
)
self
.
bn1
=
get_norm
(
bn_norm
,
64
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
# self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self
.
maxpool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
ceil_mode
=
True
)
self
.
layer1
=
self
.
_make_layer
(
block
,
64
,
layers
[
0
],
1
,
bn_norm
,
with_ibn
,
with_se
)
self
.
layer2
=
self
.
_make_layer
(
block
,
128
,
layers
[
1
],
2
,
bn_norm
,
with_ibn
,
with_se
)
self
.
layer3
=
self
.
_make_layer
(
block
,
256
,
layers
[
2
],
2
,
bn_norm
,
with_ibn
,
with_se
)
self
.
layer4
=
self
.
_make_layer
(
block
,
512
,
layers
[
3
],
last_stride
,
bn_norm
,
with_se
=
with_se
)
self
.
random_init
()
def
_make_layer
(
self
,
block
,
planes
,
blocks
,
stride
=
1
,
bn_norm
=
"BN"
,
with_ibn
=
False
,
with_se
=
False
):
downsample
=
None
if
stride
!=
1
or
self
.
inplanes
!=
planes
*
block
.
expansion
:
downsample
=
nn
.
Sequential
(
nn
.
Conv2d
(
self
.
inplanes
,
planes
*
block
.
expansion
,
kernel_size
=
1
,
stride
=
stride
,
bias
=
False
),
get_norm
(
bn_norm
,
planes
*
block
.
expansion
),
)
layers
=
[]
layers
.
append
(
block
(
self
.
inplanes
,
planes
,
bn_norm
,
with_ibn
,
with_se
,
stride
,
downsample
))
self
.
inplanes
=
planes
*
block
.
expansion
for
i
in
range
(
1
,
blocks
):
layers
.
append
(
block
(
self
.
inplanes
,
planes
,
bn_norm
,
with_ibn
,
with_se
))
self
.
channel_nums
.
append
(
self
.
inplanes
)
return
nn
.
Sequential
(
*
layers
)
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
maxpool
(
x
)
x
=
self
.
layer1
(
x
)
x
=
self
.
layer2
(
x
)
x
=
self
.
layer3
(
x
)
x
=
self
.
layer4
(
x
)
x
=
F
.
relu
(
x
,
inplace
=
True
)
return
x
def
random_init
(
self
):
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
n
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
nn
.
init
.
normal_
(
m
.
weight
,
0
,
math
.
sqrt
(
2.
/
n
))
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
nn
.
init
.
constant_
(
m
.
weight
,
1
)
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
get_bn_before_relu
(
self
):
if
isinstance
(
self
.
layer1
[
0
],
Bottleneck
):
bn1
=
self
.
layer1
[
-
1
].
bn3
bn2
=
self
.
layer2
[
-
1
].
bn3
bn3
=
self
.
layer3
[
-
1
].
bn3
bn4
=
self
.
layer4
[
-
1
].
bn3
elif
isinstance
(
self
.
layer1
[
0
],
BasicBlock
):
bn1
=
self
.
layer1
[
-
1
].
bn2
bn2
=
self
.
layer2
[
-
1
].
bn2
bn3
=
self
.
layer3
[
-
1
].
bn2
bn4
=
self
.
layer4
[
-
1
].
bn2
else
:
logger
.
info
(
"ResNet unknown block error!"
)
return
[
bn1
,
bn2
,
bn3
,
bn4
]
def
extract_feature
(
self
,
x
,
preReLU
=
False
):
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
maxpool
(
x
)
feat1
=
self
.
layer1
(
x
)
feat2
=
self
.
layer2
(
feat1
)
feat3
=
self
.
layer3
(
feat2
)
feat4
=
self
.
layer4
(
feat3
)
if
not
preReLU
:
feat1
=
F
.
relu
(
feat1
)
feat2
=
F
.
relu
(
feat2
)
feat3
=
F
.
relu
(
feat3
)
feat4
=
F
.
relu
(
feat4
)
return
[
feat1
,
feat2
,
feat3
,
feat4
],
F
.
relu
(
feat4
)
def
get_channel_nums
(
self
):
return
self
.
channel_nums
def
init_pretrained_weights
(
key
):
"""Initializes model with pretrained weights.
Layers that don't match with pretrained layers in name or size are kept unchanged.
"""
import
os
import
errno
import
gdown
def
_get_torch_home
():
ENV_TORCH_HOME
=
'TORCH_HOME'
ENV_XDG_CACHE_HOME
=
'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR
=
'~/.cache'
torch_home
=
os
.
path
.
expanduser
(
os
.
getenv
(
ENV_TORCH_HOME
,
os
.
path
.
join
(
os
.
getenv
(
ENV_XDG_CACHE_HOME
,
DEFAULT_CACHE_DIR
),
'torch'
)
)
)
return
torch_home
torch_home
=
_get_torch_home
()
model_dir
=
os
.
path
.
join
(
torch_home
,
'checkpoints'
)
try
:
os
.
makedirs
(
model_dir
)
except
OSError
as
e
:
if
e
.
errno
==
errno
.
EEXIST
:
# Directory already exists, ignore.
pass
else
:
# Unexpected OSError, re-raise.
raise
filename
=
model_urls
[
key
].
split
(
'/'
)[
-
1
]
cached_file
=
os
.
path
.
join
(
model_dir
,
filename
)
if
not
os
.
path
.
exists
(
cached_file
):
if
comm
.
is_main_process
():
gdown
.
download
(
model_urls
[
key
],
cached_file
,
quiet
=
False
)
comm
.
synchronize
()
logger
.
info
(
f
"Loading pretrained model from
{
cached_file
}
"
)
state_dict
=
torch
.
load
(
cached_file
,
map_location
=
torch
.
device
(
'cpu'
))
return
state_dict
@
BACKBONE_REGISTRY
.
register
()
def
build_resnet_backbone_distill
(
cfg
):
"""
Create a ResNet instance from config.
Returns:
ResNet: a :class:`ResNet` instance.
"""
# fmt: off
pretrain
=
cfg
.
MODEL
.
BACKBONE
.
PRETRAIN
pretrain_path
=
cfg
.
MODEL
.
BACKBONE
.
PRETRAIN_PATH
last_stride
=
cfg
.
MODEL
.
BACKBONE
.
LAST_STRIDE
bn_norm
=
cfg
.
MODEL
.
BACKBONE
.
NORM
with_ibn
=
cfg
.
MODEL
.
BACKBONE
.
WITH_IBN
with_se
=
cfg
.
MODEL
.
BACKBONE
.
WITH_SE
with_nl
=
cfg
.
MODEL
.
BACKBONE
.
WITH_NL
depth
=
cfg
.
MODEL
.
BACKBONE
.
DEPTH
# fmt: on
num_blocks_per_stage
=
{
'18x'
:
[
2
,
2
,
2
,
2
],
'34x'
:
[
3
,
4
,
6
,
3
],
'50x'
:
[
3
,
4
,
6
,
3
],
'101x'
:
[
3
,
4
,
23
,
3
],
}[
depth
]
nl_layers_per_stage
=
{
'18x'
:
[
0
,
0
,
0
,
0
],
'34x'
:
[
0
,
0
,
0
,
0
],
'50x'
:
[
0
,
2
,
3
,
0
],
'101x'
:
[
0
,
2
,
9
,
0
]
}[
depth
]
block
=
{
'18x'
:
BasicBlock
,
'34x'
:
BasicBlock
,
'50x'
:
Bottleneck
,
'101x'
:
Bottleneck
}[
depth
]
model
=
ResNet
(
last_stride
,
bn_norm
,
with_ibn
,
with_se
,
with_nl
,
block
,
num_blocks_per_stage
,
nl_layers_per_stage
)
if
pretrain
:
# Load pretrain path if specifically
if
pretrain_path
:
try
:
state_dict
=
torch
.
load
(
pretrain_path
,
map_location
=
torch
.
device
(
'cpu'
))
logger
.
info
(
f
"Loading pretrained model from
{
pretrain_path
}
"
)
except
FileNotFoundError
as
e
:
logger
.
info
(
f
'
{
pretrain_path
}
is not found! Please check this path.'
)
raise
e
except
KeyError
as
e
:
logger
.
info
(
"State dict keys error! Please check the state dict."
)
raise
e
else
:
key
=
depth
if
with_ibn
:
key
=
'ibn_'
+
key
if
with_se
:
key
=
'se_'
+
key
state_dict
=
init_pretrained_weights
(
key
)
incompatible
=
model
.
load_state_dict
(
state_dict
,
strict
=
False
)
if
incompatible
.
missing_keys
:
logger
.
info
(
get_missing_parameters_message
(
incompatible
.
missing_keys
)
)
if
incompatible
.
unexpected_keys
:
logger
.
info
(
get_unexpected_parameters_message
(
incompatible
.
unexpected_keys
)
)
return
model
Prev
1
…
10
11
12
13
14
15
16
17
18
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment