Commit 01a82723 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

vision third phase merge: pretraining methods + mit,swin backbones

parent 2b628f96
......@@ -39,7 +39,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_data_args(parser)
parser = _add_autoresume_args(parser)
parser = _add_biencoder_args(parser)
parser = _add_vit_args(parser)
parser = _add_vision_args(parser)
parser = _add_logging_args(parser)
parser = _add_inference_args(parser)
......@@ -849,9 +849,10 @@ def _add_biencoder_args(parser):
return parser
def _add_vit_args(parser):
group = parser.add_argument_group(title="vit")
def _add_vision_args(parser):
group = parser.add_argument_group(title="vision")
# general vision arguements
group.add_argument('--num-classes', type=int, default=1000,
help='num of classes in vision classificaiton task')
group.add_argument('--img-h', type=int, default=224,
......@@ -861,7 +862,7 @@ def _add_vit_args(parser):
group.add_argument('--num-channels', type=int, default=3,
help='Number of channels in input image data')
group.add_argument('--patch-dim', type=int, default=16,
help='patch dimension used in vit')
help='patch dimension')
group.add_argument('--classes-fraction', type=float, default=1.0,
help='training with fraction of classes.')
group.add_argument('--data-per-class-fraction', type=float, default=1.0,
......@@ -869,5 +870,49 @@ def _add_vit_args(parser):
group.add_argument('--no-data-sharding', action='store_false',
help='Disable data sharding.',
dest='data_sharding')
group.add_argument('--head-lr-mult', type=float, default=1.0,
help='learning rate multiplier for head during finetuning')
# pretraining type and backbone selection`
group.add_argument('--vision-pretraining-type', type=str, default='classify',
choices=['classify', 'inpaint', 'contrast'],
help='pretraining objectives')
group.add_argument('--vision-backbone-type', type=str, default='vit',
choices=['vit', 'mit', 'swin'],
help='backbone types types')
group.add_argument('--swin-backbone-type', type=str, default='tiny',
choices=['tiny', 'base', 'h3'],
help='pretraining objectives')
# inpainting arguments
group.add_argument('--mask-type', type=str, default='random',
choices=['random', 'row'],
help='mask types')
group.add_argument('--mask-factor', type=float, default=1.0,
help='mask size scaling parameter')
# dino arguments
group.add_argument('--iter-per-epoch', type=int, default=1250,
help='iterations per epoch')
group.add_argument('--dino-local-img-size', type=int, default=96,
help='Image size for vision classification task')
group.add_argument('--dino-local-crops-number', type=int, default=10,
help='Number of local crops')
group.add_argument('--dino-head-hidden-size', type=int, default=2048,
help='Hidden dimension size in dino head')
group.add_argument('--dino-bottleneck-size', type=int, default=256,
help='Bottle neck dimension in dino head ')
group.add_argument('--dino-freeze-last-layer', type=float, default=1,
help='Freezing last layer weights')
group.add_argument('--dino-norm-last-layer', action='store_true',
help='Disable Norm in last layer.')
group.add_argument('--dino-warmup-teacher-temp', type=float, default=0.04,
help='warump teacher temperature')
group.add_argument('--dino-teacher-temp', type=float, default=0.07,
help='teacher temperature')
group.add_argument('--dino-warmup-teacher-temp-epochs', type=int, default=30,
help='warmup teacher temperaure epochs')
return parser
......@@ -22,6 +22,43 @@ from megatron import get_args
from megatron.data.image_folder import ImageFolder
from megatron.data.autoaugment import ImageNetPolicy
from megatron.data.data_samplers import RandomSeedDataset
from PIL import Image, ImageFilter, ImageOps
class GaussianBlur(object):
"""
Apply Gaussian Blur to the PIL image.
"""
def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
self.prob = p
self.radius_min = radius_min
self.radius_max = radius_max
def __call__(self, img):
do_it = random.random() <= self.prob
if not do_it:
return img
return img.filter(
ImageFilter.GaussianBlur(
radius=random.uniform(self.radius_min, self.radius_max)
)
)
class Solarization(object):
"""
Apply Solarization to the PIL image.
"""
def __init__(self, p):
self.p = p
def __call__(self, img):
if random.random() < self.p:
return ImageOps.solarize(img)
else:
return img
class ClassificationTransform():
def __init__(self, image_size, train=True):
......@@ -52,14 +89,169 @@ class ClassificationTransform():
return output
class InpaintingTransform():
def __init__(self, image_size, train=True):
args = get_args()
self.mask_factor = args.mask_factor
self.mask_type = args.mask_type
self.image_size = image_size
self.patch_size = args.patch_dim
self.mask_size = int(self.mask_factor*(image_size[0]/self.patch_size)*(image_size[1]/self.patch_size))
self.train = train
assert args.fp16 or args.bf16
self.data_type = torch.half if args.fp16 else torch.bfloat16
if self.train:
self.transform = T.Compose([
T.RandomResizedCrop(self.image_size),
T.RandomHorizontalFlip(),
T.ColorJitter(0.4, 0.4, 0.4, 0.1),
ImageNetPolicy(),
T.ToTensor(),
T.ConvertImageDtype(self.data_type)
])
else:
self.transform = T.Compose([
T.Resize(self.image_size, interpolation=2),
T.CenterCrop(self.image_size),
T.ToTensor(),
T.ConvertImageDtype(self.data_type)
])
def gen_mask(self, image_size, mask_size, mask_type, patch_size):
# output: mask as a list with indices for missing patches
action_list = [[0, 1], [0, -1], [1, 0], [-1, 0]]
assert image_size[0] == image_size[1]
img_size_patch = image_size[0] // patch_size
# drop masked patches
mask = torch.zeros((image_size[0], image_size[1]), dtype=torch.float)
if mask_type == 'random':
x = torch.randint(0, img_size_patch, ())
y = torch.randint(0, img_size_patch, ())
for i in range(mask_size):
r = torch.randint(0, len(action_list), ())
x = torch.clamp(x + action_list[r][0], min=0, max=img_size_patch - 1)
y = torch.clamp(y + action_list[r][1], min=0, max=img_size_patch - 1)
x_offset = x * patch_size
y_offset = y * patch_size
mask[x_offset:x_offset+patch_size, y_offset:y_offset+patch_size] = 1
else:
assert mask_type == 'row'
count = 0
for x in reversed(range(img_size_patch)):
for y in reversed(range(img_size_patch)):
if (count < mask_size):
count += 1
x_offset = x * patch_size
y_offset = y * patch_size
mask[x_offset:x_offset+patch_size, y_offset:y_offset+patch_size] = 1
return mask
def __call__(self, input):
trans_input = self.transform(input)
mask = self.gen_mask(self.image_size, self.mask_size,
self.mask_type, self.patch_size)
mask = mask.unsqueeze(dim=0)
return trans_input, mask
class DinoTransform(object):
def __init__(self, image_size, train=True):
args = get_args()
self.data_type = torch.half if args.fp16 else torch.bfloat16
flip_and_color_jitter = T.Compose([
T.RandomHorizontalFlip(p=0.5),
T.RandomApply(
[T.ColorJitter(brightness=0.4, contrast=0.4,
saturation=0.2, hue=0.1)],
p=0.8
),
T.RandomGrayscale(p=0.2),
])
if args.fp16 or args.bf16:
normalize = T.Compose([
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
T.ConvertImageDtype(self.data_type)
])
else:
normalize = T.Compose([
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
# first global crop
scale_const = 0.4
self.global_transform1 = T.Compose([
T.RandomResizedCrop(image_size,
scale=(scale_const, 1),
interpolation=Image.BICUBIC),
flip_and_color_jitter,
GaussianBlur(1.0),
normalize
])
# second global crop
self.global_transform2 = T.Compose([
T.RandomResizedCrop(image_size,
scale=(scale_const, 1),
interpolation=Image.BICUBIC),
flip_and_color_jitter,
GaussianBlur(0.1),
Solarization(0.2),
normalize
])
# transformation for the local small crops
self.local_crops_number = args.local_crops_number
self.local_transform = T.Compose([
T.RandomResizedCrop(args.local_img_size,
scale=(0.05, scale_const),
interpolation=Image.BICUBIC),
flip_and_color_jitter,
GaussianBlur(p=0.5),
normalize
])
def __call__(self, image):
crops = []
args = get_args()
if args.street_data:
crop_transform = T.RandomCrop(300)
image = crop_transform(image)
crops.append(self.global_transform1(image))
crops.append(self.global_transform2(image))
for _ in range(self.local_crops_number):
crops.append(self.local_transform(image))
return crops
def build_train_valid_datasets(data_path, image_size=224):
args = get_args()
if args.vision_pretraining_type == 'classify':
train_transform = ClassificationTransform(image_size)
val_transform = ClassificationTransform(image_size, train=False)
elif args.vision_pretraining_type == 'inpaint':
train_transform = InpaintingTransform(image_size, train=False)
val_transform = InpaintingTransform(image_size, train=False)
elif args.vision_pretraining_type == 'dino':
train_transform = DinoTransform(image_size, train=True)
val_transform = ClassificationTransform(image_size, train=False)
else:
raise Exception('{} vit pretraining type is not supported.'.format(
args.vit_pretraining_type))
train_transform = ClassificationTransform(image_size)
val_transform = ClassificationTransform(image_size, train=False)
# training dataset
train_data_path = data_path[0]
train_data_path = data_path[0] if len(data_path) <= 2 else data_path[2] #TODO VIJAY
train_data = ImageFolder(
root=train_data_path,
transform=train_transform,
......
......@@ -19,6 +19,8 @@ import torch
from megatron import get_args
from megatron.model.utils import get_linear_layer
from megatron.model.vision.vit_backbone import VitBackbone, VitMlpHead
from megatron.model.vision.mit_backbone import mit_b3_avg
from megatron.model.vision.utils import trunc_normal_
from megatron.model.module import MegatronModule
class VitClassificationModel(MegatronModule):
......@@ -61,3 +63,35 @@ class VitClassificationModel(MegatronModule):
hidden_states = self.head(hidden_states)
return hidden_states
class MitClassificationModel(MegatronModule):
"""Mix vision Transformer Model."""
def __init__(self, num_classes
pre_process=True, post_process=True):
super(MitClassificationModel, self).__init__()
args = get_args()
self.hidden_size = args.hidden_size
self.num_classes = num_classes
self.backbone = mit_b3_avg()
self.head = torch.nn.Linear(512, num_classes)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, torch.nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, torch.nn.Linear) and m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
pass
def forward(self, input):
hidden_states = self.backbone(input)
hidden_states = self.head(hidden_states)
return hidden_states
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
# copied from https://github.com/facebookresearch/dino/blob/main/main_dino.py
# reworked/refactored some parts to make it run in Megatron.
import math
import apex
import einops
import torch
import numpy as np
import torch.nn.functional as F
from megatron import get_args, print_rank_0
from megatron.model.utils import get_linear_layer
from megatron.model.vision.vit_backbone import VitBackbone
from megatron.model.module import MegatronModule
from megatron.utils import print_tensor_min_max_norm as pt
from megatron.model.vision.utils import trunc_normal_
from megatron.model.vision.mit_backbone import mit_b5_avg
from megatron.model.vision.esvit_swin_backbone import get_swin
from megatron.model.vision.av_cam_trunk import get_av_cam_trunk
class DINOLoss(torch.nn.Module):
def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp,
warmup_teacher_temp_epochs, nepochs, student_temp=0.1,
center_momentum=0.9):
super().__init__()
self.student_temp = student_temp
self.center_momentum = center_momentum
self.ncrops = ncrops
self.register_buffer("center", torch.zeros(1, out_dim))
# we apply a warm up for the teacher temperature because
# a too high temperature makes the training instable at the beginning
self.teacher_temp_schedule = np.concatenate((
np.linspace(warmup_teacher_temp,
teacher_temp, warmup_teacher_temp_epochs),
np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
))
self.teacher_temp = teacher_temp
def forward(self, student_output, teacher_output, iteration):
"""
Cross-entropy between softmax outputs of the teacher
and student network.
"""
args = get_args()
student_out = student_output / self.student_temp
student_out = student_out.chunk(self.ncrops)
epoch = iteration // args.iter_per_epoch
# teacher centering and sharpening
temp = self.teacher_temp_schedule[epoch]
teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
teacher_out = teacher_out.detach().chunk(2)
total_loss = 0
n_loss_terms = 0
for iq, q in enumerate(teacher_out):
for v in range(len(student_out)):
if v == iq:
# we skip cases where student and teacher operate on the same view
continue
loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
total_loss += loss.mean()
n_loss_terms += 1
total_loss /= n_loss_terms
self.update_center(teacher_output)
return total_loss
@torch.no_grad()
def update_center(self, teacher_output):
"""
Update center used for teacher output.
"""
batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
torch.distributed.all_reduce(batch_center)
batch_center = batch_center / (len(teacher_output) * torch.distributed.get_world_size())
self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
class DINOHead(torch.nn.Module):
def __init__(self, in_dim, out_dim, norm_last_layer=True, nlayers=3):
super().__init__()
args = get_args()
hidden_dim = args.dino_head_hidden_size
bottleneck_dim = args.dino_bottleneck_size
nlayers = max(nlayers, 1)
if nlayers == 1:
self.mlp = torch.nn.Linear(in_dim, bottleneck_dim)
else:
layers = [torch.nn.Linear(in_dim, hidden_dim)]
layers.append(torch.nn.GELU())
for _ in range(nlayers - 2):
layers.append(torch.nn.Linear(hidden_dim, hidden_dim))
layers.append(torch.nn.GELU())
layers.append(torch.nn.Linear(hidden_dim, bottleneck_dim))
self.mlp = torch.nn.Sequential(*layers)
self.apply(self._init_weights)
self.last_layer = torch.nn.utils.weight_norm(torch.nn.Linear(bottleneck_dim, out_dim, bias=False))
self.last_layer.weight_g.data.fill_(1)
if norm_last_layer:
self.last_layer.weight_g.requires_grad = False
def _init_weights(self, m):
if isinstance(m, torch.nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, torch.nn.Linear) and m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.mlp(x)
x = torch.nn.functional.normalize(x, dim=-1, p=2)
x = self.last_layer(x)
return x
class MultiCropWrapper(MegatronModule):
"""
Perform forward pass separately on each resolution input.
The inputs corresponding to a single resolution are clubbed and single
forward is run on the same resolution inputs. Hence we do several
forward passes = number of different resolutions used. We then
concatenate all the output features and run the head forward on these
concatenated features.
"""
def __init__(self, backbone, head):
super(MultiCropWrapper, self).__init__()
# disable layers dedicated to ImageNet labels classification
#backbone.fc, backbone.head = torch.nn.Identity(), torch.nn.Identity()
self.backbone = backbone
self.head = head
def forward(self, x):
# convert to list
if not isinstance(x, list):
x = [x]
idx_crops = torch.cumsum(torch.unique_consecutive(
torch.tensor([inp.shape[-1] for inp in x]),
return_counts=True,
)[1], 0)
start_idx = 0
for end_idx in idx_crops:
_out = self.backbone(torch.cat(x[start_idx: end_idx]))
if start_idx == 0:
output = _out
else:
output = torch.cat((output, _out))
start_idx = end_idx
# Run the head forward on the concatenated features.
if self.training:
return self.head(output)
else:
return output
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep,
warmup_epochs=0, start_warmup_value=0):
warmup_schedule = np.array([])
warmup_iters = warmup_epochs * niter_per_ep
if warmup_epochs > 0:
warmup_schedule = \
np.linspace(start_warmup_value, base_value, warmup_iters)
iters = np.arange(epochs * niter_per_ep - warmup_iters)
schedule = final_value + 0.5 * (base_value - final_value) \
* (1 + np.cos(np.pi * iters / len(iters)))
schedule = np.concatenate((warmup_schedule, schedule))
assert len(schedule) == epochs * niter_per_ep
return schedule
def get_student_backbone_and_num_features(pre_process=True, post_process=True):
args = get_args()
if args.vision_backbone_type == 'vit':
student = VitBackbone(pre_process=pre_process,
post_process=post_process,
drop_path_rate=0.1,
single_token_output=True)
num_features = args.hidden_size
elif args.vision_backbone_type == 'mit':
student = mit_b5_avg(drop_path_rate=0.1)
num_features = 512
elif args.vision_backbone_type == 'swin':
student = get_swin()
num_features = student.num_features
else:
raise Exception('{} vision backbone is not supported.'.format(
args.vision_backbone_type))
return student, num_features
def get_teacher_backbone_and_num_features(pre_process=True, post_process=True):
args = get_args()
if args.vision_backbone_type == 'vit':
teacher = VitBackbone(pre_process=pre_process,
post_process=post_process,
single_token_output=True)
num_features = args.hidden_size
elif args.vision_backbone_type == 'mit':
teacher = mit_b5_avg(drop_path_rate=0.0)
num_features = 512
elif args.vision_backbone_type == 'swin':
teacher = get_swin(is_teacher=True)
num_features = teacher.num_features
else:
raise Exception('{} vision backbone is not supported.'.format(
args.vision_backbone_type))
return teacher, num_features
class DINOPretrainModel(MegatronModule):
def __init__(self, pre_process=True, post_process=True):
super(DINOPretrainModel, self).__init__()
args = get_args()
self.out_dim = 65536
self.dino_loss = DINOLoss(
self.out_dim,
args.dino_local_crops_number + 2,
args.dino_warmup_teacher_temp,
args.dino_teacher_temp,
args.dino_warmup_teacher_temp_epochs,
300,
)
self.pre_process = pre_process
self.post_process = post_process
self.momentum_teacher = 0.996
student_backbone, num_features = \
get_student_backbone_and_num_features(pre_process, post_process)
self.student = MultiCropWrapper(
student_backbone,
DINOHead(num_features, self.out_dim,
norm_last_layer=args.dino_norm_last_layer)
)
self.momentum_schedule = cosine_scheduler(
self.momentum_teacher, 1,
args.train_iters // args.iter_per_epoch,
args.iter_per_epoch
)
teacher_backbone, num_features = \
get_teacher_backbone_and_num_features(pre_process, post_process)
self.teacher = MultiCropWrapper(
teacher_backbone,
DINOHead(num_features, self.out_dim)
)
self.teacher.load_state_dict(self.student.state_dict())
for p in self.teacher.parameters():
if hasattr(p, "requires_grad") and p.requires_grad is not None:
p.requires_grad = False
def set_input_tensor(self, tensor):
pass
def forward(self, input):
student_output = None
if self.training:
student_output = self.student(input)
teacher_output = self.teacher(input[:2])
else:
teacher_output = self.teacher(input)
return student_output, teacher_output
def cancel_gradients_last_layer(self, iteration):
args = get_args()
epoch = iteration // args.iter_per_epoch
if epoch < args.dino_freeze_last_layer:
for n, p in self.student.named_parameters():
if "last_layer" in n:
p.grad = None
def update_momentum(self, iteration):
with torch.no_grad():
m = self.momentum_schedule[iteration]
for param_q, param_k in zip(self.student.parameters(), self.teacher.parameters()):
param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
# Copyright (c) 2021 Microsoft
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Modified by Chunyuan Li (chunyl@microsoft.com)
# Swin Transformer
# --------------------------------------------------------
import os
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
import torch.distributed as dist
from megatron.model.vision.utils import DropPath, trunc_normal_
from megatron import get_args
from megatron.model import LayerNorm
import numpy as np
from math import sqrt
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None,
out_features=None, act_layer=nn.GELU, drop=0.):
super(Mlp, self).__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r"""Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super(WindowAttention, self).__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2 Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0).type(attn.type())
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn_out = attn
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x, attn_out
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
@staticmethod
def compute_macs(module, input, output):
B, N, C = input[0].shape
module.__flops__ += module.flops(N) * B
class SwinTransformerBlock(nn.Module):
r"""Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.H = input_resolution[0]
self.W = input_resolution[1]
self.attn_mask_dict = {}
def create_attn_mask(self, H, W):
# calculate attention mask for SW-MSA
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = torch.zeros((1, Hp, Wp, 1)) # 1 Hp Wp 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
def forward(self, x):
B, L, C = x.shape
H = int(sqrt(L))
W = H
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# pad feature maps to multiples of window size
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
if H in self.attn_mask_dict.keys():
attn_mask = self.attn_mask_dict[H]
else:
self.attn_mask_dict[H] = self.create_attn_mask(self.H, self.W).to(x.device)
attn_mask = self.attn_mask_dict[H]
else:
shifted_x = x
attn_mask = None
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows, attn = self.attn(x_windows, attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x, attn
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
f"window_size={self.window_size}, shift_size={self.shift_size} mlp_ratio={self.mlp_ratio}"
def flops(self):
flops = 0
H, W = self.input_resolution
# norm1
flops += self.dim * H * W
# W-MSA/SW-MSA
nW = H * W / self.window_size / self.window_size
flops += nW * self.attn.flops(self.window_size * self.window_size)
# mlp
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * H * W
return flops
class PatchMerging(nn.Module):
r"""Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
""" Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
B, L, C = x.shape
H = int(sqrt(L))
W = H
x = x.view(B, H, W, C)
# padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"
def flops(self):
H, W = self.input_resolution
flops = H * W * self.dim
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
return flops
class BasicLayer(nn.Module):
"""A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x):
for blk in self.blocks:
x, _ = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
def forward_with_features(self, x):
fea = []
for blk in self.blocks:
x, _ = blk(x)
fea.append(x)
if self.downsample is not None:
x = self.downsample(x)
return x, fea
def forward_with_attention(self, x):
attns = []
for blk in self.blocks:
x, attn = blk(x)
attns.append(attn)
if self.downsample is not None:
x = self.downsample(x)
return x, attns
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops()
if self.downsample is not None:
flops += self.downsample.flops()
return flops
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None):
super().__init__()
img_size = (img_size, img_size)
patch_size = (patch_size, patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
def flops(self):
Ho, Wo = self.patches_resolution
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops
class SwinTransformer(nn.Module):
r""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
img_size (int | tuple(int)): Input image size.
patch_size (int | tuple(int)): Patch size.
in_chans (int): Number of input channels.
num_classes (int): Number of classes for classification head.
embed_dim (int): Embedding dimension.
depths (tuple(int)): Depth of Swin Transformer layers.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
drop_rate (float): Dropout rate.
attn_drop_rate (float): Attention dropout rate.
drop_path_rate (float): Stochastic depth rate.
norm_layer (nn.Module): normalization layer.
ape (bool): If True, add absolute position embedding to the patch embedding.
patch_norm (bool): If True, add normalization after patch embedding.
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True, **kwargs):
super().__init__()
self.num_classes = num_classes
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.mlp_ratio = mlp_ratio
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
input_resolution=(patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer)),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
# todo: to be implemented
return {'relative_position_bias_table'}
def forward(self, x):
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x_region = self.norm(x) # B L C
x = self.avgpool(x_region.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
return x
def forward_feature_maps(self, x):
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x_grid = self.norm(x) # B L C
x = self.avgpool(x_grid.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
return x, x_grid
def forward_selfattention(self, x, n=1):
# n=1 return the last layer attn map; otherwise return attn maps in all layers
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
if n==1:
return self.forward_last_selfattention(x)
else:
return self.forward_all_selfattention(x)
def forward_last_selfattention(self, x):
for i, layer in enumerate(self.layers):
if i < len(self.layers) - 1:
x = layer(x)
else:
x, attns = layer.forward_with_attention(x)
return attns[-1]
def forward_all_selfattention(self, x):
attn_out = []
for layer in self.layers:
x, attns = layer.forward_with_attention(x)
attn_out += attns
return attn_out
def forward_return_n_last_blocks(self, x, n=1, return_patch_avgpool=False, depth=[]):
num_blks = sum(depth)
start_idx = num_blks - n
sum_cur = 0
for i, d in enumerate(depth):
sum_cur_new = sum_cur + d
if start_idx >= sum_cur and start_idx < sum_cur_new:
start_stage = i
start_blk = start_idx - sum_cur
sum_cur = sum_cur_new
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
# we will return the averaged token features from the `n` last blocks
# note: there is no [CLS] token in Swin Transformer
output = []
s = 0
for i, layer in enumerate(self.layers):
x, fea = layer.forward_with_features(x)
if i >= start_stage:
for x_ in fea[start_blk:]:
if i == len(self.layers)-1: # use the norm in the last stage
x_ = self.norm(x_)
x_avg = torch.flatten(self.avgpool(x_.transpose(1, 2)), 1) # B C
# print(f'Stage {i}, x_avg {x_avg.shape}')
output.append(x_avg)
start_blk = 0
return torch.cat(output, dim=-1)
def flops(self):
flops = 0
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
flops += layer.flops()
if dist.get_rank() == 0:
print(f"GFLOPs layer_{i}: {layer.flops() / 1e9}")
flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
flops += self.num_features * self.num_classes
return flops
def init_weights(self, pretrained='', pretrained_layers=[], verbose=True):
if os.path.isfile(pretrained):
pretrained_dict = torch.load(pretrained, map_location='cpu')
logging.info(f'=> loading pretrained model {pretrained}')
model_dict = self.state_dict()
pretrained_dict = {
k: v for k, v in pretrained_dict.items()
if k in model_dict.keys()
}
need_init_state_dict = {}
for k, v in pretrained_dict.items():
need_init = (
k.split('.')[0] in pretrained_layers
or pretrained_layers[0] is '*'
or 'relative_position_index' not in k
or 'attn_mask' not in k
)
if need_init:
if verbose:
logging.info(f'=> init {k} from {pretrained}')
if 'relative_position_bias_table' in k and v.size() != model_dict[k].size():
relative_position_bias_table_pretrained = v
relative_position_bias_table_current = model_dict[k]
L1, nH1 = relative_position_bias_table_pretrained.size()
L2, nH2 = relative_position_bias_table_current.size()
if nH1 != nH2:
logging.info(f"Error in loading {k}, passing")
else:
if L1 != L2:
logging.info(
'=> load_pretrained: resized variant: {} to {}'
.format((L1, nH1), (L2, nH2))
)
S1 = int(L1 ** 0.5)
S2 = int(L2 ** 0.5)
relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
size=(S2, S2),
mode='bicubic')
v = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)
if 'absolute_pos_embed' in k and v.size() != model_dict[k].size():
absolute_pos_embed_pretrained = v
absolute_pos_embed_current = model_dict[k]
_, L1, C1 = absolute_pos_embed_pretrained.size()
_, L2, C2 = absolute_pos_embed_current.size()
if C1 != C1:
logging.info(f"Error in loading {k}, passing")
else:
if L1 != L2:
logging.info(
'=> load_pretrained: resized variant: {} to {}'
.format((1, L1, C1), (1, L2, C2))
)
S1 = int(L1 ** 0.5)
S2 = int(L2 ** 0.5)
absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1)
absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2)
absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(
absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic')
v = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1).flatten(1, 2)
need_init_state_dict[k] = v
self.load_state_dict(need_init_state_dict, strict=False)
def freeze_pretrained_layers(self, frozen_layers=[]):
for name, module in self.named_modules():
if (
name.split('.')[0] in frozen_layers
or '.'.join(name.split('.')[0:2]) in frozen_layers
or (len(frozen_layers) > 0 and frozen_layers[0] is '*')
):
for _name, param in module.named_parameters():
param.requires_grad = False
logging.info(
'=> set param {} requires grad to False'
.format(name)
)
for name, param in self.named_parameters():
if (
name.split('.')[0] in frozen_layers
or (len(frozen_layers) > 0 and frozen_layers[0] is '*')
and param.requires_grad is True
):
param.requires_grad = False
logging.info(
'=> set param {} requires grad to False'
.format(name)
)
return self
def get_swin(is_teacher=False):
args = get_args()
if args.swin_type == "tiny":
embed_dim = 96
depths = [2, 2, 6, 2]
num_heads = [3, 6, 12, 24]
drop_path_rate = 0.1
elif args.swin_type == 'h3':
embed_dim = 384
depths = [2, 2, 18, 2]
num_heads = [6, 12, 24, 48]
drop_path_rate = 0.2
else:
embed_dim = 128
depths = [2, 2, 18, 2]
num_heads = [4, 8, 16, 32]
drop_path_rate = 0.2
swin = SwinTransformer(
img_size=224,
in_chans=3,
num_classes=1000,
patch_size=4,
embed_dim=embed_dim,
depths=depths,
num_heads=num_heads,
window_size=7,
mlp_ratio=4,
qkv_bias=True,
drop_rate=0,
attn_drop_rate=0,
drop_path_rate=(0.0 if is_teacher else drop_path_rate),
norm_layer=partial(LayerNorm, eps=1e-6),
ape=False,
patch_norm=True,
)
return swin
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision Transformer(VIT) model."""
import math
import apex
import einops
import torch
import torch.nn.functional as F
from megatron import get_args, print_rank_0
from megatron.model.utils import get_linear_layer
from megatron.model.vision.vit_backbone import VitBackbone
from megatron.model.module import MegatronModule
from megatron.model.vision.mit_backbone import mit_b3
from megatron.model.vision.utils import resize, trunc_normal_
class VitInpaintingModel(MegatronModule):
def __init__(self, pre_process=True, post_process=True):
super(VitInpaintingModel, self).__init__()
args = get_args()
self.pre_process = pre_process
self.post_process = post_process
self.hidden_size = args.hidden_size
self.backbone = VitBackbone(
pre_process=self.pre_process,
post_process=self.post_process,
class_token=False,
)
self.patch_dim = args.patch_dim
self.img_h = args.img_h
self.img_w = args.img_w
self.seq_length = args.seq_length
# full mask
if self.post_process:
self.linear_decoder = get_linear_layer(
self.hidden_size,
self.backbone.flatten_dim,
torch.nn.init.zeros_
)
def set_input_tensor(self, input_tensor):
self.backbone.set_input_tensor(input_tensor)
def forward(self, input):
hidden_states = self.backbone(input)
if not self.post_process:
return hidden_states
decoded_output = self.linear_decoder(hidden_states)
output = einops.rearrange(
decoded_output,
"b (h w) (p1 p2 c) -> b c (h p1) (w p2)",
p1=self.patch_dim,
p2=self.patch_dim,
h=self.img_h//self.patch_dim,
w=self.img_w//self.patch_dim,
)
return output
class MLP(torch.nn.Module):
"""
Linear Embedding
"""
def __init__(self, input_dim=2048, embed_dim=768):
super().__init__()
self.proj = torch.nn.Linear(input_dim, embed_dim)
def forward(self, x):
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
return x
class MitInpaintingModel(MegatronModule):
"""Mix vision Transformer Model."""
def __init__(self, pre_process=True, post_process=True):
super(MitInpaintingModel, self).__init__()
self.pre_process = pre_process
self.post_process = post_process
args = get_args()
self.patch_dim = args.patch_dim
self.img_h = args.img_h
self.img_w = args.img_w
self.flatten_dim = self.patch_dim * self.patch_dim * 3
self.backbone = mit_b3()
self.in_channels = [64, 128, 320, 512]
self.embedding_dim = 768
c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels
self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=self.embedding_dim)
self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=self.embedding_dim)
self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=self.embedding_dim)
self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=self.embedding_dim)
self.conv_fuse = torch.nn.Conv2d(self.embedding_dim*4, self.embedding_dim, 1, 1, bias=False)
self.norm = apex.parallel.SyncBatchNorm(self.embedding_dim)
self.dropout = torch.nn.Dropout2d(0.1)
self.linear_pred = torch.nn.Conv2d(self.embedding_dim, self.flatten_dim, kernel_size=1)
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
pass
def forward(self, input):
c1, c2, c3, c4 = self.backbone(input)
n, _, h, w = c4.shape
_c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3])
_c4 = resize(_c4, size=c1.size()[2:], mode='bilinear', align_corners=False)
_c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3])
_c3 = resize(_c3, size=c1.size()[2:], mode='bilinear', align_corners=False)
_c2 = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3])
_c2 = resize(_c2, size=c1.size()[2:], mode='bilinear', align_corners=False)
_c1 = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3])
_c = torch.cat([_c4, _c3, _c2, _c1], dim=1)
_c = self.conv_fuse(_c)
x = self.norm(_c)
x = F.relu(x, inplace=True)
x = self.dropout(x)
x = self.linear_pred(x)
output = einops.rearrange(
x,
"b (c p1 p2) h w -> b c (h p1) (w p2)",
p1=self.patch_dim,
p2=self.patch_dim,
h=self.img_h//self.patch_dim,
w=self.img_w//self.patch_dim,
)
return output
import torch.nn.functional as F
import torch
from megatron import print_rank_0, get_args, mpu
from megatron.data.vit_dataset import ClassificationTransform
from megatron.data.image_folder import ImageFolder
def build_data_loader(dataset, drop_last=True, shuffle=False):
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
# Sampler.
args = get_args()
micro_batch_size = 16
num_workers = args.num_workers
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=world_size, rank=rank,
drop_last=drop_last, shuffle=shuffle
)
# Data loader. Note that batch size is the per GPU batch size.
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=micro_batch_size,
sampler=sampler,
shuffle=False,
num_workers=num_workers,
drop_last=not drop_last,
pin_memory=True,
)
return data_loader
def compute_feature_bank(model):
args = get_args()
feature_bank = []
feature_label = []
train_ds = ImageFolder(
root=args.data_path[0],
transform=ClassificationTransform((args.img_h, args.img_w), train=False),
data_per_class_fraction=1.0
)
classes = len(train_ds.classes)
dataloader = build_data_loader(train_ds)
for m in model:
m.eval()
with torch.no_grad():
for i, batch in enumerate(dataloader):
images = batch[0].cuda().contiguous()
labels = batch[1].cuda().contiguous()
student_feature, teacher_feature = model[0](images)
feature = F.normalize(teacher_feature.float(), dim=1)
feature_bank.append(feature)
feature_label.append(labels)
for m in model:
m.train()
# [N', D]
feature_bank = torch.cat(feature_bank, dim=0).contiguous()
feature_label = torch.cat(feature_label, dim=0).contiguous()
feature_banks = [torch.zeros_like(feature_bank)
for i in range(mpu.get_data_parallel_world_size())]
torch.distributed.all_gather(feature_banks,
feature_bank,
group=mpu.get_data_parallel_group())
assert torch.all(torch.eq(feature_banks[mpu.get_data_parallel_rank()],
feature_bank))
feature_labels = [torch.zeros_like(feature_label)
for i in range(mpu.get_data_parallel_world_size())]
torch.distributed.all_gather(feature_labels,
feature_label,
group=mpu.get_data_parallel_group())
# [D, N]
feature_banks = torch.cat(feature_banks, dim=0).t().contiguous()
# [N]
feature_labels = torch.cat(feature_labels, dim=0).contiguous()
print_rank_0("feature_banks size is {}".format(feature_banks.size()))
print_rank_0("feature labels size is {}".format(feature_labels.size()))
return (feature_banks, feature_labels, classes)
# knn monitor as in InstDisc https://arxiv.org/abs/1805.01978
# implementation follows http://github.com/zhirongw/lemniscate.pytorch and
# https://github.com/leftthomas/SimCLR
def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t):
# compute cos similarity between each feature vector and feature bank ---> [B, N]
sim_matrix = torch.mm(feature, feature_bank)
# [B, K]
sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
# [B, K]
sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1),
dim=-1,
index=sim_indices)
sim_weight = (sim_weight / knn_t).exp()
# counts for each class
one_hot_label = torch.zeros(feature.size(0) * knn_k,
classes,
device=sim_labels.device)
# [B*K, C]
one_hot_label = one_hot_label.scatter(dim=-1,
index=sim_labels.view(-1, 1),
value=1.0)
# weighted score ---> [B, C]
pred_scores = torch.sum(
one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1),
dim=1)
pred_labels = pred_scores.argsort(dim=-1, descending=True)
return pred_labels
# ---------------------------------------------------------------
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# ---------------------------------------------------------------
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from megatron.model.vision.utils import DropPath, trunc_normal_
from megatron.model import LayerNorm
class Mlp(nn.Module):
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.dwconv = DWConv(hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = self.fc1(x)
x = self.dwconv(x, H, W)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
sr_ratio=1):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = LayerNorm(dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
B, N, C = x.shape
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
if self.sr_ratio > 1:
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
x_ = self.norm(x_)
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
else:
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=LayerNorm, sr_ratio=1):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
return x
class OverlapPatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
super().__init__()
img_size = (img_size, img_size)
patch_size = (patch_size, patch_size)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
padding=(patch_size[0] // 2, patch_size[1] // 2))
self.norm = LayerNorm(embed_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, H, W
class MixVisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
attn_drop_rate=0., drop_path_rate=0., norm_layer=LayerNorm,
depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], output_avg=False):
super().__init__()
self.num_classes = num_classes
self.depths = depths
self.output_avg = output_avg
# patch_embed
self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
embed_dim=embed_dims[0])
self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
embed_dim=embed_dims[1])
self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
embed_dim=embed_dims[2])
self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
embed_dim=embed_dims[3])
# transformer encoder
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
cur = 0
self.block1 = nn.ModuleList([Block(
dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[0])
for i in range(depths[0])])
self.norm1 = norm_layer(embed_dims[0])
cur += depths[0]
self.block2 = nn.ModuleList([Block(
dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[1])
for i in range(depths[1])])
self.norm2 = norm_layer(embed_dims[1])
cur += depths[1]
self.block3 = nn.ModuleList([Block(
dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[2])
for i in range(depths[2])])
self.norm3 = norm_layer(embed_dims[2])
cur += depths[2]
self.block4 = nn.ModuleList([Block(
dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[3])
for i in range(depths[3])])
self.norm4 = norm_layer(embed_dims[3])
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def reset_drop_path(self, drop_path_rate):
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
cur = 0
for i in range(self.depths[0]):
self.block1[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[0]
for i in range(self.depths[1]):
self.block2[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[1]
for i in range(self.depths[2]):
self.block3[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[2]
for i in range(self.depths[3]):
self.block4[i].drop_path.drop_prob = dpr[cur + i]
def freeze_patch_emb(self):
self.patch_embed1.requires_grad = False
def forward_features(self, x):
B = x.shape[0]
outs = []
# stage 1
x, H, W = self.patch_embed1(x)
for i, blk in enumerate(self.block1):
x = blk(x, H, W)
x = self.norm1(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
# stage 2
x, H, W = self.patch_embed2(x)
for i, blk in enumerate(self.block2):
x = blk(x, H, W)
x = self.norm2(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
# stage 3
x, H, W = self.patch_embed3(x)
for i, blk in enumerate(self.block3):
x = blk(x, H, W)
x = self.norm3(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
# stage 4
x, H, W = self.patch_embed4(x)
for i, blk in enumerate(self.block4):
x = blk(x, H, W)
x = self.norm4(x)
if not self.output_avg:
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
return outs
def forward(self, x):
x = self.forward_features(x)
if self.output_avg:
x = x[3].mean(dim=1)
return x
class DWConv(nn.Module):
def __init__(self, dim=768):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
def forward(self, x, H, W):
B, N, C = x.shape
x = x.transpose(1, 2).view(B, C, H, W)
x = self.dwconv(x)
x = x.flatten(2).transpose(1, 2)
return x
class mit_b0(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b0, self).__init__(
patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
class mit_b1(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b1, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
class mit_b2(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b2, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
class mit_b3(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b3, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
class mit_b3_avg(MixVisionTransformer):
def __init__(self, drop_path_rate=0.1, **kwargs):
super(mit_b3_avg, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=drop_path_rate, output_avg=True)
class mit_b4(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b4, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
class mit_b5(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b5, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
class mit_b5_avg(MixVisionTransformer):
def __init__(self, drop_path_rate=0.1, **kwargs):
super(mit_b5_avg, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=drop_path_rate, output_avg=True)
# Copyright (c) 2021 Microsoft
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Swin Transformer
# --------------------------------------------------------
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from math import sqrt
from megatron import get_args
from functools import partial
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None,
out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.H = input_resolution[0]
self.W = input_resolution[1]
self.attn_mask_dict = {}
def create_attn_mask(self, H, W):
# calculate attention mask for SW-MSA
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = torch.zeros((1, Hp, Wp, 1)) # 1 Hp Wp 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
def forward(self, x):
B, L, C = x.shape
H = int(sqrt(L))
W = H
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
def flops(self):
flops = 0
H, W = self.input_resolution
# norm1
flops += self.dim * H * W
# W-MSA/SW-MSA
nW = H * W / self.window_size / self.window_size
flops += nW * self.attn.flops(self.window_size * self.window_size)
# mlp
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * H * W
return flops
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"
def flops(self):
H, W = self.input_resolution
flops = H * W * self.dim
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
return flops
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
x_b4_ds = x
if self.downsample is not None:
x = self.downsample(x)
return x_b4_ds, x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops()
if self.downsample is not None:
flops += self.downsample.flops()
return flops
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
def flops(self):
Ho, Wo = self.patches_resolution
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops
class SwinTransformer(nn.Module):
r""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
img_size (int | tuple(int)): Input image size. Default 224
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3,
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3,
norm_layer=partial(nn.LayerNorm, eps=1e-6), ape=False, patch_norm=True,
use_checkpoint=False, output_avg=False, **kwargs):
super().__init__()
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.mlp_ratio = mlp_ratio
self.img_size = to_2tuple(img_size)
self.patch_size = to_2tuple(patch_size)
self.output_avg = output_avg
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
input_resolution=(patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer)),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint)
self.layers.append(layer)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
def forward(self, x):
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
h = self.img_size[0] // self.patch_size[0]
w = self.img_size[1] // self.patch_size[1]
outs = []
for i, layer in enumerate(self.layers):
px, x = layer(x)
b, n, c = px.shape
if i != len(self.layers) - 1 or not self.output_avg:
px = px.permute(0, 2, 1).contiguous()
px = px.reshape(b, c, h, w)
# is this a fair assumption ?? i think it's baked into the architecture
h, w = h//2, w//2
outs.append(px)
if self.output_avg:
return outs[-1].mean(dim=1)
return outs
def flops(self):
flops = 0
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
flops += layer.flops()
flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
flops += self.num_features * self.num_classes
return flops
def get_swin(drop_path_rate=0.3, output_avg=False):
args = get_args()
window_size = 7
embed_dim = 128
depths = [2, 2, 18, 2]
num_heads = [4, 8, 16, 32]
swin = SwinTransformer(
img_size=(args.img_h, args.img_w,),
in_chans=3,
patch_size=args.patch_dim,
embed_dim=embed_dim,
depths=depths,
num_heads=num_heads,
window_size=window_size,
drop_path_rate=drop_path_rate,
output_avg=output_avg,
)
return swin
import warnings
import math
from itertools import repeat
import torch
import torch.nn as nn
import torch.nn.functional as F
def resize(input,
size=None,
scale_factor=None,
mode='nearest',
align_corners=None,
warning=True):
if warning:
if size is not None and align_corners:
input_h, input_w = tuple(int(x) for x in input.shape[2:])
output_h, output_w = tuple(int(x) for x in size)
if output_h > input_h or output_w > output_h:
if ((output_h > 1 and output_w > 1 and input_h > 1
and input_w > 1) and (output_h - 1) % (input_h - 1)
and (output_w - 1) % (input_w - 1)):
warnings.warn(
f'When align_corners={align_corners}, '
'the output would more aligned if '
f'input size {(input_h, input_w)} is `x+1` and '
f'out size {(output_h, output_w)} is `nx+1`')
if isinstance(size, torch.Size):
size = tuple(int(x) for x in size)
return F.interpolate(input, size, scale_factor, mode, align_corners)
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
# type: (Tensor, float, float, float, float) -> Tensor
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
......@@ -51,7 +51,7 @@ from megatron.data.data_samplers import build_pretraining_data_loader
from megatron.utils import calc_params_l2_norm
from megatron.schedules import get_forward_backward_func
from megatron.utils import report_memory
from megatron.model.vision.knn_monitor import compute_feature_bank
def print_datetime(string):
......@@ -465,11 +465,23 @@ def train_step(forward_step_func, data_iterator,
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
timers('backward-embedding-all-reduce').stop()
if args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module))
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)
# Update parameters.
timers('optimizer').start()
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
timers('optimizer').stop()
if args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module))
unwrapped_model.update_momentum(args.curr_iteration)
# Update learning rate.
if update_successful:
increment = get_num_microbatches() * \
......@@ -702,6 +714,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
report_memory_flag = True
while iteration < args.train_iters:
update_num_microbatches(args.consumed_train_samples)
args.curr_iteration = iteration
loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
train_step(forward_step_func,
train_data_iterator,
......@@ -791,6 +804,9 @@ def evaluate(forward_step_func,
"""Evaluation."""
args = get_args()
if args.vision_pretraining_type == "contrast":
args.knn_features = compute_feature_bank(model)
# Turn on evaluation mode which disables dropout.
for model_module in model:
model_module.eval()
......
......@@ -22,20 +22,31 @@ from megatron import get_args, get_timers, mpu, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model import ModelType
from megatron.model.vision.classification import VitClassificationModel
from megatron.model.vision.classification import MitClassificationModel
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0("building VIT model ...")
args = get_args()
model = VitClassificationModel(num_classes=args.num_classes,
pre_process=pre_process,
post_process=post_process)
if args.vision_backbone_type == 'vit':
model = VitClassificationModel(num_classes=args.num_classes,
pre_process=pre_process,
post_process=post_process)
elif args.vision_backbone_type == 'mit':
model = MitClassificationModel(num_classes=args.num_classes,
pre_process=pre_process,
post_process=post_process)
else:
raise Exception('{} vision backbone is not supported.'.format(
args.vision_backbone_type))
return model
def get_batch(data_iterator):
"""Build the batch."""
data = next(data_iterator)
......@@ -46,6 +57,7 @@ def get_batch(data_iterator):
return images, labels
def loss_func(labels, output_tensor):
logits = output_tensor.contiguous().float()
loss = F.cross_entropy(logits, labels)
......@@ -58,6 +70,7 @@ def loss_func(labels, output_tensor):
return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]}
def forward_step(data_iterator, model):
"""Forward step."""
timers = get_timers()
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain VIT"""
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import torch.distributed as dist
from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vision.contrastive import DINOPretrainModel
from megatron.model.vision.knn_monitor import knn_predict
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group, unwrap_model
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0("building VIT model ...")
return DINOPretrainModel(pre_process=pre_process, post_process=post_process)
def get_batch(data_iterator):
"""Build the batch."""
data = next(data_iterator)
# only data parallelism; no need for broadcast
if isinstance(data[0], list):
images = [aug.cuda() for aug in data[0]]
else:
images = data[0].cuda()
labels = data[1].cuda()
return images, labels
def loss_func(model, labels, output_tensor, collect_data=False):
args = get_args()
model = unwrap_model(
model,
(torchDDP, LocalDDP, Float16Module)
)
if model.training:
student_output, teacher_output = output_tensor
loss = model.dino_loss(student_output, teacher_output, args.curr_iteration)
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {"loss": averaged_loss[0]}
else:
_, teacher_feature = output_tensor
feature_bank, feature_labels, classes = args.knn_features
feature = F.normalize(teacher_feature.float(), dim=1)
knn_accs = []
for k in [10, 20, 100, 200]:
pred_labels = knn_predict(feature, feature_bank,
feature_labels, classes, k, 0.07)
knn_acc = (pred_labels[:, 0] == labels).float().mean()
knn_accs.append(knn_acc)
averaged_loss = average_losses_across_data_parallel_group(knn_accs)
return 0, {"knn_acc_10": averaged_loss[0],
"knn_acc_20": averaged_loss[1],
"knn_acc_100": averaged_loss[2],
"knn_acc_200": averaged_loss[3]}
def forward_step(data_iterator, model):
"""Forward step."""
timers = get_timers()
# Get the batch.
timers("batch-generator").start()
(
images,
labels,
) = get_batch(data_iterator)
timers("batch-generator").stop()
return model(images), partial(loss_func, model, labels)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
print_rank_0(
"> building train, validation, and test datasets " "for VIT ..."
)
train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
image_size=(args.img_h, args.img_w)
)
print_rank_0("> finished creating VIT datasets ...")
return train_ds, valid_ds, None
if __name__ == "__main__":
pretrain(
train_valid_test_datasets_provider,
model_provider,
forward_step,
args_defaults={'dataloader_type': 'cyclic'}
)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain VIT"""
import torch
import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0, print_rank_last
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vision.inpainting import VitInpaintingModel
from megatron.model.vision.inpainting import MitInpaintingModel
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
from tasks.vision.metrics import SSIM, PSNR
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
args = get_args()
if args.vision_backbone_type == 'vit':
model = VitInpaintingModel(pre_process=pre_process,
post_process=post_process)
elif args.vision_backbone_type == 'mit':
model = MitInpaintingModel(pre_process=pre_process,
post_process=post_process)
else:
raise Exception('{} vision backbone is not supported.'.format(
args.vision_backbone_type))
return model
def get_batch(data_iterator):
"""Build the batch."""
data = next(data_iterator)
# only data parallelism; no need for broadcast
images = data[0][0].cuda()
masks = data[0][1].cuda()
return images, masks
def loss_func(images, masks, masked_images, outputs, collect_data=False):
outputs = outputs.contiguous().float()
masks_flip = 1-masks
flip_masked_outputs = outputs.masked_fill(masks_flip.bool(), 0)
flip_masked_images = images.masked_fill(masks_flip.bool(), 0)
ssim_fun = SSIM()
psnr_fun = PSNR()
if not collect_data:
mask_count = torch.count_nonzero(masks)
loss = F.mse_loss(
flip_masked_outputs,
flip_masked_images.float(),
reduction="sum"
)
loss = loss/mask_count
ssim = ssim_fun(flip_masked_outputs, flip_masked_images.float())
psnr = psnr_fun(flip_masked_outputs, flip_masked_images.float())
averaged_loss = average_losses_across_data_parallel_group(
[loss, psnr, ssim]
)
return loss, {"loss": averaged_loss[0],
"psnr": averaged_loss[1],
'ssim': averaged_loss[2]}
else:
synth_images = masked_images.float() + flip_masked_outputs
ssim = ssim_fun(synth_images, images.float())
psnr = psnr_fun(synth_images, images.float())
return torch.cat((images, masked_images, synth_images), dim=2), ssim, psnr
def forward_step(data_iterator, model):
"""Forward step."""
timers = get_timers()
# Get the batch.
timers("batch-generator").start()
(
images,
masks,
) = get_batch(data_iterator)
timers("batch-generator").stop()
masked_images = images.masked_fill(masks.bool(), 0)
outputs = model(masked_images)
# Forward mode
return outputs, partial(loss_func, images, masks, masked_images)
def process_non_loss_data(data, iteration, writer):
psnr_sum = 0
ssim_sum = 0
for (output_tb, ssim, psnr) in data:
output_tb[output_tb < 0] = 0
output_tb[output_tb > 1] = 1
writer.add_images("gt-input-output-vald", output_tb,
global_step=iteration, walltime=None,
dataformats='NCHW')
psnr_sum = psnr_sum + psnr.item()
ssim_sum = ssim_sum + ssim.item()
psnr = psnr_sum/len(data)
ssim = ssim_sum/len(data)
writer.add_scalar('PSNR generate value-validation', psnr, iteration)
writer.add_scalar('SSIM generate value-validation', ssim, iteration)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
print_rank_0(
"> building train, validation, and test datasets " "for VIT ..."
)
train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
image_size=(args.img_h, args.img_w)
)
print_rank_0("> finished creating VIT datasets ...")
return train_ds, valid_ds, None
if __name__ == "__main__":
pretrain(
train_valid_test_datasets_provider,
model_provider,
forward_step,
process_non_loss_data,
args_defaults={'dataloader_type': 'cyclic'}
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment