Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
01a82723
Commit
01a82723
authored
Jan 31, 2022
by
Vijay Korthikanti
Browse files
vision third phase merge: pretraining methods + mit,swin backbones
parent
2b628f96
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
3127 additions
and
9 deletions
+3127
-9
megatron/arguments.py
megatron/arguments.py
+49
-4
megatron/data/vit_dataset.py
megatron/data/vit_dataset.py
+193
-1
megatron/model/vision/classification.py
megatron/model/vision/classification.py
+34
-0
megatron/model/vision/dino.py
megatron/model/vision/dino.py
+290
-0
megatron/model/vision/esvit_swin_backbone.py
megatron/model/vision/esvit_swin_backbone.py
+848
-0
megatron/model/vision/inpainting.py
megatron/model/vision/inpainting.py
+161
-0
megatron/model/vision/knn_monitor.py
megatron/model/vision/knn_monitor.py
+118
-0
megatron/model/vision/mit_backbone.py
megatron/model/vision/mit_backbone.py
+417
-0
megatron/model/vision/swin_backbone.py
megatron/model/vision/swin_backbone.py
+625
-0
megatron/model/vision/utils.py
megatron/model/vision/utils.py
+88
-0
megatron/training.py
megatron/training.py
+17
-1
pretrain_vision_classify.py
pretrain_vision_classify.py
+16
-3
pretrain_vision_dino.py
pretrain_vision_dino.py
+122
-0
pretrain_vision_inpaint.py
pretrain_vision_inpaint.py
+149
-0
tasks/vision/classification/classification.py
tasks/vision/classification/classification.py
+0
-0
tasks/vision/classification/eval_utils.py
tasks/vision/classification/eval_utils.py
+0
-0
No files found.
megatron/arguments.py
View file @
01a82723
...
@@ -39,7 +39,7 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -39,7 +39,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser
=
_add_data_args
(
parser
)
parser
=
_add_data_args
(
parser
)
parser
=
_add_autoresume_args
(
parser
)
parser
=
_add_autoresume_args
(
parser
)
parser
=
_add_biencoder_args
(
parser
)
parser
=
_add_biencoder_args
(
parser
)
parser
=
_add_vi
t
_args
(
parser
)
parser
=
_add_vi
sion
_args
(
parser
)
parser
=
_add_logging_args
(
parser
)
parser
=
_add_logging_args
(
parser
)
parser
=
_add_inference_args
(
parser
)
parser
=
_add_inference_args
(
parser
)
...
@@ -849,9 +849,10 @@ def _add_biencoder_args(parser):
...
@@ -849,9 +849,10 @@ def _add_biencoder_args(parser):
return
parser
return
parser
def
_add_vi
t
_args
(
parser
):
def
_add_vi
sion
_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
"vi
t
"
)
group
=
parser
.
add_argument_group
(
title
=
"vi
sion
"
)
# general vision arguements
group
.
add_argument
(
'--num-classes'
,
type
=
int
,
default
=
1000
,
group
.
add_argument
(
'--num-classes'
,
type
=
int
,
default
=
1000
,
help
=
'num of classes in vision classificaiton task'
)
help
=
'num of classes in vision classificaiton task'
)
group
.
add_argument
(
'--img-h'
,
type
=
int
,
default
=
224
,
group
.
add_argument
(
'--img-h'
,
type
=
int
,
default
=
224
,
...
@@ -861,7 +862,7 @@ def _add_vit_args(parser):
...
@@ -861,7 +862,7 @@ def _add_vit_args(parser):
group
.
add_argument
(
'--num-channels'
,
type
=
int
,
default
=
3
,
group
.
add_argument
(
'--num-channels'
,
type
=
int
,
default
=
3
,
help
=
'Number of channels in input image data'
)
help
=
'Number of channels in input image data'
)
group
.
add_argument
(
'--patch-dim'
,
type
=
int
,
default
=
16
,
group
.
add_argument
(
'--patch-dim'
,
type
=
int
,
default
=
16
,
help
=
'patch dimension
used in vit
'
)
help
=
'patch dimension'
)
group
.
add_argument
(
'--classes-fraction'
,
type
=
float
,
default
=
1.0
,
group
.
add_argument
(
'--classes-fraction'
,
type
=
float
,
default
=
1.0
,
help
=
'training with fraction of classes.'
)
help
=
'training with fraction of classes.'
)
group
.
add_argument
(
'--data-per-class-fraction'
,
type
=
float
,
default
=
1.0
,
group
.
add_argument
(
'--data-per-class-fraction'
,
type
=
float
,
default
=
1.0
,
...
@@ -869,5 +870,49 @@ def _add_vit_args(parser):
...
@@ -869,5 +870,49 @@ def _add_vit_args(parser):
group
.
add_argument
(
'--no-data-sharding'
,
action
=
'store_false'
,
group
.
add_argument
(
'--no-data-sharding'
,
action
=
'store_false'
,
help
=
'Disable data sharding.'
,
help
=
'Disable data sharding.'
,
dest
=
'data_sharding'
)
dest
=
'data_sharding'
)
group
.
add_argument
(
'--head-lr-mult'
,
type
=
float
,
default
=
1.0
,
help
=
'learning rate multiplier for head during finetuning'
)
# pretraining type and backbone selection`
group
.
add_argument
(
'--vision-pretraining-type'
,
type
=
str
,
default
=
'classify'
,
choices
=
[
'classify'
,
'inpaint'
,
'contrast'
],
help
=
'pretraining objectives'
)
group
.
add_argument
(
'--vision-backbone-type'
,
type
=
str
,
default
=
'vit'
,
choices
=
[
'vit'
,
'mit'
,
'swin'
],
help
=
'backbone types types'
)
group
.
add_argument
(
'--swin-backbone-type'
,
type
=
str
,
default
=
'tiny'
,
choices
=
[
'tiny'
,
'base'
,
'h3'
],
help
=
'pretraining objectives'
)
# inpainting arguments
group
.
add_argument
(
'--mask-type'
,
type
=
str
,
default
=
'random'
,
choices
=
[
'random'
,
'row'
],
help
=
'mask types'
)
group
.
add_argument
(
'--mask-factor'
,
type
=
float
,
default
=
1.0
,
help
=
'mask size scaling parameter'
)
# dino arguments
group
.
add_argument
(
'--iter-per-epoch'
,
type
=
int
,
default
=
1250
,
help
=
'iterations per epoch'
)
group
.
add_argument
(
'--dino-local-img-size'
,
type
=
int
,
default
=
96
,
help
=
'Image size for vision classification task'
)
group
.
add_argument
(
'--dino-local-crops-number'
,
type
=
int
,
default
=
10
,
help
=
'Number of local crops'
)
group
.
add_argument
(
'--dino-head-hidden-size'
,
type
=
int
,
default
=
2048
,
help
=
'Hidden dimension size in dino head'
)
group
.
add_argument
(
'--dino-bottleneck-size'
,
type
=
int
,
default
=
256
,
help
=
'Bottle neck dimension in dino head '
)
group
.
add_argument
(
'--dino-freeze-last-layer'
,
type
=
float
,
default
=
1
,
help
=
'Freezing last layer weights'
)
group
.
add_argument
(
'--dino-norm-last-layer'
,
action
=
'store_true'
,
help
=
'Disable Norm in last layer.'
)
group
.
add_argument
(
'--dino-warmup-teacher-temp'
,
type
=
float
,
default
=
0.04
,
help
=
'warump teacher temperature'
)
group
.
add_argument
(
'--dino-teacher-temp'
,
type
=
float
,
default
=
0.07
,
help
=
'teacher temperature'
)
group
.
add_argument
(
'--dino-warmup-teacher-temp-epochs'
,
type
=
int
,
default
=
30
,
help
=
'warmup teacher temperaure epochs'
)
return
parser
return
parser
megatron/data/vit_dataset.py
View file @
01a82723
...
@@ -22,6 +22,43 @@ from megatron import get_args
...
@@ -22,6 +22,43 @@ from megatron import get_args
from
megatron.data.image_folder
import
ImageFolder
from
megatron.data.image_folder
import
ImageFolder
from
megatron.data.autoaugment
import
ImageNetPolicy
from
megatron.data.autoaugment
import
ImageNetPolicy
from
megatron.data.data_samplers
import
RandomSeedDataset
from
megatron.data.data_samplers
import
RandomSeedDataset
from
PIL
import
Image
,
ImageFilter
,
ImageOps
class
GaussianBlur
(
object
):
"""
Apply Gaussian Blur to the PIL image.
"""
def
__init__
(
self
,
p
=
0.5
,
radius_min
=
0.1
,
radius_max
=
2.
):
self
.
prob
=
p
self
.
radius_min
=
radius_min
self
.
radius_max
=
radius_max
def
__call__
(
self
,
img
):
do_it
=
random
.
random
()
<=
self
.
prob
if
not
do_it
:
return
img
return
img
.
filter
(
ImageFilter
.
GaussianBlur
(
radius
=
random
.
uniform
(
self
.
radius_min
,
self
.
radius_max
)
)
)
class
Solarization
(
object
):
"""
Apply Solarization to the PIL image.
"""
def
__init__
(
self
,
p
):
self
.
p
=
p
def
__call__
(
self
,
img
):
if
random
.
random
()
<
self
.
p
:
return
ImageOps
.
solarize
(
img
)
else
:
return
img
class
ClassificationTransform
():
class
ClassificationTransform
():
def
__init__
(
self
,
image_size
,
train
=
True
):
def
__init__
(
self
,
image_size
,
train
=
True
):
...
@@ -52,14 +89,169 @@ class ClassificationTransform():
...
@@ -52,14 +89,169 @@ class ClassificationTransform():
return
output
return
output
class
InpaintingTransform
():
def
__init__
(
self
,
image_size
,
train
=
True
):
args
=
get_args
()
self
.
mask_factor
=
args
.
mask_factor
self
.
mask_type
=
args
.
mask_type
self
.
image_size
=
image_size
self
.
patch_size
=
args
.
patch_dim
self
.
mask_size
=
int
(
self
.
mask_factor
*
(
image_size
[
0
]
/
self
.
patch_size
)
*
(
image_size
[
1
]
/
self
.
patch_size
))
self
.
train
=
train
assert
args
.
fp16
or
args
.
bf16
self
.
data_type
=
torch
.
half
if
args
.
fp16
else
torch
.
bfloat16
if
self
.
train
:
self
.
transform
=
T
.
Compose
([
T
.
RandomResizedCrop
(
self
.
image_size
),
T
.
RandomHorizontalFlip
(),
T
.
ColorJitter
(
0.4
,
0.4
,
0.4
,
0.1
),
ImageNetPolicy
(),
T
.
ToTensor
(),
T
.
ConvertImageDtype
(
self
.
data_type
)
])
else
:
self
.
transform
=
T
.
Compose
([
T
.
Resize
(
self
.
image_size
,
interpolation
=
2
),
T
.
CenterCrop
(
self
.
image_size
),
T
.
ToTensor
(),
T
.
ConvertImageDtype
(
self
.
data_type
)
])
def
gen_mask
(
self
,
image_size
,
mask_size
,
mask_type
,
patch_size
):
# output: mask as a list with indices for missing patches
action_list
=
[[
0
,
1
],
[
0
,
-
1
],
[
1
,
0
],
[
-
1
,
0
]]
assert
image_size
[
0
]
==
image_size
[
1
]
img_size_patch
=
image_size
[
0
]
//
patch_size
# drop masked patches
mask
=
torch
.
zeros
((
image_size
[
0
],
image_size
[
1
]),
dtype
=
torch
.
float
)
if
mask_type
==
'random'
:
x
=
torch
.
randint
(
0
,
img_size_patch
,
())
y
=
torch
.
randint
(
0
,
img_size_patch
,
())
for
i
in
range
(
mask_size
):
r
=
torch
.
randint
(
0
,
len
(
action_list
),
())
x
=
torch
.
clamp
(
x
+
action_list
[
r
][
0
],
min
=
0
,
max
=
img_size_patch
-
1
)
y
=
torch
.
clamp
(
y
+
action_list
[
r
][
1
],
min
=
0
,
max
=
img_size_patch
-
1
)
x_offset
=
x
*
patch_size
y_offset
=
y
*
patch_size
mask
[
x_offset
:
x_offset
+
patch_size
,
y_offset
:
y_offset
+
patch_size
]
=
1
else
:
assert
mask_type
==
'row'
count
=
0
for
x
in
reversed
(
range
(
img_size_patch
)):
for
y
in
reversed
(
range
(
img_size_patch
)):
if
(
count
<
mask_size
):
count
+=
1
x_offset
=
x
*
patch_size
y_offset
=
y
*
patch_size
mask
[
x_offset
:
x_offset
+
patch_size
,
y_offset
:
y_offset
+
patch_size
]
=
1
return
mask
def
__call__
(
self
,
input
):
trans_input
=
self
.
transform
(
input
)
mask
=
self
.
gen_mask
(
self
.
image_size
,
self
.
mask_size
,
self
.
mask_type
,
self
.
patch_size
)
mask
=
mask
.
unsqueeze
(
dim
=
0
)
return
trans_input
,
mask
class
DinoTransform
(
object
):
def
__init__
(
self
,
image_size
,
train
=
True
):
args
=
get_args
()
self
.
data_type
=
torch
.
half
if
args
.
fp16
else
torch
.
bfloat16
flip_and_color_jitter
=
T
.
Compose
([
T
.
RandomHorizontalFlip
(
p
=
0.5
),
T
.
RandomApply
(
[
T
.
ColorJitter
(
brightness
=
0.4
,
contrast
=
0.4
,
saturation
=
0.2
,
hue
=
0.1
)],
p
=
0.8
),
T
.
RandomGrayscale
(
p
=
0.2
),
])
if
args
.
fp16
or
args
.
bf16
:
normalize
=
T
.
Compose
([
T
.
ToTensor
(),
T
.
Normalize
((
0.485
,
0.456
,
0.406
),
(
0.229
,
0.224
,
0.225
)),
T
.
ConvertImageDtype
(
self
.
data_type
)
])
else
:
normalize
=
T
.
Compose
([
T
.
ToTensor
(),
T
.
Normalize
((
0.485
,
0.456
,
0.406
),
(
0.229
,
0.224
,
0.225
)),
])
# first global crop
scale_const
=
0.4
self
.
global_transform1
=
T
.
Compose
([
T
.
RandomResizedCrop
(
image_size
,
scale
=
(
scale_const
,
1
),
interpolation
=
Image
.
BICUBIC
),
flip_and_color_jitter
,
GaussianBlur
(
1.0
),
normalize
])
# second global crop
self
.
global_transform2
=
T
.
Compose
([
T
.
RandomResizedCrop
(
image_size
,
scale
=
(
scale_const
,
1
),
interpolation
=
Image
.
BICUBIC
),
flip_and_color_jitter
,
GaussianBlur
(
0.1
),
Solarization
(
0.2
),
normalize
])
# transformation for the local small crops
self
.
local_crops_number
=
args
.
local_crops_number
self
.
local_transform
=
T
.
Compose
([
T
.
RandomResizedCrop
(
args
.
local_img_size
,
scale
=
(
0.05
,
scale_const
),
interpolation
=
Image
.
BICUBIC
),
flip_and_color_jitter
,
GaussianBlur
(
p
=
0.5
),
normalize
])
def
__call__
(
self
,
image
):
crops
=
[]
args
=
get_args
()
if
args
.
street_data
:
crop_transform
=
T
.
RandomCrop
(
300
)
image
=
crop_transform
(
image
)
crops
.
append
(
self
.
global_transform1
(
image
))
crops
.
append
(
self
.
global_transform2
(
image
))
for
_
in
range
(
self
.
local_crops_number
):
crops
.
append
(
self
.
local_transform
(
image
))
return
crops
def
build_train_valid_datasets
(
data_path
,
image_size
=
224
):
def
build_train_valid_datasets
(
data_path
,
image_size
=
224
):
args
=
get_args
()
args
=
get_args
()
if
args
.
vision_pretraining_type
==
'classify'
:
train_transform
=
ClassificationTransform
(
image_size
)
val_transform
=
ClassificationTransform
(
image_size
,
train
=
False
)
elif
args
.
vision_pretraining_type
==
'inpaint'
:
train_transform
=
InpaintingTransform
(
image_size
,
train
=
False
)
val_transform
=
InpaintingTransform
(
image_size
,
train
=
False
)
elif
args
.
vision_pretraining_type
==
'dino'
:
train_transform
=
DinoTransform
(
image_size
,
train
=
True
)
val_transform
=
ClassificationTransform
(
image_size
,
train
=
False
)
else
:
raise
Exception
(
'{} vit pretraining type is not supported.'
.
format
(
args
.
vit_pretraining_type
))
train_transform
=
ClassificationTransform
(
image_size
)
train_transform
=
ClassificationTransform
(
image_size
)
val_transform
=
ClassificationTransform
(
image_size
,
train
=
False
)
val_transform
=
ClassificationTransform
(
image_size
,
train
=
False
)
# training dataset
# training dataset
train_data_path
=
data_path
[
0
]
train_data_path
=
data_path
[
0
]
if
len
(
data_path
)
<=
2
else
data_path
[
2
]
#TODO VIJAY
train_data
=
ImageFolder
(
train_data
=
ImageFolder
(
root
=
train_data_path
,
root
=
train_data_path
,
transform
=
train_transform
,
transform
=
train_transform
,
...
...
megatron/model/vision/classification.py
View file @
01a82723
...
@@ -19,6 +19,8 @@ import torch
...
@@ -19,6 +19,8 @@ import torch
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.vision.vit_backbone
import
VitBackbone
,
VitMlpHead
from
megatron.model.vision.vit_backbone
import
VitBackbone
,
VitMlpHead
from
megatron.model.vision.mit_backbone
import
mit_b3_avg
from
megatron.model.vision.utils
import
trunc_normal_
from
megatron.model.module
import
MegatronModule
from
megatron.model.module
import
MegatronModule
class
VitClassificationModel
(
MegatronModule
):
class
VitClassificationModel
(
MegatronModule
):
...
@@ -61,3 +63,35 @@ class VitClassificationModel(MegatronModule):
...
@@ -61,3 +63,35 @@ class VitClassificationModel(MegatronModule):
hidden_states
=
self
.
head
(
hidden_states
)
hidden_states
=
self
.
head
(
hidden_states
)
return
hidden_states
return
hidden_states
class
MitClassificationModel
(
MegatronModule
):
"""Mix vision Transformer Model."""
def
__init__
(
self
,
num_classes
pre_process
=
True
,
post_process
=
True
):
super
(
MitClassificationModel
,
self
).
__init__
()
args
=
get_args
()
self
.
hidden_size
=
args
.
hidden_size
self
.
num_classes
=
num_classes
self
.
backbone
=
mit_b3_avg
()
self
.
head
=
torch
.
nn
.
Linear
(
512
,
num_classes
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
torch
.
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
torch
.
nn
.
Linear
)
and
m
.
bias
is
not
None
:
torch
.
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
pass
def
forward
(
self
,
input
):
hidden_states
=
self
.
backbone
(
input
)
hidden_states
=
self
.
head
(
hidden_states
)
return
hidden_states
megatron/model/vision/dino.py
0 → 100644
View file @
01a82723
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
# copied from https://github.com/facebookresearch/dino/blob/main/main_dino.py
# reworked/refactored some parts to make it run in Megatron.
import
math
import
apex
import
einops
import
torch
import
numpy
as
np
import
torch.nn.functional
as
F
from
megatron
import
get_args
,
print_rank_0
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.vision.vit_backbone
import
VitBackbone
from
megatron.model.module
import
MegatronModule
from
megatron.utils
import
print_tensor_min_max_norm
as
pt
from
megatron.model.vision.utils
import
trunc_normal_
from
megatron.model.vision.mit_backbone
import
mit_b5_avg
from
megatron.model.vision.esvit_swin_backbone
import
get_swin
from
megatron.model.vision.av_cam_trunk
import
get_av_cam_trunk
class
DINOLoss
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
out_dim
,
ncrops
,
warmup_teacher_temp
,
teacher_temp
,
warmup_teacher_temp_epochs
,
nepochs
,
student_temp
=
0.1
,
center_momentum
=
0.9
):
super
().
__init__
()
self
.
student_temp
=
student_temp
self
.
center_momentum
=
center_momentum
self
.
ncrops
=
ncrops
self
.
register_buffer
(
"center"
,
torch
.
zeros
(
1
,
out_dim
))
# we apply a warm up for the teacher temperature because
# a too high temperature makes the training instable at the beginning
self
.
teacher_temp_schedule
=
np
.
concatenate
((
np
.
linspace
(
warmup_teacher_temp
,
teacher_temp
,
warmup_teacher_temp_epochs
),
np
.
ones
(
nepochs
-
warmup_teacher_temp_epochs
)
*
teacher_temp
))
self
.
teacher_temp
=
teacher_temp
def
forward
(
self
,
student_output
,
teacher_output
,
iteration
):
"""
Cross-entropy between softmax outputs of the teacher
and student network.
"""
args
=
get_args
()
student_out
=
student_output
/
self
.
student_temp
student_out
=
student_out
.
chunk
(
self
.
ncrops
)
epoch
=
iteration
//
args
.
iter_per_epoch
# teacher centering and sharpening
temp
=
self
.
teacher_temp_schedule
[
epoch
]
teacher_out
=
F
.
softmax
((
teacher_output
-
self
.
center
)
/
temp
,
dim
=-
1
)
teacher_out
=
teacher_out
.
detach
().
chunk
(
2
)
total_loss
=
0
n_loss_terms
=
0
for
iq
,
q
in
enumerate
(
teacher_out
):
for
v
in
range
(
len
(
student_out
)):
if
v
==
iq
:
# we skip cases where student and teacher operate on the same view
continue
loss
=
torch
.
sum
(
-
q
*
F
.
log_softmax
(
student_out
[
v
],
dim
=-
1
),
dim
=-
1
)
total_loss
+=
loss
.
mean
()
n_loss_terms
+=
1
total_loss
/=
n_loss_terms
self
.
update_center
(
teacher_output
)
return
total_loss
@
torch
.
no_grad
()
def
update_center
(
self
,
teacher_output
):
"""
Update center used for teacher output.
"""
batch_center
=
torch
.
sum
(
teacher_output
,
dim
=
0
,
keepdim
=
True
)
torch
.
distributed
.
all_reduce
(
batch_center
)
batch_center
=
batch_center
/
(
len
(
teacher_output
)
*
torch
.
distributed
.
get_world_size
())
self
.
center
=
self
.
center
*
self
.
center_momentum
+
batch_center
*
(
1
-
self
.
center_momentum
)
class
DINOHead
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
out_dim
,
norm_last_layer
=
True
,
nlayers
=
3
):
super
().
__init__
()
args
=
get_args
()
hidden_dim
=
args
.
dino_head_hidden_size
bottleneck_dim
=
args
.
dino_bottleneck_size
nlayers
=
max
(
nlayers
,
1
)
if
nlayers
==
1
:
self
.
mlp
=
torch
.
nn
.
Linear
(
in_dim
,
bottleneck_dim
)
else
:
layers
=
[
torch
.
nn
.
Linear
(
in_dim
,
hidden_dim
)]
layers
.
append
(
torch
.
nn
.
GELU
())
for
_
in
range
(
nlayers
-
2
):
layers
.
append
(
torch
.
nn
.
Linear
(
hidden_dim
,
hidden_dim
))
layers
.
append
(
torch
.
nn
.
GELU
())
layers
.
append
(
torch
.
nn
.
Linear
(
hidden_dim
,
bottleneck_dim
))
self
.
mlp
=
torch
.
nn
.
Sequential
(
*
layers
)
self
.
apply
(
self
.
_init_weights
)
self
.
last_layer
=
torch
.
nn
.
utils
.
weight_norm
(
torch
.
nn
.
Linear
(
bottleneck_dim
,
out_dim
,
bias
=
False
))
self
.
last_layer
.
weight_g
.
data
.
fill_
(
1
)
if
norm_last_layer
:
self
.
last_layer
.
weight_g
.
requires_grad
=
False
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
torch
.
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
torch
.
nn
.
Linear
)
and
m
.
bias
is
not
None
:
torch
.
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
forward
(
self
,
x
):
x
=
self
.
mlp
(
x
)
x
=
torch
.
nn
.
functional
.
normalize
(
x
,
dim
=-
1
,
p
=
2
)
x
=
self
.
last_layer
(
x
)
return
x
class
MultiCropWrapper
(
MegatronModule
):
"""
Perform forward pass separately on each resolution input.
The inputs corresponding to a single resolution are clubbed and single
forward is run on the same resolution inputs. Hence we do several
forward passes = number of different resolutions used. We then
concatenate all the output features and run the head forward on these
concatenated features.
"""
def
__init__
(
self
,
backbone
,
head
):
super
(
MultiCropWrapper
,
self
).
__init__
()
# disable layers dedicated to ImageNet labels classification
#backbone.fc, backbone.head = torch.nn.Identity(), torch.nn.Identity()
self
.
backbone
=
backbone
self
.
head
=
head
def
forward
(
self
,
x
):
# convert to list
if
not
isinstance
(
x
,
list
):
x
=
[
x
]
idx_crops
=
torch
.
cumsum
(
torch
.
unique_consecutive
(
torch
.
tensor
([
inp
.
shape
[
-
1
]
for
inp
in
x
]),
return_counts
=
True
,
)[
1
],
0
)
start_idx
=
0
for
end_idx
in
idx_crops
:
_out
=
self
.
backbone
(
torch
.
cat
(
x
[
start_idx
:
end_idx
]))
if
start_idx
==
0
:
output
=
_out
else
:
output
=
torch
.
cat
((
output
,
_out
))
start_idx
=
end_idx
# Run the head forward on the concatenated features.
if
self
.
training
:
return
self
.
head
(
output
)
else
:
return
output
def
cosine_scheduler
(
base_value
,
final_value
,
epochs
,
niter_per_ep
,
warmup_epochs
=
0
,
start_warmup_value
=
0
):
warmup_schedule
=
np
.
array
([])
warmup_iters
=
warmup_epochs
*
niter_per_ep
if
warmup_epochs
>
0
:
warmup_schedule
=
\
np
.
linspace
(
start_warmup_value
,
base_value
,
warmup_iters
)
iters
=
np
.
arange
(
epochs
*
niter_per_ep
-
warmup_iters
)
schedule
=
final_value
+
0.5
*
(
base_value
-
final_value
)
\
*
(
1
+
np
.
cos
(
np
.
pi
*
iters
/
len
(
iters
)))
schedule
=
np
.
concatenate
((
warmup_schedule
,
schedule
))
assert
len
(
schedule
)
==
epochs
*
niter_per_ep
return
schedule
def
get_student_backbone_and_num_features
(
pre_process
=
True
,
post_process
=
True
):
args
=
get_args
()
if
args
.
vision_backbone_type
==
'vit'
:
student
=
VitBackbone
(
pre_process
=
pre_process
,
post_process
=
post_process
,
drop_path_rate
=
0.1
,
single_token_output
=
True
)
num_features
=
args
.
hidden_size
elif
args
.
vision_backbone_type
==
'mit'
:
student
=
mit_b5_avg
(
drop_path_rate
=
0.1
)
num_features
=
512
elif
args
.
vision_backbone_type
==
'swin'
:
student
=
get_swin
()
num_features
=
student
.
num_features
else
:
raise
Exception
(
'{} vision backbone is not supported.'
.
format
(
args
.
vision_backbone_type
))
return
student
,
num_features
def
get_teacher_backbone_and_num_features
(
pre_process
=
True
,
post_process
=
True
):
args
=
get_args
()
if
args
.
vision_backbone_type
==
'vit'
:
teacher
=
VitBackbone
(
pre_process
=
pre_process
,
post_process
=
post_process
,
single_token_output
=
True
)
num_features
=
args
.
hidden_size
elif
args
.
vision_backbone_type
==
'mit'
:
teacher
=
mit_b5_avg
(
drop_path_rate
=
0.0
)
num_features
=
512
elif
args
.
vision_backbone_type
==
'swin'
:
teacher
=
get_swin
(
is_teacher
=
True
)
num_features
=
teacher
.
num_features
else
:
raise
Exception
(
'{} vision backbone is not supported.'
.
format
(
args
.
vision_backbone_type
))
return
teacher
,
num_features
class
DINOPretrainModel
(
MegatronModule
):
def
__init__
(
self
,
pre_process
=
True
,
post_process
=
True
):
super
(
DINOPretrainModel
,
self
).
__init__
()
args
=
get_args
()
self
.
out_dim
=
65536
self
.
dino_loss
=
DINOLoss
(
self
.
out_dim
,
args
.
dino_local_crops_number
+
2
,
args
.
dino_warmup_teacher_temp
,
args
.
dino_teacher_temp
,
args
.
dino_warmup_teacher_temp_epochs
,
300
,
)
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
momentum_teacher
=
0.996
student_backbone
,
num_features
=
\
get_student_backbone_and_num_features
(
pre_process
,
post_process
)
self
.
student
=
MultiCropWrapper
(
student_backbone
,
DINOHead
(
num_features
,
self
.
out_dim
,
norm_last_layer
=
args
.
dino_norm_last_layer
)
)
self
.
momentum_schedule
=
cosine_scheduler
(
self
.
momentum_teacher
,
1
,
args
.
train_iters
//
args
.
iter_per_epoch
,
args
.
iter_per_epoch
)
teacher_backbone
,
num_features
=
\
get_teacher_backbone_and_num_features
(
pre_process
,
post_process
)
self
.
teacher
=
MultiCropWrapper
(
teacher_backbone
,
DINOHead
(
num_features
,
self
.
out_dim
)
)
self
.
teacher
.
load_state_dict
(
self
.
student
.
state_dict
())
for
p
in
self
.
teacher
.
parameters
():
if
hasattr
(
p
,
"requires_grad"
)
and
p
.
requires_grad
is
not
None
:
p
.
requires_grad
=
False
def
set_input_tensor
(
self
,
tensor
):
pass
def
forward
(
self
,
input
):
student_output
=
None
if
self
.
training
:
student_output
=
self
.
student
(
input
)
teacher_output
=
self
.
teacher
(
input
[:
2
])
else
:
teacher_output
=
self
.
teacher
(
input
)
return
student_output
,
teacher_output
def
cancel_gradients_last_layer
(
self
,
iteration
):
args
=
get_args
()
epoch
=
iteration
//
args
.
iter_per_epoch
if
epoch
<
args
.
dino_freeze_last_layer
:
for
n
,
p
in
self
.
student
.
named_parameters
():
if
"last_layer"
in
n
:
p
.
grad
=
None
def
update_momentum
(
self
,
iteration
):
with
torch
.
no_grad
():
m
=
self
.
momentum_schedule
[
iteration
]
for
param_q
,
param_k
in
zip
(
self
.
student
.
parameters
(),
self
.
teacher
.
parameters
()):
param_k
.
data
.
mul_
(
m
).
add_
((
1
-
m
)
*
param_q
.
detach
().
data
)
megatron/model/vision/esvit_swin_backbone.py
0 → 100644
View file @
01a82723
# Copyright (c) 2021 Microsoft
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Modified by Chunyuan Li (chunyl@microsoft.com)
# Swin Transformer
# --------------------------------------------------------
import
os
import
logging
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
functools
import
partial
import
torch.distributed
as
dist
from
megatron.model.vision.utils
import
DropPath
,
trunc_normal_
from
megatron
import
get_args
from
megatron.model
import
LayerNorm
import
numpy
as
np
from
math
import
sqrt
class
Mlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
drop
=
0.
):
super
(
Mlp
,
self
).
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
act
=
act_layer
()
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop
=
nn
.
Dropout
(
drop
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop
(
x
)
return
x
def
window_partition
(
x
,
window_size
):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B
,
H
,
W
,
C
=
x
.
shape
x
=
x
.
view
(
B
,
H
//
window_size
,
window_size
,
W
//
window_size
,
window_size
,
C
)
windows
=
x
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
contiguous
().
view
(
-
1
,
window_size
,
window_size
,
C
)
return
windows
def
window_reverse
(
windows
,
window_size
,
H
,
W
):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B
=
int
(
windows
.
shape
[
0
]
/
(
H
*
W
/
window_size
/
window_size
))
x
=
windows
.
view
(
B
,
H
//
window_size
,
W
//
window_size
,
window_size
,
window_size
,
-
1
)
x
=
x
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
contiguous
().
view
(
B
,
H
,
W
,
-
1
)
return
x
class
WindowAttention
(
nn
.
Module
):
r
"""Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def
__init__
(
self
,
dim
,
window_size
,
num_heads
,
qkv_bias
=
True
,
qk_scale
=
None
,
attn_drop
=
0.
,
proj_drop
=
0.
):
super
(
WindowAttention
,
self
).
__init__
()
self
.
dim
=
dim
self
.
window_size
=
window_size
# Wh, Ww
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
self
.
scale
=
qk_scale
or
head_dim
**
-
0.5
# define a parameter table of relative position bias
self
.
relative_position_bias_table
=
nn
.
Parameter
(
torch
.
zeros
((
2
*
window_size
[
0
]
-
1
)
*
(
2
*
window_size
[
1
]
-
1
),
num_heads
))
# 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h
=
torch
.
arange
(
self
.
window_size
[
0
])
coords_w
=
torch
.
arange
(
self
.
window_size
[
1
])
coords
=
torch
.
stack
(
torch
.
meshgrid
([
coords_h
,
coords_w
]))
# 2, Wh, Ww
coords_flatten
=
torch
.
flatten
(
coords
,
1
)
# 2 Wh*Ww
relative_coords
=
coords_flatten
[:,
:,
None
]
-
coords_flatten
[:,
None
,
:]
# 2, Wh*Ww, Wh*Ww
relative_coords
=
relative_coords
.
permute
(
1
,
2
,
0
).
contiguous
()
# Wh*Ww, Wh*Ww, 2
relative_coords
[:,
:,
0
]
+=
self
.
window_size
[
0
]
-
1
# shift to start from 0
relative_coords
[:,
:,
1
]
+=
self
.
window_size
[
1
]
-
1
relative_coords
[:,
:,
0
]
*=
2
*
self
.
window_size
[
1
]
-
1
relative_position_index
=
relative_coords
.
sum
(
-
1
)
# Wh*Ww, Wh*Ww
self
.
register_buffer
(
"relative_position_index"
,
relative_position_index
)
self
.
qkv
=
nn
.
Linear
(
dim
,
dim
*
3
,
bias
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
trunc_normal_
(
self
.
relative_position_bias_table
,
std
=
.
02
)
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
def
forward
(
self
,
x
,
mask
=
None
):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
).
reshape
(
B_
,
N
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
# make torchscript happy (cannot use tensor as tuple)
q
=
q
*
self
.
scale
attn
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
relative_position_bias
=
self
.
relative_position_bias_table
[
self
.
relative_position_index
.
view
(
-
1
)].
view
(
self
.
window_size
[
0
]
*
self
.
window_size
[
1
],
self
.
window_size
[
0
]
*
self
.
window_size
[
1
],
-
1
)
# Wh*Ww,Wh*Ww,nH
relative_position_bias
=
relative_position_bias
.
permute
(
2
,
0
,
1
).
contiguous
()
# nH, Wh*Ww, Wh*Ww
attn
=
attn
+
relative_position_bias
.
unsqueeze
(
0
)
if
mask
is
not
None
:
nW
=
mask
.
shape
[
0
]
attn
=
attn
.
view
(
B_
//
nW
,
nW
,
self
.
num_heads
,
N
,
N
)
+
mask
.
unsqueeze
(
1
).
unsqueeze
(
0
).
type
(
attn
.
type
())
attn
=
attn
.
view
(
-
1
,
self
.
num_heads
,
N
,
N
)
attn
=
self
.
softmax
(
attn
)
else
:
attn
=
self
.
softmax
(
attn
)
attn_out
=
attn
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B_
,
N
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
,
attn_out
def
extra_repr
(
self
)
->
str
:
return
f
'dim=
{
self
.
dim
}
, window_size=
{
self
.
window_size
}
, num_heads=
{
self
.
num_heads
}
'
def
flops
(
self
,
N
):
# calculate flops for 1 window with token length of N
flops
=
0
# qkv = self.qkv(x)
flops
+=
N
*
self
.
dim
*
3
*
self
.
dim
# attn = (q @ k.transpose(-2, -1))
flops
+=
self
.
num_heads
*
N
*
(
self
.
dim
//
self
.
num_heads
)
*
N
# x = (attn @ v)
flops
+=
self
.
num_heads
*
N
*
N
*
(
self
.
dim
//
self
.
num_heads
)
# x = self.proj(x)
flops
+=
N
*
self
.
dim
*
self
.
dim
return
flops
@
staticmethod
def
compute_macs
(
module
,
input
,
output
):
B
,
N
,
C
=
input
[
0
].
shape
module
.
__flops__
+=
module
.
flops
(
N
)
*
B
class
SwinTransformerBlock
(
nn
.
Module
):
r
"""Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def
__init__
(
self
,
dim
,
input_resolution
,
num_heads
,
window_size
=
7
,
shift_size
=
0
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
act_layer
=
nn
.
GELU
,
norm_layer
=
nn
.
LayerNorm
):
super
().
__init__
()
self
.
dim
=
dim
self
.
input_resolution
=
input_resolution
self
.
num_heads
=
num_heads
self
.
window_size
=
window_size
self
.
shift_size
=
shift_size
self
.
mlp_ratio
=
mlp_ratio
if
min
(
self
.
input_resolution
)
<=
self
.
window_size
:
# if window size is larger than input resolution, we don't partition windows
self
.
shift_size
=
0
self
.
window_size
=
min
(
self
.
input_resolution
)
assert
0
<=
self
.
shift_size
<
self
.
window_size
,
"shift_size must in 0-window_size"
self
.
norm1
=
norm_layer
(
dim
)
self
.
attn
=
WindowAttention
(
dim
,
window_size
=
(
self
.
window_size
,
self
.
window_size
),
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
)
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
else
nn
.
Identity
()
self
.
norm2
=
norm_layer
(
dim
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
mlp
=
Mlp
(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
act_layer
,
drop
=
drop
)
self
.
H
=
input_resolution
[
0
]
self
.
W
=
input_resolution
[
1
]
self
.
attn_mask_dict
=
{}
def
create_attn_mask
(
self
,
H
,
W
):
# calculate attention mask for SW-MSA
Hp
=
int
(
np
.
ceil
(
H
/
self
.
window_size
))
*
self
.
window_size
Wp
=
int
(
np
.
ceil
(
W
/
self
.
window_size
))
*
self
.
window_size
img_mask
=
torch
.
zeros
((
1
,
Hp
,
Wp
,
1
))
# 1 Hp Wp 1
h_slices
=
(
slice
(
0
,
-
self
.
window_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
slice
(
-
self
.
shift_size
,
None
))
w_slices
=
(
slice
(
0
,
-
self
.
window_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
slice
(
-
self
.
shift_size
,
None
))
cnt
=
0
for
h
in
h_slices
:
for
w
in
w_slices
:
img_mask
[:,
h
,
w
,
:]
=
cnt
cnt
+=
1
mask_windows
=
window_partition
(
img_mask
,
self
.
window_size
)
# nW, window_size, window_size, 1
mask_windows
=
mask_windows
.
view
(
-
1
,
self
.
window_size
*
self
.
window_size
)
attn_mask
=
mask_windows
.
unsqueeze
(
1
)
-
mask_windows
.
unsqueeze
(
2
)
attn_mask
=
attn_mask
.
masked_fill
(
attn_mask
!=
0
,
float
(
-
100.0
)).
masked_fill
(
attn_mask
==
0
,
float
(
0.0
))
return
attn_mask
def
forward
(
self
,
x
):
B
,
L
,
C
=
x
.
shape
H
=
int
(
sqrt
(
L
))
W
=
H
shortcut
=
x
x
=
self
.
norm1
(
x
)
x
=
x
.
view
(
B
,
H
,
W
,
C
)
# pad feature maps to multiples of window size
pad_l
=
pad_t
=
0
pad_r
=
(
self
.
window_size
-
W
%
self
.
window_size
)
%
self
.
window_size
pad_b
=
(
self
.
window_size
-
H
%
self
.
window_size
)
%
self
.
window_size
x
=
F
.
pad
(
x
,
(
0
,
0
,
pad_l
,
pad_r
,
pad_t
,
pad_b
))
_
,
Hp
,
Wp
,
_
=
x
.
shape
# cyclic shift
if
self
.
shift_size
>
0
:
shifted_x
=
torch
.
roll
(
x
,
shifts
=
(
-
self
.
shift_size
,
-
self
.
shift_size
),
dims
=
(
1
,
2
))
if
H
in
self
.
attn_mask_dict
.
keys
():
attn_mask
=
self
.
attn_mask_dict
[
H
]
else
:
self
.
attn_mask_dict
[
H
]
=
self
.
create_attn_mask
(
self
.
H
,
self
.
W
).
to
(
x
.
device
)
attn_mask
=
self
.
attn_mask_dict
[
H
]
else
:
shifted_x
=
x
attn_mask
=
None
# partition windows
x_windows
=
window_partition
(
shifted_x
,
self
.
window_size
)
# nW*B, window_size, window_size, C
x_windows
=
x_windows
.
view
(
-
1
,
self
.
window_size
*
self
.
window_size
,
C
)
# nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows
,
attn
=
self
.
attn
(
x_windows
,
attn_mask
)
# nW*B, window_size*window_size, C
# merge windows
attn_windows
=
attn_windows
.
view
(
-
1
,
self
.
window_size
,
self
.
window_size
,
C
)
shifted_x
=
window_reverse
(
attn_windows
,
self
.
window_size
,
Hp
,
Wp
)
# B H' W' C
# reverse cyclic shift
if
self
.
shift_size
>
0
:
x
=
torch
.
roll
(
shifted_x
,
shifts
=
(
self
.
shift_size
,
self
.
shift_size
),
dims
=
(
1
,
2
))
else
:
x
=
shifted_x
if
pad_r
>
0
or
pad_b
>
0
:
x
=
x
[:,
:
H
,
:
W
,
:].
contiguous
()
x
=
x
.
view
(
B
,
H
*
W
,
C
)
# FFN
x
=
shortcut
+
self
.
drop_path
(
x
)
x
=
x
+
self
.
drop_path
(
self
.
mlp
(
self
.
norm2
(
x
)))
return
x
,
attn
def
extra_repr
(
self
)
->
str
:
return
f
"dim=
{
self
.
dim
}
, input_resolution=
{
self
.
input_resolution
}
, num_heads=
{
self
.
num_heads
}
, "
\
f
"window_size=
{
self
.
window_size
}
, shift_size=
{
self
.
shift_size
}
mlp_ratio=
{
self
.
mlp_ratio
}
"
def
flops
(
self
):
flops
=
0
H
,
W
=
self
.
input_resolution
# norm1
flops
+=
self
.
dim
*
H
*
W
# W-MSA/SW-MSA
nW
=
H
*
W
/
self
.
window_size
/
self
.
window_size
flops
+=
nW
*
self
.
attn
.
flops
(
self
.
window_size
*
self
.
window_size
)
# mlp
flops
+=
2
*
H
*
W
*
self
.
dim
*
self
.
dim
*
self
.
mlp_ratio
# norm2
flops
+=
self
.
dim
*
H
*
W
return
flops
class
PatchMerging
(
nn
.
Module
):
r
"""Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def
__init__
(
self
,
input_resolution
,
dim
,
norm_layer
=
nn
.
LayerNorm
):
super
().
__init__
()
self
.
input_resolution
=
input_resolution
self
.
dim
=
dim
self
.
reduction
=
nn
.
Linear
(
4
*
dim
,
2
*
dim
,
bias
=
False
)
self
.
norm
=
norm_layer
(
4
*
dim
)
def
forward
(
self
,
x
):
""" Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
B
,
L
,
C
=
x
.
shape
H
=
int
(
sqrt
(
L
))
W
=
H
x
=
x
.
view
(
B
,
H
,
W
,
C
)
# padding
pad_input
=
(
H
%
2
==
1
)
or
(
W
%
2
==
1
)
if
pad_input
:
x
=
F
.
pad
(
x
,
(
0
,
0
,
0
,
W
%
2
,
0
,
H
%
2
))
x0
=
x
[:,
0
::
2
,
0
::
2
,
:]
# B H/2 W/2 C
x1
=
x
[:,
1
::
2
,
0
::
2
,
:]
# B H/2 W/2 C
x2
=
x
[:,
0
::
2
,
1
::
2
,
:]
# B H/2 W/2 C
x3
=
x
[:,
1
::
2
,
1
::
2
,
:]
# B H/2 W/2 C
x
=
torch
.
cat
([
x0
,
x1
,
x2
,
x3
],
-
1
)
# B H/2 W/2 4*C
x
=
x
.
view
(
B
,
-
1
,
4
*
C
)
# B H/2*W/2 4*C
x
=
self
.
norm
(
x
)
x
=
self
.
reduction
(
x
)
return
x
def
extra_repr
(
self
)
->
str
:
return
f
"input_resolution=
{
self
.
input_resolution
}
, dim=
{
self
.
dim
}
"
def
flops
(
self
):
H
,
W
=
self
.
input_resolution
flops
=
H
*
W
*
self
.
dim
flops
+=
(
H
//
2
)
*
(
W
//
2
)
*
4
*
self
.
dim
*
2
*
self
.
dim
return
flops
class
BasicLayer
(
nn
.
Module
):
"""A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
"""
def
__init__
(
self
,
dim
,
input_resolution
,
depth
,
num_heads
,
window_size
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
norm_layer
=
nn
.
LayerNorm
,
downsample
=
None
):
super
().
__init__
()
self
.
dim
=
dim
self
.
input_resolution
=
input_resolution
self
.
depth
=
depth
self
.
blocks
=
nn
.
ModuleList
([
SwinTransformerBlock
(
dim
=
dim
,
input_resolution
=
input_resolution
,
num_heads
=
num_heads
,
window_size
=
window_size
,
shift_size
=
0
if
(
i
%
2
==
0
)
else
window_size
//
2
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop
,
attn_drop
=
attn_drop
,
drop_path
=
drop_path
[
i
]
if
isinstance
(
drop_path
,
list
)
else
drop_path
,
norm_layer
=
norm_layer
)
for
i
in
range
(
depth
)])
if
downsample
is
not
None
:
self
.
downsample
=
downsample
(
input_resolution
,
dim
=
dim
,
norm_layer
=
norm_layer
)
else
:
self
.
downsample
=
None
def
forward
(
self
,
x
):
for
blk
in
self
.
blocks
:
x
,
_
=
blk
(
x
)
if
self
.
downsample
is
not
None
:
x
=
self
.
downsample
(
x
)
return
x
def
forward_with_features
(
self
,
x
):
fea
=
[]
for
blk
in
self
.
blocks
:
x
,
_
=
blk
(
x
)
fea
.
append
(
x
)
if
self
.
downsample
is
not
None
:
x
=
self
.
downsample
(
x
)
return
x
,
fea
def
forward_with_attention
(
self
,
x
):
attns
=
[]
for
blk
in
self
.
blocks
:
x
,
attn
=
blk
(
x
)
attns
.
append
(
attn
)
if
self
.
downsample
is
not
None
:
x
=
self
.
downsample
(
x
)
return
x
,
attns
def
extra_repr
(
self
)
->
str
:
return
f
"dim=
{
self
.
dim
}
, input_resolution=
{
self
.
input_resolution
}
, depth=
{
self
.
depth
}
"
def
flops
(
self
):
flops
=
0
for
blk
in
self
.
blocks
:
flops
+=
blk
.
flops
()
if
self
.
downsample
is
not
None
:
flops
+=
self
.
downsample
.
flops
()
return
flops
class
PatchEmbed
(
nn
.
Module
):
""" Image to Patch Embedding
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
embed_dim
=
768
,
norm_layer
=
None
):
super
().
__init__
()
img_size
=
(
img_size
,
img_size
)
patch_size
=
(
patch_size
,
patch_size
)
patches_resolution
=
[
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
]]
self
.
img_size
=
img_size
self
.
patch_size
=
patch_size
self
.
patches_resolution
=
patches_resolution
self
.
num_patches
=
patches_resolution
[
0
]
*
patches_resolution
[
1
]
self
.
in_chans
=
in_chans
self
.
embed_dim
=
embed_dim
self
.
proj
=
nn
.
Conv2d
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
)
if
norm_layer
is
not
None
:
self
.
norm
=
norm_layer
(
embed_dim
)
else
:
self
.
norm
=
None
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
x
=
self
.
proj
(
x
).
flatten
(
2
).
transpose
(
1
,
2
)
# B Ph*Pw C
if
self
.
norm
is
not
None
:
x
=
self
.
norm
(
x
)
return
x
def
flops
(
self
):
Ho
,
Wo
=
self
.
patches_resolution
flops
=
Ho
*
Wo
*
self
.
embed_dim
*
self
.
in_chans
*
(
self
.
patch_size
[
0
]
*
self
.
patch_size
[
1
])
if
self
.
norm
is
not
None
:
flops
+=
Ho
*
Wo
*
self
.
embed_dim
return
flops
class
SwinTransformer
(
nn
.
Module
):
r
""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
img_size (int | tuple(int)): Input image size.
patch_size (int | tuple(int)): Patch size.
in_chans (int): Number of input channels.
num_classes (int): Number of classes for classification head.
embed_dim (int): Embedding dimension.
depths (tuple(int)): Depth of Swin Transformer layers.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
drop_rate (float): Dropout rate.
attn_drop_rate (float): Attention dropout rate.
drop_path_rate (float): Stochastic depth rate.
norm_layer (nn.Module): normalization layer.
ape (bool): If True, add absolute position embedding to the patch embedding.
patch_norm (bool): If True, add normalization after patch embedding.
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
4
,
in_chans
=
3
,
num_classes
=
1000
,
embed_dim
=
96
,
depths
=
[
2
,
2
,
6
,
2
],
num_heads
=
[
3
,
6
,
12
,
24
],
window_size
=
7
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.1
,
norm_layer
=
nn
.
LayerNorm
,
ape
=
False
,
patch_norm
=
True
,
**
kwargs
):
super
().
__init__
()
self
.
num_classes
=
num_classes
self
.
num_layers
=
len
(
depths
)
self
.
embed_dim
=
embed_dim
self
.
ape
=
ape
self
.
patch_norm
=
patch_norm
self
.
num_features
=
int
(
embed_dim
*
2
**
(
self
.
num_layers
-
1
))
self
.
mlp_ratio
=
mlp_ratio
self
.
patch_embed
=
PatchEmbed
(
img_size
=
img_size
,
patch_size
=
patch_size
,
in_chans
=
in_chans
,
embed_dim
=
embed_dim
,
norm_layer
=
norm_layer
if
self
.
patch_norm
else
None
)
num_patches
=
self
.
patch_embed
.
num_patches
patches_resolution
=
self
.
patch_embed
.
patches_resolution
self
.
patches_resolution
=
patches_resolution
if
self
.
ape
:
self
.
absolute_pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_patches
,
embed_dim
))
trunc_normal_
(
self
.
absolute_pos_embed
,
std
=
.
02
)
self
.
pos_drop
=
nn
.
Dropout
(
p
=
drop_rate
)
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
sum
(
depths
))]
# stochastic depth decay rule
self
.
layers
=
nn
.
ModuleList
()
for
i_layer
in
range
(
self
.
num_layers
):
layer
=
BasicLayer
(
dim
=
int
(
embed_dim
*
2
**
i_layer
),
input_resolution
=
(
patches_resolution
[
0
]
//
(
2
**
i_layer
),
patches_resolution
[
1
]
//
(
2
**
i_layer
)),
depth
=
depths
[
i_layer
],
num_heads
=
num_heads
[
i_layer
],
window_size
=
window_size
,
mlp_ratio
=
self
.
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
sum
(
depths
[:
i_layer
]):
sum
(
depths
[:
i_layer
+
1
])],
norm_layer
=
norm_layer
,
downsample
=
PatchMerging
if
(
i_layer
<
self
.
num_layers
-
1
)
else
None
)
self
.
layers
.
append
(
layer
)
self
.
norm
=
norm_layer
(
self
.
num_features
)
self
.
avgpool
=
nn
.
AdaptiveAvgPool1d
(
1
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
@
torch
.
jit
.
ignore
def
no_weight_decay
(
self
):
return
{
'absolute_pos_embed'
}
@
torch
.
jit
.
ignore
def
no_weight_decay_keywords
(
self
):
# todo: to be implemented
return
{
'relative_position_bias_table'
}
def
forward
(
self
,
x
):
x
=
self
.
patch_embed
(
x
)
if
self
.
ape
:
x
=
x
+
self
.
absolute_pos_embed
x
=
self
.
pos_drop
(
x
)
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
x_region
=
self
.
norm
(
x
)
# B L C
x
=
self
.
avgpool
(
x_region
.
transpose
(
1
,
2
))
# B C 1
x
=
torch
.
flatten
(
x
,
1
)
return
x
def
forward_feature_maps
(
self
,
x
):
x
=
self
.
patch_embed
(
x
)
if
self
.
ape
:
x
=
x
+
self
.
absolute_pos_embed
x
=
self
.
pos_drop
(
x
)
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
x_grid
=
self
.
norm
(
x
)
# B L C
x
=
self
.
avgpool
(
x_grid
.
transpose
(
1
,
2
))
# B C 1
x
=
torch
.
flatten
(
x
,
1
)
return
x
,
x_grid
def
forward_selfattention
(
self
,
x
,
n
=
1
):
# n=1 return the last layer attn map; otherwise return attn maps in all layers
x
=
self
.
patch_embed
(
x
)
if
self
.
ape
:
x
=
x
+
self
.
absolute_pos_embed
x
=
self
.
pos_drop
(
x
)
if
n
==
1
:
return
self
.
forward_last_selfattention
(
x
)
else
:
return
self
.
forward_all_selfattention
(
x
)
def
forward_last_selfattention
(
self
,
x
):
for
i
,
layer
in
enumerate
(
self
.
layers
):
if
i
<
len
(
self
.
layers
)
-
1
:
x
=
layer
(
x
)
else
:
x
,
attns
=
layer
.
forward_with_attention
(
x
)
return
attns
[
-
1
]
def
forward_all_selfattention
(
self
,
x
):
attn_out
=
[]
for
layer
in
self
.
layers
:
x
,
attns
=
layer
.
forward_with_attention
(
x
)
attn_out
+=
attns
return
attn_out
def
forward_return_n_last_blocks
(
self
,
x
,
n
=
1
,
return_patch_avgpool
=
False
,
depth
=
[]):
num_blks
=
sum
(
depth
)
start_idx
=
num_blks
-
n
sum_cur
=
0
for
i
,
d
in
enumerate
(
depth
):
sum_cur_new
=
sum_cur
+
d
if
start_idx
>=
sum_cur
and
start_idx
<
sum_cur_new
:
start_stage
=
i
start_blk
=
start_idx
-
sum_cur
sum_cur
=
sum_cur_new
x
=
self
.
patch_embed
(
x
)
if
self
.
ape
:
x
=
x
+
self
.
absolute_pos_embed
x
=
self
.
pos_drop
(
x
)
# we will return the averaged token features from the `n` last blocks
# note: there is no [CLS] token in Swin Transformer
output
=
[]
s
=
0
for
i
,
layer
in
enumerate
(
self
.
layers
):
x
,
fea
=
layer
.
forward_with_features
(
x
)
if
i
>=
start_stage
:
for
x_
in
fea
[
start_blk
:]:
if
i
==
len
(
self
.
layers
)
-
1
:
# use the norm in the last stage
x_
=
self
.
norm
(
x_
)
x_avg
=
torch
.
flatten
(
self
.
avgpool
(
x_
.
transpose
(
1
,
2
)),
1
)
# B C
# print(f'Stage {i}, x_avg {x_avg.shape}')
output
.
append
(
x_avg
)
start_blk
=
0
return
torch
.
cat
(
output
,
dim
=-
1
)
def
flops
(
self
):
flops
=
0
flops
+=
self
.
patch_embed
.
flops
()
for
i
,
layer
in
enumerate
(
self
.
layers
):
flops
+=
layer
.
flops
()
if
dist
.
get_rank
()
==
0
:
print
(
f
"GFLOPs layer_
{
i
}
:
{
layer
.
flops
()
/
1e9
}
"
)
flops
+=
self
.
num_features
*
self
.
patches_resolution
[
0
]
*
self
.
patches_resolution
[
1
]
//
(
2
**
self
.
num_layers
)
flops
+=
self
.
num_features
*
self
.
num_classes
return
flops
def
init_weights
(
self
,
pretrained
=
''
,
pretrained_layers
=
[],
verbose
=
True
):
if
os
.
path
.
isfile
(
pretrained
):
pretrained_dict
=
torch
.
load
(
pretrained
,
map_location
=
'cpu'
)
logging
.
info
(
f
'=> loading pretrained model
{
pretrained
}
'
)
model_dict
=
self
.
state_dict
()
pretrained_dict
=
{
k
:
v
for
k
,
v
in
pretrained_dict
.
items
()
if
k
in
model_dict
.
keys
()
}
need_init_state_dict
=
{}
for
k
,
v
in
pretrained_dict
.
items
():
need_init
=
(
k
.
split
(
'.'
)[
0
]
in
pretrained_layers
or
pretrained_layers
[
0
]
is
'*'
or
'relative_position_index'
not
in
k
or
'attn_mask'
not
in
k
)
if
need_init
:
if
verbose
:
logging
.
info
(
f
'=> init
{
k
}
from
{
pretrained
}
'
)
if
'relative_position_bias_table'
in
k
and
v
.
size
()
!=
model_dict
[
k
].
size
():
relative_position_bias_table_pretrained
=
v
relative_position_bias_table_current
=
model_dict
[
k
]
L1
,
nH1
=
relative_position_bias_table_pretrained
.
size
()
L2
,
nH2
=
relative_position_bias_table_current
.
size
()
if
nH1
!=
nH2
:
logging
.
info
(
f
"Error in loading
{
k
}
, passing"
)
else
:
if
L1
!=
L2
:
logging
.
info
(
'=> load_pretrained: resized variant: {} to {}'
.
format
((
L1
,
nH1
),
(
L2
,
nH2
))
)
S1
=
int
(
L1
**
0.5
)
S2
=
int
(
L2
**
0.5
)
relative_position_bias_table_pretrained_resized
=
torch
.
nn
.
functional
.
interpolate
(
relative_position_bias_table_pretrained
.
permute
(
1
,
0
).
view
(
1
,
nH1
,
S1
,
S1
),
size
=
(
S2
,
S2
),
mode
=
'bicubic'
)
v
=
relative_position_bias_table_pretrained_resized
.
view
(
nH2
,
L2
).
permute
(
1
,
0
)
if
'absolute_pos_embed'
in
k
and
v
.
size
()
!=
model_dict
[
k
].
size
():
absolute_pos_embed_pretrained
=
v
absolute_pos_embed_current
=
model_dict
[
k
]
_
,
L1
,
C1
=
absolute_pos_embed_pretrained
.
size
()
_
,
L2
,
C2
=
absolute_pos_embed_current
.
size
()
if
C1
!=
C1
:
logging
.
info
(
f
"Error in loading
{
k
}
, passing"
)
else
:
if
L1
!=
L2
:
logging
.
info
(
'=> load_pretrained: resized variant: {} to {}'
.
format
((
1
,
L1
,
C1
),
(
1
,
L2
,
C2
))
)
S1
=
int
(
L1
**
0.5
)
S2
=
int
(
L2
**
0.5
)
absolute_pos_embed_pretrained
=
absolute_pos_embed_pretrained
.
reshape
(
-
1
,
S1
,
S1
,
C1
)
absolute_pos_embed_pretrained
=
absolute_pos_embed_pretrained
.
permute
(
0
,
3
,
1
,
2
)
absolute_pos_embed_pretrained_resized
=
torch
.
nn
.
functional
.
interpolate
(
absolute_pos_embed_pretrained
,
size
=
(
S2
,
S2
),
mode
=
'bicubic'
)
v
=
absolute_pos_embed_pretrained_resized
.
permute
(
0
,
2
,
3
,
1
).
flatten
(
1
,
2
)
need_init_state_dict
[
k
]
=
v
self
.
load_state_dict
(
need_init_state_dict
,
strict
=
False
)
def
freeze_pretrained_layers
(
self
,
frozen_layers
=
[]):
for
name
,
module
in
self
.
named_modules
():
if
(
name
.
split
(
'.'
)[
0
]
in
frozen_layers
or
'.'
.
join
(
name
.
split
(
'.'
)[
0
:
2
])
in
frozen_layers
or
(
len
(
frozen_layers
)
>
0
and
frozen_layers
[
0
]
is
'*'
)
):
for
_name
,
param
in
module
.
named_parameters
():
param
.
requires_grad
=
False
logging
.
info
(
'=> set param {} requires grad to False'
.
format
(
name
)
)
for
name
,
param
in
self
.
named_parameters
():
if
(
name
.
split
(
'.'
)[
0
]
in
frozen_layers
or
(
len
(
frozen_layers
)
>
0
and
frozen_layers
[
0
]
is
'*'
)
and
param
.
requires_grad
is
True
):
param
.
requires_grad
=
False
logging
.
info
(
'=> set param {} requires grad to False'
.
format
(
name
)
)
return
self
def
get_swin
(
is_teacher
=
False
):
args
=
get_args
()
if
args
.
swin_type
==
"tiny"
:
embed_dim
=
96
depths
=
[
2
,
2
,
6
,
2
]
num_heads
=
[
3
,
6
,
12
,
24
]
drop_path_rate
=
0.1
elif
args
.
swin_type
==
'h3'
:
embed_dim
=
384
depths
=
[
2
,
2
,
18
,
2
]
num_heads
=
[
6
,
12
,
24
,
48
]
drop_path_rate
=
0.2
else
:
embed_dim
=
128
depths
=
[
2
,
2
,
18
,
2
]
num_heads
=
[
4
,
8
,
16
,
32
]
drop_path_rate
=
0.2
swin
=
SwinTransformer
(
img_size
=
224
,
in_chans
=
3
,
num_classes
=
1000
,
patch_size
=
4
,
embed_dim
=
embed_dim
,
depths
=
depths
,
num_heads
=
num_heads
,
window_size
=
7
,
mlp_ratio
=
4
,
qkv_bias
=
True
,
drop_rate
=
0
,
attn_drop_rate
=
0
,
drop_path_rate
=
(
0.0
if
is_teacher
else
drop_path_rate
),
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
ape
=
False
,
patch_norm
=
True
,
)
return
swin
megatron/model/vision/inpainting.py
0 → 100644
View file @
01a82723
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision Transformer(VIT) model."""
import
math
import
apex
import
einops
import
torch
import
torch.nn.functional
as
F
from
megatron
import
get_args
,
print_rank_0
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.vision.vit_backbone
import
VitBackbone
from
megatron.model.module
import
MegatronModule
from
megatron.model.vision.mit_backbone
import
mit_b3
from
megatron.model.vision.utils
import
resize
,
trunc_normal_
class
VitInpaintingModel
(
MegatronModule
):
def
__init__
(
self
,
pre_process
=
True
,
post_process
=
True
):
super
(
VitInpaintingModel
,
self
).
__init__
()
args
=
get_args
()
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
hidden_size
=
args
.
hidden_size
self
.
backbone
=
VitBackbone
(
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
class_token
=
False
,
)
self
.
patch_dim
=
args
.
patch_dim
self
.
img_h
=
args
.
img_h
self
.
img_w
=
args
.
img_w
self
.
seq_length
=
args
.
seq_length
# full mask
if
self
.
post_process
:
self
.
linear_decoder
=
get_linear_layer
(
self
.
hidden_size
,
self
.
backbone
.
flatten_dim
,
torch
.
nn
.
init
.
zeros_
)
def
set_input_tensor
(
self
,
input_tensor
):
self
.
backbone
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
input
):
hidden_states
=
self
.
backbone
(
input
)
if
not
self
.
post_process
:
return
hidden_states
decoded_output
=
self
.
linear_decoder
(
hidden_states
)
output
=
einops
.
rearrange
(
decoded_output
,
"b (h w) (p1 p2 c) -> b c (h p1) (w p2)"
,
p1
=
self
.
patch_dim
,
p2
=
self
.
patch_dim
,
h
=
self
.
img_h
//
self
.
patch_dim
,
w
=
self
.
img_w
//
self
.
patch_dim
,
)
return
output
class
MLP
(
torch
.
nn
.
Module
):
"""
Linear Embedding
"""
def
__init__
(
self
,
input_dim
=
2048
,
embed_dim
=
768
):
super
().
__init__
()
self
.
proj
=
torch
.
nn
.
Linear
(
input_dim
,
embed_dim
)
def
forward
(
self
,
x
):
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
x
=
self
.
proj
(
x
)
return
x
class
MitInpaintingModel
(
MegatronModule
):
"""Mix vision Transformer Model."""
def
__init__
(
self
,
pre_process
=
True
,
post_process
=
True
):
super
(
MitInpaintingModel
,
self
).
__init__
()
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
args
=
get_args
()
self
.
patch_dim
=
args
.
patch_dim
self
.
img_h
=
args
.
img_h
self
.
img_w
=
args
.
img_w
self
.
flatten_dim
=
self
.
patch_dim
*
self
.
patch_dim
*
3
self
.
backbone
=
mit_b3
()
self
.
in_channels
=
[
64
,
128
,
320
,
512
]
self
.
embedding_dim
=
768
c1_in_channels
,
c2_in_channels
,
c3_in_channels
,
c4_in_channels
=
self
.
in_channels
self
.
linear_c4
=
MLP
(
input_dim
=
c4_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
linear_c3
=
MLP
(
input_dim
=
c3_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
linear_c2
=
MLP
(
input_dim
=
c2_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
linear_c1
=
MLP
(
input_dim
=
c1_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
conv_fuse
=
torch
.
nn
.
Conv2d
(
self
.
embedding_dim
*
4
,
self
.
embedding_dim
,
1
,
1
,
bias
=
False
)
self
.
norm
=
apex
.
parallel
.
SyncBatchNorm
(
self
.
embedding_dim
)
self
.
dropout
=
torch
.
nn
.
Dropout2d
(
0.1
)
self
.
linear_pred
=
torch
.
nn
.
Conv2d
(
self
.
embedding_dim
,
self
.
flatten_dim
,
kernel_size
=
1
)
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
pass
def
forward
(
self
,
input
):
c1
,
c2
,
c3
,
c4
=
self
.
backbone
(
input
)
n
,
_
,
h
,
w
=
c4
.
shape
_c4
=
self
.
linear_c4
(
c4
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c4
.
shape
[
2
],
c4
.
shape
[
3
])
_c4
=
resize
(
_c4
,
size
=
c1
.
size
()[
2
:],
mode
=
'bilinear'
,
align_corners
=
False
)
_c3
=
self
.
linear_c3
(
c3
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c3
.
shape
[
2
],
c3
.
shape
[
3
])
_c3
=
resize
(
_c3
,
size
=
c1
.
size
()[
2
:],
mode
=
'bilinear'
,
align_corners
=
False
)
_c2
=
self
.
linear_c2
(
c2
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c2
.
shape
[
2
],
c2
.
shape
[
3
])
_c2
=
resize
(
_c2
,
size
=
c1
.
size
()[
2
:],
mode
=
'bilinear'
,
align_corners
=
False
)
_c1
=
self
.
linear_c1
(
c1
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c1
.
shape
[
2
],
c1
.
shape
[
3
])
_c
=
torch
.
cat
([
_c4
,
_c3
,
_c2
,
_c1
],
dim
=
1
)
_c
=
self
.
conv_fuse
(
_c
)
x
=
self
.
norm
(
_c
)
x
=
F
.
relu
(
x
,
inplace
=
True
)
x
=
self
.
dropout
(
x
)
x
=
self
.
linear_pred
(
x
)
output
=
einops
.
rearrange
(
x
,
"b (c p1 p2) h w -> b c (h p1) (w p2)"
,
p1
=
self
.
patch_dim
,
p2
=
self
.
patch_dim
,
h
=
self
.
img_h
//
self
.
patch_dim
,
w
=
self
.
img_w
//
self
.
patch_dim
,
)
return
output
megatron/model/vision/knn_monitor.py
0 → 100644
View file @
01a82723
import
torch.nn.functional
as
F
import
torch
from
megatron
import
print_rank_0
,
get_args
,
mpu
from
megatron.data.vit_dataset
import
ClassificationTransform
from
megatron.data.image_folder
import
ImageFolder
def
build_data_loader
(
dataset
,
drop_last
=
True
,
shuffle
=
False
):
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
# Sampler.
args
=
get_args
()
micro_batch_size
=
16
num_workers
=
args
.
num_workers
world_size
=
mpu
.
get_data_parallel_world_size
()
rank
=
mpu
.
get_data_parallel_rank
()
sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
dataset
,
num_replicas
=
world_size
,
rank
=
rank
,
drop_last
=
drop_last
,
shuffle
=
shuffle
)
# Data loader. Note that batch size is the per GPU batch size.
data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
micro_batch_size
,
sampler
=
sampler
,
shuffle
=
False
,
num_workers
=
num_workers
,
drop_last
=
not
drop_last
,
pin_memory
=
True
,
)
return
data_loader
def
compute_feature_bank
(
model
):
args
=
get_args
()
feature_bank
=
[]
feature_label
=
[]
train_ds
=
ImageFolder
(
root
=
args
.
data_path
[
0
],
transform
=
ClassificationTransform
((
args
.
img_h
,
args
.
img_w
),
train
=
False
),
data_per_class_fraction
=
1.0
)
classes
=
len
(
train_ds
.
classes
)
dataloader
=
build_data_loader
(
train_ds
)
for
m
in
model
:
m
.
eval
()
with
torch
.
no_grad
():
for
i
,
batch
in
enumerate
(
dataloader
):
images
=
batch
[
0
].
cuda
().
contiguous
()
labels
=
batch
[
1
].
cuda
().
contiguous
()
student_feature
,
teacher_feature
=
model
[
0
](
images
)
feature
=
F
.
normalize
(
teacher_feature
.
float
(),
dim
=
1
)
feature_bank
.
append
(
feature
)
feature_label
.
append
(
labels
)
for
m
in
model
:
m
.
train
()
# [N', D]
feature_bank
=
torch
.
cat
(
feature_bank
,
dim
=
0
).
contiguous
()
feature_label
=
torch
.
cat
(
feature_label
,
dim
=
0
).
contiguous
()
feature_banks
=
[
torch
.
zeros_like
(
feature_bank
)
for
i
in
range
(
mpu
.
get_data_parallel_world_size
())]
torch
.
distributed
.
all_gather
(
feature_banks
,
feature_bank
,
group
=
mpu
.
get_data_parallel_group
())
assert
torch
.
all
(
torch
.
eq
(
feature_banks
[
mpu
.
get_data_parallel_rank
()],
feature_bank
))
feature_labels
=
[
torch
.
zeros_like
(
feature_label
)
for
i
in
range
(
mpu
.
get_data_parallel_world_size
())]
torch
.
distributed
.
all_gather
(
feature_labels
,
feature_label
,
group
=
mpu
.
get_data_parallel_group
())
# [D, N]
feature_banks
=
torch
.
cat
(
feature_banks
,
dim
=
0
).
t
().
contiguous
()
# [N]
feature_labels
=
torch
.
cat
(
feature_labels
,
dim
=
0
).
contiguous
()
print_rank_0
(
"feature_banks size is {}"
.
format
(
feature_banks
.
size
()))
print_rank_0
(
"feature labels size is {}"
.
format
(
feature_labels
.
size
()))
return
(
feature_banks
,
feature_labels
,
classes
)
# knn monitor as in InstDisc https://arxiv.org/abs/1805.01978
# implementation follows http://github.com/zhirongw/lemniscate.pytorch and
# https://github.com/leftthomas/SimCLR
def
knn_predict
(
feature
,
feature_bank
,
feature_labels
,
classes
,
knn_k
,
knn_t
):
# compute cos similarity between each feature vector and feature bank ---> [B, N]
sim_matrix
=
torch
.
mm
(
feature
,
feature_bank
)
# [B, K]
sim_weight
,
sim_indices
=
sim_matrix
.
topk
(
k
=
knn_k
,
dim
=-
1
)
# [B, K]
sim_labels
=
torch
.
gather
(
feature_labels
.
expand
(
feature
.
size
(
0
),
-
1
),
dim
=-
1
,
index
=
sim_indices
)
sim_weight
=
(
sim_weight
/
knn_t
).
exp
()
# counts for each class
one_hot_label
=
torch
.
zeros
(
feature
.
size
(
0
)
*
knn_k
,
classes
,
device
=
sim_labels
.
device
)
# [B*K, C]
one_hot_label
=
one_hot_label
.
scatter
(
dim
=-
1
,
index
=
sim_labels
.
view
(
-
1
,
1
),
value
=
1.0
)
# weighted score ---> [B, C]
pred_scores
=
torch
.
sum
(
one_hot_label
.
view
(
feature
.
size
(
0
),
-
1
,
classes
)
*
sim_weight
.
unsqueeze
(
dim
=-
1
),
dim
=
1
)
pred_labels
=
pred_scores
.
argsort
(
dim
=-
1
,
descending
=
True
)
return
pred_labels
megatron/model/vision/mit_backbone.py
0 → 100644
View file @
01a82723
# ---------------------------------------------------------------
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# ---------------------------------------------------------------
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
functools
import
partial
from
megatron.model.vision.utils
import
DropPath
,
trunc_normal_
from
megatron.model
import
LayerNorm
class
Mlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
drop
=
0.
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
dwconv
=
DWConv
(
hidden_features
)
self
.
act
=
act_layer
()
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop
=
nn
.
Dropout
(
drop
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
elif
isinstance
(
m
,
nn
.
Conv2d
):
fan_out
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
fan_out
//=
m
.
groups
m
.
weight
.
data
.
normal_
(
0
,
math
.
sqrt
(
2.0
/
fan_out
))
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
,
H
,
W
):
x
=
self
.
fc1
(
x
)
x
=
self
.
dwconv
(
x
,
H
,
W
)
x
=
self
.
act
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop
(
x
)
return
x
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
=
8
,
qkv_bias
=
False
,
qk_scale
=
None
,
attn_drop
=
0.
,
proj_drop
=
0.
,
sr_ratio
=
1
):
super
().
__init__
()
assert
dim
%
num_heads
==
0
,
f
"dim
{
dim
}
should be divided by num_heads
{
num_heads
}
."
self
.
dim
=
dim
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
self
.
scale
=
qk_scale
or
head_dim
**
-
0.5
self
.
q
=
nn
.
Linear
(
dim
,
dim
,
bias
=
qkv_bias
)
self
.
kv
=
nn
.
Linear
(
dim
,
dim
*
2
,
bias
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
self
.
sr_ratio
=
sr_ratio
if
sr_ratio
>
1
:
self
.
sr
=
nn
.
Conv2d
(
dim
,
dim
,
kernel_size
=
sr_ratio
,
stride
=
sr_ratio
)
self
.
norm
=
LayerNorm
(
dim
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
elif
isinstance
(
m
,
nn
.
Conv2d
):
fan_out
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
fan_out
//=
m
.
groups
m
.
weight
.
data
.
normal_
(
0
,
math
.
sqrt
(
2.0
/
fan_out
))
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
,
H
,
W
):
B
,
N
,
C
=
x
.
shape
q
=
self
.
q
(
x
).
reshape
(
B
,
N
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
0
,
2
,
1
,
3
)
if
self
.
sr_ratio
>
1
:
x_
=
x
.
permute
(
0
,
2
,
1
).
reshape
(
B
,
C
,
H
,
W
)
x_
=
self
.
sr
(
x_
).
reshape
(
B
,
C
,
-
1
).
permute
(
0
,
2
,
1
)
x_
=
self
.
norm
(
x_
)
kv
=
self
.
kv
(
x_
).
reshape
(
B
,
-
1
,
2
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
else
:
kv
=
self
.
kv
(
x
).
reshape
(
B
,
-
1
,
2
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
k
,
v
=
kv
[
0
],
kv
[
1
]
attn
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
*
self
.
scale
attn
=
attn
.
softmax
(
dim
=-
1
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B
,
N
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
,
mlp_ratio
=
4.
,
qkv_bias
=
False
,
qk_scale
=
None
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
act_layer
=
nn
.
GELU
,
norm_layer
=
LayerNorm
,
sr_ratio
=
1
):
super
().
__init__
()
self
.
norm1
=
norm_layer
(
dim
)
self
.
attn
=
Attention
(
dim
,
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
,
sr_ratio
=
sr_ratio
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
else
nn
.
Identity
()
self
.
norm2
=
norm_layer
(
dim
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
mlp
=
Mlp
(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
act_layer
,
drop
=
drop
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
elif
isinstance
(
m
,
nn
.
Conv2d
):
fan_out
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
fan_out
//=
m
.
groups
m
.
weight
.
data
.
normal_
(
0
,
math
.
sqrt
(
2.0
/
fan_out
))
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
,
H
,
W
):
x
=
x
+
self
.
drop_path
(
self
.
attn
(
self
.
norm1
(
x
),
H
,
W
))
x
=
x
+
self
.
drop_path
(
self
.
mlp
(
self
.
norm2
(
x
),
H
,
W
))
return
x
class
OverlapPatchEmbed
(
nn
.
Module
):
""" Image to Patch Embedding
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
7
,
stride
=
4
,
in_chans
=
3
,
embed_dim
=
768
):
super
().
__init__
()
img_size
=
(
img_size
,
img_size
)
patch_size
=
(
patch_size
,
patch_size
)
self
.
proj
=
nn
.
Conv2d
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
stride
,
padding
=
(
patch_size
[
0
]
//
2
,
patch_size
[
1
]
//
2
))
self
.
norm
=
LayerNorm
(
embed_dim
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
elif
isinstance
(
m
,
nn
.
Conv2d
):
fan_out
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
fan_out
//=
m
.
groups
m
.
weight
.
data
.
normal_
(
0
,
math
.
sqrt
(
2.0
/
fan_out
))
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
):
x
=
self
.
proj
(
x
)
_
,
_
,
H
,
W
=
x
.
shape
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
x
=
self
.
norm
(
x
)
return
x
,
H
,
W
class
MixVisionTransformer
(
nn
.
Module
):
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
num_classes
=
1000
,
embed_dims
=
[
64
,
128
,
256
,
512
],
num_heads
=
[
1
,
2
,
4
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
False
,
qk_scale
=
None
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
norm_layer
=
LayerNorm
,
depths
=
[
3
,
4
,
6
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
output_avg
=
False
):
super
().
__init__
()
self
.
num_classes
=
num_classes
self
.
depths
=
depths
self
.
output_avg
=
output_avg
# patch_embed
self
.
patch_embed1
=
OverlapPatchEmbed
(
img_size
=
img_size
,
patch_size
=
7
,
stride
=
4
,
in_chans
=
in_chans
,
embed_dim
=
embed_dims
[
0
])
self
.
patch_embed2
=
OverlapPatchEmbed
(
img_size
=
img_size
//
4
,
patch_size
=
3
,
stride
=
2
,
in_chans
=
embed_dims
[
0
],
embed_dim
=
embed_dims
[
1
])
self
.
patch_embed3
=
OverlapPatchEmbed
(
img_size
=
img_size
//
8
,
patch_size
=
3
,
stride
=
2
,
in_chans
=
embed_dims
[
1
],
embed_dim
=
embed_dims
[
2
])
self
.
patch_embed4
=
OverlapPatchEmbed
(
img_size
=
img_size
//
16
,
patch_size
=
3
,
stride
=
2
,
in_chans
=
embed_dims
[
2
],
embed_dim
=
embed_dims
[
3
])
# transformer encoder
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
sum
(
depths
))]
# stochastic depth decay rule
cur
=
0
self
.
block1
=
nn
.
ModuleList
([
Block
(
dim
=
embed_dims
[
0
],
num_heads
=
num_heads
[
0
],
mlp_ratio
=
mlp_ratios
[
0
],
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
cur
+
i
],
norm_layer
=
norm_layer
,
sr_ratio
=
sr_ratios
[
0
])
for
i
in
range
(
depths
[
0
])])
self
.
norm1
=
norm_layer
(
embed_dims
[
0
])
cur
+=
depths
[
0
]
self
.
block2
=
nn
.
ModuleList
([
Block
(
dim
=
embed_dims
[
1
],
num_heads
=
num_heads
[
1
],
mlp_ratio
=
mlp_ratios
[
1
],
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
cur
+
i
],
norm_layer
=
norm_layer
,
sr_ratio
=
sr_ratios
[
1
])
for
i
in
range
(
depths
[
1
])])
self
.
norm2
=
norm_layer
(
embed_dims
[
1
])
cur
+=
depths
[
1
]
self
.
block3
=
nn
.
ModuleList
([
Block
(
dim
=
embed_dims
[
2
],
num_heads
=
num_heads
[
2
],
mlp_ratio
=
mlp_ratios
[
2
],
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
cur
+
i
],
norm_layer
=
norm_layer
,
sr_ratio
=
sr_ratios
[
2
])
for
i
in
range
(
depths
[
2
])])
self
.
norm3
=
norm_layer
(
embed_dims
[
2
])
cur
+=
depths
[
2
]
self
.
block4
=
nn
.
ModuleList
([
Block
(
dim
=
embed_dims
[
3
],
num_heads
=
num_heads
[
3
],
mlp_ratio
=
mlp_ratios
[
3
],
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
cur
+
i
],
norm_layer
=
norm_layer
,
sr_ratio
=
sr_ratios
[
3
])
for
i
in
range
(
depths
[
3
])])
self
.
norm4
=
norm_layer
(
embed_dims
[
3
])
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
elif
isinstance
(
m
,
nn
.
Conv2d
):
fan_out
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
fan_out
//=
m
.
groups
m
.
weight
.
data
.
normal_
(
0
,
math
.
sqrt
(
2.0
/
fan_out
))
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
def
reset_drop_path
(
self
,
drop_path_rate
):
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
sum
(
self
.
depths
))]
cur
=
0
for
i
in
range
(
self
.
depths
[
0
]):
self
.
block1
[
i
].
drop_path
.
drop_prob
=
dpr
[
cur
+
i
]
cur
+=
self
.
depths
[
0
]
for
i
in
range
(
self
.
depths
[
1
]):
self
.
block2
[
i
].
drop_path
.
drop_prob
=
dpr
[
cur
+
i
]
cur
+=
self
.
depths
[
1
]
for
i
in
range
(
self
.
depths
[
2
]):
self
.
block3
[
i
].
drop_path
.
drop_prob
=
dpr
[
cur
+
i
]
cur
+=
self
.
depths
[
2
]
for
i
in
range
(
self
.
depths
[
3
]):
self
.
block4
[
i
].
drop_path
.
drop_prob
=
dpr
[
cur
+
i
]
def
freeze_patch_emb
(
self
):
self
.
patch_embed1
.
requires_grad
=
False
def
forward_features
(
self
,
x
):
B
=
x
.
shape
[
0
]
outs
=
[]
# stage 1
x
,
H
,
W
=
self
.
patch_embed1
(
x
)
for
i
,
blk
in
enumerate
(
self
.
block1
):
x
=
blk
(
x
,
H
,
W
)
x
=
self
.
norm1
(
x
)
x
=
x
.
reshape
(
B
,
H
,
W
,
-
1
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
outs
.
append
(
x
)
# stage 2
x
,
H
,
W
=
self
.
patch_embed2
(
x
)
for
i
,
blk
in
enumerate
(
self
.
block2
):
x
=
blk
(
x
,
H
,
W
)
x
=
self
.
norm2
(
x
)
x
=
x
.
reshape
(
B
,
H
,
W
,
-
1
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
outs
.
append
(
x
)
# stage 3
x
,
H
,
W
=
self
.
patch_embed3
(
x
)
for
i
,
blk
in
enumerate
(
self
.
block3
):
x
=
blk
(
x
,
H
,
W
)
x
=
self
.
norm3
(
x
)
x
=
x
.
reshape
(
B
,
H
,
W
,
-
1
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
outs
.
append
(
x
)
# stage 4
x
,
H
,
W
=
self
.
patch_embed4
(
x
)
for
i
,
blk
in
enumerate
(
self
.
block4
):
x
=
blk
(
x
,
H
,
W
)
x
=
self
.
norm4
(
x
)
if
not
self
.
output_avg
:
x
=
x
.
reshape
(
B
,
H
,
W
,
-
1
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
outs
.
append
(
x
)
return
outs
def
forward
(
self
,
x
):
x
=
self
.
forward_features
(
x
)
if
self
.
output_avg
:
x
=
x
[
3
].
mean
(
dim
=
1
)
return
x
class
DWConv
(
nn
.
Module
):
def
__init__
(
self
,
dim
=
768
):
super
(
DWConv
,
self
).
__init__
()
self
.
dwconv
=
nn
.
Conv2d
(
dim
,
dim
,
3
,
1
,
1
,
bias
=
True
,
groups
=
dim
)
def
forward
(
self
,
x
,
H
,
W
):
B
,
N
,
C
=
x
.
shape
x
=
x
.
transpose
(
1
,
2
).
view
(
B
,
C
,
H
,
W
)
x
=
self
.
dwconv
(
x
)
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
return
x
class
mit_b0
(
MixVisionTransformer
):
def
__init__
(
self
,
**
kwargs
):
super
(
mit_b0
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
32
,
64
,
160
,
256
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
2
,
2
,
2
,
2
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
0.1
)
class
mit_b1
(
MixVisionTransformer
):
def
__init__
(
self
,
**
kwargs
):
super
(
mit_b1
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
2
,
2
,
2
,
2
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
0.1
)
class
mit_b2
(
MixVisionTransformer
):
def
__init__
(
self
,
**
kwargs
):
super
(
mit_b2
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
3
,
4
,
6
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
0.1
)
class
mit_b3
(
MixVisionTransformer
):
def
__init__
(
self
,
**
kwargs
):
super
(
mit_b3
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
3
,
4
,
18
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
0.1
)
class
mit_b3_avg
(
MixVisionTransformer
):
def
__init__
(
self
,
drop_path_rate
=
0.1
,
**
kwargs
):
super
(
mit_b3_avg
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
3
,
4
,
18
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
drop_path_rate
,
output_avg
=
True
)
class
mit_b4
(
MixVisionTransformer
):
def
__init__
(
self
,
**
kwargs
):
super
(
mit_b4
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
3
,
8
,
27
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
0.1
)
class
mit_b5
(
MixVisionTransformer
):
def
__init__
(
self
,
**
kwargs
):
super
(
mit_b5
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
3
,
6
,
40
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
0.1
)
class
mit_b5_avg
(
MixVisionTransformer
):
def
__init__
(
self
,
drop_path_rate
=
0.1
,
**
kwargs
):
super
(
mit_b5_avg
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
3
,
6
,
40
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
drop_path_rate
,
output_avg
=
True
)
megatron/model/vision/swin_backbone.py
0 → 100644
View file @
01a82723
# Copyright (c) 2021 Microsoft
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Swin Transformer
# --------------------------------------------------------
import
torch
import
torch.nn
as
nn
import
torch.utils.checkpoint
as
checkpoint
from
timm.models.layers
import
DropPath
,
to_2tuple
,
trunc_normal_
from
math
import
sqrt
from
megatron
import
get_args
from
functools
import
partial
class
Mlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
drop
=
0.
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
act
=
act_layer
()
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop
=
nn
.
Dropout
(
drop
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop
(
x
)
return
x
def
window_partition
(
x
,
window_size
):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B
,
H
,
W
,
C
=
x
.
shape
x
=
x
.
view
(
B
,
H
//
window_size
,
window_size
,
W
//
window_size
,
window_size
,
C
)
windows
=
x
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
contiguous
().
view
(
-
1
,
window_size
,
window_size
,
C
)
return
windows
def
window_reverse
(
windows
,
window_size
,
H
,
W
):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B
=
int
(
windows
.
shape
[
0
]
/
(
H
*
W
/
window_size
/
window_size
))
x
=
windows
.
view
(
B
,
H
//
window_size
,
W
//
window_size
,
window_size
,
window_size
,
-
1
)
x
=
x
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
contiguous
().
view
(
B
,
H
,
W
,
-
1
)
return
x
class
WindowAttention
(
nn
.
Module
):
r
""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def
__init__
(
self
,
dim
,
window_size
,
num_heads
,
qkv_bias
=
True
,
qk_scale
=
None
,
attn_drop
=
0.
,
proj_drop
=
0.
):
super
().
__init__
()
self
.
dim
=
dim
self
.
window_size
=
window_size
# Wh, Ww
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
self
.
scale
=
qk_scale
or
head_dim
**
-
0.5
# define a parameter table of relative position bias
self
.
relative_position_bias_table
=
nn
.
Parameter
(
torch
.
zeros
((
2
*
window_size
[
0
]
-
1
)
*
(
2
*
window_size
[
1
]
-
1
),
num_heads
))
# 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h
=
torch
.
arange
(
self
.
window_size
[
0
])
coords_w
=
torch
.
arange
(
self
.
window_size
[
1
])
coords
=
torch
.
stack
(
torch
.
meshgrid
([
coords_h
,
coords_w
]))
# 2, Wh, Ww
coords_flatten
=
torch
.
flatten
(
coords
,
1
)
# 2, Wh*Ww
relative_coords
=
coords_flatten
[:,
:,
None
]
-
coords_flatten
[:,
None
,
:]
# 2, Wh*Ww, Wh*Ww
relative_coords
=
relative_coords
.
permute
(
1
,
2
,
0
).
contiguous
()
# Wh*Ww, Wh*Ww, 2
relative_coords
[:,
:,
0
]
+=
self
.
window_size
[
0
]
-
1
# shift to start from 0
relative_coords
[:,
:,
1
]
+=
self
.
window_size
[
1
]
-
1
relative_coords
[:,
:,
0
]
*=
2
*
self
.
window_size
[
1
]
-
1
relative_position_index
=
relative_coords
.
sum
(
-
1
)
# Wh*Ww, Wh*Ww
self
.
register_buffer
(
"relative_position_index"
,
relative_position_index
)
self
.
qkv
=
nn
.
Linear
(
dim
,
dim
*
3
,
bias
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
trunc_normal_
(
self
.
relative_position_bias_table
,
std
=
.
02
)
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
def
forward
(
self
,
x
,
mask
=
None
):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
).
reshape
(
B_
,
N
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
# make torchscript happy (cannot use tensor as tuple)
q
=
q
*
self
.
scale
attn
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
relative_position_bias
=
self
.
relative_position_bias_table
[
self
.
relative_position_index
.
view
(
-
1
)].
view
(
self
.
window_size
[
0
]
*
self
.
window_size
[
1
],
self
.
window_size
[
0
]
*
self
.
window_size
[
1
],
-
1
)
# Wh*Ww,Wh*Ww,nH
relative_position_bias
=
relative_position_bias
.
permute
(
2
,
0
,
1
).
contiguous
()
# nH, Wh*Ww, Wh*Ww
attn
=
attn
+
relative_position_bias
.
unsqueeze
(
0
)
if
mask
is
not
None
:
nW
=
mask
.
shape
[
0
]
attn
=
attn
.
view
(
B_
//
nW
,
nW
,
self
.
num_heads
,
N
,
N
)
+
mask
.
unsqueeze
(
1
).
unsqueeze
(
0
)
attn
=
attn
.
view
(
-
1
,
self
.
num_heads
,
N
,
N
)
attn
=
self
.
softmax
(
attn
)
else
:
attn
=
self
.
softmax
(
attn
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B_
,
N
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
def
extra_repr
(
self
)
->
str
:
return
f
'dim=
{
self
.
dim
}
, window_size=
{
self
.
window_size
}
, num_heads=
{
self
.
num_heads
}
'
def
flops
(
self
,
N
):
# calculate flops for 1 window with token length of N
flops
=
0
# qkv = self.qkv(x)
flops
+=
N
*
self
.
dim
*
3
*
self
.
dim
# attn = (q @ k.transpose(-2, -1))
flops
+=
self
.
num_heads
*
N
*
(
self
.
dim
//
self
.
num_heads
)
*
N
# x = (attn @ v)
flops
+=
self
.
num_heads
*
N
*
N
*
(
self
.
dim
//
self
.
num_heads
)
# x = self.proj(x)
flops
+=
N
*
self
.
dim
*
self
.
dim
return
flops
class
SwinTransformerBlock
(
nn
.
Module
):
r
""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def
__init__
(
self
,
dim
,
input_resolution
,
num_heads
,
window_size
=
7
,
shift_size
=
0
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
act_layer
=
nn
.
GELU
,
norm_layer
=
nn
.
LayerNorm
):
super
().
__init__
()
self
.
dim
=
dim
self
.
input_resolution
=
input_resolution
self
.
num_heads
=
num_heads
self
.
window_size
=
window_size
self
.
shift_size
=
shift_size
self
.
mlp_ratio
=
mlp_ratio
if
min
(
self
.
input_resolution
)
<=
self
.
window_size
:
# if window size is larger than input resolution, we don't partition windows
self
.
shift_size
=
0
self
.
window_size
=
min
(
self
.
input_resolution
)
assert
0
<=
self
.
shift_size
<
self
.
window_size
,
"shift_size must in 0-window_size"
self
.
norm1
=
norm_layer
(
dim
)
self
.
attn
=
WindowAttention
(
dim
,
window_size
=
to_2tuple
(
self
.
window_size
),
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
)
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
else
nn
.
Identity
()
self
.
norm2
=
norm_layer
(
dim
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
mlp
=
Mlp
(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
act_layer
,
drop
=
drop
)
self
.
H
=
input_resolution
[
0
]
self
.
W
=
input_resolution
[
1
]
self
.
attn_mask_dict
=
{}
def
create_attn_mask
(
self
,
H
,
W
):
# calculate attention mask for SW-MSA
Hp
=
int
(
np
.
ceil
(
H
/
self
.
window_size
))
*
self
.
window_size
Wp
=
int
(
np
.
ceil
(
W
/
self
.
window_size
))
*
self
.
window_size
img_mask
=
torch
.
zeros
((
1
,
Hp
,
Wp
,
1
))
# 1 Hp Wp 1
h_slices
=
(
slice
(
0
,
-
self
.
window_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
slice
(
-
self
.
shift_size
,
None
))
w_slices
=
(
slice
(
0
,
-
self
.
window_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
slice
(
-
self
.
shift_size
,
None
))
cnt
=
0
for
h
in
h_slices
:
for
w
in
w_slices
:
img_mask
[:,
h
,
w
,
:]
=
cnt
cnt
+=
1
mask_windows
=
window_partition
(
img_mask
,
self
.
window_size
)
# nW, window_size, window_size, 1
mask_windows
=
mask_windows
.
view
(
-
1
,
self
.
window_size
*
self
.
window_size
)
attn_mask
=
mask_windows
.
unsqueeze
(
1
)
-
mask_windows
.
unsqueeze
(
2
)
attn_mask
=
attn_mask
.
masked_fill
(
attn_mask
!=
0
,
float
(
-
100.0
)).
masked_fill
(
attn_mask
==
0
,
float
(
0.0
))
return
attn_mask
def
forward
(
self
,
x
):
B
,
L
,
C
=
x
.
shape
H
=
int
(
sqrt
(
L
))
W
=
H
shortcut
=
x
x
=
self
.
norm1
(
x
)
x
=
x
.
view
(
B
,
H
,
W
,
C
)
# cyclic shift
if
self
.
shift_size
>
0
:
shifted_x
=
torch
.
roll
(
x
,
shifts
=
(
-
self
.
shift_size
,
-
self
.
shift_size
),
dims
=
(
1
,
2
))
else
:
shifted_x
=
x
# partition windows
x_windows
=
window_partition
(
shifted_x
,
self
.
window_size
)
# nW*B, window_size, window_size, C
x_windows
=
x_windows
.
view
(
-
1
,
self
.
window_size
*
self
.
window_size
,
C
)
# nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows
=
self
.
attn
(
x_windows
,
mask
=
self
.
attn_mask
)
# nW*B, window_size*window_size, C
# merge windows
attn_windows
=
attn_windows
.
view
(
-
1
,
self
.
window_size
,
self
.
window_size
,
C
)
shifted_x
=
window_reverse
(
attn_windows
,
self
.
window_size
,
H
,
W
)
# B H' W' C
# reverse cyclic shift
if
self
.
shift_size
>
0
:
x
=
torch
.
roll
(
shifted_x
,
shifts
=
(
self
.
shift_size
,
self
.
shift_size
),
dims
=
(
1
,
2
))
else
:
x
=
shifted_x
x
=
x
.
view
(
B
,
H
*
W
,
C
)
# FFN
x
=
shortcut
+
self
.
drop_path
(
x
)
x
=
x
+
self
.
drop_path
(
self
.
mlp
(
self
.
norm2
(
x
)))
return
x
def
extra_repr
(
self
)
->
str
:
return
f
"dim=
{
self
.
dim
}
, input_resolution=
{
self
.
input_resolution
}
, num_heads=
{
self
.
num_heads
}
, "
\
f
"window_size=
{
self
.
window_size
}
, shift_size=
{
self
.
shift_size
}
, mlp_ratio=
{
self
.
mlp_ratio
}
"
def
flops
(
self
):
flops
=
0
H
,
W
=
self
.
input_resolution
# norm1
flops
+=
self
.
dim
*
H
*
W
# W-MSA/SW-MSA
nW
=
H
*
W
/
self
.
window_size
/
self
.
window_size
flops
+=
nW
*
self
.
attn
.
flops
(
self
.
window_size
*
self
.
window_size
)
# mlp
flops
+=
2
*
H
*
W
*
self
.
dim
*
self
.
dim
*
self
.
mlp_ratio
# norm2
flops
+=
self
.
dim
*
H
*
W
return
flops
class
PatchMerging
(
nn
.
Module
):
r
""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def
__init__
(
self
,
input_resolution
,
dim
,
norm_layer
=
nn
.
LayerNorm
):
super
().
__init__
()
self
.
input_resolution
=
input_resolution
self
.
dim
=
dim
self
.
reduction
=
nn
.
Linear
(
4
*
dim
,
2
*
dim
,
bias
=
False
)
self
.
norm
=
norm_layer
(
4
*
dim
)
def
forward
(
self
,
x
):
"""
x: B, H*W, C
"""
H
,
W
=
self
.
input_resolution
B
,
L
,
C
=
x
.
shape
assert
L
==
H
*
W
,
"input feature has wrong size"
assert
H
%
2
==
0
and
W
%
2
==
0
,
f
"x size (
{
H
}
*
{
W
}
) are not even."
x
=
x
.
view
(
B
,
H
,
W
,
C
)
x0
=
x
[:,
0
::
2
,
0
::
2
,
:]
# B H/2 W/2 C
x1
=
x
[:,
1
::
2
,
0
::
2
,
:]
# B H/2 W/2 C
x2
=
x
[:,
0
::
2
,
1
::
2
,
:]
# B H/2 W/2 C
x3
=
x
[:,
1
::
2
,
1
::
2
,
:]
# B H/2 W/2 C
x
=
torch
.
cat
([
x0
,
x1
,
x2
,
x3
],
-
1
)
# B H/2 W/2 4*C
x
=
x
.
view
(
B
,
-
1
,
4
*
C
)
# B H/2*W/2 4*C
x
=
self
.
norm
(
x
)
x
=
self
.
reduction
(
x
)
return
x
def
extra_repr
(
self
)
->
str
:
return
f
"input_resolution=
{
self
.
input_resolution
}
, dim=
{
self
.
dim
}
"
def
flops
(
self
):
H
,
W
=
self
.
input_resolution
flops
=
H
*
W
*
self
.
dim
flops
+=
(
H
//
2
)
*
(
W
//
2
)
*
4
*
self
.
dim
*
2
*
self
.
dim
return
flops
class
BasicLayer
(
nn
.
Module
):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def
__init__
(
self
,
dim
,
input_resolution
,
depth
,
num_heads
,
window_size
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
norm_layer
=
nn
.
LayerNorm
,
downsample
=
None
,
use_checkpoint
=
False
):
super
().
__init__
()
self
.
dim
=
dim
self
.
input_resolution
=
input_resolution
self
.
depth
=
depth
self
.
use_checkpoint
=
use_checkpoint
# build blocks
self
.
blocks
=
nn
.
ModuleList
([
SwinTransformerBlock
(
dim
=
dim
,
input_resolution
=
input_resolution
,
num_heads
=
num_heads
,
window_size
=
window_size
,
shift_size
=
0
if
(
i
%
2
==
0
)
else
window_size
//
2
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop
,
attn_drop
=
attn_drop
,
drop_path
=
drop_path
[
i
]
if
isinstance
(
drop_path
,
list
)
else
drop_path
,
norm_layer
=
norm_layer
)
for
i
in
range
(
depth
)])
# patch merging layer
if
downsample
is
not
None
:
self
.
downsample
=
downsample
(
input_resolution
,
dim
=
dim
,
norm_layer
=
norm_layer
)
else
:
self
.
downsample
=
None
def
forward
(
self
,
x
):
for
blk
in
self
.
blocks
:
if
self
.
use_checkpoint
:
x
=
checkpoint
.
checkpoint
(
blk
,
x
)
else
:
x
=
blk
(
x
)
x_b4_ds
=
x
if
self
.
downsample
is
not
None
:
x
=
self
.
downsample
(
x
)
return
x_b4_ds
,
x
def
extra_repr
(
self
)
->
str
:
return
f
"dim=
{
self
.
dim
}
, input_resolution=
{
self
.
input_resolution
}
, depth=
{
self
.
depth
}
"
def
flops
(
self
):
flops
=
0
for
blk
in
self
.
blocks
:
flops
+=
blk
.
flops
()
if
self
.
downsample
is
not
None
:
flops
+=
self
.
downsample
.
flops
()
return
flops
class
PatchEmbed
(
nn
.
Module
):
r
""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
4
,
in_chans
=
3
,
embed_dim
=
96
,
norm_layer
=
None
):
super
().
__init__
()
img_size
=
to_2tuple
(
img_size
)
patch_size
=
to_2tuple
(
patch_size
)
patches_resolution
=
[
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
]]
self
.
img_size
=
img_size
self
.
patch_size
=
patch_size
self
.
patches_resolution
=
patches_resolution
self
.
num_patches
=
patches_resolution
[
0
]
*
patches_resolution
[
1
]
self
.
in_chans
=
in_chans
self
.
embed_dim
=
embed_dim
self
.
proj
=
nn
.
Conv2d
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
)
if
norm_layer
is
not
None
:
self
.
norm
=
norm_layer
(
embed_dim
)
else
:
self
.
norm
=
None
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
# FIXME look at relaxing size constraints
assert
H
==
self
.
img_size
[
0
]
and
W
==
self
.
img_size
[
1
],
\
f
"Input image size (
{
H
}
*
{
W
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
*
{
self
.
img_size
[
1
]
}
)."
x
=
self
.
proj
(
x
).
flatten
(
2
).
transpose
(
1
,
2
)
# B Ph*Pw C
if
self
.
norm
is
not
None
:
x
=
self
.
norm
(
x
)
return
x
def
flops
(
self
):
Ho
,
Wo
=
self
.
patches_resolution
flops
=
Ho
*
Wo
*
self
.
embed_dim
*
self
.
in_chans
*
(
self
.
patch_size
[
0
]
*
self
.
patch_size
[
1
])
if
self
.
norm
is
not
None
:
flops
+=
Ho
*
Wo
*
self
.
embed_dim
return
flops
class
SwinTransformer
(
nn
.
Module
):
r
""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
img_size (int | tuple(int)): Input image size. Default 224
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
4
,
in_chans
=
3
,
embed_dim
=
96
,
depths
=
[
2
,
2
,
6
,
2
],
num_heads
=
[
3
,
6
,
12
,
24
],
window_size
=
7
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.3
,
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
),
ape
=
False
,
patch_norm
=
True
,
use_checkpoint
=
False
,
output_avg
=
False
,
**
kwargs
):
super
().
__init__
()
self
.
num_layers
=
len
(
depths
)
self
.
embed_dim
=
embed_dim
self
.
ape
=
ape
self
.
patch_norm
=
patch_norm
self
.
num_features
=
int
(
embed_dim
*
2
**
(
self
.
num_layers
-
1
))
self
.
mlp_ratio
=
mlp_ratio
self
.
img_size
=
to_2tuple
(
img_size
)
self
.
patch_size
=
to_2tuple
(
patch_size
)
self
.
output_avg
=
output_avg
# split image into non-overlapping patches
self
.
patch_embed
=
PatchEmbed
(
img_size
=
img_size
,
patch_size
=
patch_size
,
in_chans
=
in_chans
,
embed_dim
=
embed_dim
,
norm_layer
=
norm_layer
if
self
.
patch_norm
else
None
)
num_patches
=
self
.
patch_embed
.
num_patches
patches_resolution
=
self
.
patch_embed
.
patches_resolution
self
.
patches_resolution
=
patches_resolution
# absolute position embedding
if
self
.
ape
:
self
.
absolute_pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_patches
,
embed_dim
))
trunc_normal_
(
self
.
absolute_pos_embed
,
std
=
.
02
)
self
.
pos_drop
=
nn
.
Dropout
(
p
=
drop_rate
)
# stochastic depth
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
sum
(
depths
))]
# stochastic depth decay rule
# build layers
self
.
layers
=
nn
.
ModuleList
()
for
i_layer
in
range
(
self
.
num_layers
):
layer
=
BasicLayer
(
dim
=
int
(
embed_dim
*
2
**
i_layer
),
input_resolution
=
(
patches_resolution
[
0
]
//
(
2
**
i_layer
),
patches_resolution
[
1
]
//
(
2
**
i_layer
)),
depth
=
depths
[
i_layer
],
num_heads
=
num_heads
[
i_layer
],
window_size
=
window_size
,
mlp_ratio
=
self
.
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
sum
(
depths
[:
i_layer
]):
sum
(
depths
[:
i_layer
+
1
])],
norm_layer
=
norm_layer
,
downsample
=
PatchMerging
if
(
i_layer
<
self
.
num_layers
-
1
)
else
None
,
use_checkpoint
=
use_checkpoint
)
self
.
layers
.
append
(
layer
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
@
torch
.
jit
.
ignore
def
no_weight_decay
(
self
):
return
{
'absolute_pos_embed'
}
@
torch
.
jit
.
ignore
def
no_weight_decay_keywords
(
self
):
return
{
'relative_position_bias_table'
}
def
forward
(
self
,
x
):
x
=
self
.
patch_embed
(
x
)
if
self
.
ape
:
x
=
x
+
self
.
absolute_pos_embed
x
=
self
.
pos_drop
(
x
)
h
=
self
.
img_size
[
0
]
//
self
.
patch_size
[
0
]
w
=
self
.
img_size
[
1
]
//
self
.
patch_size
[
1
]
outs
=
[]
for
i
,
layer
in
enumerate
(
self
.
layers
):
px
,
x
=
layer
(
x
)
b
,
n
,
c
=
px
.
shape
if
i
!=
len
(
self
.
layers
)
-
1
or
not
self
.
output_avg
:
px
=
px
.
permute
(
0
,
2
,
1
).
contiguous
()
px
=
px
.
reshape
(
b
,
c
,
h
,
w
)
# is this a fair assumption ?? i think it's baked into the architecture
h
,
w
=
h
//
2
,
w
//
2
outs
.
append
(
px
)
if
self
.
output_avg
:
return
outs
[
-
1
].
mean
(
dim
=
1
)
return
outs
def
flops
(
self
):
flops
=
0
flops
+=
self
.
patch_embed
.
flops
()
for
i
,
layer
in
enumerate
(
self
.
layers
):
flops
+=
layer
.
flops
()
flops
+=
self
.
num_features
*
self
.
patches_resolution
[
0
]
*
self
.
patches_resolution
[
1
]
//
(
2
**
self
.
num_layers
)
flops
+=
self
.
num_features
*
self
.
num_classes
return
flops
def
get_swin
(
drop_path_rate
=
0.3
,
output_avg
=
False
):
args
=
get_args
()
window_size
=
7
embed_dim
=
128
depths
=
[
2
,
2
,
18
,
2
]
num_heads
=
[
4
,
8
,
16
,
32
]
swin
=
SwinTransformer
(
img_size
=
(
args
.
img_h
,
args
.
img_w
,),
in_chans
=
3
,
patch_size
=
args
.
patch_dim
,
embed_dim
=
embed_dim
,
depths
=
depths
,
num_heads
=
num_heads
,
window_size
=
window_size
,
drop_path_rate
=
drop_path_rate
,
output_avg
=
output_avg
,
)
return
swin
megatron/model/vision/utils.py
0 → 100644
View file @
01a82723
import
warnings
import
math
from
itertools
import
repeat
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
def
resize
(
input
,
size
=
None
,
scale_factor
=
None
,
mode
=
'nearest'
,
align_corners
=
None
,
warning
=
True
):
if
warning
:
if
size
is
not
None
and
align_corners
:
input_h
,
input_w
=
tuple
(
int
(
x
)
for
x
in
input
.
shape
[
2
:])
output_h
,
output_w
=
tuple
(
int
(
x
)
for
x
in
size
)
if
output_h
>
input_h
or
output_w
>
output_h
:
if
((
output_h
>
1
and
output_w
>
1
and
input_h
>
1
and
input_w
>
1
)
and
(
output_h
-
1
)
%
(
input_h
-
1
)
and
(
output_w
-
1
)
%
(
input_w
-
1
)):
warnings
.
warn
(
f
'When align_corners=
{
align_corners
}
, '
'the output would more aligned if '
f
'input size
{
(
input_h
,
input_w
)
}
is `x+1` and '
f
'out size
{
(
output_h
,
output_w
)
}
is `nx+1`'
)
if
isinstance
(
size
,
torch
.
Size
):
size
=
tuple
(
int
(
x
)
for
x
in
size
)
return
F
.
interpolate
(
input
,
size
,
scale_factor
,
mode
,
align_corners
)
def
_no_grad_trunc_normal_
(
tensor
,
mean
,
std
,
a
,
b
):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def
norm_cdf
(
x
):
# Computes standard normal cumulative distribution function
return
(
1.
+
math
.
erf
(
x
/
math
.
sqrt
(
2.
)))
/
2.
if
(
mean
<
a
-
2
*
std
)
or
(
mean
>
b
+
2
*
std
):
warnings
.
warn
(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect."
,
stacklevel
=
2
)
with
torch
.
no_grad
():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l
=
norm_cdf
((
a
-
mean
)
/
std
)
u
=
norm_cdf
((
b
-
mean
)
/
std
)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor
.
uniform_
(
2
*
l
-
1
,
2
*
u
-
1
)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor
.
erfinv_
()
# Transform to proper mean, std
tensor
.
mul_
(
std
*
math
.
sqrt
(
2.
))
tensor
.
add_
(
mean
)
# Clamp to ensure it's in the proper range
tensor
.
clamp_
(
min
=
a
,
max
=
b
)
return
tensor
def
trunc_normal_
(
tensor
,
mean
=
0.
,
std
=
1.
,
a
=-
2.
,
b
=
2.
):
# type: (Tensor, float, float, float, float) -> Tensor
r
"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return
_no_grad_trunc_normal_
(
tensor
,
mean
,
std
,
a
,
b
)
megatron/training.py
View file @
01a82723
...
@@ -51,7 +51,7 @@ from megatron.data.data_samplers import build_pretraining_data_loader
...
@@ -51,7 +51,7 @@ from megatron.data.data_samplers import build_pretraining_data_loader
from
megatron.utils
import
calc_params_l2_norm
from
megatron.utils
import
calc_params_l2_norm
from
megatron.schedules
import
get_forward_backward_func
from
megatron.schedules
import
get_forward_backward_func
from
megatron.utils
import
report_memory
from
megatron.utils
import
report_memory
from
megatron.model.vision.knn_monitor
import
compute_feature_bank
def
print_datetime
(
string
):
def
print_datetime
(
string
):
...
@@ -465,11 +465,23 @@ def train_step(forward_step_func, data_iterator,
...
@@ -465,11 +465,23 @@ def train_step(forward_step_func, data_iterator,
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_position_embedding_group
())
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_position_embedding_group
())
timers
(
'backward-embedding-all-reduce'
).
stop
()
timers
(
'backward-embedding-all-reduce'
).
stop
()
if
args
.
vision_pretraining_type
==
"dino"
:
unwrapped_model
=
unwrap_model
(
model
[
0
],
(
torchDDP
,
LocalDDP
,
Float16Module
))
unwrapped_model
.
cancel_gradients_last_layer
(
args
.
curr_iteration
)
# Update parameters.
# Update parameters.
timers
(
'optimizer'
).
start
()
timers
(
'optimizer'
).
start
()
update_successful
,
grad_norm
,
num_zeros_in_grad
=
optimizer
.
step
()
update_successful
,
grad_norm
,
num_zeros_in_grad
=
optimizer
.
step
()
timers
(
'optimizer'
).
stop
()
timers
(
'optimizer'
).
stop
()
if
args
.
vision_pretraining_type
==
"dino"
:
unwrapped_model
=
unwrap_model
(
model
[
0
],
(
torchDDP
,
LocalDDP
,
Float16Module
))
unwrapped_model
.
update_momentum
(
args
.
curr_iteration
)
# Update learning rate.
# Update learning rate.
if
update_successful
:
if
update_successful
:
increment
=
get_num_microbatches
()
*
\
increment
=
get_num_microbatches
()
*
\
...
@@ -702,6 +714,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
...
@@ -702,6 +714,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
report_memory_flag
=
True
report_memory_flag
=
True
while
iteration
<
args
.
train_iters
:
while
iteration
<
args
.
train_iters
:
update_num_microbatches
(
args
.
consumed_train_samples
)
update_num_microbatches
(
args
.
consumed_train_samples
)
args
.
curr_iteration
=
iteration
loss_dict
,
skipped_iter
,
grad_norm
,
num_zeros_in_grad
=
\
loss_dict
,
skipped_iter
,
grad_norm
,
num_zeros_in_grad
=
\
train_step
(
forward_step_func
,
train_step
(
forward_step_func
,
train_data_iterator
,
train_data_iterator
,
...
@@ -791,6 +804,9 @@ def evaluate(forward_step_func,
...
@@ -791,6 +804,9 @@ def evaluate(forward_step_func,
"""Evaluation."""
"""Evaluation."""
args
=
get_args
()
args
=
get_args
()
if
args
.
vision_pretraining_type
==
"contrast"
:
args
.
knn_features
=
compute_feature_bank
(
model
)
# Turn on evaluation mode which disables dropout.
# Turn on evaluation mode which disables dropout.
for
model_module
in
model
:
for
model_module
in
model
:
model_module
.
eval
()
model_module
.
eval
()
...
...
pretrain_vi
t
.py
→
pretrain_vi
sion_classify
.py
View file @
01a82723
...
@@ -22,20 +22,31 @@ from megatron import get_args, get_timers, mpu, print_rank_0
...
@@ -22,20 +22,31 @@ from megatron import get_args, get_timers, mpu, print_rank_0
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.model
import
ModelType
from
megatron.model
import
ModelType
from
megatron.model.vision.classification
import
VitClassificationModel
from
megatron.model.vision.classification
import
VitClassificationModel
from
megatron.model.vision.classification
import
MitClassificationModel
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
"""Build the model."""
print_rank_0
(
"building VIT model ..."
)
print_rank_0
(
"building VIT model ..."
)
args
=
get_args
()
args
=
get_args
()
model
=
VitClassificationModel
(
num_classes
=
args
.
num_classes
,
if
args
.
vision_backbone_type
==
'vit'
:
pre_process
=
pre_process
,
model
=
VitClassificationModel
(
num_classes
=
args
.
num_classes
,
post_process
=
post_process
)
pre_process
=
pre_process
,
post_process
=
post_process
)
elif
args
.
vision_backbone_type
==
'mit'
:
model
=
MitClassificationModel
(
num_classes
=
args
.
num_classes
,
pre_process
=
pre_process
,
post_process
=
post_process
)
else
:
raise
Exception
(
'{} vision backbone is not supported.'
.
format
(
args
.
vision_backbone_type
))
return
model
return
model
def
get_batch
(
data_iterator
):
def
get_batch
(
data_iterator
):
"""Build the batch."""
"""Build the batch."""
data
=
next
(
data_iterator
)
data
=
next
(
data_iterator
)
...
@@ -46,6 +57,7 @@ def get_batch(data_iterator):
...
@@ -46,6 +57,7 @@ def get_batch(data_iterator):
return
images
,
labels
return
images
,
labels
def
loss_func
(
labels
,
output_tensor
):
def
loss_func
(
labels
,
output_tensor
):
logits
=
output_tensor
.
contiguous
().
float
()
logits
=
output_tensor
.
contiguous
().
float
()
loss
=
F
.
cross_entropy
(
logits
,
labels
)
loss
=
F
.
cross_entropy
(
logits
,
labels
)
...
@@ -58,6 +70,7 @@ def loss_func(labels, output_tensor):
...
@@ -58,6 +70,7 @@ def loss_func(labels, output_tensor):
return
loss
,
{
"loss"
:
averaged_loss
[
0
],
"accuracy"
:
averaged_loss
[
1
]}
return
loss
,
{
"loss"
:
averaged_loss
[
0
],
"accuracy"
:
averaged_loss
[
1
]}
def
forward_step
(
data_iterator
,
model
):
def
forward_step
(
data_iterator
,
model
):
"""Forward step."""
"""Forward step."""
timers
=
get_timers
()
timers
=
get_timers
()
...
...
pretrain_vision_dino.py
0 → 100644
View file @
01a82723
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain VIT"""
import
torch
import
torch.nn.functional
as
F
import
torch.nn
as
nn
import
numpy
as
np
import
torch.distributed
as
dist
from
functools
import
partial
from
megatron
import
get_args
,
get_timers
,
mpu
,
print_rank_0
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.model.vision.contrastive
import
DINOPretrainModel
from
megatron.model.vision.knn_monitor
import
knn_predict
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
,
unwrap_model
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
print_rank_0
(
"building VIT model ..."
)
return
DINOPretrainModel
(
pre_process
=
pre_process
,
post_process
=
post_process
)
def
get_batch
(
data_iterator
):
"""Build the batch."""
data
=
next
(
data_iterator
)
# only data parallelism; no need for broadcast
if
isinstance
(
data
[
0
],
list
):
images
=
[
aug
.
cuda
()
for
aug
in
data
[
0
]]
else
:
images
=
data
[
0
].
cuda
()
labels
=
data
[
1
].
cuda
()
return
images
,
labels
def
loss_func
(
model
,
labels
,
output_tensor
,
collect_data
=
False
):
args
=
get_args
()
model
=
unwrap_model
(
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
)
)
if
model
.
training
:
student_output
,
teacher_output
=
output_tensor
loss
=
model
.
dino_loss
(
student_output
,
teacher_output
,
args
.
curr_iteration
)
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
"loss"
:
averaged_loss
[
0
]}
else
:
_
,
teacher_feature
=
output_tensor
feature_bank
,
feature_labels
,
classes
=
args
.
knn_features
feature
=
F
.
normalize
(
teacher_feature
.
float
(),
dim
=
1
)
knn_accs
=
[]
for
k
in
[
10
,
20
,
100
,
200
]:
pred_labels
=
knn_predict
(
feature
,
feature_bank
,
feature_labels
,
classes
,
k
,
0.07
)
knn_acc
=
(
pred_labels
[:,
0
]
==
labels
).
float
().
mean
()
knn_accs
.
append
(
knn_acc
)
averaged_loss
=
average_losses_across_data_parallel_group
(
knn_accs
)
return
0
,
{
"knn_acc_10"
:
averaged_loss
[
0
],
"knn_acc_20"
:
averaged_loss
[
1
],
"knn_acc_100"
:
averaged_loss
[
2
],
"knn_acc_200"
:
averaged_loss
[
3
]}
def
forward_step
(
data_iterator
,
model
):
"""Forward step."""
timers
=
get_timers
()
# Get the batch.
timers
(
"batch-generator"
).
start
()
(
images
,
labels
,
)
=
get_batch
(
data_iterator
)
timers
(
"batch-generator"
).
stop
()
return
model
(
images
),
partial
(
loss_func
,
model
,
labels
)
def
train_valid_test_datasets_provider
(
train_val_test_num_samples
):
"""Build train, valid, and test datasets."""
args
=
get_args
()
print_rank_0
(
"> building train, validation, and test datasets "
"for VIT ..."
)
train_ds
,
valid_ds
=
build_train_valid_datasets
(
data_path
=
args
.
data_path
,
image_size
=
(
args
.
img_h
,
args
.
img_w
)
)
print_rank_0
(
"> finished creating VIT datasets ..."
)
return
train_ds
,
valid_ds
,
None
if
__name__
==
"__main__"
:
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
forward_step
,
args_defaults
=
{
'dataloader_type'
:
'cyclic'
}
)
pretrain_vision_inpaint.py
0 → 100644
View file @
01a82723
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain VIT"""
import
torch
import
torch.nn.functional
as
F
from
functools
import
partial
from
megatron
import
get_args
,
get_timers
,
mpu
,
print_rank_0
,
print_rank_last
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.model.vision.inpainting
import
VitInpaintingModel
from
megatron.model.vision.inpainting
import
MitInpaintingModel
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
from
tasks.vision.metrics
import
SSIM
,
PSNR
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
args
=
get_args
()
if
args
.
vision_backbone_type
==
'vit'
:
model
=
VitInpaintingModel
(
pre_process
=
pre_process
,
post_process
=
post_process
)
elif
args
.
vision_backbone_type
==
'mit'
:
model
=
MitInpaintingModel
(
pre_process
=
pre_process
,
post_process
=
post_process
)
else
:
raise
Exception
(
'{} vision backbone is not supported.'
.
format
(
args
.
vision_backbone_type
))
return
model
def
get_batch
(
data_iterator
):
"""Build the batch."""
data
=
next
(
data_iterator
)
# only data parallelism; no need for broadcast
images
=
data
[
0
][
0
].
cuda
()
masks
=
data
[
0
][
1
].
cuda
()
return
images
,
masks
def
loss_func
(
images
,
masks
,
masked_images
,
outputs
,
collect_data
=
False
):
outputs
=
outputs
.
contiguous
().
float
()
masks_flip
=
1
-
masks
flip_masked_outputs
=
outputs
.
masked_fill
(
masks_flip
.
bool
(),
0
)
flip_masked_images
=
images
.
masked_fill
(
masks_flip
.
bool
(),
0
)
ssim_fun
=
SSIM
()
psnr_fun
=
PSNR
()
if
not
collect_data
:
mask_count
=
torch
.
count_nonzero
(
masks
)
loss
=
F
.
mse_loss
(
flip_masked_outputs
,
flip_masked_images
.
float
(),
reduction
=
"sum"
)
loss
=
loss
/
mask_count
ssim
=
ssim_fun
(
flip_masked_outputs
,
flip_masked_images
.
float
())
psnr
=
psnr_fun
(
flip_masked_outputs
,
flip_masked_images
.
float
())
averaged_loss
=
average_losses_across_data_parallel_group
(
[
loss
,
psnr
,
ssim
]
)
return
loss
,
{
"loss"
:
averaged_loss
[
0
],
"psnr"
:
averaged_loss
[
1
],
'ssim'
:
averaged_loss
[
2
]}
else
:
synth_images
=
masked_images
.
float
()
+
flip_masked_outputs
ssim
=
ssim_fun
(
synth_images
,
images
.
float
())
psnr
=
psnr_fun
(
synth_images
,
images
.
float
())
return
torch
.
cat
((
images
,
masked_images
,
synth_images
),
dim
=
2
),
ssim
,
psnr
def
forward_step
(
data_iterator
,
model
):
"""Forward step."""
timers
=
get_timers
()
# Get the batch.
timers
(
"batch-generator"
).
start
()
(
images
,
masks
,
)
=
get_batch
(
data_iterator
)
timers
(
"batch-generator"
).
stop
()
masked_images
=
images
.
masked_fill
(
masks
.
bool
(),
0
)
outputs
=
model
(
masked_images
)
# Forward mode
return
outputs
,
partial
(
loss_func
,
images
,
masks
,
masked_images
)
def
process_non_loss_data
(
data
,
iteration
,
writer
):
psnr_sum
=
0
ssim_sum
=
0
for
(
output_tb
,
ssim
,
psnr
)
in
data
:
output_tb
[
output_tb
<
0
]
=
0
output_tb
[
output_tb
>
1
]
=
1
writer
.
add_images
(
"gt-input-output-vald"
,
output_tb
,
global_step
=
iteration
,
walltime
=
None
,
dataformats
=
'NCHW'
)
psnr_sum
=
psnr_sum
+
psnr
.
item
()
ssim_sum
=
ssim_sum
+
ssim
.
item
()
psnr
=
psnr_sum
/
len
(
data
)
ssim
=
ssim_sum
/
len
(
data
)
writer
.
add_scalar
(
'PSNR generate value-validation'
,
psnr
,
iteration
)
writer
.
add_scalar
(
'SSIM generate value-validation'
,
ssim
,
iteration
)
def
train_valid_test_datasets_provider
(
train_val_test_num_samples
):
"""Build train, valid, and test datasets."""
args
=
get_args
()
print_rank_0
(
"> building train, validation, and test datasets "
"for VIT ..."
)
train_ds
,
valid_ds
=
build_train_valid_datasets
(
data_path
=
args
.
data_path
,
image_size
=
(
args
.
img_h
,
args
.
img_w
)
)
print_rank_0
(
"> finished creating VIT datasets ..."
)
return
train_ds
,
valid_ds
,
None
if
__name__
==
"__main__"
:
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
forward_step
,
process_non_loss_data
,
args_defaults
=
{
'dataloader_type'
:
'cyclic'
}
)
tasks/vision/classification.py
→
tasks/vision/classification
/classification
.py
View file @
01a82723
File moved
tasks/vision/eval_utils.py
→
tasks/vision/
classification/
eval_utils.py
View file @
01a82723
File moved
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment