"vscode:/vscode.git/clone" did not exist on "8088cc94f2155403f6b09cd54edadafa68daa977"
Commit 48c2a144 authored by Vijay Korthikanti's avatar Vijay Korthikanti Committed by Sangkug Lym
Browse files

vision third phase merge: pretraining methods + mit,swin backbones

parent f00d0a3f
......@@ -39,7 +39,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_data_args(parser)
parser = _add_autoresume_args(parser)
parser = _add_biencoder_args(parser)
parser = _add_vit_args(parser)
parser = _add_vision_args(parser)
parser = _add_logging_args(parser)
parser = _add_inference_args(parser)
......@@ -856,9 +856,10 @@ def _add_biencoder_args(parser):
return parser
def _add_vit_args(parser):
group = parser.add_argument_group(title="vit")
def _add_vision_args(parser):
group = parser.add_argument_group(title="vision")
# general vision arguements
group.add_argument('--num-classes', type=int, default=1000,
help='num of classes in vision classificaiton task')
group.add_argument('--img-h', type=int, default=224,
......@@ -868,7 +869,7 @@ def _add_vit_args(parser):
group.add_argument('--num-channels', type=int, default=3,
help='Number of channels in input image data')
group.add_argument('--patch-dim', type=int, default=16,
help='patch dimension used in vit')
help='patch dimension')
group.add_argument('--classes-fraction', type=float, default=1.0,
help='training with fraction of classes.')
group.add_argument('--data-per-class-fraction', type=float, default=1.0,
......@@ -876,5 +877,49 @@ def _add_vit_args(parser):
group.add_argument('--no-data-sharding', action='store_false',
help='Disable data sharding.',
dest='data_sharding')
group.add_argument('--head-lr-mult', type=float, default=1.0,
help='learning rate multiplier for head during finetuning')
# pretraining type and backbone selection`
group.add_argument('--vision-pretraining-type', type=str, default='classify',
choices=['classify', 'inpaint', 'contrast'],
help='pretraining objectives')
group.add_argument('--vision-backbone-type', type=str, default='vit',
choices=['vit', 'mit', 'swin'],
help='backbone types types')
group.add_argument('--swin-backbone-type', type=str, default='tiny',
choices=['tiny', 'base', 'h3'],
help='pretraining objectives')
# inpainting arguments
group.add_argument('--mask-type', type=str, default='random',
choices=['random', 'row'],
help='mask types')
group.add_argument('--mask-factor', type=float, default=1.0,
help='mask size scaling parameter')
# dino arguments
group.add_argument('--iter-per-epoch', type=int, default=1250,
help='iterations per epoch')
group.add_argument('--dino-local-img-size', type=int, default=96,
help='Image size for vision classification task')
group.add_argument('--dino-local-crops-number', type=int, default=10,
help='Number of local crops')
group.add_argument('--dino-head-hidden-size', type=int, default=2048,
help='Hidden dimension size in dino head')
group.add_argument('--dino-bottleneck-size', type=int, default=256,
help='Bottle neck dimension in dino head ')
group.add_argument('--dino-freeze-last-layer', type=float, default=1,
help='Freezing last layer weights')
group.add_argument('--dino-norm-last-layer', action='store_true',
help='Disable Norm in last layer.')
group.add_argument('--dino-warmup-teacher-temp', type=float, default=0.04,
help='warump teacher temperature')
group.add_argument('--dino-teacher-temp', type=float, default=0.07,
help='teacher temperature')
group.add_argument('--dino-warmup-teacher-temp-epochs', type=int, default=30,
help='warmup teacher temperaure epochs')
return parser
......@@ -22,6 +22,43 @@ from megatron import get_args
from megatron.data.image_folder import ImageFolder
from megatron.data.autoaugment import ImageNetPolicy
from megatron.data.data_samplers import RandomSeedDataset
from PIL import Image, ImageFilter, ImageOps
class GaussianBlur(object):
"""
Apply Gaussian Blur to the PIL image.
"""
def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
self.prob = p
self.radius_min = radius_min
self.radius_max = radius_max
def __call__(self, img):
do_it = random.random() <= self.prob
if not do_it:
return img
return img.filter(
ImageFilter.GaussianBlur(
radius=random.uniform(self.radius_min, self.radius_max)
)
)
class Solarization(object):
"""
Apply Solarization to the PIL image.
"""
def __init__(self, p):
self.p = p
def __call__(self, img):
if random.random() < self.p:
return ImageOps.solarize(img)
else:
return img
class ClassificationTransform():
def __init__(self, image_size, train=True):
......@@ -52,14 +89,169 @@ class ClassificationTransform():
return output
class InpaintingTransform():
def __init__(self, image_size, train=True):
args = get_args()
self.mask_factor = args.mask_factor
self.mask_type = args.mask_type
self.image_size = image_size
self.patch_size = args.patch_dim
self.mask_size = int(self.mask_factor*(image_size[0]/self.patch_size)*(image_size[1]/self.patch_size))
self.train = train
assert args.fp16 or args.bf16
self.data_type = torch.half if args.fp16 else torch.bfloat16
if self.train:
self.transform = T.Compose([
T.RandomResizedCrop(self.image_size),
T.RandomHorizontalFlip(),
T.ColorJitter(0.4, 0.4, 0.4, 0.1),
ImageNetPolicy(),
T.ToTensor(),
T.ConvertImageDtype(self.data_type)
])
else:
self.transform = T.Compose([
T.Resize(self.image_size, interpolation=2),
T.CenterCrop(self.image_size),
T.ToTensor(),
T.ConvertImageDtype(self.data_type)
])
def gen_mask(self, image_size, mask_size, mask_type, patch_size):
# output: mask as a list with indices for missing patches
action_list = [[0, 1], [0, -1], [1, 0], [-1, 0]]
assert image_size[0] == image_size[1]
img_size_patch = image_size[0] // patch_size
# drop masked patches
mask = torch.zeros((image_size[0], image_size[1]), dtype=torch.float)
if mask_type == 'random':
x = torch.randint(0, img_size_patch, ())
y = torch.randint(0, img_size_patch, ())
for i in range(mask_size):
r = torch.randint(0, len(action_list), ())
x = torch.clamp(x + action_list[r][0], min=0, max=img_size_patch - 1)
y = torch.clamp(y + action_list[r][1], min=0, max=img_size_patch - 1)
x_offset = x * patch_size
y_offset = y * patch_size
mask[x_offset:x_offset+patch_size, y_offset:y_offset+patch_size] = 1
else:
assert mask_type == 'row'
count = 0
for x in reversed(range(img_size_patch)):
for y in reversed(range(img_size_patch)):
if (count < mask_size):
count += 1
x_offset = x * patch_size
y_offset = y * patch_size
mask[x_offset:x_offset+patch_size, y_offset:y_offset+patch_size] = 1
return mask
def __call__(self, input):
trans_input = self.transform(input)
mask = self.gen_mask(self.image_size, self.mask_size,
self.mask_type, self.patch_size)
mask = mask.unsqueeze(dim=0)
return trans_input, mask
class DinoTransform(object):
def __init__(self, image_size, train=True):
args = get_args()
self.data_type = torch.half if args.fp16 else torch.bfloat16
flip_and_color_jitter = T.Compose([
T.RandomHorizontalFlip(p=0.5),
T.RandomApply(
[T.ColorJitter(brightness=0.4, contrast=0.4,
saturation=0.2, hue=0.1)],
p=0.8
),
T.RandomGrayscale(p=0.2),
])
if args.fp16 or args.bf16:
normalize = T.Compose([
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
T.ConvertImageDtype(self.data_type)
])
else:
normalize = T.Compose([
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
# first global crop
scale_const = 0.4
self.global_transform1 = T.Compose([
T.RandomResizedCrop(image_size,
scale=(scale_const, 1),
interpolation=Image.BICUBIC),
flip_and_color_jitter,
GaussianBlur(1.0),
normalize
])
# second global crop
self.global_transform2 = T.Compose([
T.RandomResizedCrop(image_size,
scale=(scale_const, 1),
interpolation=Image.BICUBIC),
flip_and_color_jitter,
GaussianBlur(0.1),
Solarization(0.2),
normalize
])
# transformation for the local small crops
self.local_crops_number = args.local_crops_number
self.local_transform = T.Compose([
T.RandomResizedCrop(args.local_img_size,
scale=(0.05, scale_const),
interpolation=Image.BICUBIC),
flip_and_color_jitter,
GaussianBlur(p=0.5),
normalize
])
def __call__(self, image):
crops = []
args = get_args()
if args.street_data:
crop_transform = T.RandomCrop(300)
image = crop_transform(image)
crops.append(self.global_transform1(image))
crops.append(self.global_transform2(image))
for _ in range(self.local_crops_number):
crops.append(self.local_transform(image))
return crops
def build_train_valid_datasets(data_path, image_size=224):
args = get_args()
if args.vision_pretraining_type == 'classify':
train_transform = ClassificationTransform(image_size)
val_transform = ClassificationTransform(image_size, train=False)
elif args.vision_pretraining_type == 'inpaint':
train_transform = InpaintingTransform(image_size, train=False)
val_transform = InpaintingTransform(image_size, train=False)
elif args.vision_pretraining_type == 'dino':
train_transform = DinoTransform(image_size, train=True)
val_transform = ClassificationTransform(image_size, train=False)
else:
raise Exception('{} vit pretraining type is not supported.'.format(
args.vit_pretraining_type))
train_transform = ClassificationTransform(image_size)
val_transform = ClassificationTransform(image_size, train=False)
# training dataset
train_data_path = data_path[0]
train_data_path = data_path[0] if len(data_path) <= 2 else data_path[2] #TODO VIJAY
train_data = ImageFolder(
root=train_data_path,
transform=train_transform,
......
......@@ -19,6 +19,8 @@ import torch
from megatron import get_args
from megatron.model.utils import get_linear_layer
from megatron.model.vision.vit_backbone import VitBackbone, VitMlpHead
from megatron.model.vision.mit_backbone import mit_b3_avg
from megatron.model.vision.utils import trunc_normal_
from megatron.model.module import MegatronModule
class VitClassificationModel(MegatronModule):
......@@ -61,3 +63,35 @@ class VitClassificationModel(MegatronModule):
hidden_states = self.head(hidden_states)
return hidden_states
class MitClassificationModel(MegatronModule):
"""Mix vision Transformer Model."""
def __init__(self, num_classes
pre_process=True, post_process=True):
super(MitClassificationModel, self).__init__()
args = get_args()
self.hidden_size = args.hidden_size
self.num_classes = num_classes
self.backbone = mit_b3_avg()
self.head = torch.nn.Linear(512, num_classes)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, torch.nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, torch.nn.Linear) and m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
pass
def forward(self, input):
hidden_states = self.backbone(input)
hidden_states = self.head(hidden_states)
return hidden_states
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
# copied from https://github.com/facebookresearch/dino/blob/main/main_dino.py
# reworked/refactored some parts to make it run in Megatron.
import math
import apex
import einops
import torch
import numpy as np
import torch.nn.functional as F
from megatron import get_args, print_rank_0
from megatron.model.utils import get_linear_layer
from megatron.model.vision.vit_backbone import VitBackbone
from megatron.model.module import MegatronModule
from megatron.utils import print_tensor_min_max_norm as pt
from megatron.model.vision.utils import trunc_normal_
from megatron.model.vision.mit_backbone import mit_b5_avg
from megatron.model.vision.esvit_swin_backbone import get_swin
from megatron.model.vision.av_cam_trunk import get_av_cam_trunk
class DINOLoss(torch.nn.Module):
def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp,
warmup_teacher_temp_epochs, nepochs, student_temp=0.1,
center_momentum=0.9):
super().__init__()
self.student_temp = student_temp
self.center_momentum = center_momentum
self.ncrops = ncrops
self.register_buffer("center", torch.zeros(1, out_dim))
# we apply a warm up for the teacher temperature because
# a too high temperature makes the training instable at the beginning
self.teacher_temp_schedule = np.concatenate((
np.linspace(warmup_teacher_temp,
teacher_temp, warmup_teacher_temp_epochs),
np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
))
self.teacher_temp = teacher_temp
def forward(self, student_output, teacher_output, iteration):
"""
Cross-entropy between softmax outputs of the teacher
and student network.
"""
args = get_args()
student_out = student_output / self.student_temp
student_out = student_out.chunk(self.ncrops)
epoch = iteration // args.iter_per_epoch
# teacher centering and sharpening
temp = self.teacher_temp_schedule[epoch]
teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
teacher_out = teacher_out.detach().chunk(2)
total_loss = 0
n_loss_terms = 0
for iq, q in enumerate(teacher_out):
for v in range(len(student_out)):
if v == iq:
# we skip cases where student and teacher operate on the same view
continue
loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
total_loss += loss.mean()
n_loss_terms += 1
total_loss /= n_loss_terms
self.update_center(teacher_output)
return total_loss
@torch.no_grad()
def update_center(self, teacher_output):
"""
Update center used for teacher output.
"""
batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
torch.distributed.all_reduce(batch_center)
batch_center = batch_center / (len(teacher_output) * torch.distributed.get_world_size())
self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
class DINOHead(torch.nn.Module):
def __init__(self, in_dim, out_dim, norm_last_layer=True, nlayers=3):
super().__init__()
args = get_args()
hidden_dim = args.dino_head_hidden_size
bottleneck_dim = args.dino_bottleneck_size
nlayers = max(nlayers, 1)
if nlayers == 1:
self.mlp = torch.nn.Linear(in_dim, bottleneck_dim)
else:
layers = [torch.nn.Linear(in_dim, hidden_dim)]
layers.append(torch.nn.GELU())
for _ in range(nlayers - 2):
layers.append(torch.nn.Linear(hidden_dim, hidden_dim))
layers.append(torch.nn.GELU())
layers.append(torch.nn.Linear(hidden_dim, bottleneck_dim))
self.mlp = torch.nn.Sequential(*layers)
self.apply(self._init_weights)
self.last_layer = torch.nn.utils.weight_norm(torch.nn.Linear(bottleneck_dim, out_dim, bias=False))
self.last_layer.weight_g.data.fill_(1)
if norm_last_layer:
self.last_layer.weight_g.requires_grad = False
def _init_weights(self, m):
if isinstance(m, torch.nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, torch.nn.Linear) and m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.mlp(x)
x = torch.nn.functional.normalize(x, dim=-1, p=2)
x = self.last_layer(x)
return x
class MultiCropWrapper(MegatronModule):
"""
Perform forward pass separately on each resolution input.
The inputs corresponding to a single resolution are clubbed and single
forward is run on the same resolution inputs. Hence we do several
forward passes = number of different resolutions used. We then
concatenate all the output features and run the head forward on these
concatenated features.
"""
def __init__(self, backbone, head):
super(MultiCropWrapper, self).__init__()
# disable layers dedicated to ImageNet labels classification
#backbone.fc, backbone.head = torch.nn.Identity(), torch.nn.Identity()
self.backbone = backbone
self.head = head
def forward(self, x):
# convert to list
if not isinstance(x, list):
x = [x]
idx_crops = torch.cumsum(torch.unique_consecutive(
torch.tensor([inp.shape[-1] for inp in x]),
return_counts=True,
)[1], 0)
start_idx = 0
for end_idx in idx_crops:
_out = self.backbone(torch.cat(x[start_idx: end_idx]))
if start_idx == 0:
output = _out
else:
output = torch.cat((output, _out))
start_idx = end_idx
# Run the head forward on the concatenated features.
if self.training:
return self.head(output)
else:
return output
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep,
warmup_epochs=0, start_warmup_value=0):
warmup_schedule = np.array([])
warmup_iters = warmup_epochs * niter_per_ep
if warmup_epochs > 0:
warmup_schedule = \
np.linspace(start_warmup_value, base_value, warmup_iters)
iters = np.arange(epochs * niter_per_ep - warmup_iters)
schedule = final_value + 0.5 * (base_value - final_value) \
* (1 + np.cos(np.pi * iters / len(iters)))
schedule = np.concatenate((warmup_schedule, schedule))
assert len(schedule) == epochs * niter_per_ep
return schedule
def get_student_backbone_and_num_features(pre_process=True, post_process=True):
args = get_args()
if args.vision_backbone_type == 'vit':
student = VitBackbone(pre_process=pre_process,
post_process=post_process,
drop_path_rate=0.1,
single_token_output=True)
num_features = args.hidden_size
elif args.vision_backbone_type == 'mit':
student = mit_b5_avg(drop_path_rate=0.1)
num_features = 512
elif args.vision_backbone_type == 'swin':
student = get_swin()
num_features = student.num_features
else:
raise Exception('{} vision backbone is not supported.'.format(
args.vision_backbone_type))
return student, num_features
def get_teacher_backbone_and_num_features(pre_process=True, post_process=True):
args = get_args()
if args.vision_backbone_type == 'vit':
teacher = VitBackbone(pre_process=pre_process,
post_process=post_process,
single_token_output=True)
num_features = args.hidden_size
elif args.vision_backbone_type == 'mit':
teacher = mit_b5_avg(drop_path_rate=0.0)
num_features = 512
elif args.vision_backbone_type == 'swin':
teacher = get_swin(is_teacher=True)
num_features = teacher.num_features
else:
raise Exception('{} vision backbone is not supported.'.format(
args.vision_backbone_type))
return teacher, num_features
class DINOPretrainModel(MegatronModule):
def __init__(self, pre_process=True, post_process=True):
super(DINOPretrainModel, self).__init__()
args = get_args()
self.out_dim = 65536
self.dino_loss = DINOLoss(
self.out_dim,
args.dino_local_crops_number + 2,
args.dino_warmup_teacher_temp,
args.dino_teacher_temp,
args.dino_warmup_teacher_temp_epochs,
300,
)
self.pre_process = pre_process
self.post_process = post_process
self.momentum_teacher = 0.996
student_backbone, num_features = \
get_student_backbone_and_num_features(pre_process, post_process)
self.student = MultiCropWrapper(
student_backbone,
DINOHead(num_features, self.out_dim,
norm_last_layer=args.dino_norm_last_layer)
)
self.momentum_schedule = cosine_scheduler(
self.momentum_teacher, 1,
args.train_iters // args.iter_per_epoch,
args.iter_per_epoch
)
teacher_backbone, num_features = \
get_teacher_backbone_and_num_features(pre_process, post_process)
self.teacher = MultiCropWrapper(
teacher_backbone,
DINOHead(num_features, self.out_dim)
)
self.teacher.load_state_dict(self.student.state_dict())
for p in self.teacher.parameters():
if hasattr(p, "requires_grad") and p.requires_grad is not None:
p.requires_grad = False
def set_input_tensor(self, tensor):
pass
def forward(self, input):
student_output = None
if self.training:
student_output = self.student(input)
teacher_output = self.teacher(input[:2])
else:
teacher_output = self.teacher(input)
return student_output, teacher_output
def cancel_gradients_last_layer(self, iteration):
args = get_args()
epoch = iteration // args.iter_per_epoch
if epoch < args.dino_freeze_last_layer:
for n, p in self.student.named_parameters():
if "last_layer" in n:
p.grad = None
def update_momentum(self, iteration):
with torch.no_grad():
m = self.momentum_schedule[iteration]
for param_q, param_k in zip(self.student.parameters(), self.teacher.parameters()):
param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
This diff is collapsed.
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision Transformer(VIT) model."""
import math
import apex
import einops
import torch
import torch.nn.functional as F
from megatron import get_args, print_rank_0
from megatron.model.utils import get_linear_layer
from megatron.model.vision.vit_backbone import VitBackbone
from megatron.model.module import MegatronModule
from megatron.model.vision.mit_backbone import mit_b3
from megatron.model.vision.utils import resize, trunc_normal_
class VitInpaintingModel(MegatronModule):
def __init__(self, pre_process=True, post_process=True):
super(VitInpaintingModel, self).__init__()
args = get_args()
self.pre_process = pre_process
self.post_process = post_process
self.hidden_size = args.hidden_size
self.backbone = VitBackbone(
pre_process=self.pre_process,
post_process=self.post_process,
class_token=False,
)
self.patch_dim = args.patch_dim
self.img_h = args.img_h
self.img_w = args.img_w
self.seq_length = args.seq_length
# full mask
if self.post_process:
self.linear_decoder = get_linear_layer(
self.hidden_size,
self.backbone.flatten_dim,
torch.nn.init.zeros_
)
def set_input_tensor(self, input_tensor):
self.backbone.set_input_tensor(input_tensor)
def forward(self, input):
hidden_states = self.backbone(input)
if not self.post_process:
return hidden_states
decoded_output = self.linear_decoder(hidden_states)
output = einops.rearrange(
decoded_output,
"b (h w) (p1 p2 c) -> b c (h p1) (w p2)",
p1=self.patch_dim,
p2=self.patch_dim,
h=self.img_h//self.patch_dim,
w=self.img_w//self.patch_dim,
)
return output
class MLP(torch.nn.Module):
"""
Linear Embedding
"""
def __init__(self, input_dim=2048, embed_dim=768):
super().__init__()
self.proj = torch.nn.Linear(input_dim, embed_dim)
def forward(self, x):
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
return x
class MitInpaintingModel(MegatronModule):
"""Mix vision Transformer Model."""
def __init__(self, pre_process=True, post_process=True):
super(MitInpaintingModel, self).__init__()
self.pre_process = pre_process
self.post_process = post_process
args = get_args()
self.patch_dim = args.patch_dim
self.img_h = args.img_h
self.img_w = args.img_w
self.flatten_dim = self.patch_dim * self.patch_dim * 3
self.backbone = mit_b3()
self.in_channels = [64, 128, 320, 512]
self.embedding_dim = 768
c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels
self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=self.embedding_dim)
self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=self.embedding_dim)
self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=self.embedding_dim)
self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=self.embedding_dim)
self.conv_fuse = torch.nn.Conv2d(self.embedding_dim*4, self.embedding_dim, 1, 1, bias=False)
self.norm = apex.parallel.SyncBatchNorm(self.embedding_dim)
self.dropout = torch.nn.Dropout2d(0.1)
self.linear_pred = torch.nn.Conv2d(self.embedding_dim, self.flatten_dim, kernel_size=1)
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
pass
def forward(self, input):
c1, c2, c3, c4 = self.backbone(input)
n, _, h, w = c4.shape
_c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3])
_c4 = resize(_c4, size=c1.size()[2:], mode='bilinear', align_corners=False)
_c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3])
_c3 = resize(_c3, size=c1.size()[2:], mode='bilinear', align_corners=False)
_c2 = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3])
_c2 = resize(_c2, size=c1.size()[2:], mode='bilinear', align_corners=False)
_c1 = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3])
_c = torch.cat([_c4, _c3, _c2, _c1], dim=1)
_c = self.conv_fuse(_c)
x = self.norm(_c)
x = F.relu(x, inplace=True)
x = self.dropout(x)
x = self.linear_pred(x)
output = einops.rearrange(
x,
"b (c p1 p2) h w -> b c (h p1) (w p2)",
p1=self.patch_dim,
p2=self.patch_dim,
h=self.img_h//self.patch_dim,
w=self.img_w//self.patch_dim,
)
return output
import torch.nn.functional as F
import torch
from megatron import print_rank_0, get_args, mpu
from megatron.data.vit_dataset import ClassificationTransform
from megatron.data.image_folder import ImageFolder
def build_data_loader(dataset, drop_last=True, shuffle=False):
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
# Sampler.
args = get_args()
micro_batch_size = 16
num_workers = args.num_workers
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=world_size, rank=rank,
drop_last=drop_last, shuffle=shuffle
)
# Data loader. Note that batch size is the per GPU batch size.
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=micro_batch_size,
sampler=sampler,
shuffle=False,
num_workers=num_workers,
drop_last=not drop_last,
pin_memory=True,
)
return data_loader
def compute_feature_bank(model):
args = get_args()
feature_bank = []
feature_label = []
train_ds = ImageFolder(
root=args.data_path[0],
transform=ClassificationTransform((args.img_h, args.img_w), train=False),
data_per_class_fraction=1.0
)
classes = len(train_ds.classes)
dataloader = build_data_loader(train_ds)
for m in model:
m.eval()
with torch.no_grad():
for i, batch in enumerate(dataloader):
images = batch[0].cuda().contiguous()
labels = batch[1].cuda().contiguous()
student_feature, teacher_feature = model[0](images)
feature = F.normalize(teacher_feature.float(), dim=1)
feature_bank.append(feature)
feature_label.append(labels)
for m in model:
m.train()
# [N', D]
feature_bank = torch.cat(feature_bank, dim=0).contiguous()
feature_label = torch.cat(feature_label, dim=0).contiguous()
feature_banks = [torch.zeros_like(feature_bank)
for i in range(mpu.get_data_parallel_world_size())]
torch.distributed.all_gather(feature_banks,
feature_bank,
group=mpu.get_data_parallel_group())
assert torch.all(torch.eq(feature_banks[mpu.get_data_parallel_rank()],
feature_bank))
feature_labels = [torch.zeros_like(feature_label)
for i in range(mpu.get_data_parallel_world_size())]
torch.distributed.all_gather(feature_labels,
feature_label,
group=mpu.get_data_parallel_group())
# [D, N]
feature_banks = torch.cat(feature_banks, dim=0).t().contiguous()
# [N]
feature_labels = torch.cat(feature_labels, dim=0).contiguous()
print_rank_0("feature_banks size is {}".format(feature_banks.size()))
print_rank_0("feature labels size is {}".format(feature_labels.size()))
return (feature_banks, feature_labels, classes)
# knn monitor as in InstDisc https://arxiv.org/abs/1805.01978
# implementation follows http://github.com/zhirongw/lemniscate.pytorch and
# https://github.com/leftthomas/SimCLR
def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t):
# compute cos similarity between each feature vector and feature bank ---> [B, N]
sim_matrix = torch.mm(feature, feature_bank)
# [B, K]
sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
# [B, K]
sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1),
dim=-1,
index=sim_indices)
sim_weight = (sim_weight / knn_t).exp()
# counts for each class
one_hot_label = torch.zeros(feature.size(0) * knn_k,
classes,
device=sim_labels.device)
# [B*K, C]
one_hot_label = one_hot_label.scatter(dim=-1,
index=sim_labels.view(-1, 1),
value=1.0)
# weighted score ---> [B, C]
pred_scores = torch.sum(
one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1),
dim=1)
pred_labels = pred_scores.argsort(dim=-1, descending=True)
return pred_labels
# ---------------------------------------------------------------
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# ---------------------------------------------------------------
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from megatron.model.vision.utils import DropPath, trunc_normal_
from megatron.model import LayerNorm
class Mlp(nn.Module):
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.dwconv = DWConv(hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = self.fc1(x)
x = self.dwconv(x, H, W)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
sr_ratio=1):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = LayerNorm(dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
B, N, C = x.shape
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
if self.sr_ratio > 1:
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
x_ = self.norm(x_)
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
else:
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=LayerNorm, sr_ratio=1):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
return x
class OverlapPatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
super().__init__()
img_size = (img_size, img_size)
patch_size = (patch_size, patch_size)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
padding=(patch_size[0] // 2, patch_size[1] // 2))
self.norm = LayerNorm(embed_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, H, W
class MixVisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
attn_drop_rate=0., drop_path_rate=0., norm_layer=LayerNorm,
depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], output_avg=False):
super().__init__()
self.num_classes = num_classes
self.depths = depths
self.output_avg = output_avg
# patch_embed
self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
embed_dim=embed_dims[0])
self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
embed_dim=embed_dims[1])
self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
embed_dim=embed_dims[2])
self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
embed_dim=embed_dims[3])
# transformer encoder
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
cur = 0
self.block1 = nn.ModuleList([Block(
dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[0])
for i in range(depths[0])])
self.norm1 = norm_layer(embed_dims[0])
cur += depths[0]
self.block2 = nn.ModuleList([Block(
dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[1])
for i in range(depths[1])])
self.norm2 = norm_layer(embed_dims[1])
cur += depths[1]
self.block3 = nn.ModuleList([Block(
dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[2])
for i in range(depths[2])])
self.norm3 = norm_layer(embed_dims[2])
cur += depths[2]
self.block4 = nn.ModuleList([Block(
dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[3])
for i in range(depths[3])])
self.norm4 = norm_layer(embed_dims[3])
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def reset_drop_path(self, drop_path_rate):
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
cur = 0
for i in range(self.depths[0]):
self.block1[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[0]
for i in range(self.depths[1]):
self.block2[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[1]
for i in range(self.depths[2]):
self.block3[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[2]
for i in range(self.depths[3]):
self.block4[i].drop_path.drop_prob = dpr[cur + i]
def freeze_patch_emb(self):
self.patch_embed1.requires_grad = False
def forward_features(self, x):
B = x.shape[0]
outs = []
# stage 1
x, H, W = self.patch_embed1(x)
for i, blk in enumerate(self.block1):
x = blk(x, H, W)
x = self.norm1(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
# stage 2
x, H, W = self.patch_embed2(x)
for i, blk in enumerate(self.block2):
x = blk(x, H, W)
x = self.norm2(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
# stage 3
x, H, W = self.patch_embed3(x)
for i, blk in enumerate(self.block3):
x = blk(x, H, W)
x = self.norm3(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
# stage 4
x, H, W = self.patch_embed4(x)
for i, blk in enumerate(self.block4):
x = blk(x, H, W)
x = self.norm4(x)
if not self.output_avg:
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
return outs
def forward(self, x):
x = self.forward_features(x)
if self.output_avg:
x = x[3].mean(dim=1)
return x
class DWConv(nn.Module):
def __init__(self, dim=768):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
def forward(self, x, H, W):
B, N, C = x.shape
x = x.transpose(1, 2).view(B, C, H, W)
x = self.dwconv(x)
x = x.flatten(2).transpose(1, 2)
return x
class mit_b0(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b0, self).__init__(
patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
class mit_b1(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b1, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
class mit_b2(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b2, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
class mit_b3(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b3, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
class mit_b3_avg(MixVisionTransformer):
def __init__(self, drop_path_rate=0.1, **kwargs):
super(mit_b3_avg, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=drop_path_rate, output_avg=True)
class mit_b4(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b4, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
class mit_b5(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b5, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
class mit_b5_avg(MixVisionTransformer):
def __init__(self, drop_path_rate=0.1, **kwargs):
super(mit_b5_avg, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=drop_path_rate, output_avg=True)
This diff is collapsed.
import warnings
import math
from itertools import repeat
import torch
import torch.nn as nn
import torch.nn.functional as F
def resize(input,
size=None,
scale_factor=None,
mode='nearest',
align_corners=None,
warning=True):
if warning:
if size is not None and align_corners:
input_h, input_w = tuple(int(x) for x in input.shape[2:])
output_h, output_w = tuple(int(x) for x in size)
if output_h > input_h or output_w > output_h:
if ((output_h > 1 and output_w > 1 and input_h > 1
and input_w > 1) and (output_h - 1) % (input_h - 1)
and (output_w - 1) % (input_w - 1)):
warnings.warn(
f'When align_corners={align_corners}, '
'the output would more aligned if '
f'input size {(input_h, input_w)} is `x+1` and '
f'out size {(output_h, output_w)} is `nx+1`')
if isinstance(size, torch.Size):
size = tuple(int(x) for x in size)
return F.interpolate(input, size, scale_factor, mode, align_corners)
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
# type: (Tensor, float, float, float, float) -> Tensor
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
......@@ -51,7 +51,7 @@ from megatron.data.data_samplers import build_pretraining_data_loader
from megatron.utils import calc_params_l2_norm
from megatron.schedules import get_forward_backward_func
from megatron.utils import report_memory
from megatron.model.vision.knn_monitor import compute_feature_bank
def print_datetime(string):
......@@ -465,11 +465,23 @@ def train_step(forward_step_func, data_iterator,
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
timers('backward-embedding-all-reduce').stop()
if args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module))
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)
# Update parameters.
timers('optimizer').start()
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
timers('optimizer').stop()
if args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module))
unwrapped_model.update_momentum(args.curr_iteration)
# Update learning rate.
if update_successful:
increment = get_num_microbatches() * \
......@@ -702,6 +714,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
report_memory_flag = True
while iteration < args.train_iters:
update_num_microbatches(args.consumed_train_samples)
args.curr_iteration = iteration
loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
train_step(forward_step_func,
train_data_iterator,
......@@ -791,6 +804,9 @@ def evaluate(forward_step_func,
"""Evaluation."""
args = get_args()
if args.vision_pretraining_type == "contrast":
args.knn_features = compute_feature_bank(model)
# Turn on evaluation mode which disables dropout.
for model_module in model:
model_module.eval()
......
......@@ -22,20 +22,31 @@ from megatron import get_args, get_timers, mpu, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model import ModelType
from megatron.model.vision.classification import VitClassificationModel
from megatron.model.vision.classification import MitClassificationModel
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0("building VIT model ...")
args = get_args()
model = VitClassificationModel(num_classes=args.num_classes,
pre_process=pre_process,
post_process=post_process)
if args.vision_backbone_type == 'vit':
model = VitClassificationModel(num_classes=args.num_classes,
pre_process=pre_process,
post_process=post_process)
elif args.vision_backbone_type == 'mit':
model = MitClassificationModel(num_classes=args.num_classes,
pre_process=pre_process,
post_process=post_process)
else:
raise Exception('{} vision backbone is not supported.'.format(
args.vision_backbone_type))
return model
def get_batch(data_iterator):
"""Build the batch."""
data = next(data_iterator)
......@@ -46,6 +57,7 @@ def get_batch(data_iterator):
return images, labels
def loss_func(labels, output_tensor):
logits = output_tensor.contiguous().float()
loss = F.cross_entropy(logits, labels)
......@@ -58,6 +70,7 @@ def loss_func(labels, output_tensor):
return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]}
def forward_step(data_iterator, model):
"""Forward step."""
timers = get_timers()
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain VIT"""
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import torch.distributed as dist
from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vision.contrastive import DINOPretrainModel
from megatron.model.vision.knn_monitor import knn_predict
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group, unwrap_model
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0("building VIT model ...")
return DINOPretrainModel(pre_process=pre_process, post_process=post_process)
def get_batch(data_iterator):
"""Build the batch."""
data = next(data_iterator)
# only data parallelism; no need for broadcast
if isinstance(data[0], list):
images = [aug.cuda() for aug in data[0]]
else:
images = data[0].cuda()
labels = data[1].cuda()
return images, labels
def loss_func(model, labels, output_tensor, collect_data=False):
args = get_args()
model = unwrap_model(
model,
(torchDDP, LocalDDP, Float16Module)
)
if model.training:
student_output, teacher_output = output_tensor
loss = model.dino_loss(student_output, teacher_output, args.curr_iteration)
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {"loss": averaged_loss[0]}
else:
_, teacher_feature = output_tensor
feature_bank, feature_labels, classes = args.knn_features
feature = F.normalize(teacher_feature.float(), dim=1)
knn_accs = []
for k in [10, 20, 100, 200]:
pred_labels = knn_predict(feature, feature_bank,
feature_labels, classes, k, 0.07)
knn_acc = (pred_labels[:, 0] == labels).float().mean()
knn_accs.append(knn_acc)
averaged_loss = average_losses_across_data_parallel_group(knn_accs)
return 0, {"knn_acc_10": averaged_loss[0],
"knn_acc_20": averaged_loss[1],
"knn_acc_100": averaged_loss[2],
"knn_acc_200": averaged_loss[3]}
def forward_step(data_iterator, model):
"""Forward step."""
timers = get_timers()
# Get the batch.
timers("batch-generator").start()
(
images,
labels,
) = get_batch(data_iterator)
timers("batch-generator").stop()
return model(images), partial(loss_func, model, labels)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
print_rank_0(
"> building train, validation, and test datasets " "for VIT ..."
)
train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
image_size=(args.img_h, args.img_w)
)
print_rank_0("> finished creating VIT datasets ...")
return train_ds, valid_ds, None
if __name__ == "__main__":
pretrain(
train_valid_test_datasets_provider,
model_provider,
forward_step,
args_defaults={'dataloader_type': 'cyclic'}
)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain VIT"""
import torch
import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0, print_rank_last
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vision.inpainting import VitInpaintingModel
from megatron.model.vision.inpainting import MitInpaintingModel
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
from tasks.vision.metrics import SSIM, PSNR
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
args = get_args()
if args.vision_backbone_type == 'vit':
model = VitInpaintingModel(pre_process=pre_process,
post_process=post_process)
elif args.vision_backbone_type == 'mit':
model = MitInpaintingModel(pre_process=pre_process,
post_process=post_process)
else:
raise Exception('{} vision backbone is not supported.'.format(
args.vision_backbone_type))
return model
def get_batch(data_iterator):
"""Build the batch."""
data = next(data_iterator)
# only data parallelism; no need for broadcast
images = data[0][0].cuda()
masks = data[0][1].cuda()
return images, masks
def loss_func(images, masks, masked_images, outputs, collect_data=False):
outputs = outputs.contiguous().float()
masks_flip = 1-masks
flip_masked_outputs = outputs.masked_fill(masks_flip.bool(), 0)
flip_masked_images = images.masked_fill(masks_flip.bool(), 0)
ssim_fun = SSIM()
psnr_fun = PSNR()
if not collect_data:
mask_count = torch.count_nonzero(masks)
loss = F.mse_loss(
flip_masked_outputs,
flip_masked_images.float(),
reduction="sum"
)
loss = loss/mask_count
ssim = ssim_fun(flip_masked_outputs, flip_masked_images.float())
psnr = psnr_fun(flip_masked_outputs, flip_masked_images.float())
averaged_loss = average_losses_across_data_parallel_group(
[loss, psnr, ssim]
)
return loss, {"loss": averaged_loss[0],
"psnr": averaged_loss[1],
'ssim': averaged_loss[2]}
else:
synth_images = masked_images.float() + flip_masked_outputs
ssim = ssim_fun(synth_images, images.float())
psnr = psnr_fun(synth_images, images.float())
return torch.cat((images, masked_images, synth_images), dim=2), ssim, psnr
def forward_step(data_iterator, model):
"""Forward step."""
timers = get_timers()
# Get the batch.
timers("batch-generator").start()
(
images,
masks,
) = get_batch(data_iterator)
timers("batch-generator").stop()
masked_images = images.masked_fill(masks.bool(), 0)
outputs = model(masked_images)
# Forward mode
return outputs, partial(loss_func, images, masks, masked_images)
def process_non_loss_data(data, iteration, writer):
psnr_sum = 0
ssim_sum = 0
for (output_tb, ssim, psnr) in data:
output_tb[output_tb < 0] = 0
output_tb[output_tb > 1] = 1
writer.add_images("gt-input-output-vald", output_tb,
global_step=iteration, walltime=None,
dataformats='NCHW')
psnr_sum = psnr_sum + psnr.item()
ssim_sum = ssim_sum + ssim.item()
psnr = psnr_sum/len(data)
ssim = ssim_sum/len(data)
writer.add_scalar('PSNR generate value-validation', psnr, iteration)
writer.add_scalar('SSIM generate value-validation', ssim, iteration)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
print_rank_0(
"> building train, validation, and test datasets " "for VIT ..."
)
train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
image_size=(args.img_h, args.img_w)
)
print_rank_0("> finished creating VIT datasets ...")
return train_ds, valid_ds, None
if __name__ == "__main__":
pretrain(
train_valid_test_datasets_provider,
model_provider,
forward_step,
process_non_loss_data,
args_defaults={'dataloader_type': 'cyclic'}
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment