Commit 8c8063eb authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'vision-merge' into 'main'

vision third phase merge: pretraining methods + mit,swin backbones

See merge request ADLR/megatron-lm!384
parents f00d0a3f 4554c3fe
...@@ -29,12 +29,13 @@ The following applies to all files unless otherwise noted: ...@@ -29,12 +29,13 @@ The following applies to all files unless otherwise noted:
-- --
This repository also contains code from Hugging Face Inc., Google Research, This repository also contains code from Hugging Face Inc., Google Research,
Facebook (from their Fairseq project), and Philip Popien. Files from these Facebook (from their Fairseq and Dino projects), Microsoft(from their
organizations have notices at the top of each file. Below are licenses Swin-Transformer project)and Philip Popien. Files from these
used in those files, as indicated. organizations have notices at the top of each file. Below are
licenses used in those files, as indicated.
------------- LICENSE FOR huggingface and Google Research code -------------- ------------- LICENSE FOR Facebook, huggingface and Google Research code --------------
Apache License Apache License
...@@ -263,3 +264,113 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, ...@@ -263,3 +264,113 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE. SOFTWARE.
------------- LICENSE FOR Mircrosoft Swin transformer code --------------
MIT License
Copyright (c) Microsoft Corporation.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
--------------- NVIDIA Source Code License for SegFormer -----------------
1. Definitions
“Licensor” means any person or entity that distributes its Work.
“Software” means the original work of authorship made available under this
License.
“Work” means the Software and any additions to or derivative works of the
Software that are made available under this License.
The terms “reproduce,” “reproduction,” “derivative works,” and
“distribution” have the meaning as provided under U.S. copyright law;
provided, however, that for the purposes of this License, derivative works
shall not include works that remain separable from, or merely link
(or bind by name) to the interfaces of, the Work.
Works, including the Software, are “made available” under this License by
including in or with the Work either (a) a copyright notice referencing
the applicability of this License to the Work, or (b) a copy of this License.
2. License Grant
2.1 Copyright Grant. Subject to the terms and conditions of this License,
each Licensor grants to you a perpetual, worldwide, non-exclusive,
royalty-free, copyright license to reproduce, prepare derivative works of,
publicly display, publicly perform, sublicense and distribute its Work
and any resulting derivative works in any form.
3. Limitations
3.1 Redistribution. You may reproduce or distribute the Work only if
(a) you do so under this License, (b) you include a complete copy of this
License with your distribution, and (c) you retain without modification any
copyright, patent, trademark, or attribution notices that are present
in the Work.
3.2 Derivative Works. You may specify that additional or different terms
apply to the use, reproduction, and distribution of your derivative works
of the Work (“Your Terms”) only if (a) Your Terms provide that the use
limitation in Section 3.3 applies to your derivative works, and (b) you
identify the specific derivative works that are subject to Your Terms.
Notwithstanding Your Terms, this License (including the redistribution
requirements in Section 3.1) will continue to apply to the Work itself.
3.3 Use Limitation. The Work and any derivative works thereof only may
be used or intended for use non-commercially. Notwithstanding the
foregoing, NVIDIA and its affiliates may use the Work and any derivative
works commercially. As used herein, “non-commercially” means for research
or evaluation purposes only.
3.4 Patent Claims. If you bring or threaten to bring a patent claim against
any Licensor (including any claim, cross-claim or counterclaim in a lawsuit)
to enforce any patents that you allege are infringed by any Work, then
your rights under this License from such Licensor (including the grant
in Section 2.1) will terminate immediately.
3.5 Trademarks. This License does not grant any rights to use any Licensor’s
or its affiliates’ names, logos, or trademarks, except as necessary to
reproduce the notices described in this License.
3.6 Termination. If you violate any term of this License, then your rights
under this License (including the grant in Section 2.1) will terminate
immediately.
4. Disclaimer of Warranty.
THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT.
YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
5. Limitation of Liability.
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT
OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN
ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
...@@ -39,7 +39,7 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -39,7 +39,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_data_args(parser) parser = _add_data_args(parser)
parser = _add_autoresume_args(parser) parser = _add_autoresume_args(parser)
parser = _add_biencoder_args(parser) parser = _add_biencoder_args(parser)
parser = _add_vit_args(parser) parser = _add_vision_args(parser)
parser = _add_logging_args(parser) parser = _add_logging_args(parser)
parser = _add_inference_args(parser) parser = _add_inference_args(parser)
...@@ -856,9 +856,10 @@ def _add_biencoder_args(parser): ...@@ -856,9 +856,10 @@ def _add_biencoder_args(parser):
return parser return parser
def _add_vit_args(parser): def _add_vision_args(parser):
group = parser.add_argument_group(title="vit") group = parser.add_argument_group(title="vision")
# general vision arguements
group.add_argument('--num-classes', type=int, default=1000, group.add_argument('--num-classes', type=int, default=1000,
help='num of classes in vision classificaiton task') help='num of classes in vision classificaiton task')
group.add_argument('--img-h', type=int, default=224, group.add_argument('--img-h', type=int, default=224,
...@@ -868,7 +869,7 @@ def _add_vit_args(parser): ...@@ -868,7 +869,7 @@ def _add_vit_args(parser):
group.add_argument('--num-channels', type=int, default=3, group.add_argument('--num-channels', type=int, default=3,
help='Number of channels in input image data') help='Number of channels in input image data')
group.add_argument('--patch-dim', type=int, default=16, group.add_argument('--patch-dim', type=int, default=16,
help='patch dimension used in vit') help='patch dimension')
group.add_argument('--classes-fraction', type=float, default=1.0, group.add_argument('--classes-fraction', type=float, default=1.0,
help='training with fraction of classes.') help='training with fraction of classes.')
group.add_argument('--data-per-class-fraction', type=float, default=1.0, group.add_argument('--data-per-class-fraction', type=float, default=1.0,
...@@ -876,5 +877,49 @@ def _add_vit_args(parser): ...@@ -876,5 +877,49 @@ def _add_vit_args(parser):
group.add_argument('--no-data-sharding', action='store_false', group.add_argument('--no-data-sharding', action='store_false',
help='Disable data sharding.', help='Disable data sharding.',
dest='data_sharding') dest='data_sharding')
group.add_argument('--head-lr-mult', type=float, default=1.0,
help='learning rate multiplier for head during finetuning')
# pretraining type and backbone selection`
group.add_argument('--vision-pretraining', action='store_true',
help='flag to indicate vision pretraining')
group.add_argument('--vision-pretraining-type', type=str, default='classify',
choices=['classify', 'inpaint', 'dino'],
help='pretraining objectives')
group.add_argument('--vision-backbone-type', type=str, default='vit',
choices=['vit', 'mit', 'swin'],
help='backbone types types')
group.add_argument('--swin-backbone-type', type=str, default='tiny',
choices=['tiny', 'base', 'h3'],
help='pretraining objectives')
# inpainting arguments
group.add_argument('--mask-type', type=str, default='random',
choices=['random', 'row'],
help='mask types')
group.add_argument('--mask-factor', type=float, default=1.0,
help='mask size scaling parameter')
# dino arguments
group.add_argument('--iter-per-epoch', type=int, default=1250,
help='iterations per epoch')
group.add_argument('--dino-local-img-size', type=int, default=96,
help='Image size for vision classification task')
group.add_argument('--dino-local-crops-number', type=int, default=10,
help='Number of local crops')
group.add_argument('--dino-head-hidden-size', type=int, default=2048,
help='Hidden dimension size in dino head')
group.add_argument('--dino-bottleneck-size', type=int, default=256,
help='Bottle neck dimension in dino head ')
group.add_argument('--dino-freeze-last-layer', type=float, default=1,
help='Freezing last layer weights')
group.add_argument('--dino-norm-last-layer', action='store_true',
help='Disable Norm in last layer.')
group.add_argument('--dino-warmup-teacher-temp', type=float, default=0.04,
help='warump teacher temperature')
group.add_argument('--dino-teacher-temp', type=float, default=0.07,
help='teacher temperature')
group.add_argument('--dino-warmup-teacher-temp-epochs', type=int, default=30,
help='warmup teacher temperaure epochs')
return parser return parser
...@@ -22,6 +22,43 @@ from megatron import get_args ...@@ -22,6 +22,43 @@ from megatron import get_args
from megatron.data.image_folder import ImageFolder from megatron.data.image_folder import ImageFolder
from megatron.data.autoaugment import ImageNetPolicy from megatron.data.autoaugment import ImageNetPolicy
from megatron.data.data_samplers import RandomSeedDataset from megatron.data.data_samplers import RandomSeedDataset
from PIL import Image, ImageFilter, ImageOps
class GaussianBlur(object):
"""
Apply Gaussian Blur to the PIL image.
"""
def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
self.prob = p
self.radius_min = radius_min
self.radius_max = radius_max
def __call__(self, img):
do_it = random.random() <= self.prob
if not do_it:
return img
return img.filter(
ImageFilter.GaussianBlur(
radius=random.uniform(self.radius_min, self.radius_max)
)
)
class Solarization(object):
"""
Apply Solarization to the PIL image.
"""
def __init__(self, p):
self.p = p
def __call__(self, img):
if random.random() < self.p:
return ImageOps.solarize(img)
else:
return img
class ClassificationTransform(): class ClassificationTransform():
def __init__(self, image_size, train=True): def __init__(self, image_size, train=True):
...@@ -52,14 +89,160 @@ class ClassificationTransform(): ...@@ -52,14 +89,160 @@ class ClassificationTransform():
return output return output
class InpaintingTransform():
def __init__(self, image_size, train=True):
args = get_args()
self.mask_factor = args.mask_factor
self.mask_type = args.mask_type
self.image_size = image_size
self.patch_size = args.patch_dim
self.mask_size = int(self.mask_factor*(image_size[0]/self.patch_size)*(image_size[1]/self.patch_size))
self.train = train
assert args.fp16 or args.bf16
self.data_type = torch.half if args.fp16 else torch.bfloat16
if self.train:
self.transform = T.Compose([
T.RandomResizedCrop(self.image_size),
T.RandomHorizontalFlip(),
T.ColorJitter(0.4, 0.4, 0.4, 0.1),
ImageNetPolicy(),
T.ToTensor(),
T.ConvertImageDtype(self.data_type)
])
else:
self.transform = T.Compose([
T.Resize(self.image_size, interpolation=2),
T.CenterCrop(self.image_size),
T.ToTensor(),
T.ConvertImageDtype(self.data_type)
])
def gen_mask(self, image_size, mask_size, mask_type, patch_size):
# output: mask as a list with indices for missing patches
action_list = [[0, 1], [0, -1], [1, 0], [-1, 0]]
assert image_size[0] == image_size[1]
img_size_patch = image_size[0] // patch_size
# drop masked patches
mask = torch.zeros((image_size[0], image_size[1]), dtype=torch.float)
if mask_type == 'random':
x = torch.randint(0, img_size_patch, ())
y = torch.randint(0, img_size_patch, ())
for i in range(mask_size):
r = torch.randint(0, len(action_list), ())
x = torch.clamp(x + action_list[r][0], min=0, max=img_size_patch - 1)
y = torch.clamp(y + action_list[r][1], min=0, max=img_size_patch - 1)
x_offset = x * patch_size
y_offset = y * patch_size
mask[x_offset:x_offset+patch_size, y_offset:y_offset+patch_size] = 1
else:
assert mask_type == 'row'
count = 0
for x in reversed(range(img_size_patch)):
for y in reversed(range(img_size_patch)):
if (count < mask_size):
count += 1
x_offset = x * patch_size
y_offset = y * patch_size
mask[x_offset:x_offset+patch_size, y_offset:y_offset+patch_size] = 1
return mask
def __call__(self, input):
trans_input = self.transform(input)
mask = self.gen_mask(self.image_size, self.mask_size,
self.mask_type, self.patch_size)
mask = mask.unsqueeze(dim=0)
return trans_input, mask
class DinoTransform(object):
def __init__(self, image_size, train=True):
args = get_args()
self.data_type = torch.half if args.fp16 else torch.bfloat16
flip_and_color_jitter = T.Compose([
T.RandomHorizontalFlip(p=0.5),
T.RandomApply(
[T.ColorJitter(brightness=0.4, contrast=0.4,
saturation=0.2, hue=0.1)],
p=0.8
),
T.RandomGrayscale(p=0.2),
])
if args.fp16 or args.bf16:
normalize = T.Compose([
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
T.ConvertImageDtype(self.data_type)
])
else:
normalize = T.Compose([
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
# first global crop
scale_const = 0.4
self.global_transform1 = T.Compose([
T.RandomResizedCrop(image_size,
scale=(scale_const, 1),
interpolation=Image.BICUBIC),
flip_and_color_jitter,
GaussianBlur(1.0),
normalize
])
# second global crop
self.global_transform2 = T.Compose([
T.RandomResizedCrop(image_size,
scale=(scale_const, 1),
interpolation=Image.BICUBIC),
flip_and_color_jitter,
GaussianBlur(0.1),
Solarization(0.2),
normalize
])
# transformation for the local small crops
self.local_crops_number = args.dino_local_crops_number
self.local_transform = T.Compose([
T.RandomResizedCrop(args.dino_local_img_size,
scale=(0.05, scale_const),
interpolation=Image.BICUBIC),
flip_and_color_jitter,
GaussianBlur(p=0.5),
normalize
])
def __call__(self, image):
crops = []
crops.append(self.global_transform1(image))
crops.append(self.global_transform2(image))
for _ in range(self.local_crops_number):
crops.append(self.local_transform(image))
return crops
def build_train_valid_datasets(data_path, image_size=224): def build_train_valid_datasets(data_path, image_size=224):
args = get_args() args = get_args()
train_transform = ClassificationTransform(image_size)
val_transform = ClassificationTransform(image_size, train=False) if args.vision_pretraining_type == 'classify':
train_transform = ClassificationTransform(image_size)
val_transform = ClassificationTransform(image_size, train=False)
elif args.vision_pretraining_type == 'inpaint':
train_transform = InpaintingTransform(image_size, train=False)
val_transform = InpaintingTransform(image_size, train=False)
elif args.vision_pretraining_type == 'dino':
train_transform = DinoTransform(image_size, train=True)
val_transform = ClassificationTransform(image_size, train=False)
else:
raise Exception('{} vit pretraining type is not supported.'.format(
args.vit_pretraining_type))
# training dataset # training dataset
train_data_path = data_path[0] train_data_path = data_path[0] if len(data_path) <= 2 else data_path[2]
train_data = ImageFolder( train_data = ImageFolder(
root=train_data_path, root=train_data_path,
transform=train_transform, transform=train_transform,
......
...@@ -16,9 +16,11 @@ ...@@ -16,9 +16,11 @@
"""Vision Transformer(VIT) model.""" """Vision Transformer(VIT) model."""
import torch import torch
from torch.nn.init import trunc_normal_
from megatron import get_args from megatron import get_args
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.vision.vit_backbone import VitBackbone, VitMlpHead from megatron.model.vision.vit_backbone import VitBackbone, VitMlpHead
from megatron.model.vision.mit_backbone import mit_b3_avg
from megatron.model.module import MegatronModule from megatron.model.module import MegatronModule
class VitClassificationModel(MegatronModule): class VitClassificationModel(MegatronModule):
...@@ -61,3 +63,35 @@ class VitClassificationModel(MegatronModule): ...@@ -61,3 +63,35 @@ class VitClassificationModel(MegatronModule):
hidden_states = self.head(hidden_states) hidden_states = self.head(hidden_states)
return hidden_states return hidden_states
class MitClassificationModel(MegatronModule):
"""Mix vision Transformer Model."""
def __init__(self, num_classes,
pre_process=True, post_process=True):
super(MitClassificationModel, self).__init__()
args = get_args()
self.hidden_size = args.hidden_size
self.num_classes = num_classes
self.backbone = mit_b3_avg()
self.head = torch.nn.Linear(512, num_classes)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, torch.nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, torch.nn.Linear) and m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
pass
def forward(self, input):
hidden_states = self.backbone(input)
hidden_states = self.head(hidden_states)
return hidden_states
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
# copied from https://github.com/facebookresearch/dino/blob/main/main_dino.py
# reworked/refactored some parts to make it run in Megatron.
import math
import apex
import einops
import torch
import numpy as np
import torch.nn.functional as F
from torch.nn.init import trunc_normal_
from megatron import get_args, print_rank_0
from megatron.model.utils import get_linear_layer
from megatron.model.vision.vit_backbone import VitBackbone
from megatron.model.module import MegatronModule
from megatron.model.vision.mit_backbone import mit_b5_avg
from megatron.model.vision.esvit_swin_backbone import get_swin
class DINOLoss(torch.nn.Module):
def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp,
warmup_teacher_temp_epochs, nepochs, student_temp=0.1,
center_momentum=0.9):
super().__init__()
self.student_temp = student_temp
self.center_momentum = center_momentum
self.ncrops = ncrops
self.register_buffer("center", torch.zeros(1, out_dim))
# we apply a warm up for the teacher temperature because
# a too high temperature makes the training instable at the beginning
self.teacher_temp_schedule = np.concatenate((
np.linspace(warmup_teacher_temp,
teacher_temp, warmup_teacher_temp_epochs),
np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
))
self.teacher_temp = teacher_temp
def forward(self, student_output, teacher_output, iteration):
"""
Cross-entropy between softmax outputs of the teacher
and student network.
"""
args = get_args()
student_out = student_output / self.student_temp
student_out = student_out.chunk(self.ncrops)
epoch = iteration // args.iter_per_epoch
# teacher centering and sharpening
temp = self.teacher_temp_schedule[epoch]
teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
teacher_out = teacher_out.detach().chunk(2)
total_loss = 0
n_loss_terms = 0
for iq, q in enumerate(teacher_out):
for v in range(len(student_out)):
if v == iq:
# we skip cases where student and teacher operate on the same view
continue
loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
total_loss += loss.mean()
n_loss_terms += 1
total_loss /= n_loss_terms
self.update_center(teacher_output)
return total_loss
@torch.no_grad()
def update_center(self, teacher_output):
"""
Update center used for teacher output.
"""
batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
torch.distributed.all_reduce(batch_center)
batch_center = batch_center / (len(teacher_output) * torch.distributed.get_world_size())
self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
class DINOHead(torch.nn.Module):
def __init__(self, in_dim, out_dim, norm_last_layer=True, nlayers=3):
super().__init__()
args = get_args()
hidden_dim = args.dino_head_hidden_size
bottleneck_dim = args.dino_bottleneck_size
nlayers = max(nlayers, 1)
if nlayers == 1:
self.mlp = torch.nn.Linear(in_dim, bottleneck_dim)
else:
layers = [torch.nn.Linear(in_dim, hidden_dim)]
layers.append(torch.nn.GELU())
for _ in range(nlayers - 2):
layers.append(torch.nn.Linear(hidden_dim, hidden_dim))
layers.append(torch.nn.GELU())
layers.append(torch.nn.Linear(hidden_dim, bottleneck_dim))
self.mlp = torch.nn.Sequential(*layers)
self.apply(self._init_weights)
self.last_layer = torch.nn.utils.weight_norm(torch.nn.Linear(bottleneck_dim, out_dim, bias=False))
self.last_layer.weight_g.data.fill_(1)
if norm_last_layer:
self.last_layer.weight_g.requires_grad = False
def _init_weights(self, m):
if isinstance(m, torch.nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, torch.nn.Linear) and m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.mlp(x)
x = torch.nn.functional.normalize(x, dim=-1, p=2)
x = self.last_layer(x)
return x
class MultiCropWrapper(MegatronModule):
"""
Perform forward pass separately on each resolution input.
The inputs corresponding to a single resolution are clubbed and single
forward is run on the same resolution inputs. Hence we do several
forward passes = number of different resolutions used. We then
concatenate all the output features and run the head forward on these
concatenated features.
"""
def __init__(self, backbone, head):
super(MultiCropWrapper, self).__init__()
# disable layers dedicated to ImageNet labels classification
#backbone.fc, backbone.head = torch.nn.Identity(), torch.nn.Identity()
self.backbone = backbone
self.head = head
def forward(self, x):
# convert to list
if not isinstance(x, list):
x = [x]
idx_crops = torch.cumsum(torch.unique_consecutive(
torch.tensor([inp.shape[-1] for inp in x]),
return_counts=True,
)[1], 0)
start_idx = 0
for end_idx in idx_crops:
_out = self.backbone(torch.cat(x[start_idx: end_idx]))
if start_idx == 0:
output = _out
else:
output = torch.cat((output, _out))
start_idx = end_idx
# Run the head forward on the concatenated features.
if self.training:
return self.head(output)
else:
return output
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep,
warmup_epochs=0, start_warmup_value=0):
warmup_schedule = np.array([])
warmup_iters = warmup_epochs * niter_per_ep
if warmup_epochs > 0:
warmup_schedule = \
np.linspace(start_warmup_value, base_value, warmup_iters)
iters = np.arange(epochs * niter_per_ep - warmup_iters)
schedule = final_value + 0.5 * (base_value - final_value) \
* (1 + np.cos(np.pi * iters / len(iters)))
schedule = np.concatenate((warmup_schedule, schedule))
assert len(schedule) == epochs * niter_per_ep
return schedule
def get_student_backbone_and_num_features(pre_process=True, post_process=True):
args = get_args()
if args.vision_backbone_type == 'vit':
student = VitBackbone(pre_process=pre_process,
post_process=post_process,
drop_path_rate=0.1,
single_token_output=True)
num_features = args.hidden_size
elif args.vision_backbone_type == 'mit':
student = mit_b5_avg(drop_path_rate=0.1)
num_features = 512
elif args.vision_backbone_type == 'swin':
student = get_swin()
num_features = student.num_features
else:
raise Exception('{} vision backbone is not supported.'.format(
args.vision_backbone_type))
return student, num_features
def get_teacher_backbone_and_num_features(pre_process=True, post_process=True):
args = get_args()
if args.vision_backbone_type == 'vit':
teacher = VitBackbone(pre_process=pre_process,
post_process=post_process,
single_token_output=True)
num_features = args.hidden_size
elif args.vision_backbone_type == 'mit':
teacher = mit_b5_avg(drop_path_rate=0.0)
num_features = 512
elif args.vision_backbone_type == 'swin':
teacher = get_swin(is_teacher=True)
num_features = teacher.num_features
else:
raise Exception('{} vision backbone is not supported.'.format(
args.vision_backbone_type))
return teacher, num_features
class DINOPretrainModel(MegatronModule):
def __init__(self, pre_process=True, post_process=True):
super(DINOPretrainModel, self).__init__()
args = get_args()
self.out_dim = 65536
self.dino_loss = DINOLoss(
self.out_dim,
args.dino_local_crops_number + 2,
args.dino_warmup_teacher_temp,
args.dino_teacher_temp,
args.dino_warmup_teacher_temp_epochs,
300,
)
self.pre_process = pre_process
self.post_process = post_process
self.momentum_teacher = 0.996
student_backbone, num_features = \
get_student_backbone_and_num_features(pre_process, post_process)
self.student = MultiCropWrapper(
student_backbone,
DINOHead(num_features, self.out_dim,
norm_last_layer=args.dino_norm_last_layer)
)
self.momentum_schedule = cosine_scheduler(
self.momentum_teacher, 1,
args.train_iters // args.iter_per_epoch,
args.iter_per_epoch
)
teacher_backbone, num_features = \
get_teacher_backbone_and_num_features(pre_process, post_process)
self.teacher = MultiCropWrapper(
teacher_backbone,
DINOHead(num_features, self.out_dim)
)
self.teacher.load_state_dict(self.student.state_dict())
for p in self.teacher.parameters():
if hasattr(p, "requires_grad") and p.requires_grad is not None:
p.requires_grad = False
def set_input_tensor(self, tensor):
pass
def forward(self, input):
student_output = None
if self.training:
student_output = self.student(input)
teacher_output = self.teacher(input[:2])
else:
teacher_output = self.teacher(input)
return student_output, teacher_output
def cancel_gradients_last_layer(self, iteration):
args = get_args()
epoch = iteration // args.iter_per_epoch
if epoch < args.dino_freeze_last_layer:
for n, p in self.student.named_parameters():
if "last_layer" in n:
p.grad = None
def update_momentum(self, iteration):
with torch.no_grad():
m = self.momentum_schedule[iteration]
for param_q, param_k in zip(self.student.parameters(), self.teacher.parameters()):
param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
This diff is collapsed.
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
i
import math
import apex
import einops
import torch
import torch.nn.functional as F
from megatron import get_args, print_rank_0
from megatron.model.utils import get_linear_layer
from megatron.model.vision.vit_backbone import VitBackbone
from megatron.model.module import MegatronModule
from megatron.model.vision.mit_backbone import mit_b3
from megatron.model.vision.utils import resize_
class VitInpaintingModel(MegatronModule):
def __init__(self, pre_process=True, post_process=True):
super(VitInpaintingModel, self).__init__()
args = get_args()
self.pre_process = pre_process
self.post_process = post_process
self.hidden_size = args.hidden_size
self.backbone = VitBackbone(
pre_process=self.pre_process,
post_process=self.post_process,
class_token=False,
)
self.patch_dim = args.patch_dim
self.img_h = args.img_h
self.img_w = args.img_w
self.seq_length = args.seq_length
# full mask
if self.post_process:
self.linear_decoder = get_linear_layer(
self.hidden_size,
self.backbone.flatten_dim,
torch.nn.init.zeros_
)
def set_input_tensor(self, input_tensor):
self.backbone.set_input_tensor(input_tensor)
def forward(self, input):
hidden_states = self.backbone(input)
if not self.post_process:
return hidden_states
decoded_output = self.linear_decoder(hidden_states)
output = einops.rearrange(
decoded_output,
"b (h w) (p1 p2 c) -> b c (h p1) (w p2)",
p1=self.patch_dim,
p2=self.patch_dim,
h=self.img_h//self.patch_dim,
w=self.img_w//self.patch_dim,
)
return output
class MLP(torch.nn.Module):
"""
Linear Embedding
"""
def __init__(self, input_dim=2048, embed_dim=768):
super().__init__()
self.proj = torch.nn.Linear(input_dim, embed_dim)
def forward(self, x):
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
return x
class MitInpaintingModel(MegatronModule):
"""Mix vision Transformer Model."""
def __init__(self, pre_process=True, post_process=True):
super(MitInpaintingModel, self).__init__()
self.pre_process = pre_process
self.post_process = post_process
args = get_args()
self.patch_dim = args.patch_dim
self.img_h = args.img_h
self.img_w = args.img_w
self.flatten_dim = self.patch_dim * self.patch_dim * 3
self.backbone = mit_b3()
self.in_channels = [64, 128, 320, 512]
self.embedding_dim = 768
c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels
self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=self.embedding_dim)
self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=self.embedding_dim)
self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=self.embedding_dim)
self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=self.embedding_dim)
self.conv_fuse = torch.nn.Conv2d(self.embedding_dim*4, self.embedding_dim, 1, 1, bias=False)
self.norm = apex.parallel.SyncBatchNorm(self.embedding_dim)
self.dropout = torch.nn.Dropout2d(0.1)
self.linear_pred = torch.nn.Conv2d(self.embedding_dim, self.flatten_dim, kernel_size=1)
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
pass
def forward(self, input):
c1, c2, c3, c4 = self.backbone(input)
n, _, h, w = c4.shape
_c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3])
_c4 = resize(_c4, size=c1.size()[2:], mode='bilinear', align_corners=False)
_c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3])
_c3 = resize(_c3, size=c1.size()[2:], mode='bilinear', align_corners=False)
_c2 = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3])
_c2 = resize(_c2, size=c1.size()[2:], mode='bilinear', align_corners=False)
_c1 = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3])
_c = torch.cat([_c4, _c3, _c2, _c1], dim=1)
_c = self.conv_fuse(_c)
x = self.norm(_c)
x = F.relu(x, inplace=True)
x = self.dropout(x)
x = self.linear_pred(x)
output = einops.rearrange(
x,
"b (c p1 p2) h w -> b c (h p1) (w p2)",
p1=self.patch_dim,
p2=self.patch_dim,
h=self.img_h//self.patch_dim,
w=self.img_w//self.patch_dim,
)
return output
import torch.nn.functional as F
import torch
from megatron import print_rank_0, get_args, mpu
from megatron.data.vit_dataset import ClassificationTransform
from megatron.data.image_folder import ImageFolder
_FEATURE_BANK = None
def build_data_loader(dataset, drop_last=True, shuffle=False):
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
# Sampler.
args = get_args()
micro_batch_size = 16
num_workers = args.num_workers
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=world_size, rank=rank,
drop_last=drop_last, shuffle=shuffle
)
# Data loader. Note that batch size is the per GPU batch size.
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=micro_batch_size,
sampler=sampler,
shuffle=False,
num_workers=num_workers,
drop_last=not drop_last,
pin_memory=True,
)
return data_loader
def compute_feature_bank(model):
args = get_args()
global _FEATURE_BANK
feature_bank = []
feature_label = []
train_ds = ImageFolder(
root=args.data_path[0],
transform=ClassificationTransform((args.img_h, args.img_w), train=False),
data_per_class_fraction=1.0
)
classes = len(train_ds.classes)
dataloader = build_data_loader(train_ds)
for m in model:
m.eval()
with torch.no_grad():
for i, batch in enumerate(dataloader):
images = batch[0].cuda().contiguous()
labels = batch[1].cuda().contiguous()
student_feature, teacher_feature = model[0](images)
feature = F.normalize(teacher_feature.float(), dim=1)
feature_bank.append(feature)
feature_label.append(labels)
for m in model:
m.train()
# [N', D]
feature_bank = torch.cat(feature_bank, dim=0).contiguous()
feature_label = torch.cat(feature_label, dim=0).contiguous()
feature_banks = [torch.zeros_like(feature_bank)
for i in range(mpu.get_data_parallel_world_size())]
torch.distributed.all_gather(feature_banks,
feature_bank,
group=mpu.get_data_parallel_group())
assert torch.all(torch.eq(feature_banks[mpu.get_data_parallel_rank()],
feature_bank))
feature_labels = [torch.zeros_like(feature_label)
for i in range(mpu.get_data_parallel_world_size())]
torch.distributed.all_gather(feature_labels,
feature_label,
group=mpu.get_data_parallel_group())
# [D, N]
feature_banks = torch.cat(feature_banks, dim=0).t().contiguous()
# [N]
feature_labels = torch.cat(feature_labels, dim=0).contiguous()
print_rank_0("feature_banks size is {}".format(feature_banks.size()))
print_rank_0("feature labels size is {}".format(feature_labels.size()))
_FEATURE_BANK = (feature_banks, feature_labels, classes)
def get_feature_bank():
global _FEATURE_BANK
assert _FEATURE_BANK is not None
return _FEATURE_BANK
# knn monitor as in InstDisc https://arxiv.org/abs/1805.01978
# implementation follows http://github.com/zhirongw/lemniscate.pytorch and
# https://github.com/leftthomas/SimCLR
def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t):
# compute cos similarity between each feature vector and feature bank ---> [B, N]
sim_matrix = torch.mm(feature, feature_bank)
# [B, K]
sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
# [B, K]
sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1),
dim=-1,
index=sim_indices)
sim_weight = (sim_weight / knn_t).exp()
# counts for each class
one_hot_label = torch.zeros(feature.size(0) * knn_k,
classes,
device=sim_labels.device)
# [B*K, C]
one_hot_label = one_hot_label.scatter(dim=-1,
index=sim_labels.view(-1, 1),
value=1.0)
# weighted score ---> [B, C]
pred_scores = torch.sum(
one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1),
dim=1)
pred_labels = pred_scores.argsort(dim=-1, descending=True)
return pred_labels
# ---------------------------------------------------------------
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# found in the LICENSE file in the root directory of this
# source tree.
# ---------------------------------------------------------------
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from torch.nn.init import trunc_normal_
from megatron.model.transformer import DropPath
from megatron.model import LayerNorm
class Mlp(nn.Module):
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.dwconv = DWConv(hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = self.fc1(x)
x = self.dwconv(x, H, W)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
sr_ratio=1):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = LayerNorm(dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
B, N, C = x.shape
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
if self.sr_ratio > 1:
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
x_ = self.norm(x_)
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
else:
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=LayerNorm, sr_ratio=1):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
return x
class OverlapPatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
super().__init__()
img_size = (img_size, img_size)
patch_size = (patch_size, patch_size)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
padding=(patch_size[0] // 2, patch_size[1] // 2))
self.norm = LayerNorm(embed_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, H, W
class MixVisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
attn_drop_rate=0., drop_path_rate=0., norm_layer=LayerNorm,
depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], output_avg=False):
super().__init__()
self.num_classes = num_classes
self.depths = depths
self.output_avg = output_avg
# patch_embed
self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
embed_dim=embed_dims[0])
self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
embed_dim=embed_dims[1])
self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
embed_dim=embed_dims[2])
self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
embed_dim=embed_dims[3])
# transformer encoder
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
cur = 0
self.block1 = nn.ModuleList([Block(
dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[0])
for i in range(depths[0])])
self.norm1 = norm_layer(embed_dims[0])
cur += depths[0]
self.block2 = nn.ModuleList([Block(
dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[1])
for i in range(depths[1])])
self.norm2 = norm_layer(embed_dims[1])
cur += depths[1]
self.block3 = nn.ModuleList([Block(
dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[2])
for i in range(depths[2])])
self.norm3 = norm_layer(embed_dims[2])
cur += depths[2]
self.block4 = nn.ModuleList([Block(
dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[3])
for i in range(depths[3])])
self.norm4 = norm_layer(embed_dims[3])
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def reset_drop_path(self, drop_path_rate):
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
cur = 0
for i in range(self.depths[0]):
self.block1[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[0]
for i in range(self.depths[1]):
self.block2[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[1]
for i in range(self.depths[2]):
self.block3[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[2]
for i in range(self.depths[3]):
self.block4[i].drop_path.drop_prob = dpr[cur + i]
def freeze_patch_emb(self):
self.patch_embed1.requires_grad = False
def forward_features(self, x):
B = x.shape[0]
outs = []
# stage 1
x, H, W = self.patch_embed1(x)
for i, blk in enumerate(self.block1):
x = blk(x, H, W)
x = self.norm1(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
# stage 2
x, H, W = self.patch_embed2(x)
for i, blk in enumerate(self.block2):
x = blk(x, H, W)
x = self.norm2(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
# stage 3
x, H, W = self.patch_embed3(x)
for i, blk in enumerate(self.block3):
x = blk(x, H, W)
x = self.norm3(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
# stage 4
x, H, W = self.patch_embed4(x)
for i, blk in enumerate(self.block4):
x = blk(x, H, W)
x = self.norm4(x)
if not self.output_avg:
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
return outs
def forward(self, x):
x = self.forward_features(x)
if self.output_avg:
x = x[3].mean(dim=1)
return x
class DWConv(nn.Module):
def __init__(self, dim=768):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
def forward(self, x, H, W):
B, N, C = x.shape
x = x.transpose(1, 2).view(B, C, H, W)
x = self.dwconv(x)
x = x.flatten(2).transpose(1, 2)
return x
class mit_b0(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b0, self).__init__(
patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
class mit_b1(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b1, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
class mit_b2(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b2, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
class mit_b3(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b3, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
class mit_b3_avg(MixVisionTransformer):
def __init__(self, drop_path_rate=0.1, **kwargs):
super(mit_b3_avg, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=drop_path_rate, output_avg=True)
class mit_b4(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b4, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
class mit_b5(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b5, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
class mit_b5_avg(MixVisionTransformer):
def __init__(self, drop_path_rate=0.1, **kwargs):
super(mit_b5_avg, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=drop_path_rate, output_avg=True)
This diff is collapsed.
import warnings
import torch
import torch.nn.functional as F
def resize(input,
size=None,
scale_factor=None,
mode='nearest',
align_corners=None,
warning=True):
if warning:
if size is not None and align_corners:
input_h, input_w = tuple(int(x) for x in input.shape[2:])
output_h, output_w = tuple(int(x) for x in size)
if output_h > input_h or output_w > output_h:
if ((output_h > 1 and output_w > 1 and input_h > 1
and input_w > 1) and (output_h - 1) % (input_h - 1)
and (output_w - 1) % (input_w - 1)):
warnings.warn(
f'When align_corners={align_corners}, '
'the output would more aligned if '
f'input size {(input_h, input_w)} is `x+1` and '
f'out size {(output_h, output_w)} is `nx+1`')
if isinstance(size, torch.Size):
size = tuple(int(x) for x in size)
return F.interpolate(input, size, scale_factor, mode, align_corners)
...@@ -147,7 +147,8 @@ class VitBackbone(MegatronModule): ...@@ -147,7 +147,8 @@ class VitBackbone(MegatronModule):
pre_process=True, pre_process=True,
post_process=True, post_process=True,
class_token=True, class_token=True,
single_token_output=False): single_token_output=False,
drop_path_rate=0.0):
super(VitBackbone, self).__init__(share_word_embeddings=False) super(VitBackbone, self).__init__(share_word_embeddings=False)
args = get_args() args = get_args()
...@@ -170,6 +171,7 @@ class VitBackbone(MegatronModule): ...@@ -170,6 +171,7 @@ class VitBackbone(MegatronModule):
self.img_w = args.img_w self.img_w = args.img_w
self.micro_batch_size = args.micro_batch_size self.micro_batch_size = args.micro_batch_size
self.single_token_output = single_token_output self.single_token_output = single_token_output
self.drop_path_rate = drop_path_rate
assert self.img_h % self.patch_dim == 0 assert self.img_h % self.patch_dim == 0
assert self.img_w % self.patch_dim == 0 assert self.img_w % self.patch_dim == 0
...@@ -216,6 +218,7 @@ class VitBackbone(MegatronModule): ...@@ -216,6 +218,7 @@ class VitBackbone(MegatronModule):
self.scaled_init_method, self.scaled_init_method,
pre_process=self.pre_process, pre_process=self.pre_process,
post_process=self.post_process, post_process=self.post_process,
drop_path_rate=self.drop_path_rate
) )
def set_input_tensor(self, input_tensor): def set_input_tensor(self, input_tensor):
......
...@@ -21,7 +21,6 @@ import sys ...@@ -21,7 +21,6 @@ import sys
import time import time
# The earliest we can measure the start time. # The earliest we can measure the start time.
_TRAIN_START_TIME = time.time() _TRAIN_START_TIME = time.time()
import torch import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
...@@ -51,7 +50,7 @@ from megatron.data.data_samplers import build_pretraining_data_loader ...@@ -51,7 +50,7 @@ from megatron.data.data_samplers import build_pretraining_data_loader
from megatron.utils import calc_params_l2_norm from megatron.utils import calc_params_l2_norm
from megatron.schedules import get_forward_backward_func from megatron.schedules import get_forward_backward_func
from megatron.utils import report_memory from megatron.utils import report_memory
from megatron.model.vision.knn_monitor import compute_feature_bank
def print_datetime(string): def print_datetime(string):
...@@ -465,11 +464,23 @@ def train_step(forward_step_func, data_iterator, ...@@ -465,11 +464,23 @@ def train_step(forward_step_func, data_iterator,
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group()) torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
timers('backward-embedding-all-reduce').stop() timers('backward-embedding-all-reduce').stop()
if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module))
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)
# Update parameters. # Update parameters.
timers('optimizer').start() timers('optimizer').start()
update_successful, grad_norm, num_zeros_in_grad = optimizer.step() update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
timers('optimizer').stop() timers('optimizer').stop()
if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module))
unwrapped_model.update_momentum(args.curr_iteration)
# Update learning rate. # Update learning rate.
if update_successful: if update_successful:
increment = get_num_microbatches() * \ increment = get_num_microbatches() * \
...@@ -702,6 +713,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, ...@@ -702,6 +713,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
report_memory_flag = True report_memory_flag = True
while iteration < args.train_iters: while iteration < args.train_iters:
update_num_microbatches(args.consumed_train_samples) update_num_microbatches(args.consumed_train_samples)
args.curr_iteration = iteration
loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \ loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
train_step(forward_step_func, train_step(forward_step_func,
train_data_iterator, train_data_iterator,
...@@ -791,6 +803,9 @@ def evaluate(forward_step_func, ...@@ -791,6 +803,9 @@ def evaluate(forward_step_func,
"""Evaluation.""" """Evaluation."""
args = get_args() args = get_args()
if args.vision_pretraining and args.vision_pretraining_type == "dino":
compute_feature_bank(model)
# Turn on evaluation mode which disables dropout. # Turn on evaluation mode which disables dropout.
for model_module in model: for model_module in model:
model_module.eval() model_module.eval()
......
...@@ -22,20 +22,32 @@ from megatron import get_args, get_timers, mpu, print_rank_0 ...@@ -22,20 +22,32 @@ from megatron import get_args, get_timers, mpu, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model import ModelType from megatron.model import ModelType
from megatron.model.vision.classification import VitClassificationModel from megatron.model.vision.classification import VitClassificationModel
from megatron.model.vision.classification import MitClassificationModel
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
def model_provider(pre_process=True, post_process=True): def model_provider(pre_process=True, post_process=True):
"""Build the model.""" """Build the model."""
print_rank_0("building VIT model ...")
args = get_args() args = get_args()
model = VitClassificationModel(num_classes=args.num_classes, if args.vision_backbone_type == 'vit':
pre_process=pre_process, print_rank_0("building VIT model ...")
post_process=post_process) model = VitClassificationModel(num_classes=args.num_classes,
pre_process=pre_process,
post_process=post_process)
elif args.vision_backbone_type == 'mit':
print_rank_0("building MIT model ...")
model = MitClassificationModel(num_classes=args.num_classes,
pre_process=pre_process,
post_process=post_process)
else:
raise Exception('{} vision backbone is not supported.'.format(
args.vision_backbone_type))
return model return model
def get_batch(data_iterator): def get_batch(data_iterator):
"""Build the batch.""" """Build the batch."""
data = next(data_iterator) data = next(data_iterator)
...@@ -46,6 +58,7 @@ def get_batch(data_iterator): ...@@ -46,6 +58,7 @@ def get_batch(data_iterator):
return images, labels return images, labels
def loss_func(labels, output_tensor): def loss_func(labels, output_tensor):
logits = output_tensor.contiguous().float() logits = output_tensor.contiguous().float()
loss = F.cross_entropy(logits, labels) loss = F.cross_entropy(logits, labels)
...@@ -58,6 +71,7 @@ def loss_func(labels, output_tensor): ...@@ -58,6 +71,7 @@ def loss_func(labels, output_tensor):
return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]} return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]}
def forward_step(data_iterator, model): def forward_step(data_iterator, model):
"""Forward step.""" """Forward step."""
timers = get_timers() timers = get_timers()
...@@ -98,5 +112,5 @@ if __name__ == "__main__": ...@@ -98,5 +112,5 @@ if __name__ == "__main__":
model_provider, model_provider,
ModelType.encoder_or_decoder, ModelType.encoder_or_decoder,
forward_step, forward_step,
args_defaults={'dataloader_type': 'cyclic'} args_defaults={'dataloader_type': 'cyclic', 'vision_pretraining': True}
) )
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import torch.distributed as dist
from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vision.dino import DINOPretrainModel
from megatron.model.vision.knn_monitor import knn_predict, get_feature_bank
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group, unwrap_model
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.model import ModelType
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
return DINOPretrainModel(pre_process=pre_process, post_process=post_process)
def get_batch(data_iterator):
"""Build the batch."""
data = next(data_iterator)
# only data parallelism; no need for broadcast
if isinstance(data[0], list):
images = [aug.cuda() for aug in data[0]]
else:
images = data[0].cuda()
labels = data[1].cuda()
return images, labels
def loss_func(model, labels, output_tensor, collect_data=False):
args = get_args()
model = unwrap_model(
model,
(torchDDP, LocalDDP, Float16Module)
)
if model.training:
student_output, teacher_output = output_tensor
loss = model.dino_loss(student_output, teacher_output, args.curr_iteration)
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {"loss": averaged_loss[0]}
else:
_, teacher_feature = output_tensor
feature_bank, feature_labels, classes = get_feature_bank()
feature = F.normalize(teacher_feature.float(), dim=1)
knn_accs = []
for k in [10, 20, 100, 200]:
pred_labels = knn_predict(feature, feature_bank,
feature_labels, classes, k, 0.07)
knn_acc = (pred_labels[:, 0] == labels).float().mean()
knn_accs.append(knn_acc)
averaged_loss = average_losses_across_data_parallel_group(knn_accs)
return 0, {"knn_acc_10": averaged_loss[0],
"knn_acc_20": averaged_loss[1],
"knn_acc_100": averaged_loss[2],
"knn_acc_200": averaged_loss[3]}
def forward_step(data_iterator, model):
"""Forward step."""
timers = get_timers()
# Get the batch.
timers("batch-generator").start()
(
images,
labels,
) = get_batch(data_iterator)
timers("batch-generator").stop()
return model(images), partial(loss_func, model, labels)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
print_rank_0(
"> building train, validation, and test datasets " "for VIT ..."
)
train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
image_size=(args.img_h, args.img_w)
)
print_rank_0("> finished creating VIT datasets ...")
return train_ds, valid_ds, None
if __name__ == "__main__":
pretrain(
train_valid_test_datasets_provider,
model_provider,
ModelType.encoder_or_decoder,
forward_step,
args_defaults={'dataloader_type': 'cyclic', 'vision_pretraining': True}
)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain VIT"""
import torch
import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0, print_rank_last
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vision.inpainting import VitInpaintingModel
from megatron.model.vision.inpainting import MitInpaintingModel
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
from tasks.vision.metrics import SSIM, PSNR
from megatron.model import ModelType
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
args = get_args()
if args.vision_backbone_type == 'vit':
model = VitInpaintingModel(pre_process=pre_process,
post_process=post_process)
elif args.vision_backbone_type == 'mit':
model = MitInpaintingModel(pre_process=pre_process,
post_process=post_process)
else:
raise Exception('{} vision backbone is not supported.'.format(
args.vision_backbone_type))
return model
def get_batch(data_iterator):
"""Build the batch."""
data = next(data_iterator)
# only data parallelism; no need for broadcast
images = data[0][0].cuda()
masks = data[0][1].cuda()
return images, masks
def loss_func(images, masks, masked_images, outputs, collect_data=False):
outputs = outputs.contiguous().float()
masks_flip = 1-masks
flip_masked_outputs = outputs.masked_fill(masks_flip.bool(), 0)
flip_masked_images = images.masked_fill(masks_flip.bool(), 0)
ssim_fun = SSIM()
psnr_fun = PSNR()
if not collect_data:
mask_count = torch.count_nonzero(masks)
loss = F.mse_loss(
flip_masked_outputs,
flip_masked_images.float(),
reduction="sum"
)
loss = loss/mask_count
ssim = ssim_fun(flip_masked_outputs, flip_masked_images.float())
psnr = psnr_fun(flip_masked_outputs, flip_masked_images.float())
averaged_loss = average_losses_across_data_parallel_group(
[loss, psnr, ssim]
)
return loss, {"loss": averaged_loss[0],
"psnr": averaged_loss[1],
'ssim': averaged_loss[2]}
else:
synth_images = masked_images.float() + flip_masked_outputs
ssim = ssim_fun(synth_images, images.float())
psnr = psnr_fun(synth_images, images.float())
return torch.cat((images, masked_images, synth_images), dim=2), ssim, psnr
def forward_step(data_iterator, model):
"""Forward step."""
timers = get_timers()
# Get the batch.
timers("batch-generator").start()
(
images,
masks,
) = get_batch(data_iterator)
timers("batch-generator").stop()
masked_images = images.masked_fill(masks.bool(), 0)
outputs = model(masked_images)
# Forward mode
return outputs, partial(loss_func, images, masks, masked_images)
def process_non_loss_data(data, iteration, writer):
psnr_sum = 0
ssim_sum = 0
for (output_tb, ssim, psnr) in data:
output_tb[output_tb < 0] = 0
output_tb[output_tb > 1] = 1
writer.add_images("gt-input-output-vald", output_tb,
global_step=iteration, walltime=None,
dataformats='NCHW')
psnr_sum = psnr_sum + psnr.item()
ssim_sum = ssim_sum + ssim.item()
psnr = psnr_sum/len(data)
ssim = ssim_sum/len(data)
writer.add_scalar('PSNR generate value-validation', psnr, iteration)
writer.add_scalar('SSIM generate value-validation', ssim, iteration)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
print_rank_0(
"> building train, validation, and test datasets " "for VIT ..."
)
train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
image_size=(args.img_h, args.img_w)
)
print_rank_0("> finished creating VIT datasets ...")
return train_ds, valid_ds, None
if __name__ == "__main__":
pretrain(
train_valid_test_datasets_provider,
model_provider,
ModelType.encoder_or_decoder,
forward_step,
process_non_loss_data,
args_defaults={'dataloader_type': 'cyclic', 'vision_pretraining': True}
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment