Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
OmniMotion_pytorch
Commits
3d92aebb
Commit
3d92aebb
authored
Jul 16, 2024
by
bailuo
Browse files
add preprocessing
parent
fcc0bcf3
Pipeline
#1379
canceled with stages
Changes
68
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
2589 additions
and
0 deletions
+2589
-0
preprocessing/dino/extract_dino_features.py
preprocessing/dino/extract_dino_features.py
+124
-0
preprocessing/dino/hubconf.py
preprocessing/dino/hubconf.py
+151
-0
preprocessing/dino/main_dino.py
preprocessing/dino/main_dino.py
+471
-0
preprocessing/dino/run_with_submitit.py
preprocessing/dino/run_with_submitit.py
+132
-0
preprocessing/dino/utils.py
preprocessing/dino/utils.py
+829
-0
preprocessing/dino/video_generation.py
preprocessing/dino/video_generation.py
+378
-0
preprocessing/dino/vision_transformer.py
preprocessing/dino/vision_transformer.py
+291
-0
preprocessing/dino/visualize_attention.py
preprocessing/dino/visualize_attention.py
+213
-0
No files found.
preprocessing/dino/extract_dino_features.py
0 → 100644
View file @
3d92aebb
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Some parts are taken from https://github.com/Liusifei/UVC
"""
import
os
import
glob
import
argparse
import
numpy
as
np
from
tqdm
import
tqdm
import
cv2
import
torch
import
utils
import
vision_transformer
as
vits
def
extract_feature
(
model
,
frame
,
return_h_w
=
False
):
"""Extract one frame feature everytime."""
out
=
model
.
get_intermediate_layers
(
frame
.
unsqueeze
(
0
).
cuda
(),
n
=
1
)[
0
]
out
=
out
[:,
1
:,
:]
# we discard the [CLS] token
h
,
w
=
int
(
frame
.
shape
[
1
]
/
model
.
patch_embed
.
patch_size
),
int
(
frame
.
shape
[
2
]
/
model
.
patch_embed
.
patch_size
)
dim
=
out
.
shape
[
-
1
]
out
=
out
[
0
].
reshape
(
h
,
w
,
dim
)
out
=
out
.
reshape
(
-
1
,
dim
)
if
return_h_w
:
return
out
,
h
,
w
return
out
def
read_frame
(
frame_dir
,
scale_size
=
[
480
]):
"""
read a single frame & preprocess
"""
img
=
cv2
.
imread
(
frame_dir
)
ori_h
,
ori_w
,
_
=
img
.
shape
if
len
(
scale_size
)
==
1
:
if
(
ori_h
>
ori_w
):
tw
=
scale_size
[
0
]
th
=
(
tw
*
ori_h
)
/
ori_w
th
=
int
((
th
//
64
)
*
64
)
else
:
th
=
scale_size
[
0
]
tw
=
(
th
*
ori_w
)
/
ori_h
tw
=
int
((
tw
//
64
)
*
64
)
else
:
th
,
tw
=
scale_size
img
=
cv2
.
resize
(
img
,
(
tw
,
th
))
img
=
img
.
astype
(
np
.
float32
)
img
=
img
/
255.0
img
=
img
[:,
:,
::
-
1
]
img
=
np
.
transpose
(
img
.
copy
(),
(
2
,
0
,
1
))
img
=
torch
.
from_numpy
(
img
).
float
()
img
=
color_normalize
(
img
)
return
img
,
ori_h
,
ori_w
def
color_normalize
(
x
,
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.228
,
0.224
,
0.225
]):
for
t
,
m
,
s
in
zip
(
x
,
mean
,
std
):
t
.
sub_
(
m
)
t
.
div_
(
s
)
return
x
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
'Evaluation with video object segmentation on DAVIS 2017'
)
parser
.
add_argument
(
'--pretrained_weights'
,
default
=
'.'
,
type
=
str
,
help
=
"Path to pretrained weights to evaluate."
)
parser
.
add_argument
(
'--arch'
,
default
=
'vit_small'
,
type
=
str
,
choices
=
[
'vit_tiny'
,
'vit_small'
,
'vit_base'
],
help
=
'Architecture (support only ViT atm).'
)
parser
.
add_argument
(
'--patch_size'
,
default
=
16
,
type
=
int
,
help
=
'Patch resolution of the model.'
)
parser
.
add_argument
(
"--checkpoint_key"
,
default
=
"teacher"
,
type
=
str
,
help
=
'Key to use in the checkpoint (example: "teacher")'
)
parser
.
add_argument
(
'--output_dir'
,
default
=
"."
,
help
=
'Path where to save segmentations'
)
parser
.
add_argument
(
'--data_path'
,
default
=
'/path/to/davis/'
,
type
=
str
)
parser
.
add_argument
(
"--n_last_frames"
,
type
=
int
,
default
=
7
,
help
=
"number of preceeding frames"
)
parser
.
add_argument
(
"--size_mask_neighborhood"
,
default
=
12
,
type
=
int
,
help
=
"We restrict the set of source nodes considered to a spatial neighborhood of the query node"
)
parser
.
add_argument
(
"--topk"
,
type
=
int
,
default
=
5
,
help
=
"accumulate label from top k neighbors"
)
parser
.
add_argument
(
"--bs"
,
type
=
int
,
default
=
6
,
help
=
"Batch size, try to reduce if OOM"
)
parser
.
add_argument
(
'--data_dir'
,
type
=
str
,
default
=
''
,
help
=
'dataset dir'
)
args
=
parser
.
parse_args
()
print
(
"git:
\n
{}
\n
"
.
format
(
utils
.
get_sha
()))
print
(
"
\n
"
.
join
(
"%s: %s"
%
(
k
,
str
(
v
))
for
k
,
v
in
sorted
(
dict
(
vars
(
args
)).
items
())))
# building network
model
=
vits
.
__dict__
[
args
.
arch
](
patch_size
=
args
.
patch_size
,
num_classes
=
0
)
print
(
f
"Model
{
args
.
arch
}
{
args
.
patch_size
}
x
{
args
.
patch_size
}
built."
)
model
.
cuda
()
utils
.
load_pretrained_weights
(
model
,
args
.
pretrained_weights
,
args
.
checkpoint_key
,
args
.
arch
,
args
.
patch_size
)
for
param
in
model
.
parameters
():
param
.
requires_grad
=
False
model
.
eval
()
scene_dir
=
args
.
data_dir
frame_list
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
scene_dir
,
'color'
,
'*'
)))
save_dir
=
os
.
path
.
join
(
scene_dir
,
'features'
,
'dino'
)
print
(
'computing dino features for {}...'
.
format
(
scene_dir
))
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
for
frame_path
in
tqdm
(
frame_list
):
frame
,
ori_h
,
ori_w
=
read_frame
(
frame_path
)
frame_feat
,
h
,
w
=
extract_feature
(
model
,
frame
,
return_h_w
=
True
)
# dim x h*w
frame_feat
=
frame_feat
.
reshape
(
h
,
w
,
-
1
)
frame_feat
=
frame_feat
.
cpu
().
numpy
()
frame_name
=
os
.
path
.
basename
(
frame_path
)
np
.
save
(
os
.
path
.
join
(
save_dir
,
frame_name
+
'.npy'
),
frame_feat
)
print
(
'computing dino features for {} is done
\n
'
.
format
(
scene_dir
))
preprocessing/dino/hubconf.py
0 → 100644
View file @
3d92aebb
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
torchvision.models.resnet
import
resnet50
import
vision_transformer
as
vits
dependencies
=
[
"torch"
,
"torchvision"
]
def
dino_vits16
(
pretrained
=
True
,
**
kwargs
):
"""
ViT-Small/16x16 pre-trained with DINO.
Achieves 74.5% top-1 accuracy on ImageNet with k-NN classification.
"""
model
=
vits
.
__dict__
[
"vit_small"
](
patch_size
=
16
,
num_classes
=
0
,
**
kwargs
)
if
pretrained
:
state_dict
=
torch
.
hub
.
load_state_dict_from_url
(
url
=
"https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
,
map_location
=
"cpu"
,
)
model
.
load_state_dict
(
state_dict
,
strict
=
True
)
return
model
def
dino_vits8
(
pretrained
=
True
,
**
kwargs
):
"""
ViT-Small/8x8 pre-trained with DINO.
Achieves 78.3% top-1 accuracy on ImageNet with k-NN classification.
"""
model
=
vits
.
__dict__
[
"vit_small"
](
patch_size
=
8
,
num_classes
=
0
,
**
kwargs
)
if
pretrained
:
state_dict
=
torch
.
hub
.
load_state_dict_from_url
(
url
=
"https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth"
,
map_location
=
"cpu"
,
)
model
.
load_state_dict
(
state_dict
,
strict
=
True
)
return
model
def
dino_vitb16
(
pretrained
=
True
,
**
kwargs
):
"""
ViT-Base/16x16 pre-trained with DINO.
Achieves 76.1% top-1 accuracy on ImageNet with k-NN classification.
"""
model
=
vits
.
__dict__
[
"vit_base"
](
patch_size
=
16
,
num_classes
=
0
,
**
kwargs
)
if
pretrained
:
state_dict
=
torch
.
hub
.
load_state_dict_from_url
(
url
=
"https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
,
map_location
=
"cpu"
,
)
model
.
load_state_dict
(
state_dict
,
strict
=
True
)
return
model
def
dino_vitb8
(
pretrained
=
True
,
**
kwargs
):
"""
ViT-Base/8x8 pre-trained with DINO.
Achieves 77.4% top-1 accuracy on ImageNet with k-NN classification.
"""
model
=
vits
.
__dict__
[
"vit_base"
](
patch_size
=
8
,
num_classes
=
0
,
**
kwargs
)
if
pretrained
:
state_dict
=
torch
.
hub
.
load_state_dict_from_url
(
url
=
"https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
,
map_location
=
"cpu"
,
)
model
.
load_state_dict
(
state_dict
,
strict
=
True
)
return
model
def
dino_resnet50
(
pretrained
=
True
,
**
kwargs
):
"""
ResNet-50 pre-trained with DINO.
Achieves 75.3% top-1 accuracy on ImageNet linear evaluation benchmark (requires to train `fc`).
"""
model
=
resnet50
(
pretrained
=
False
,
**
kwargs
)
model
.
fc
=
torch
.
nn
.
Identity
()
if
pretrained
:
state_dict
=
torch
.
hub
.
load_state_dict_from_url
(
url
=
"https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth"
,
map_location
=
"cpu"
,
)
model
.
load_state_dict
(
state_dict
,
strict
=
False
)
return
model
def
dino_xcit_small_12_p16
(
pretrained
=
True
,
**
kwargs
):
"""
XCiT-Small-12/16 pre-trained with DINO.
"""
model
=
torch
.
hub
.
load
(
'facebookresearch/xcit:main'
,
"xcit_small_12_p16"
,
num_classes
=
0
,
**
kwargs
)
if
pretrained
:
state_dict
=
torch
.
hub
.
load_state_dict_from_url
(
url
=
"https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain.pth"
,
map_location
=
"cpu"
,
)
model
.
load_state_dict
(
state_dict
,
strict
=
True
)
return
model
def
dino_xcit_small_12_p8
(
pretrained
=
True
,
**
kwargs
):
"""
XCiT-Small-12/8 pre-trained with DINO.
"""
model
=
torch
.
hub
.
load
(
'facebookresearch/xcit:main'
,
"xcit_small_12_p8"
,
num_classes
=
0
,
**
kwargs
)
if
pretrained
:
state_dict
=
torch
.
hub
.
load_state_dict_from_url
(
url
=
"https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain.pth"
,
map_location
=
"cpu"
,
)
model
.
load_state_dict
(
state_dict
,
strict
=
True
)
return
model
def
dino_xcit_medium_24_p16
(
pretrained
=
True
,
**
kwargs
):
"""
XCiT-Medium-24/16 pre-trained with DINO.
"""
model
=
torch
.
hub
.
load
(
'facebookresearch/xcit:main'
,
"xcit_medium_24_p16"
,
num_classes
=
0
,
**
kwargs
)
if
pretrained
:
state_dict
=
torch
.
hub
.
load_state_dict_from_url
(
url
=
"https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain.pth"
,
map_location
=
"cpu"
,
)
model
.
load_state_dict
(
state_dict
,
strict
=
True
)
return
model
def
dino_xcit_medium_24_p8
(
pretrained
=
True
,
**
kwargs
):
"""
XCiT-Medium-24/8 pre-trained with DINO.
"""
model
=
torch
.
hub
.
load
(
'facebookresearch/xcit:main'
,
"xcit_medium_24_p8"
,
num_classes
=
0
,
**
kwargs
)
if
pretrained
:
state_dict
=
torch
.
hub
.
load_state_dict_from_url
(
url
=
"https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain.pth"
,
map_location
=
"cpu"
,
)
model
.
load_state_dict
(
state_dict
,
strict
=
True
)
return
model
preprocessing/dino/main_dino.py
0 → 100644
View file @
3d92aebb
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
os
import
sys
import
datetime
import
time
import
math
import
json
from
pathlib
import
Path
import
numpy
as
np
from
PIL
import
Image
import
torch
import
torch.nn
as
nn
import
torch.distributed
as
dist
import
torch.backends.cudnn
as
cudnn
import
torch.nn.functional
as
F
from
torchvision
import
datasets
,
transforms
from
torchvision
import
models
as
torchvision_models
import
utils
import
vision_transformer
as
vits
from
vision_transformer
import
DINOHead
torchvision_archs
=
sorted
(
name
for
name
in
torchvision_models
.
__dict__
if
name
.
islower
()
and
not
name
.
startswith
(
"__"
)
and
callable
(
torchvision_models
.
__dict__
[
name
]))
def
get_args_parser
():
parser
=
argparse
.
ArgumentParser
(
'DINO'
,
add_help
=
False
)
# Model parameters
parser
.
add_argument
(
'--arch'
,
default
=
'vit_small'
,
type
=
str
,
choices
=
[
'vit_tiny'
,
'vit_small'
,
'vit_base'
,
'xcit'
,
'deit_tiny'
,
'deit_small'
]
\
+
torchvision_archs
+
torch
.
hub
.
list
(
"facebookresearch/xcit:main"
),
help
=
"""Name of architecture to train. For quick experiments with ViTs,
we recommend using vit_tiny or vit_small."""
)
parser
.
add_argument
(
'--patch_size'
,
default
=
16
,
type
=
int
,
help
=
"""Size in pixels
of input square patches - default 16 (for 16x16 patches). Using smaller
values leads to better performance but requires more memory. Applies only
for ViTs (vit_tiny, vit_small and vit_base). If <16, we recommend disabling
mixed precision training (--use_fp16 false) to avoid unstabilities."""
)
parser
.
add_argument
(
'--out_dim'
,
default
=
65536
,
type
=
int
,
help
=
"""Dimensionality of
the DINO head output. For complex and large datasets large values (like 65k) work well."""
)
parser
.
add_argument
(
'--norm_last_layer'
,
default
=
True
,
type
=
utils
.
bool_flag
,
help
=
"""Whether or not to weight normalize the last layer of the DINO head.
Not normalizing leads to better performance but can make the training unstable.
In our experiments, we typically set this paramater to False with vit_small and True with vit_base."""
)
parser
.
add_argument
(
'--momentum_teacher'
,
default
=
0.996
,
type
=
float
,
help
=
"""Base EMA
parameter for teacher update. The value is increased to 1 during training with cosine schedule.
We recommend setting a higher value with small batches: for example use 0.9995 with batch size of 256."""
)
parser
.
add_argument
(
'--use_bn_in_head'
,
default
=
False
,
type
=
utils
.
bool_flag
,
help
=
"Whether to use batch normalizations in projection head (Default: False)"
)
# Temperature teacher parameters
parser
.
add_argument
(
'--warmup_teacher_temp'
,
default
=
0.04
,
type
=
float
,
help
=
"""Initial value for the teacher temperature: 0.04 works well in most cases.
Try decreasing it if the training loss does not decrease."""
)
parser
.
add_argument
(
'--teacher_temp'
,
default
=
0.04
,
type
=
float
,
help
=
"""Final value (after linear warmup)
of the teacher temperature. For most experiments, anything above 0.07 is unstable. We recommend
starting with the default value of 0.04 and increase this slightly if needed."""
)
parser
.
add_argument
(
'--warmup_teacher_temp_epochs'
,
default
=
0
,
type
=
int
,
help
=
'Number of warmup epochs for the teacher temperature (Default: 30).'
)
# Training/Optimization parameters
parser
.
add_argument
(
'--use_fp16'
,
type
=
utils
.
bool_flag
,
default
=
True
,
help
=
"""Whether or not
to use half precision for training. Improves training time and memory requirements,
but can provoke instability and slight decay of performance. We recommend disabling
mixed precision if the loss is unstable, if reducing the patch size or if training with bigger ViTs."""
)
parser
.
add_argument
(
'--weight_decay'
,
type
=
float
,
default
=
0.04
,
help
=
"""Initial value of the
weight decay. With ViT, a smaller value at the beginning of training works well."""
)
parser
.
add_argument
(
'--weight_decay_end'
,
type
=
float
,
default
=
0.4
,
help
=
"""Final value of the
weight decay. We use a cosine schedule for WD and using a larger decay by
the end of training improves performance for ViTs."""
)
parser
.
add_argument
(
'--clip_grad'
,
type
=
float
,
default
=
3.0
,
help
=
"""Maximal parameter
gradient norm if using gradient clipping. Clipping with norm .3 ~ 1.0 can
help optimization for larger ViT architectures. 0 for disabling."""
)
parser
.
add_argument
(
'--batch_size_per_gpu'
,
default
=
64
,
type
=
int
,
help
=
'Per-GPU batch-size : number of distinct images loaded on one GPU.'
)
parser
.
add_argument
(
'--epochs'
,
default
=
100
,
type
=
int
,
help
=
'Number of epochs of training.'
)
parser
.
add_argument
(
'--freeze_last_layer'
,
default
=
1
,
type
=
int
,
help
=
"""Number of epochs
during which we keep the output layer fixed. Typically doing so during
the first epoch helps training. Try increasing this value if the loss does not decrease."""
)
parser
.
add_argument
(
"--lr"
,
default
=
0.0005
,
type
=
float
,
help
=
"""Learning rate at the end of
linear warmup (highest LR used during training). The learning rate is linearly scaled
with the batch size, and specified here for a reference batch size of 256."""
)
parser
.
add_argument
(
"--warmup_epochs"
,
default
=
10
,
type
=
int
,
help
=
"Number of epochs for the linear learning-rate warm up."
)
parser
.
add_argument
(
'--min_lr'
,
type
=
float
,
default
=
1e-6
,
help
=
"""Target LR at the
end of optimization. We use a cosine LR schedule with linear warmup."""
)
parser
.
add_argument
(
'--optimizer'
,
default
=
'adamw'
,
type
=
str
,
choices
=
[
'adamw'
,
'sgd'
,
'lars'
],
help
=
"""Type of optimizer. We recommend using adamw with ViTs."""
)
parser
.
add_argument
(
'--drop_path_rate'
,
type
=
float
,
default
=
0.1
,
help
=
"stochastic depth rate"
)
# Multi-crop parameters
parser
.
add_argument
(
'--global_crops_scale'
,
type
=
float
,
nargs
=
'+'
,
default
=
(
0.4
,
1.
),
help
=
"""Scale range of the cropped image before resizing, relatively to the origin image.
Used for large global view cropping. When disabling multi-crop (--local_crops_number 0), we
recommand using a wider range of scale ("--global_crops_scale 0.14 1." for example)"""
)
parser
.
add_argument
(
'--local_crops_number'
,
type
=
int
,
default
=
8
,
help
=
"""Number of small
local views to generate. Set this parameter to 0 to disable multi-crop training.
When disabling multi-crop we recommend to use "--global_crops_scale 0.14 1." """
)
parser
.
add_argument
(
'--local_crops_scale'
,
type
=
float
,
nargs
=
'+'
,
default
=
(
0.05
,
0.4
),
help
=
"""Scale range of the cropped image before resizing, relatively to the origin image.
Used for small local view cropping of multi-crop."""
)
# Misc
parser
.
add_argument
(
'--data_path'
,
default
=
'/path/to/imagenet/train/'
,
type
=
str
,
help
=
'Please specify path to the ImageNet training data.'
)
parser
.
add_argument
(
'--output_dir'
,
default
=
"."
,
type
=
str
,
help
=
'Path to save logs and checkpoints.'
)
parser
.
add_argument
(
'--saveckp_freq'
,
default
=
20
,
type
=
int
,
help
=
'Save checkpoint every x epochs.'
)
parser
.
add_argument
(
'--seed'
,
default
=
0
,
type
=
int
,
help
=
'Random seed.'
)
parser
.
add_argument
(
'--num_workers'
,
default
=
10
,
type
=
int
,
help
=
'Number of data loading workers per GPU.'
)
parser
.
add_argument
(
"--dist_url"
,
default
=
"env://"
,
type
=
str
,
help
=
"""url used to set up
distributed training; see https://pytorch.org/docs/stable/distributed.html"""
)
parser
.
add_argument
(
"--local_rank"
,
default
=
0
,
type
=
int
,
help
=
"Please ignore and do not set this argument."
)
return
parser
def
train_dino
(
args
):
utils
.
init_distributed_mode
(
args
)
utils
.
fix_random_seeds
(
args
.
seed
)
print
(
"git:
\n
{}
\n
"
.
format
(
utils
.
get_sha
()))
print
(
"
\n
"
.
join
(
"%s: %s"
%
(
k
,
str
(
v
))
for
k
,
v
in
sorted
(
dict
(
vars
(
args
)).
items
())))
cudnn
.
benchmark
=
True
# ============ preparing data ... ============
transform
=
DataAugmentationDINO
(
args
.
global_crops_scale
,
args
.
local_crops_scale
,
args
.
local_crops_number
,
)
dataset
=
datasets
.
ImageFolder
(
args
.
data_path
,
transform
=
transform
)
sampler
=
torch
.
utils
.
data
.
DistributedSampler
(
dataset
,
shuffle
=
True
)
data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
sampler
=
sampler
,
batch_size
=
args
.
batch_size_per_gpu
,
num_workers
=
args
.
num_workers
,
pin_memory
=
True
,
drop_last
=
True
,
)
print
(
f
"Data loaded: there are
{
len
(
dataset
)
}
images."
)
# ============ building student and teacher networks ... ============
# we changed the name DeiT-S for ViT-S to avoid confusions
args
.
arch
=
args
.
arch
.
replace
(
"deit"
,
"vit"
)
# if the network is a Vision Transformer (i.e. vit_tiny, vit_small, vit_base)
if
args
.
arch
in
vits
.
__dict__
.
keys
():
student
=
vits
.
__dict__
[
args
.
arch
](
patch_size
=
args
.
patch_size
,
drop_path_rate
=
args
.
drop_path_rate
,
# stochastic depth
)
teacher
=
vits
.
__dict__
[
args
.
arch
](
patch_size
=
args
.
patch_size
)
embed_dim
=
student
.
embed_dim
# if the network is a XCiT
elif
args
.
arch
in
torch
.
hub
.
list
(
"facebookresearch/xcit:main"
):
student
=
torch
.
hub
.
load
(
'facebookresearch/xcit:main'
,
args
.
arch
,
pretrained
=
False
,
drop_path_rate
=
args
.
drop_path_rate
)
teacher
=
torch
.
hub
.
load
(
'facebookresearch/xcit:main'
,
args
.
arch
,
pretrained
=
False
)
embed_dim
=
student
.
embed_dim
# otherwise, we check if the architecture is in torchvision models
elif
args
.
arch
in
torchvision_models
.
__dict__
.
keys
():
student
=
torchvision_models
.
__dict__
[
args
.
arch
]()
teacher
=
torchvision_models
.
__dict__
[
args
.
arch
]()
embed_dim
=
student
.
fc
.
weight
.
shape
[
1
]
else
:
print
(
f
"Unknow architecture:
{
args
.
arch
}
"
)
# multi-crop wrapper handles forward with inputs of different resolutions
student
=
utils
.
MultiCropWrapper
(
student
,
DINOHead
(
embed_dim
,
args
.
out_dim
,
use_bn
=
args
.
use_bn_in_head
,
norm_last_layer
=
args
.
norm_last_layer
,
))
teacher
=
utils
.
MultiCropWrapper
(
teacher
,
DINOHead
(
embed_dim
,
args
.
out_dim
,
args
.
use_bn_in_head
),
)
# move networks to gpu
student
,
teacher
=
student
.
cuda
(),
teacher
.
cuda
()
# synchronize batch norms (if any)
if
utils
.
has_batchnorms
(
student
):
student
=
nn
.
SyncBatchNorm
.
convert_sync_batchnorm
(
student
)
teacher
=
nn
.
SyncBatchNorm
.
convert_sync_batchnorm
(
teacher
)
# we need DDP wrapper to have synchro batch norms working...
teacher
=
nn
.
parallel
.
DistributedDataParallel
(
teacher
,
device_ids
=
[
args
.
gpu
])
teacher_without_ddp
=
teacher
.
module
else
:
# teacher_without_ddp and teacher are the same thing
teacher_without_ddp
=
teacher
student
=
nn
.
parallel
.
DistributedDataParallel
(
student
,
device_ids
=
[
args
.
gpu
])
# teacher and student start with the same weights
teacher_without_ddp
.
load_state_dict
(
student
.
module
.
state_dict
())
# there is no backpropagation through the teacher, so no need for gradients
for
p
in
teacher
.
parameters
():
p
.
requires_grad
=
False
print
(
f
"Student and Teacher are built: they are both
{
args
.
arch
}
network."
)
# ============ preparing loss ... ============
dino_loss
=
DINOLoss
(
args
.
out_dim
,
args
.
local_crops_number
+
2
,
# total number of crops = 2 global crops + local_crops_number
args
.
warmup_teacher_temp
,
args
.
teacher_temp
,
args
.
warmup_teacher_temp_epochs
,
args
.
epochs
,
).
cuda
()
# ============ preparing optimizer ... ============
params_groups
=
utils
.
get_params_groups
(
student
)
if
args
.
optimizer
==
"adamw"
:
optimizer
=
torch
.
optim
.
AdamW
(
params_groups
)
# to use with ViTs
elif
args
.
optimizer
==
"sgd"
:
optimizer
=
torch
.
optim
.
SGD
(
params_groups
,
lr
=
0
,
momentum
=
0.9
)
# lr is set by scheduler
elif
args
.
optimizer
==
"lars"
:
optimizer
=
utils
.
LARS
(
params_groups
)
# to use with convnet and large batches
# for mixed precision training
fp16_scaler
=
None
if
args
.
use_fp16
:
fp16_scaler
=
torch
.
cuda
.
amp
.
GradScaler
()
# ============ init schedulers ... ============
lr_schedule
=
utils
.
cosine_scheduler
(
args
.
lr
*
(
args
.
batch_size_per_gpu
*
utils
.
get_world_size
())
/
256.
,
# linear scaling rule
args
.
min_lr
,
args
.
epochs
,
len
(
data_loader
),
warmup_epochs
=
args
.
warmup_epochs
,
)
wd_schedule
=
utils
.
cosine_scheduler
(
args
.
weight_decay
,
args
.
weight_decay_end
,
args
.
epochs
,
len
(
data_loader
),
)
# momentum parameter is increased to 1. during training with a cosine schedule
momentum_schedule
=
utils
.
cosine_scheduler
(
args
.
momentum_teacher
,
1
,
args
.
epochs
,
len
(
data_loader
))
print
(
f
"Loss, optimizer and schedulers ready."
)
# ============ optionally resume training ... ============
to_restore
=
{
"epoch"
:
0
}
utils
.
restart_from_checkpoint
(
os
.
path
.
join
(
args
.
output_dir
,
"checkpoint.pth"
),
run_variables
=
to_restore
,
student
=
student
,
teacher
=
teacher
,
optimizer
=
optimizer
,
fp16_scaler
=
fp16_scaler
,
dino_loss
=
dino_loss
,
)
start_epoch
=
to_restore
[
"epoch"
]
start_time
=
time
.
time
()
print
(
"Starting DINO training !"
)
for
epoch
in
range
(
start_epoch
,
args
.
epochs
):
data_loader
.
sampler
.
set_epoch
(
epoch
)
# ============ training one epoch of DINO ... ============
train_stats
=
train_one_epoch
(
student
,
teacher
,
teacher_without_ddp
,
dino_loss
,
data_loader
,
optimizer
,
lr_schedule
,
wd_schedule
,
momentum_schedule
,
epoch
,
fp16_scaler
,
args
)
# ============ writing logs ... ============
save_dict
=
{
'student'
:
student
.
state_dict
(),
'teacher'
:
teacher
.
state_dict
(),
'optimizer'
:
optimizer
.
state_dict
(),
'epoch'
:
epoch
+
1
,
'args'
:
args
,
'dino_loss'
:
dino_loss
.
state_dict
(),
}
if
fp16_scaler
is
not
None
:
save_dict
[
'fp16_scaler'
]
=
fp16_scaler
.
state_dict
()
utils
.
save_on_master
(
save_dict
,
os
.
path
.
join
(
args
.
output_dir
,
'checkpoint.pth'
))
if
args
.
saveckp_freq
and
epoch
%
args
.
saveckp_freq
==
0
:
utils
.
save_on_master
(
save_dict
,
os
.
path
.
join
(
args
.
output_dir
,
f
'checkpoint
{
epoch
:
04
}
.pth'
))
log_stats
=
{
**
{
f
'train_
{
k
}
'
:
v
for
k
,
v
in
train_stats
.
items
()},
'epoch'
:
epoch
}
if
utils
.
is_main_process
():
with
(
Path
(
args
.
output_dir
)
/
"log.txt"
).
open
(
"a"
)
as
f
:
f
.
write
(
json
.
dumps
(
log_stats
)
+
"
\n
"
)
total_time
=
time
.
time
()
-
start_time
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
total_time
)))
print
(
'Training time {}'
.
format
(
total_time_str
))
def
train_one_epoch
(
student
,
teacher
,
teacher_without_ddp
,
dino_loss
,
data_loader
,
optimizer
,
lr_schedule
,
wd_schedule
,
momentum_schedule
,
epoch
,
fp16_scaler
,
args
):
metric_logger
=
utils
.
MetricLogger
(
delimiter
=
" "
)
header
=
'Epoch: [{}/{}]'
.
format
(
epoch
,
args
.
epochs
)
for
it
,
(
images
,
_
)
in
enumerate
(
metric_logger
.
log_every
(
data_loader
,
10
,
header
)):
# update weight decay and learning rate according to their schedule
it
=
len
(
data_loader
)
*
epoch
+
it
# global training iteration
for
i
,
param_group
in
enumerate
(
optimizer
.
param_groups
):
param_group
[
"lr"
]
=
lr_schedule
[
it
]
if
i
==
0
:
# only the first group is regularized
param_group
[
"weight_decay"
]
=
wd_schedule
[
it
]
# move images to gpu
images
=
[
im
.
cuda
(
non_blocking
=
True
)
for
im
in
images
]
# teacher and student forward passes + compute dino loss
with
torch
.
cuda
.
amp
.
autocast
(
fp16_scaler
is
not
None
):
teacher_output
=
teacher
(
images
[:
2
])
# only the 2 global views pass through the teacher
student_output
=
student
(
images
)
loss
=
dino_loss
(
student_output
,
teacher_output
,
epoch
)
if
not
math
.
isfinite
(
loss
.
item
()):
print
(
"Loss is {}, stopping training"
.
format
(
loss
.
item
()),
force
=
True
)
sys
.
exit
(
1
)
# student update
optimizer
.
zero_grad
()
param_norms
=
None
if
fp16_scaler
is
None
:
loss
.
backward
()
if
args
.
clip_grad
:
param_norms
=
utils
.
clip_gradients
(
student
,
args
.
clip_grad
)
utils
.
cancel_gradients_last_layer
(
epoch
,
student
,
args
.
freeze_last_layer
)
optimizer
.
step
()
else
:
fp16_scaler
.
scale
(
loss
).
backward
()
if
args
.
clip_grad
:
fp16_scaler
.
unscale_
(
optimizer
)
# unscale the gradients of optimizer's assigned params in-place
param_norms
=
utils
.
clip_gradients
(
student
,
args
.
clip_grad
)
utils
.
cancel_gradients_last_layer
(
epoch
,
student
,
args
.
freeze_last_layer
)
fp16_scaler
.
step
(
optimizer
)
fp16_scaler
.
update
()
# EMA update for the teacher
with
torch
.
no_grad
():
m
=
momentum_schedule
[
it
]
# momentum parameter
for
param_q
,
param_k
in
zip
(
student
.
module
.
parameters
(),
teacher_without_ddp
.
parameters
()):
param_k
.
data
.
mul_
(
m
).
add_
((
1
-
m
)
*
param_q
.
detach
().
data
)
# logging
torch
.
cuda
.
synchronize
()
metric_logger
.
update
(
loss
=
loss
.
item
())
metric_logger
.
update
(
lr
=
optimizer
.
param_groups
[
0
][
"lr"
])
metric_logger
.
update
(
wd
=
optimizer
.
param_groups
[
0
][
"weight_decay"
])
# gather the stats from all processes
metric_logger
.
synchronize_between_processes
()
print
(
"Averaged stats:"
,
metric_logger
)
return
{
k
:
meter
.
global_avg
for
k
,
meter
in
metric_logger
.
meters
.
items
()}
class
DINOLoss
(
nn
.
Module
):
def
__init__
(
self
,
out_dim
,
ncrops
,
warmup_teacher_temp
,
teacher_temp
,
warmup_teacher_temp_epochs
,
nepochs
,
student_temp
=
0.1
,
center_momentum
=
0.9
):
super
().
__init__
()
self
.
student_temp
=
student_temp
self
.
center_momentum
=
center_momentum
self
.
ncrops
=
ncrops
self
.
register_buffer
(
"center"
,
torch
.
zeros
(
1
,
out_dim
))
# we apply a warm up for the teacher temperature because
# a too high temperature makes the training instable at the beginning
self
.
teacher_temp_schedule
=
np
.
concatenate
((
np
.
linspace
(
warmup_teacher_temp
,
teacher_temp
,
warmup_teacher_temp_epochs
),
np
.
ones
(
nepochs
-
warmup_teacher_temp_epochs
)
*
teacher_temp
))
def
forward
(
self
,
student_output
,
teacher_output
,
epoch
):
"""
Cross-entropy between softmax outputs of the teacher and student networks.
"""
student_out
=
student_output
/
self
.
student_temp
student_out
=
student_out
.
chunk
(
self
.
ncrops
)
# teacher centering and sharpening
temp
=
self
.
teacher_temp_schedule
[
epoch
]
teacher_out
=
F
.
softmax
((
teacher_output
-
self
.
center
)
/
temp
,
dim
=-
1
)
teacher_out
=
teacher_out
.
detach
().
chunk
(
2
)
total_loss
=
0
n_loss_terms
=
0
for
iq
,
q
in
enumerate
(
teacher_out
):
for
v
in
range
(
len
(
student_out
)):
if
v
==
iq
:
# we skip cases where student and teacher operate on the same view
continue
loss
=
torch
.
sum
(
-
q
*
F
.
log_softmax
(
student_out
[
v
],
dim
=-
1
),
dim
=-
1
)
total_loss
+=
loss
.
mean
()
n_loss_terms
+=
1
total_loss
/=
n_loss_terms
self
.
update_center
(
teacher_output
)
return
total_loss
@
torch
.
no_grad
()
def
update_center
(
self
,
teacher_output
):
"""
Update center used for teacher output.
"""
batch_center
=
torch
.
sum
(
teacher_output
,
dim
=
0
,
keepdim
=
True
)
dist
.
all_reduce
(
batch_center
)
batch_center
=
batch_center
/
(
len
(
teacher_output
)
*
dist
.
get_world_size
())
# ema update
self
.
center
=
self
.
center
*
self
.
center_momentum
+
batch_center
*
(
1
-
self
.
center_momentum
)
class
DataAugmentationDINO
(
object
):
def
__init__
(
self
,
global_crops_scale
,
local_crops_scale
,
local_crops_number
):
flip_and_color_jitter
=
transforms
.
Compose
([
transforms
.
RandomHorizontalFlip
(
p
=
0.5
),
transforms
.
RandomApply
(
[
transforms
.
ColorJitter
(
brightness
=
0.4
,
contrast
=
0.4
,
saturation
=
0.2
,
hue
=
0.1
)],
p
=
0.8
),
transforms
.
RandomGrayscale
(
p
=
0.2
),
])
normalize
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.485
,
0.456
,
0.406
),
(
0.229
,
0.224
,
0.225
)),
])
# first global crop
self
.
global_transfo1
=
transforms
.
Compose
([
transforms
.
RandomResizedCrop
(
224
,
scale
=
global_crops_scale
,
interpolation
=
Image
.
BICUBIC
),
flip_and_color_jitter
,
utils
.
GaussianBlur
(
1.0
),
normalize
,
])
# second global crop
self
.
global_transfo2
=
transforms
.
Compose
([
transforms
.
RandomResizedCrop
(
224
,
scale
=
global_crops_scale
,
interpolation
=
Image
.
BICUBIC
),
flip_and_color_jitter
,
utils
.
GaussianBlur
(
0.1
),
utils
.
Solarization
(
0.2
),
normalize
,
])
# transformation for the local small crops
self
.
local_crops_number
=
local_crops_number
self
.
local_transfo
=
transforms
.
Compose
([
transforms
.
RandomResizedCrop
(
96
,
scale
=
local_crops_scale
,
interpolation
=
Image
.
BICUBIC
),
flip_and_color_jitter
,
utils
.
GaussianBlur
(
p
=
0.5
),
normalize
,
])
def
__call__
(
self
,
image
):
crops
=
[]
crops
.
append
(
self
.
global_transfo1
(
image
))
crops
.
append
(
self
.
global_transfo2
(
image
))
for
_
in
range
(
self
.
local_crops_number
):
crops
.
append
(
self
.
local_transfo
(
image
))
return
crops
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
'DINO'
,
parents
=
[
get_args_parser
()])
args
=
parser
.
parse_args
()
Path
(
args
.
output_dir
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
train_dino
(
args
)
preprocessing/dino/run_with_submitit.py
0 → 100644
View file @
3d92aebb
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A script to run multinode training with submitit.
Almost copy-paste from https://github.com/facebookresearch/deit/blob/main/run_with_submitit.py
"""
import
argparse
import
os
import
uuid
from
pathlib
import
Path
import
main_dino
import
submitit
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
"Submitit for DINO"
,
parents
=
[
main_dino
.
get_args_parser
()])
parser
.
add_argument
(
"--ngpus"
,
default
=
8
,
type
=
int
,
help
=
"Number of gpus to request on each node"
)
parser
.
add_argument
(
"--nodes"
,
default
=
2
,
type
=
int
,
help
=
"Number of nodes to request"
)
parser
.
add_argument
(
"--timeout"
,
default
=
2800
,
type
=
int
,
help
=
"Duration of the job"
)
parser
.
add_argument
(
"--partition"
,
default
=
"learnfair"
,
type
=
str
,
help
=
"Partition where to submit"
)
parser
.
add_argument
(
"--use_volta32"
,
action
=
'store_true'
,
help
=
"Big models? Use this"
)
parser
.
add_argument
(
'--comment'
,
default
=
""
,
type
=
str
,
help
=
'Comment to pass to scheduler, e.g. priority message'
)
return
parser
.
parse_args
()
def
get_shared_folder
()
->
Path
:
user
=
os
.
getenv
(
"USER"
)
if
Path
(
"/checkpoint/"
).
is_dir
():
p
=
Path
(
f
"/checkpoint/
{
user
}
/experiments"
)
p
.
mkdir
(
exist_ok
=
True
)
return
p
raise
RuntimeError
(
"No shared folder available"
)
def
get_init_file
():
# Init file must not exist, but it's parent dir must exist.
os
.
makedirs
(
str
(
get_shared_folder
()),
exist_ok
=
True
)
init_file
=
get_shared_folder
()
/
f
"
{
uuid
.
uuid4
().
hex
}
_init"
if
init_file
.
exists
():
os
.
remove
(
str
(
init_file
))
return
init_file
class
Trainer
(
object
):
def
__init__
(
self
,
args
):
self
.
args
=
args
def
__call__
(
self
):
import
main_dino
self
.
_setup_gpu_args
()
main_dino
.
train_dino
(
self
.
args
)
def
checkpoint
(
self
):
import
os
import
submitit
self
.
args
.
dist_url
=
get_init_file
().
as_uri
()
print
(
"Requeuing "
,
self
.
args
)
empty_trainer
=
type
(
self
)(
self
.
args
)
return
submitit
.
helpers
.
DelayedSubmission
(
empty_trainer
)
def
_setup_gpu_args
(
self
):
import
submitit
from
pathlib
import
Path
job_env
=
submitit
.
JobEnvironment
()
self
.
args
.
output_dir
=
Path
(
str
(
self
.
args
.
output_dir
).
replace
(
"%j"
,
str
(
job_env
.
job_id
)))
self
.
args
.
gpu
=
job_env
.
local_rank
self
.
args
.
rank
=
job_env
.
global_rank
self
.
args
.
world_size
=
job_env
.
num_tasks
print
(
f
"Process group:
{
job_env
.
num_tasks
}
tasks, rank:
{
job_env
.
global_rank
}
"
)
def
main
():
args
=
parse_args
()
if
args
.
output_dir
==
""
:
args
.
output_dir
=
get_shared_folder
()
/
"%j"
Path
(
args
.
output_dir
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
executor
=
submitit
.
AutoExecutor
(
folder
=
args
.
output_dir
,
slurm_max_num_timeout
=
30
)
num_gpus_per_node
=
args
.
ngpus
nodes
=
args
.
nodes
timeout_min
=
args
.
timeout
partition
=
args
.
partition
kwargs
=
{}
if
args
.
use_volta32
:
kwargs
[
'slurm_constraint'
]
=
'volta32gb'
if
args
.
comment
:
kwargs
[
'slurm_comment'
]
=
args
.
comment
executor
.
update_parameters
(
mem_gb
=
40
*
num_gpus_per_node
,
gpus_per_node
=
num_gpus_per_node
,
tasks_per_node
=
num_gpus_per_node
,
# one task per GPU
cpus_per_task
=
10
,
nodes
=
nodes
,
timeout_min
=
timeout_min
,
# max is 60 * 72
# Below are cluster dependent parameters
slurm_partition
=
partition
,
slurm_signal_delay_s
=
120
,
**
kwargs
)
executor
.
update_parameters
(
name
=
"dino"
)
args
.
dist_url
=
get_init_file
().
as_uri
()
trainer
=
Trainer
(
args
)
job
=
executor
.
submit
(
trainer
)
print
(
f
"Submitted job_id:
{
job
.
job_id
}
"
)
print
(
f
"Logs and checkpoints will be saved at:
{
args
.
output_dir
}
"
)
if
__name__
==
"__main__"
:
main
()
preprocessing/dino/utils.py
0 → 100644
View file @
3d92aebb
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Misc functions.
Mostly copy-paste from torchvision references or other public repos like DETR:
https://github.com/facebookresearch/detr/blob/master/util/misc.py
"""
import
os
import
sys
import
time
import
math
import
random
import
datetime
import
subprocess
from
collections
import
defaultdict
,
deque
import
numpy
as
np
import
torch
from
torch
import
nn
import
torch.distributed
as
dist
from
PIL
import
ImageFilter
,
ImageOps
class
GaussianBlur
(
object
):
"""
Apply Gaussian Blur to the PIL image.
"""
def
__init__
(
self
,
p
=
0.5
,
radius_min
=
0.1
,
radius_max
=
2.
):
self
.
prob
=
p
self
.
radius_min
=
radius_min
self
.
radius_max
=
radius_max
def
__call__
(
self
,
img
):
do_it
=
random
.
random
()
<=
self
.
prob
if
not
do_it
:
return
img
return
img
.
filter
(
ImageFilter
.
GaussianBlur
(
radius
=
random
.
uniform
(
self
.
radius_min
,
self
.
radius_max
)
)
)
class
Solarization
(
object
):
"""
Apply Solarization to the PIL image.
"""
def
__init__
(
self
,
p
):
self
.
p
=
p
def
__call__
(
self
,
img
):
if
random
.
random
()
<
self
.
p
:
return
ImageOps
.
solarize
(
img
)
else
:
return
img
def
load_pretrained_weights
(
model
,
pretrained_weights
,
checkpoint_key
,
model_name
,
patch_size
):
if
os
.
path
.
isfile
(
pretrained_weights
):
state_dict
=
torch
.
load
(
pretrained_weights
,
map_location
=
"cpu"
)
if
checkpoint_key
is
not
None
and
checkpoint_key
in
state_dict
:
print
(
f
"Take key
{
checkpoint_key
}
in provided checkpoint dict"
)
state_dict
=
state_dict
[
checkpoint_key
]
# remove `module.` prefix
state_dict
=
{
k
.
replace
(
"module."
,
""
):
v
for
k
,
v
in
state_dict
.
items
()}
# remove `backbone.` prefix induced by multicrop wrapper
state_dict
=
{
k
.
replace
(
"backbone."
,
""
):
v
for
k
,
v
in
state_dict
.
items
()}
msg
=
model
.
load_state_dict
(
state_dict
,
strict
=
False
)
print
(
'Pretrained weights found at {} and loaded with msg: {}'
.
format
(
pretrained_weights
,
msg
))
else
:
print
(
"Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate."
)
url
=
None
if
model_name
==
"vit_small"
and
patch_size
==
16
:
url
=
"dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
elif
model_name
==
"vit_small"
and
patch_size
==
8
:
url
=
"dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth"
elif
model_name
==
"vit_base"
and
patch_size
==
16
:
url
=
"dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
elif
model_name
==
"vit_base"
and
patch_size
==
8
:
url
=
"dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
elif
model_name
==
"xcit_small_12_p16"
:
url
=
"dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain.pth"
elif
model_name
==
"xcit_small_12_p8"
:
url
=
"dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain.pth"
elif
model_name
==
"xcit_medium_24_p16"
:
url
=
"dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain.pth"
elif
model_name
==
"xcit_medium_24_p8"
:
url
=
"dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain.pth"
elif
model_name
==
"resnet50"
:
url
=
"dino_resnet50_pretrain/dino_resnet50_pretrain.pth"
if
url
is
not
None
:
print
(
"Since no pretrained weights have been provided, we load the reference pretrained DINO weights."
)
state_dict
=
torch
.
hub
.
load_state_dict_from_url
(
url
=
"https://dl.fbaipublicfiles.com/dino/"
+
url
)
model
.
load_state_dict
(
state_dict
,
strict
=
True
)
else
:
print
(
"There is no reference weights available for this model => We use random weights."
)
def
load_pretrained_linear_weights
(
linear_classifier
,
model_name
,
patch_size
):
url
=
None
if
model_name
==
"vit_small"
and
patch_size
==
16
:
url
=
"dino_deitsmall16_pretrain/dino_deitsmall16_linearweights.pth"
elif
model_name
==
"vit_small"
and
patch_size
==
8
:
url
=
"dino_deitsmall8_pretrain/dino_deitsmall8_linearweights.pth"
elif
model_name
==
"vit_base"
and
patch_size
==
16
:
url
=
"dino_vitbase16_pretrain/dino_vitbase16_linearweights.pth"
elif
model_name
==
"vit_base"
and
patch_size
==
8
:
url
=
"dino_vitbase8_pretrain/dino_vitbase8_linearweights.pth"
elif
model_name
==
"resnet50"
:
url
=
"dino_resnet50_pretrain/dino_resnet50_linearweights.pth"
if
url
is
not
None
:
print
(
"We load the reference pretrained linear weights."
)
state_dict
=
torch
.
hub
.
load_state_dict_from_url
(
url
=
"https://dl.fbaipublicfiles.com/dino/"
+
url
)[
"state_dict"
]
linear_classifier
.
load_state_dict
(
state_dict
,
strict
=
True
)
else
:
print
(
"We use random linear weights."
)
def
clip_gradients
(
model
,
clip
):
norms
=
[]
for
name
,
p
in
model
.
named_parameters
():
if
p
.
grad
is
not
None
:
param_norm
=
p
.
grad
.
data
.
norm
(
2
)
norms
.
append
(
param_norm
.
item
())
clip_coef
=
clip
/
(
param_norm
+
1e-6
)
if
clip_coef
<
1
:
p
.
grad
.
data
.
mul_
(
clip_coef
)
return
norms
def
cancel_gradients_last_layer
(
epoch
,
model
,
freeze_last_layer
):
if
epoch
>=
freeze_last_layer
:
return
for
n
,
p
in
model
.
named_parameters
():
if
"last_layer"
in
n
:
p
.
grad
=
None
def
restart_from_checkpoint
(
ckp_path
,
run_variables
=
None
,
**
kwargs
):
"""
Re-start from checkpoint
"""
if
not
os
.
path
.
isfile
(
ckp_path
):
return
print
(
"Found checkpoint at {}"
.
format
(
ckp_path
))
# open checkpoint file
checkpoint
=
torch
.
load
(
ckp_path
,
map_location
=
"cpu"
)
# key is what to look for in the checkpoint file
# value is the object to load
# example: {'state_dict': model}
for
key
,
value
in
kwargs
.
items
():
if
key
in
checkpoint
and
value
is
not
None
:
try
:
msg
=
value
.
load_state_dict
(
checkpoint
[
key
],
strict
=
False
)
print
(
"=> loaded '{}' from checkpoint '{}' with msg {}"
.
format
(
key
,
ckp_path
,
msg
))
except
TypeError
:
try
:
msg
=
value
.
load_state_dict
(
checkpoint
[
key
])
print
(
"=> loaded '{}' from checkpoint: '{}'"
.
format
(
key
,
ckp_path
))
except
ValueError
:
print
(
"=> failed to load '{}' from checkpoint: '{}'"
.
format
(
key
,
ckp_path
))
else
:
print
(
"=> key '{}' not found in checkpoint: '{}'"
.
format
(
key
,
ckp_path
))
# re load variable important for the run
if
run_variables
is
not
None
:
for
var_name
in
run_variables
:
if
var_name
in
checkpoint
:
run_variables
[
var_name
]
=
checkpoint
[
var_name
]
def
cosine_scheduler
(
base_value
,
final_value
,
epochs
,
niter_per_ep
,
warmup_epochs
=
0
,
start_warmup_value
=
0
):
warmup_schedule
=
np
.
array
([])
warmup_iters
=
warmup_epochs
*
niter_per_ep
if
warmup_epochs
>
0
:
warmup_schedule
=
np
.
linspace
(
start_warmup_value
,
base_value
,
warmup_iters
)
iters
=
np
.
arange
(
epochs
*
niter_per_ep
-
warmup_iters
)
schedule
=
final_value
+
0.5
*
(
base_value
-
final_value
)
*
(
1
+
np
.
cos
(
np
.
pi
*
iters
/
len
(
iters
)))
schedule
=
np
.
concatenate
((
warmup_schedule
,
schedule
))
assert
len
(
schedule
)
==
epochs
*
niter_per_ep
return
schedule
def
bool_flag
(
s
):
"""
Parse boolean arguments from the command line.
"""
FALSY_STRINGS
=
{
"off"
,
"false"
,
"0"
}
TRUTHY_STRINGS
=
{
"on"
,
"true"
,
"1"
}
if
s
.
lower
()
in
FALSY_STRINGS
:
return
False
elif
s
.
lower
()
in
TRUTHY_STRINGS
:
return
True
else
:
raise
argparse
.
ArgumentTypeError
(
"invalid value for a boolean flag"
)
def
fix_random_seeds
(
seed
=
31
):
"""
Fix random seeds.
"""
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
np
.
random
.
seed
(
seed
)
class
SmoothedValue
(
object
):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def
__init__
(
self
,
window_size
=
20
,
fmt
=
None
):
if
fmt
is
None
:
fmt
=
"{median:.6f} ({global_avg:.6f})"
self
.
deque
=
deque
(
maxlen
=
window_size
)
self
.
total
=
0.0
self
.
count
=
0
self
.
fmt
=
fmt
def
update
(
self
,
value
,
n
=
1
):
self
.
deque
.
append
(
value
)
self
.
count
+=
n
self
.
total
+=
value
*
n
def
synchronize_between_processes
(
self
):
"""
Warning: does not synchronize the deque!
"""
if
not
is_dist_avail_and_initialized
():
return
t
=
torch
.
tensor
([
self
.
count
,
self
.
total
],
dtype
=
torch
.
float64
,
device
=
'cuda'
)
dist
.
barrier
()
dist
.
all_reduce
(
t
)
t
=
t
.
tolist
()
self
.
count
=
int
(
t
[
0
])
self
.
total
=
t
[
1
]
@
property
def
median
(
self
):
d
=
torch
.
tensor
(
list
(
self
.
deque
))
return
d
.
median
().
item
()
@
property
def
avg
(
self
):
d
=
torch
.
tensor
(
list
(
self
.
deque
),
dtype
=
torch
.
float32
)
return
d
.
mean
().
item
()
@
property
def
global_avg
(
self
):
return
self
.
total
/
self
.
count
@
property
def
max
(
self
):
return
max
(
self
.
deque
)
@
property
def
value
(
self
):
return
self
.
deque
[
-
1
]
def
__str__
(
self
):
return
self
.
fmt
.
format
(
median
=
self
.
median
,
avg
=
self
.
avg
,
global_avg
=
self
.
global_avg
,
max
=
self
.
max
,
value
=
self
.
value
)
def
reduce_dict
(
input_dict
,
average
=
True
):
"""
Args:
input_dict (dict): all the values will be reduced
average (bool): whether to do average or sum
Reduce the values in the dictionary from all processes so that all processes
have the averaged results. Returns a dict with the same fields as
input_dict, after reduction.
"""
world_size
=
get_world_size
()
if
world_size
<
2
:
return
input_dict
with
torch
.
no_grad
():
names
=
[]
values
=
[]
# sort the keys so that they are consistent across processes
for
k
in
sorted
(
input_dict
.
keys
()):
names
.
append
(
k
)
values
.
append
(
input_dict
[
k
])
values
=
torch
.
stack
(
values
,
dim
=
0
)
dist
.
all_reduce
(
values
)
if
average
:
values
/=
world_size
reduced_dict
=
{
k
:
v
for
k
,
v
in
zip
(
names
,
values
)}
return
reduced_dict
class
MetricLogger
(
object
):
def
__init__
(
self
,
delimiter
=
"
\t
"
):
self
.
meters
=
defaultdict
(
SmoothedValue
)
self
.
delimiter
=
delimiter
def
update
(
self
,
**
kwargs
):
for
k
,
v
in
kwargs
.
items
():
if
isinstance
(
v
,
torch
.
Tensor
):
v
=
v
.
item
()
assert
isinstance
(
v
,
(
float
,
int
))
self
.
meters
[
k
].
update
(
v
)
def
__getattr__
(
self
,
attr
):
if
attr
in
self
.
meters
:
return
self
.
meters
[
attr
]
if
attr
in
self
.
__dict__
:
return
self
.
__dict__
[
attr
]
raise
AttributeError
(
"'{}' object has no attribute '{}'"
.
format
(
type
(
self
).
__name__
,
attr
))
def
__str__
(
self
):
loss_str
=
[]
for
name
,
meter
in
self
.
meters
.
items
():
loss_str
.
append
(
"{}: {}"
.
format
(
name
,
str
(
meter
))
)
return
self
.
delimiter
.
join
(
loss_str
)
def
synchronize_between_processes
(
self
):
for
meter
in
self
.
meters
.
values
():
meter
.
synchronize_between_processes
()
def
add_meter
(
self
,
name
,
meter
):
self
.
meters
[
name
]
=
meter
def
log_every
(
self
,
iterable
,
print_freq
,
header
=
None
):
i
=
0
if
not
header
:
header
=
''
start_time
=
time
.
time
()
end
=
time
.
time
()
iter_time
=
SmoothedValue
(
fmt
=
'{avg:.6f}'
)
data_time
=
SmoothedValue
(
fmt
=
'{avg:.6f}'
)
space_fmt
=
':'
+
str
(
len
(
str
(
len
(
iterable
))))
+
'd'
if
torch
.
cuda
.
is_available
():
log_msg
=
self
.
delimiter
.
join
([
header
,
'[{0'
+
space_fmt
+
'}/{1}]'
,
'eta: {eta}'
,
'{meters}'
,
'time: {time}'
,
'data: {data}'
,
'max mem: {memory:.0f}'
])
else
:
log_msg
=
self
.
delimiter
.
join
([
header
,
'[{0'
+
space_fmt
+
'}/{1}]'
,
'eta: {eta}'
,
'{meters}'
,
'time: {time}'
,
'data: {data}'
])
MB
=
1024.0
*
1024.0
for
obj
in
iterable
:
data_time
.
update
(
time
.
time
()
-
end
)
yield
obj
iter_time
.
update
(
time
.
time
()
-
end
)
if
i
%
print_freq
==
0
or
i
==
len
(
iterable
)
-
1
:
eta_seconds
=
iter_time
.
global_avg
*
(
len
(
iterable
)
-
i
)
eta_string
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
eta_seconds
)))
if
torch
.
cuda
.
is_available
():
print
(
log_msg
.
format
(
i
,
len
(
iterable
),
eta
=
eta_string
,
meters
=
str
(
self
),
time
=
str
(
iter_time
),
data
=
str
(
data_time
),
memory
=
torch
.
cuda
.
max_memory_allocated
()
/
MB
))
else
:
print
(
log_msg
.
format
(
i
,
len
(
iterable
),
eta
=
eta_string
,
meters
=
str
(
self
),
time
=
str
(
iter_time
),
data
=
str
(
data_time
)))
i
+=
1
end
=
time
.
time
()
total_time
=
time
.
time
()
-
start_time
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
total_time
)))
print
(
'{} Total time: {} ({:.6f} s / it)'
.
format
(
header
,
total_time_str
,
total_time
/
len
(
iterable
)))
def
get_sha
():
cwd
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
def
_run
(
command
):
return
subprocess
.
check_output
(
command
,
cwd
=
cwd
).
decode
(
'ascii'
).
strip
()
sha
=
'N/A'
diff
=
"clean"
branch
=
'N/A'
try
:
sha
=
_run
([
'git'
,
'rev-parse'
,
'HEAD'
])
subprocess
.
check_output
([
'git'
,
'diff'
],
cwd
=
cwd
)
diff
=
_run
([
'git'
,
'diff-index'
,
'HEAD'
])
diff
=
"has uncommited changes"
if
diff
else
"clean"
branch
=
_run
([
'git'
,
'rev-parse'
,
'--abbrev-ref'
,
'HEAD'
])
except
Exception
:
pass
message
=
f
"sha:
{
sha
}
, status:
{
diff
}
, branch:
{
branch
}
"
return
message
def
is_dist_avail_and_initialized
():
if
not
dist
.
is_available
():
return
False
if
not
dist
.
is_initialized
():
return
False
return
True
def
get_world_size
():
if
not
is_dist_avail_and_initialized
():
return
1
return
dist
.
get_world_size
()
def
get_rank
():
if
not
is_dist_avail_and_initialized
():
return
0
return
dist
.
get_rank
()
def
is_main_process
():
return
get_rank
()
==
0
def
save_on_master
(
*
args
,
**
kwargs
):
if
is_main_process
():
torch
.
save
(
*
args
,
**
kwargs
)
def
setup_for_distributed
(
is_master
):
"""
This function disables printing when not in master process
"""
import
builtins
as
__builtin__
builtin_print
=
__builtin__
.
print
def
print
(
*
args
,
**
kwargs
):
force
=
kwargs
.
pop
(
'force'
,
False
)
if
is_master
or
force
:
builtin_print
(
*
args
,
**
kwargs
)
__builtin__
.
print
=
print
def
init_distributed_mode
(
args
):
# launched with torch.distributed.launch
if
'RANK'
in
os
.
environ
and
'WORLD_SIZE'
in
os
.
environ
:
args
.
rank
=
int
(
os
.
environ
[
"RANK"
])
args
.
world_size
=
int
(
os
.
environ
[
'WORLD_SIZE'
])
args
.
gpu
=
int
(
os
.
environ
[
'LOCAL_RANK'
])
# launched with submitit on a slurm cluster
elif
'SLURM_PROCID'
in
os
.
environ
:
args
.
rank
=
int
(
os
.
environ
[
'SLURM_PROCID'
])
args
.
gpu
=
args
.
rank
%
torch
.
cuda
.
device_count
()
# launched naively with `python main_dino.py`
# we manually add MASTER_ADDR and MASTER_PORT to env variables
elif
torch
.
cuda
.
is_available
():
print
(
'Will run the code on one GPU.'
)
args
.
rank
,
args
.
gpu
,
args
.
world_size
=
0
,
0
,
1
os
.
environ
[
'MASTER_ADDR'
]
=
'127.0.0.1'
os
.
environ
[
'MASTER_PORT'
]
=
'29500'
else
:
print
(
'Does not support training without GPU.'
)
sys
.
exit
(
1
)
dist
.
init_process_group
(
backend
=
"nccl"
,
init_method
=
args
.
dist_url
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
,
)
torch
.
cuda
.
set_device
(
args
.
gpu
)
print
(
'| distributed init (rank {}): {}'
.
format
(
args
.
rank
,
args
.
dist_url
),
flush
=
True
)
dist
.
barrier
()
setup_for_distributed
(
args
.
rank
==
0
)
def
accuracy
(
output
,
target
,
topk
=
(
1
,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
maxk
=
max
(
topk
)
batch_size
=
target
.
size
(
0
)
_
,
pred
=
output
.
topk
(
maxk
,
1
,
True
,
True
)
pred
=
pred
.
t
()
correct
=
pred
.
eq
(
target
.
reshape
(
1
,
-
1
).
expand_as
(
pred
))
return
[
correct
[:
k
].
reshape
(
-
1
).
float
().
sum
(
0
)
*
100.
/
batch_size
for
k
in
topk
]
def
_no_grad_trunc_normal_
(
tensor
,
mean
,
std
,
a
,
b
):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def
norm_cdf
(
x
):
# Computes standard normal cumulative distribution function
return
(
1.
+
math
.
erf
(
x
/
math
.
sqrt
(
2.
)))
/
2.
if
(
mean
<
a
-
2
*
std
)
or
(
mean
>
b
+
2
*
std
):
warnings
.
warn
(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect."
,
stacklevel
=
2
)
with
torch
.
no_grad
():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l
=
norm_cdf
((
a
-
mean
)
/
std
)
u
=
norm_cdf
((
b
-
mean
)
/
std
)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor
.
uniform_
(
2
*
l
-
1
,
2
*
u
-
1
)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor
.
erfinv_
()
# Transform to proper mean, std
tensor
.
mul_
(
std
*
math
.
sqrt
(
2.
))
tensor
.
add_
(
mean
)
# Clamp to ensure it's in the proper range
tensor
.
clamp_
(
min
=
a
,
max
=
b
)
return
tensor
def
trunc_normal_
(
tensor
,
mean
=
0.
,
std
=
1.
,
a
=-
2.
,
b
=
2.
):
# type: (Tensor, float, float, float, float) -> Tensor
return
_no_grad_trunc_normal_
(
tensor
,
mean
,
std
,
a
,
b
)
class
LARS
(
torch
.
optim
.
Optimizer
):
"""
Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py
"""
def
__init__
(
self
,
params
,
lr
=
0
,
weight_decay
=
0
,
momentum
=
0.9
,
eta
=
0.001
,
weight_decay_filter
=
None
,
lars_adaptation_filter
=
None
):
defaults
=
dict
(
lr
=
lr
,
weight_decay
=
weight_decay
,
momentum
=
momentum
,
eta
=
eta
,
weight_decay_filter
=
weight_decay_filter
,
lars_adaptation_filter
=
lars_adaptation_filter
)
super
().
__init__
(
params
,
defaults
)
@
torch
.
no_grad
()
def
step
(
self
):
for
g
in
self
.
param_groups
:
for
p
in
g
[
'params'
]:
dp
=
p
.
grad
if
dp
is
None
:
continue
if
p
.
ndim
!=
1
:
dp
=
dp
.
add
(
p
,
alpha
=
g
[
'weight_decay'
])
if
p
.
ndim
!=
1
:
param_norm
=
torch
.
norm
(
p
)
update_norm
=
torch
.
norm
(
dp
)
one
=
torch
.
ones_like
(
param_norm
)
q
=
torch
.
where
(
param_norm
>
0.
,
torch
.
where
(
update_norm
>
0
,
(
g
[
'eta'
]
*
param_norm
/
update_norm
),
one
),
one
)
dp
=
dp
.
mul
(
q
)
param_state
=
self
.
state
[
p
]
if
'mu'
not
in
param_state
:
param_state
[
'mu'
]
=
torch
.
zeros_like
(
p
)
mu
=
param_state
[
'mu'
]
mu
.
mul_
(
g
[
'momentum'
]).
add_
(
dp
)
p
.
add_
(
mu
,
alpha
=-
g
[
'lr'
])
class
MultiCropWrapper
(
nn
.
Module
):
"""
Perform forward pass separately on each resolution input.
The inputs corresponding to a single resolution are clubbed and single
forward is run on the same resolution inputs. Hence we do several
forward passes = number of different resolutions used. We then
concatenate all the output features and run the head forward on these
concatenated features.
"""
def
__init__
(
self
,
backbone
,
head
):
super
(
MultiCropWrapper
,
self
).
__init__
()
# disable layers dedicated to ImageNet labels classification
backbone
.
fc
,
backbone
.
head
=
nn
.
Identity
(),
nn
.
Identity
()
self
.
backbone
=
backbone
self
.
head
=
head
def
forward
(
self
,
x
):
# convert to list
if
not
isinstance
(
x
,
list
):
x
=
[
x
]
idx_crops
=
torch
.
cumsum
(
torch
.
unique_consecutive
(
torch
.
tensor
([
inp
.
shape
[
-
1
]
for
inp
in
x
]),
return_counts
=
True
,
)[
1
],
0
)
start_idx
,
output
=
0
,
torch
.
empty
(
0
).
to
(
x
[
0
].
device
)
for
end_idx
in
idx_crops
:
_out
=
self
.
backbone
(
torch
.
cat
(
x
[
start_idx
:
end_idx
]))
# The output is a tuple with XCiT model. See:
# https://github.com/facebookresearch/xcit/blob/master/xcit.py#L404-L405
if
isinstance
(
_out
,
tuple
):
_out
=
_out
[
0
]
# accumulate outputs
output
=
torch
.
cat
((
output
,
_out
))
start_idx
=
end_idx
# Run the head forward on the concatenated features.
return
self
.
head
(
output
)
def
get_params_groups
(
model
):
regularized
=
[]
not_regularized
=
[]
for
name
,
param
in
model
.
named_parameters
():
if
not
param
.
requires_grad
:
continue
# we do not regularize biases nor Norm parameters
if
name
.
endswith
(
".bias"
)
or
len
(
param
.
shape
)
==
1
:
not_regularized
.
append
(
param
)
else
:
regularized
.
append
(
param
)
return
[{
'params'
:
regularized
},
{
'params'
:
not_regularized
,
'weight_decay'
:
0.
}]
def
has_batchnorms
(
model
):
bn_types
=
(
nn
.
BatchNorm1d
,
nn
.
BatchNorm2d
,
nn
.
BatchNorm3d
,
nn
.
SyncBatchNorm
)
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
bn_types
):
return
True
return
False
class
PCA
():
"""
Class to compute and apply PCA.
"""
def
__init__
(
self
,
dim
=
256
,
whit
=
0.5
):
self
.
dim
=
dim
self
.
whit
=
whit
self
.
mean
=
None
def
train_pca
(
self
,
cov
):
"""
Takes a covariance matrix (np.ndarray) as input.
"""
d
,
v
=
np
.
linalg
.
eigh
(
cov
)
eps
=
d
.
max
()
*
1e-5
n_0
=
(
d
<
eps
).
sum
()
if
n_0
>
0
:
d
[
d
<
eps
]
=
eps
# total energy
totenergy
=
d
.
sum
()
# sort eigenvectors with eigenvalues order
idx
=
np
.
argsort
(
d
)[::
-
1
][:
self
.
dim
]
d
=
d
[
idx
]
v
=
v
[:,
idx
]
print
(
"keeping %.2f %% of the energy"
%
(
d
.
sum
()
/
totenergy
*
100.0
))
# for the whitening
d
=
np
.
diag
(
1.
/
d
**
self
.
whit
)
# principal components
self
.
dvt
=
np
.
dot
(
d
,
v
.
T
)
def
apply
(
self
,
x
):
# input is from numpy
if
isinstance
(
x
,
np
.
ndarray
):
if
self
.
mean
is
not
None
:
x
-=
self
.
mean
return
np
.
dot
(
self
.
dvt
,
x
.
T
).
T
# input is from torch and is on GPU
if
x
.
is_cuda
:
if
self
.
mean
is
not
None
:
x
-=
torch
.
cuda
.
FloatTensor
(
self
.
mean
)
return
torch
.
mm
(
torch
.
cuda
.
FloatTensor
(
self
.
dvt
),
x
.
transpose
(
0
,
1
)).
transpose
(
0
,
1
)
# input if from torch, on CPU
if
self
.
mean
is
not
None
:
x
-=
torch
.
FloatTensor
(
self
.
mean
)
return
torch
.
mm
(
torch
.
FloatTensor
(
self
.
dvt
),
x
.
transpose
(
0
,
1
)).
transpose
(
0
,
1
)
def
compute_ap
(
ranks
,
nres
):
"""
Computes average precision for given ranked indexes.
Arguments
---------
ranks : zerro-based ranks of positive images
nres : number of positive images
Returns
-------
ap : average precision
"""
# number of images ranked by the system
nimgranks
=
len
(
ranks
)
# accumulate trapezoids in PR-plot
ap
=
0
recall_step
=
1.
/
nres
for
j
in
np
.
arange
(
nimgranks
):
rank
=
ranks
[
j
]
if
rank
==
0
:
precision_0
=
1.
else
:
precision_0
=
float
(
j
)
/
rank
precision_1
=
float
(
j
+
1
)
/
(
rank
+
1
)
ap
+=
(
precision_0
+
precision_1
)
*
recall_step
/
2.
return
ap
def
compute_map
(
ranks
,
gnd
,
kappas
=
[]):
"""
Computes the mAP for a given set of returned results.
Usage:
map = compute_map (ranks, gnd)
computes mean average precsion (map) only
map, aps, pr, prs = compute_map (ranks, gnd, kappas)
computes mean average precision (map), average precision (aps) for each query
computes mean precision at kappas (pr), precision at kappas (prs) for each query
Notes:
1) ranks starts from 0, ranks.shape = db_size X #queries
2) The junk results (e.g., the query itself) should be declared in the gnd stuct array
3) If there are no positive images for some query, that query is excluded from the evaluation
"""
map
=
0.
nq
=
len
(
gnd
)
# number of queries
aps
=
np
.
zeros
(
nq
)
pr
=
np
.
zeros
(
len
(
kappas
))
prs
=
np
.
zeros
((
nq
,
len
(
kappas
)))
nempty
=
0
for
i
in
np
.
arange
(
nq
):
qgnd
=
np
.
array
(
gnd
[
i
][
'ok'
])
# no positive images, skip from the average
if
qgnd
.
shape
[
0
]
==
0
:
aps
[
i
]
=
float
(
'nan'
)
prs
[
i
,
:]
=
float
(
'nan'
)
nempty
+=
1
continue
try
:
qgndj
=
np
.
array
(
gnd
[
i
][
'junk'
])
except
:
qgndj
=
np
.
empty
(
0
)
# sorted positions of positive and junk images (0 based)
pos
=
np
.
arange
(
ranks
.
shape
[
0
])[
np
.
in1d
(
ranks
[:,
i
],
qgnd
)]
junk
=
np
.
arange
(
ranks
.
shape
[
0
])[
np
.
in1d
(
ranks
[:,
i
],
qgndj
)]
k
=
0
;
ij
=
0
;
if
len
(
junk
):
# decrease positions of positives based on the number of
# junk images appearing before them
ip
=
0
while
(
ip
<
len
(
pos
)):
while
(
ij
<
len
(
junk
)
and
pos
[
ip
]
>
junk
[
ij
]):
k
+=
1
ij
+=
1
pos
[
ip
]
=
pos
[
ip
]
-
k
ip
+=
1
# compute ap
ap
=
compute_ap
(
pos
,
len
(
qgnd
))
map
=
map
+
ap
aps
[
i
]
=
ap
# compute precision @ k
pos
+=
1
# get it to 1-based
for
j
in
np
.
arange
(
len
(
kappas
)):
kq
=
min
(
max
(
pos
),
kappas
[
j
]);
prs
[
i
,
j
]
=
(
pos
<=
kq
).
sum
()
/
kq
pr
=
pr
+
prs
[
i
,
:]
map
=
map
/
(
nq
-
nempty
)
pr
=
pr
/
(
nq
-
nempty
)
return
map
,
aps
,
pr
,
prs
def
multi_scale
(
samples
,
model
):
v
=
None
for
s
in
[
1
,
1
/
2
**
(
1
/
2
),
1
/
2
]:
# we use 3 different scales
if
s
==
1
:
inp
=
samples
.
clone
()
else
:
inp
=
nn
.
functional
.
interpolate
(
samples
,
scale_factor
=
s
,
mode
=
'bilinear'
,
align_corners
=
False
)
feats
=
model
(
inp
).
clone
()
if
v
is
None
:
v
=
feats
else
:
v
+=
feats
v
/=
3
v
/=
v
.
norm
()
return
v
preprocessing/dino/video_generation.py
0 → 100644
View file @
3d92aebb
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
glob
import
sys
import
argparse
import
cv2
from
tqdm
import
tqdm
import
matplotlib.pyplot
as
plt
import
torch
import
torch.nn
as
nn
import
torchvision
from
torchvision
import
transforms
as
pth_transforms
import
numpy
as
np
from
PIL
import
Image
import
utils
import
vision_transformer
as
vits
FOURCC
=
{
"mp4"
:
cv2
.
VideoWriter_fourcc
(
*
"MP4V"
),
"avi"
:
cv2
.
VideoWriter_fourcc
(
*
"XVID"
),
}
DEVICE
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
class
VideoGenerator
:
def
__init__
(
self
,
args
):
self
.
args
=
args
# self.model = None
# Don't need to load model if you only want a video
if
not
self
.
args
.
video_only
:
self
.
model
=
self
.
__load_model
()
def
run
(
self
):
if
self
.
args
.
input_path
is
None
:
print
(
f
"Provided input path
{
self
.
args
.
input_path
}
is non valid."
)
sys
.
exit
(
1
)
else
:
if
self
.
args
.
video_only
:
self
.
_generate_video_from_images
(
self
.
args
.
input_path
,
self
.
args
.
output_path
)
else
:
# If input path exists
if
os
.
path
.
exists
(
self
.
args
.
input_path
):
# If input is a video file
if
os
.
path
.
isfile
(
self
.
args
.
input_path
):
frames_folder
=
os
.
path
.
join
(
self
.
args
.
output_path
,
"frames"
)
attention_folder
=
os
.
path
.
join
(
self
.
args
.
output_path
,
"attention"
)
os
.
makedirs
(
frames_folder
,
exist_ok
=
True
)
os
.
makedirs
(
attention_folder
,
exist_ok
=
True
)
self
.
_extract_frames_from_video
(
self
.
args
.
input_path
,
frames_folder
)
self
.
_inference
(
frames_folder
,
attention_folder
,
)
self
.
_generate_video_from_images
(
attention_folder
,
self
.
args
.
output_path
)
# If input is a folder of already extracted frames
if
os
.
path
.
isdir
(
self
.
args
.
input_path
):
attention_folder
=
os
.
path
.
join
(
self
.
args
.
output_path
,
"attention"
)
os
.
makedirs
(
attention_folder
,
exist_ok
=
True
)
self
.
_inference
(
self
.
args
.
input_path
,
attention_folder
)
self
.
_generate_video_from_images
(
attention_folder
,
self
.
args
.
output_path
)
# If input path doesn't exists
else
:
print
(
f
"Provided input path
{
self
.
args
.
input_path
}
doesn't exists."
)
sys
.
exit
(
1
)
def
_extract_frames_from_video
(
self
,
inp
:
str
,
out
:
str
):
vidcap
=
cv2
.
VideoCapture
(
inp
)
self
.
args
.
fps
=
vidcap
.
get
(
cv2
.
CAP_PROP_FPS
)
print
(
f
"Video:
{
inp
}
(
{
self
.
args
.
fps
}
fps)"
)
print
(
f
"Extracting frames to
{
out
}
"
)
success
,
image
=
vidcap
.
read
()
count
=
0
while
success
:
cv2
.
imwrite
(
os
.
path
.
join
(
out
,
f
"frame-
{
count
:
04
}
.jpg"
),
image
,
)
success
,
image
=
vidcap
.
read
()
count
+=
1
def
_generate_video_from_images
(
self
,
inp
:
str
,
out
:
str
):
img_array
=
[]
attention_images_list
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
inp
,
"attn-*.jpg"
)))
# Get size of the first image
with
open
(
attention_images_list
[
0
],
"rb"
)
as
f
:
img
=
Image
.
open
(
f
)
img
=
img
.
convert
(
"RGB"
)
size
=
(
img
.
width
,
img
.
height
)
img_array
.
append
(
cv2
.
cvtColor
(
np
.
array
(
img
),
cv2
.
COLOR_RGB2BGR
))
print
(
f
"Generating video
{
size
}
to
{
out
}
"
)
for
filename
in
tqdm
(
attention_images_list
[
1
:]):
with
open
(
filename
,
"rb"
)
as
f
:
img
=
Image
.
open
(
f
)
img
=
img
.
convert
(
"RGB"
)
img_array
.
append
(
cv2
.
cvtColor
(
np
.
array
(
img
),
cv2
.
COLOR_RGB2BGR
))
out
=
cv2
.
VideoWriter
(
os
.
path
.
join
(
out
,
"video."
+
self
.
args
.
video_format
),
FOURCC
[
self
.
args
.
video_format
],
self
.
args
.
fps
,
size
,
)
for
i
in
range
(
len
(
img_array
)):
out
.
write
(
img_array
[
i
])
out
.
release
()
print
(
"Done"
)
def
_inference
(
self
,
inp
:
str
,
out
:
str
):
print
(
f
"Generating attention images to
{
out
}
"
)
for
img_path
in
tqdm
(
sorted
(
glob
.
glob
(
os
.
path
.
join
(
inp
,
"*.jpg"
)))):
with
open
(
img_path
,
"rb"
)
as
f
:
img
=
Image
.
open
(
f
)
img
=
img
.
convert
(
"RGB"
)
if
self
.
args
.
resize
is
not
None
:
transform
=
pth_transforms
.
Compose
(
[
pth_transforms
.
ToTensor
(),
pth_transforms
.
Resize
(
self
.
args
.
resize
),
pth_transforms
.
Normalize
(
(
0.485
,
0.456
,
0.406
),
(
0.229
,
0.224
,
0.225
)
),
]
)
else
:
transform
=
pth_transforms
.
Compose
(
[
pth_transforms
.
ToTensor
(),
pth_transforms
.
Normalize
(
(
0.485
,
0.456
,
0.406
),
(
0.229
,
0.224
,
0.225
)
),
]
)
img
=
transform
(
img
)
# make the image divisible by the patch size
w
,
h
=
(
img
.
shape
[
1
]
-
img
.
shape
[
1
]
%
self
.
args
.
patch_size
,
img
.
shape
[
2
]
-
img
.
shape
[
2
]
%
self
.
args
.
patch_size
,
)
img
=
img
[:,
:
w
,
:
h
].
unsqueeze
(
0
)
w_featmap
=
img
.
shape
[
-
2
]
//
self
.
args
.
patch_size
h_featmap
=
img
.
shape
[
-
1
]
//
self
.
args
.
patch_size
attentions
=
self
.
model
.
get_last_selfattention
(
img
.
to
(
DEVICE
))
nh
=
attentions
.
shape
[
1
]
# number of head
# we keep only the output patch attention
attentions
=
attentions
[
0
,
:,
0
,
1
:].
reshape
(
nh
,
-
1
)
# we keep only a certain percentage of the mass
val
,
idx
=
torch
.
sort
(
attentions
)
val
/=
torch
.
sum
(
val
,
dim
=
1
,
keepdim
=
True
)
cumval
=
torch
.
cumsum
(
val
,
dim
=
1
)
th_attn
=
cumval
>
(
1
-
self
.
args
.
threshold
)
idx2
=
torch
.
argsort
(
idx
)
for
head
in
range
(
nh
):
th_attn
[
head
]
=
th_attn
[
head
][
idx2
[
head
]]
th_attn
=
th_attn
.
reshape
(
nh
,
w_featmap
,
h_featmap
).
float
()
# interpolate
th_attn
=
(
nn
.
functional
.
interpolate
(
th_attn
.
unsqueeze
(
0
),
scale_factor
=
self
.
args
.
patch_size
,
mode
=
"nearest"
,
)[
0
]
.
cpu
()
.
numpy
()
)
attentions
=
attentions
.
reshape
(
nh
,
w_featmap
,
h_featmap
)
attentions
=
(
nn
.
functional
.
interpolate
(
attentions
.
unsqueeze
(
0
),
scale_factor
=
self
.
args
.
patch_size
,
mode
=
"nearest"
,
)[
0
]
.
cpu
()
.
numpy
()
)
# save attentions heatmaps
fname
=
os
.
path
.
join
(
out
,
"attn-"
+
os
.
path
.
basename
(
img_path
))
plt
.
imsave
(
fname
=
fname
,
arr
=
sum
(
attentions
[
i
]
*
1
/
attentions
.
shape
[
0
]
for
i
in
range
(
attentions
.
shape
[
0
])
),
cmap
=
"inferno"
,
format
=
"jpg"
,
)
def
__load_model
(
self
):
# build model
model
=
vits
.
__dict__
[
self
.
args
.
arch
](
patch_size
=
self
.
args
.
patch_size
,
num_classes
=
0
)
for
p
in
model
.
parameters
():
p
.
requires_grad
=
False
model
.
eval
()
model
.
to
(
DEVICE
)
if
os
.
path
.
isfile
(
self
.
args
.
pretrained_weights
):
state_dict
=
torch
.
load
(
self
.
args
.
pretrained_weights
,
map_location
=
"cpu"
)
if
(
self
.
args
.
checkpoint_key
is
not
None
and
self
.
args
.
checkpoint_key
in
state_dict
):
print
(
f
"Take key
{
self
.
args
.
checkpoint_key
}
in provided checkpoint dict"
)
state_dict
=
state_dict
[
self
.
args
.
checkpoint_key
]
state_dict
=
{
k
.
replace
(
"module."
,
""
):
v
for
k
,
v
in
state_dict
.
items
()}
# remove `backbone.` prefix induced by multicrop wrapper
state_dict
=
{
k
.
replace
(
"backbone."
,
""
):
v
for
k
,
v
in
state_dict
.
items
()}
msg
=
model
.
load_state_dict
(
state_dict
,
strict
=
False
)
print
(
"Pretrained weights found at {} and loaded with msg: {}"
.
format
(
self
.
args
.
pretrained_weights
,
msg
)
)
else
:
print
(
"Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate."
)
url
=
None
if
self
.
args
.
arch
==
"vit_small"
and
self
.
args
.
patch_size
==
16
:
url
=
"dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
elif
self
.
args
.
arch
==
"vit_small"
and
self
.
args
.
patch_size
==
8
:
url
=
"dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth"
# model used for visualizations in our paper
elif
self
.
args
.
arch
==
"vit_base"
and
self
.
args
.
patch_size
==
16
:
url
=
"dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
elif
self
.
args
.
arch
==
"vit_base"
and
self
.
args
.
patch_size
==
8
:
url
=
"dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
if
url
is
not
None
:
print
(
"Since no pretrained weights have been provided, we load the reference pretrained DINO weights."
)
state_dict
=
torch
.
hub
.
load_state_dict_from_url
(
url
=
"https://dl.fbaipublicfiles.com/dino/"
+
url
)
model
.
load_state_dict
(
state_dict
,
strict
=
True
)
else
:
print
(
"There is no reference weights available for this model => We use random weights."
)
return
model
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
"Generation self-attention video"
)
parser
.
add_argument
(
"--arch"
,
default
=
"vit_small"
,
type
=
str
,
choices
=
[
"vit_tiny"
,
"vit_small"
,
"vit_base"
],
help
=
"Architecture (support only ViT atm)."
,
)
parser
.
add_argument
(
"--patch_size"
,
default
=
8
,
type
=
int
,
help
=
"Patch resolution of the self.model."
)
parser
.
add_argument
(
"--pretrained_weights"
,
default
=
""
,
type
=
str
,
help
=
"Path to pretrained weights to load."
,
)
parser
.
add_argument
(
"--checkpoint_key"
,
default
=
"teacher"
,
type
=
str
,
help
=
'Key to use in the checkpoint (example: "teacher")'
,
)
parser
.
add_argument
(
"--input_path"
,
required
=
True
,
type
=
str
,
help
=
"""Path to a video file if you want to extract frames
or to a folder of images already extracted by yourself.
or to a folder of attention images."""
,
)
parser
.
add_argument
(
"--output_path"
,
default
=
"./"
,
type
=
str
,
help
=
"""Path to store a folder of frames and / or a folder of attention images.
and / or a final video. Default to current directory."""
,
)
parser
.
add_argument
(
"--threshold"
,
type
=
float
,
default
=
0.6
,
help
=
"""We visualize masks
obtained by thresholding the self-attention maps to keep xx percent of the mass."""
,
)
parser
.
add_argument
(
"--resize"
,
default
=
None
,
type
=
int
,
nargs
=
"+"
,
help
=
"""Apply a resize transformation to input image(s). Use if OOM error.
Usage (single or W H): --resize 512, --resize 720 1280"""
,
)
parser
.
add_argument
(
"--video_only"
,
action
=
"store_true"
,
help
=
"""Use this flag if you only want to generate a video and not all attention images.
If used, --input_path must be set to the folder of attention images. Ex: ./attention/"""
,
)
parser
.
add_argument
(
"--fps"
,
default
=
30.0
,
type
=
float
,
help
=
"FPS of input / output video. Automatically set if you extract frames from a video."
,
)
parser
.
add_argument
(
"--video_format"
,
default
=
"mp4"
,
type
=
str
,
choices
=
[
"mp4"
,
"avi"
],
help
=
"Format of generated video (mp4 or avi)."
,
)
return
parser
.
parse_args
()
if
__name__
==
"__main__"
:
args
=
parse_args
()
vg
=
VideoGenerator
(
args
)
vg
.
run
()
preprocessing/dino/vision_transformer.py
0 → 100644
View file @
3d92aebb
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Mostly copy-paste from timm library.
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
import
math
from
functools
import
partial
import
torch
import
torch.nn
as
nn
from
utils
import
trunc_normal_
def
drop_path
(
x
,
drop_prob
:
float
=
0.
,
training
:
bool
=
False
):
if
drop_prob
==
0.
or
not
training
:
return
x
keep_prob
=
1
-
drop_prob
shape
=
(
x
.
shape
[
0
],)
+
(
1
,)
*
(
x
.
ndim
-
1
)
# work with diff dim tensors, not just 2D ConvNets
random_tensor
=
keep_prob
+
torch
.
rand
(
shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
random_tensor
.
floor_
()
# binarize
output
=
x
.
div
(
keep_prob
)
*
random_tensor
return
output
class
DropPath
(
nn
.
Module
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def
__init__
(
self
,
drop_prob
=
None
):
super
(
DropPath
,
self
).
__init__
()
self
.
drop_prob
=
drop_prob
def
forward
(
self
,
x
):
return
drop_path
(
x
,
self
.
drop_prob
,
self
.
training
)
class
Mlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
drop
=
0.
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
act
=
act_layer
()
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop
=
nn
.
Dropout
(
drop
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop
(
x
)
return
x
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
=
8
,
qkv_bias
=
False
,
qk_scale
=
None
,
attn_drop
=
0.
,
proj_drop
=
0.
):
super
().
__init__
()
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
self
.
scale
=
qk_scale
or
head_dim
**
-
0.5
self
.
qkv
=
nn
.
Linear
(
dim
,
dim
*
3
,
bias
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
def
forward
(
self
,
x
):
B
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
).
reshape
(
B
,
N
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
attn
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
*
self
.
scale
attn
=
attn
.
softmax
(
dim
=-
1
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B
,
N
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
,
attn
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
,
mlp_ratio
=
4.
,
qkv_bias
=
False
,
qk_scale
=
None
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
act_layer
=
nn
.
GELU
,
norm_layer
=
nn
.
LayerNorm
):
super
().
__init__
()
self
.
norm1
=
norm_layer
(
dim
)
self
.
attn
=
Attention
(
dim
,
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
)
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
else
nn
.
Identity
()
self
.
norm2
=
norm_layer
(
dim
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
mlp
=
Mlp
(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
act_layer
,
drop
=
drop
)
def
forward
(
self
,
x
,
return_attention
=
False
):
y
,
attn
=
self
.
attn
(
self
.
norm1
(
x
))
if
return_attention
:
return
attn
x
=
x
+
self
.
drop_path
(
y
)
x
=
x
+
self
.
drop_path
(
self
.
mlp
(
self
.
norm2
(
x
)))
return
x
class
PatchEmbed
(
nn
.
Module
):
""" Image to Patch Embedding
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
embed_dim
=
768
):
super
().
__init__
()
num_patches
=
(
img_size
//
patch_size
)
*
(
img_size
//
patch_size
)
self
.
img_size
=
img_size
self
.
patch_size
=
patch_size
self
.
num_patches
=
num_patches
self
.
proj
=
nn
.
Conv2d
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
)
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
x
=
self
.
proj
(
x
).
flatten
(
2
).
transpose
(
1
,
2
)
return
x
class
VisionTransformer
(
nn
.
Module
):
""" Vision Transformer """
def
__init__
(
self
,
img_size
=
[
224
],
patch_size
=
16
,
in_chans
=
3
,
num_classes
=
0
,
embed_dim
=
768
,
depth
=
12
,
num_heads
=
12
,
mlp_ratio
=
4.
,
qkv_bias
=
False
,
qk_scale
=
None
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
norm_layer
=
nn
.
LayerNorm
,
**
kwargs
):
super
().
__init__
()
self
.
num_features
=
self
.
embed_dim
=
embed_dim
self
.
patch_embed
=
PatchEmbed
(
img_size
=
img_size
[
0
],
patch_size
=
patch_size
,
in_chans
=
in_chans
,
embed_dim
=
embed_dim
)
num_patches
=
self
.
patch_embed
.
num_patches
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
embed_dim
))
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_patches
+
1
,
embed_dim
))
self
.
pos_drop
=
nn
.
Dropout
(
p
=
drop_rate
)
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
depth
)]
# stochastic depth decay rule
self
.
blocks
=
nn
.
ModuleList
([
Block
(
dim
=
embed_dim
,
num_heads
=
num_heads
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
i
],
norm_layer
=
norm_layer
)
for
i
in
range
(
depth
)])
self
.
norm
=
norm_layer
(
embed_dim
)
# Classifier head
self
.
head
=
nn
.
Linear
(
embed_dim
,
num_classes
)
if
num_classes
>
0
else
nn
.
Identity
()
trunc_normal_
(
self
.
pos_embed
,
std
=
.
02
)
trunc_normal_
(
self
.
cls_token
,
std
=
.
02
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
interpolate_pos_encoding
(
self
,
x
,
w
,
h
):
npatch
=
x
.
shape
[
1
]
-
1
N
=
self
.
pos_embed
.
shape
[
1
]
-
1
if
npatch
==
N
and
w
==
h
:
return
self
.
pos_embed
class_pos_embed
=
self
.
pos_embed
[:,
0
]
patch_pos_embed
=
self
.
pos_embed
[:,
1
:]
dim
=
x
.
shape
[
-
1
]
w0
=
w
//
self
.
patch_embed
.
patch_size
h0
=
h
//
self
.
patch_embed
.
patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0
,
h0
=
w0
+
0.1
,
h0
+
0.1
patch_pos_embed
=
nn
.
functional
.
interpolate
(
patch_pos_embed
.
reshape
(
1
,
int
(
math
.
sqrt
(
N
)),
int
(
math
.
sqrt
(
N
)),
dim
).
permute
(
0
,
3
,
1
,
2
),
scale_factor
=
(
w0
/
math
.
sqrt
(
N
),
h0
/
math
.
sqrt
(
N
)),
mode
=
'bicubic'
,
)
assert
int
(
w0
)
==
patch_pos_embed
.
shape
[
-
2
]
and
int
(
h0
)
==
patch_pos_embed
.
shape
[
-
1
]
patch_pos_embed
=
patch_pos_embed
.
permute
(
0
,
2
,
3
,
1
).
view
(
1
,
-
1
,
dim
)
return
torch
.
cat
((
class_pos_embed
.
unsqueeze
(
0
),
patch_pos_embed
),
dim
=
1
)
def
prepare_tokens
(
self
,
x
):
B
,
nc
,
w
,
h
=
x
.
shape
x
=
self
.
patch_embed
(
x
)
# patch linear embedding
# add the [CLS] token to the embed patch tokens
cls_tokens
=
self
.
cls_token
.
expand
(
B
,
-
1
,
-
1
)
x
=
torch
.
cat
((
cls_tokens
,
x
),
dim
=
1
)
# add positional encoding to each token
x
=
x
+
self
.
interpolate_pos_encoding
(
x
,
w
,
h
)
return
self
.
pos_drop
(
x
)
def
forward
(
self
,
x
):
x
=
self
.
prepare_tokens
(
x
)
for
blk
in
self
.
blocks
:
x
=
blk
(
x
)
x
=
self
.
norm
(
x
)
return
x
[:,
0
]
def
get_last_selfattention
(
self
,
x
):
x
=
self
.
prepare_tokens
(
x
)
for
i
,
blk
in
enumerate
(
self
.
blocks
):
if
i
<
len
(
self
.
blocks
)
-
1
:
x
=
blk
(
x
)
else
:
# return attention of the last block
return
blk
(
x
,
return_attention
=
True
)
def
get_intermediate_layers
(
self
,
x
,
n
=
1
):
x
=
self
.
prepare_tokens
(
x
)
# we return the output tokens from the `n` last blocks
output
=
[]
for
i
,
blk
in
enumerate
(
self
.
blocks
):
x
=
blk
(
x
)
if
len
(
self
.
blocks
)
-
i
<=
n
:
output
.
append
(
self
.
norm
(
x
))
return
output
def
vit_tiny
(
patch_size
=
16
,
**
kwargs
):
model
=
VisionTransformer
(
patch_size
=
patch_size
,
embed_dim
=
192
,
depth
=
12
,
num_heads
=
3
,
mlp_ratio
=
4
,
qkv_bias
=
True
,
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
),
**
kwargs
)
return
model
def
vit_small
(
patch_size
=
16
,
**
kwargs
):
model
=
VisionTransformer
(
patch_size
=
patch_size
,
embed_dim
=
384
,
depth
=
12
,
num_heads
=
6
,
mlp_ratio
=
4
,
qkv_bias
=
True
,
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
),
**
kwargs
)
return
model
def
vit_base
(
patch_size
=
16
,
**
kwargs
):
model
=
VisionTransformer
(
patch_size
=
patch_size
,
embed_dim
=
768
,
depth
=
12
,
num_heads
=
12
,
mlp_ratio
=
4
,
qkv_bias
=
True
,
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
),
**
kwargs
)
return
model
class
DINOHead
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
out_dim
,
use_bn
=
False
,
norm_last_layer
=
True
,
nlayers
=
3
,
hidden_dim
=
2048
,
bottleneck_dim
=
256
):
super
().
__init__
()
nlayers
=
max
(
nlayers
,
1
)
if
nlayers
==
1
:
self
.
mlp
=
nn
.
Linear
(
in_dim
,
bottleneck_dim
)
else
:
layers
=
[
nn
.
Linear
(
in_dim
,
hidden_dim
)]
if
use_bn
:
layers
.
append
(
nn
.
BatchNorm1d
(
hidden_dim
))
layers
.
append
(
nn
.
GELU
())
for
_
in
range
(
nlayers
-
2
):
layers
.
append
(
nn
.
Linear
(
hidden_dim
,
hidden_dim
))
if
use_bn
:
layers
.
append
(
nn
.
BatchNorm1d
(
hidden_dim
))
layers
.
append
(
nn
.
GELU
())
layers
.
append
(
nn
.
Linear
(
hidden_dim
,
bottleneck_dim
))
self
.
mlp
=
nn
.
Sequential
(
*
layers
)
self
.
apply
(
self
.
_init_weights
)
self
.
last_layer
=
nn
.
utils
.
weight_norm
(
nn
.
Linear
(
bottleneck_dim
,
out_dim
,
bias
=
False
))
self
.
last_layer
.
weight_g
.
data
.
fill_
(
1
)
if
norm_last_layer
:
self
.
last_layer
.
weight_g
.
requires_grad
=
False
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
forward
(
self
,
x
):
x
=
self
.
mlp
(
x
)
x
=
nn
.
functional
.
normalize
(
x
,
dim
=-
1
,
p
=
2
)
x
=
self
.
last_layer
(
x
)
return
x
preprocessing/dino/visualize_attention.py
0 → 100644
View file @
3d92aebb
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
argparse
import
cv2
import
random
import
colorsys
import
requests
from
io
import
BytesIO
import
skimage.io
from
skimage.measure
import
find_contours
import
matplotlib.pyplot
as
plt
from
matplotlib.patches
import
Polygon
import
torch
import
torch.nn
as
nn
import
torchvision
from
torchvision
import
transforms
as
pth_transforms
import
numpy
as
np
from
PIL
import
Image
import
utils
import
vision_transformer
as
vits
def
apply_mask
(
image
,
mask
,
color
,
alpha
=
0.5
):
for
c
in
range
(
3
):
image
[:,
:,
c
]
=
image
[:,
:,
c
]
*
(
1
-
alpha
*
mask
)
+
alpha
*
mask
*
color
[
c
]
*
255
return
image
def
random_colors
(
N
,
bright
=
True
):
"""
Generate random colors.
"""
brightness
=
1.0
if
bright
else
0.7
hsv
=
[(
i
/
N
,
1
,
brightness
)
for
i
in
range
(
N
)]
colors
=
list
(
map
(
lambda
c
:
colorsys
.
hsv_to_rgb
(
*
c
),
hsv
))
random
.
shuffle
(
colors
)
return
colors
def
display_instances
(
image
,
mask
,
fname
=
"test"
,
figsize
=
(
5
,
5
),
blur
=
False
,
contour
=
True
,
alpha
=
0.5
):
fig
=
plt
.
figure
(
figsize
=
figsize
,
frameon
=
False
)
ax
=
plt
.
Axes
(
fig
,
[
0.
,
0.
,
1.
,
1.
])
ax
.
set_axis_off
()
fig
.
add_axes
(
ax
)
ax
=
plt
.
gca
()
N
=
1
mask
=
mask
[
None
,
:,
:]
# Generate random colors
colors
=
random_colors
(
N
)
# Show area outside image boundaries.
height
,
width
=
image
.
shape
[:
2
]
margin
=
0
ax
.
set_ylim
(
height
+
margin
,
-
margin
)
ax
.
set_xlim
(
-
margin
,
width
+
margin
)
ax
.
axis
(
'off'
)
masked_image
=
image
.
astype
(
np
.
uint32
).
copy
()
for
i
in
range
(
N
):
color
=
colors
[
i
]
_mask
=
mask
[
i
]
if
blur
:
_mask
=
cv2
.
blur
(
_mask
,(
10
,
10
))
# Mask
masked_image
=
apply_mask
(
masked_image
,
_mask
,
color
,
alpha
)
# Mask Polygon
# Pad to ensure proper polygons for masks that touch image edges.
if
contour
:
padded_mask
=
np
.
zeros
((
_mask
.
shape
[
0
]
+
2
,
_mask
.
shape
[
1
]
+
2
))
padded_mask
[
1
:
-
1
,
1
:
-
1
]
=
_mask
contours
=
find_contours
(
padded_mask
,
0.5
)
for
verts
in
contours
:
# Subtract the padding and flip (y, x) to (x, y)
verts
=
np
.
fliplr
(
verts
)
-
1
p
=
Polygon
(
verts
,
facecolor
=
"none"
,
edgecolor
=
color
)
ax
.
add_patch
(
p
)
ax
.
imshow
(
masked_image
.
astype
(
np
.
uint8
),
aspect
=
'auto'
)
fig
.
savefig
(
fname
)
print
(
f
"
{
fname
}
saved."
)
return
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
'Visualize Self-Attention maps'
)
parser
.
add_argument
(
'--arch'
,
default
=
'vit_small'
,
type
=
str
,
choices
=
[
'vit_tiny'
,
'vit_small'
,
'vit_base'
],
help
=
'Architecture (support only ViT atm).'
)
parser
.
add_argument
(
'--patch_size'
,
default
=
8
,
type
=
int
,
help
=
'Patch resolution of the model.'
)
parser
.
add_argument
(
'--pretrained_weights'
,
default
=
''
,
type
=
str
,
help
=
"Path to pretrained weights to load."
)
parser
.
add_argument
(
"--checkpoint_key"
,
default
=
"teacher"
,
type
=
str
,
help
=
'Key to use in the checkpoint (example: "teacher")'
)
parser
.
add_argument
(
"--image_path"
,
default
=
None
,
type
=
str
,
help
=
"Path of the image to load."
)
parser
.
add_argument
(
"--image_size"
,
default
=
(
480
,
480
),
type
=
int
,
nargs
=
"+"
,
help
=
"Resize image."
)
parser
.
add_argument
(
'--output_dir'
,
default
=
'.'
,
help
=
'Path where to save visualizations.'
)
parser
.
add_argument
(
"--threshold"
,
type
=
float
,
default
=
None
,
help
=
"""We visualize masks
obtained by thresholding the self-attention maps to keep xx% of the mass."""
)
args
=
parser
.
parse_args
()
device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
# build model
model
=
vits
.
__dict__
[
args
.
arch
](
patch_size
=
args
.
patch_size
,
num_classes
=
0
)
for
p
in
model
.
parameters
():
p
.
requires_grad
=
False
model
.
eval
()
model
.
to
(
device
)
if
os
.
path
.
isfile
(
args
.
pretrained_weights
):
state_dict
=
torch
.
load
(
args
.
pretrained_weights
,
map_location
=
"cpu"
)
if
args
.
checkpoint_key
is
not
None
and
args
.
checkpoint_key
in
state_dict
:
print
(
f
"Take key
{
args
.
checkpoint_key
}
in provided checkpoint dict"
)
state_dict
=
state_dict
[
args
.
checkpoint_key
]
# remove `module.` prefix
state_dict
=
{
k
.
replace
(
"module."
,
""
):
v
for
k
,
v
in
state_dict
.
items
()}
# remove `backbone.` prefix induced by multicrop wrapper
state_dict
=
{
k
.
replace
(
"backbone."
,
""
):
v
for
k
,
v
in
state_dict
.
items
()}
msg
=
model
.
load_state_dict
(
state_dict
,
strict
=
False
)
print
(
'Pretrained weights found at {} and loaded with msg: {}'
.
format
(
args
.
pretrained_weights
,
msg
))
else
:
print
(
"Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate."
)
url
=
None
if
args
.
arch
==
"vit_small"
and
args
.
patch_size
==
16
:
url
=
"dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
elif
args
.
arch
==
"vit_small"
and
args
.
patch_size
==
8
:
url
=
"dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth"
# model used for visualizations in our paper
elif
args
.
arch
==
"vit_base"
and
args
.
patch_size
==
16
:
url
=
"dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
elif
args
.
arch
==
"vit_base"
and
args
.
patch_size
==
8
:
url
=
"dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
if
url
is
not
None
:
print
(
"Since no pretrained weights have been provided, we load the reference pretrained DINO weights."
)
state_dict
=
torch
.
hub
.
load_state_dict_from_url
(
url
=
"https://dl.fbaipublicfiles.com/dino/"
+
url
)
model
.
load_state_dict
(
state_dict
,
strict
=
True
)
else
:
print
(
"There is no reference weights available for this model => We use random weights."
)
# open image
if
args
.
image_path
is
None
:
# user has not specified any image - we use our own image
print
(
"Please use the `--image_path` argument to indicate the path of the image you wish to visualize."
)
print
(
"Since no image path have been provided, we take the first image in our paper."
)
response
=
requests
.
get
(
"https://dl.fbaipublicfiles.com/dino/img.png"
)
img
=
Image
.
open
(
BytesIO
(
response
.
content
))
img
=
img
.
convert
(
'RGB'
)
elif
os
.
path
.
isfile
(
args
.
image_path
):
with
open
(
args
.
image_path
,
'rb'
)
as
f
:
img
=
Image
.
open
(
f
)
img
=
img
.
convert
(
'RGB'
)
else
:
print
(
f
"Provided image path
{
args
.
image_path
}
is non valid."
)
sys
.
exit
(
1
)
transform
=
pth_transforms
.
Compose
([
pth_transforms
.
Resize
(
args
.
image_size
),
pth_transforms
.
ToTensor
(),
pth_transforms
.
Normalize
((
0.485
,
0.456
,
0.406
),
(
0.229
,
0.224
,
0.225
)),
])
img
=
transform
(
img
)
# make the image divisible by the patch size
w
,
h
=
img
.
shape
[
1
]
-
img
.
shape
[
1
]
%
args
.
patch_size
,
img
.
shape
[
2
]
-
img
.
shape
[
2
]
%
args
.
patch_size
img
=
img
[:,
:
w
,
:
h
].
unsqueeze
(
0
)
w_featmap
=
img
.
shape
[
-
2
]
//
args
.
patch_size
h_featmap
=
img
.
shape
[
-
1
]
//
args
.
patch_size
attentions
=
model
.
get_last_selfattention
(
img
.
to
(
device
))
nh
=
attentions
.
shape
[
1
]
# number of head
# we keep only the output patch attention
attentions
=
attentions
[
0
,
:,
0
,
1
:].
reshape
(
nh
,
-
1
)
if
args
.
threshold
is
not
None
:
# we keep only a certain percentage of the mass
val
,
idx
=
torch
.
sort
(
attentions
)
val
/=
torch
.
sum
(
val
,
dim
=
1
,
keepdim
=
True
)
cumval
=
torch
.
cumsum
(
val
,
dim
=
1
)
th_attn
=
cumval
>
(
1
-
args
.
threshold
)
idx2
=
torch
.
argsort
(
idx
)
for
head
in
range
(
nh
):
th_attn
[
head
]
=
th_attn
[
head
][
idx2
[
head
]]
th_attn
=
th_attn
.
reshape
(
nh
,
w_featmap
,
h_featmap
).
float
()
# interpolate
th_attn
=
nn
.
functional
.
interpolate
(
th_attn
.
unsqueeze
(
0
),
scale_factor
=
args
.
patch_size
,
mode
=
"nearest"
)[
0
].
cpu
().
numpy
()
attentions
=
attentions
.
reshape
(
nh
,
w_featmap
,
h_featmap
)
attentions
=
nn
.
functional
.
interpolate
(
attentions
.
unsqueeze
(
0
),
scale_factor
=
args
.
patch_size
,
mode
=
"nearest"
)[
0
].
cpu
().
numpy
()
# save attentions heatmaps
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
torchvision
.
utils
.
save_image
(
torchvision
.
utils
.
make_grid
(
img
,
normalize
=
True
,
scale_each
=
True
),
os
.
path
.
join
(
args
.
output_dir
,
"img.png"
))
for
j
in
range
(
nh
):
fname
=
os
.
path
.
join
(
args
.
output_dir
,
"attn-head"
+
str
(
j
)
+
".png"
)
plt
.
imsave
(
fname
=
fname
,
arr
=
attentions
[
j
],
format
=
'png'
)
print
(
f
"
{
fname
}
saved."
)
if
args
.
threshold
is
not
None
:
image
=
skimage
.
io
.
imread
(
os
.
path
.
join
(
args
.
output_dir
,
"img.png"
))
for
j
in
range
(
nh
):
display_instances
(
image
,
th_attn
[
j
],
fname
=
os
.
path
.
join
(
args
.
output_dir
,
"mask_th"
+
str
(
args
.
threshold
)
+
"_head"
+
str
(
j
)
+
".png"
),
blur
=
False
)
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment