Commit 1cd3650d authored by Vijay Korthikanti's avatar Vijay Korthikanti Committed by Sangkug Lym
Browse files

more minor fixes

parent 6f3bf9c0
......@@ -206,9 +206,9 @@ class DinoTransform(object):
normalize
])
# transformation for the local small crops
self.local_crops_number = args.local_crops_number
self.local_crops_number = args.dino_local_crops_number
self.local_transform = T.Compose([
T.RandomResizedCrop(args.local_img_size,
T.RandomResizedCrop(args.dino_local_img_size,
scale=(0.05, scale_const),
interpolation=Image.BICUBIC),
flip_and_color_jitter,
......@@ -218,12 +218,6 @@ class DinoTransform(object):
def __call__(self, image):
crops = []
args = get_args()
if args.street_data:
crop_transform = T.RandomCrop(300)
image = crop_transform(image)
crops.append(self.global_transform1(image))
crops.append(self.global_transform2(image))
for _ in range(self.local_crops_number):
......@@ -247,9 +241,6 @@ def build_train_valid_datasets(data_path, image_size=224):
raise Exception('{} vit pretraining type is not supported.'.format(
args.vit_pretraining_type))
train_transform = ClassificationTransform(image_size)
val_transform = ClassificationTransform(image_size, train=False)
# training dataset
train_data_path = data_path[0] if len(data_path) <= 2 else data_path[2]
train_data = ImageFolder(
......
......@@ -15,11 +15,9 @@ from megatron import get_args, print_rank_0
from megatron.model.utils import get_linear_layer
from megatron.model.vision.vit_backbone import VitBackbone
from megatron.model.module import MegatronModule
from megatron.utils import print_tensor_min_max_norm as pt
from megatron.model.vision.utils import trunc_normal_
from megatron.model.vision.mit_backbone import mit_b5_avg
from megatron.model.vision.esvit_swin_backbone import get_swin
from megatron.model.vision.av_cam_trunk import get_av_cam_trunk
class DINOLoss(torch.nn.Module):
......
......@@ -14,7 +14,8 @@ import torch.nn as nn
import torch.nn.functional as F
from functools import partial
import torch.distributed as dist
from megatron.model.vision.utils import DropPath, trunc_normal_
from megatron.model.vision.utils import trunc_normal_
from megatron.model.transformer import DropPath
from megatron import get_args
from megatron.model import LayerNorm
import numpy as np
......@@ -809,12 +810,12 @@ class SwinTransformer(nn.Module):
def get_swin(is_teacher=False):
args = get_args()
if args.swin_type == "tiny":
if args.swin_backbone_type == "tiny":
embed_dim = 96
depths = [2, 2, 6, 2]
num_heads = [3, 6, 12, 24]
drop_path_rate = 0.1
elif args.swin_type == 'h3':
elif args.swin_backbone_type == 'h3':
embed_dim = 384
depths = [2, 2, 18, 2]
num_heads = [6, 12, 24, 48]
......
......@@ -147,7 +147,8 @@ class VitBackbone(MegatronModule):
pre_process=True,
post_process=True,
class_token=True,
single_token_output=False):
single_token_output=False,
drop_path_rate=0.0):
super(VitBackbone, self).__init__(share_word_embeddings=False)
args = get_args()
......@@ -170,6 +171,7 @@ class VitBackbone(MegatronModule):
self.img_w = args.img_w
self.micro_batch_size = args.micro_batch_size
self.single_token_output = single_token_output
self.drop_path_rate = drop_path_rate
assert self.img_h % self.patch_dim == 0
assert self.img_w % self.patch_dim == 0
......@@ -216,6 +218,7 @@ class VitBackbone(MegatronModule):
self.scaled_init_method,
pre_process=self.pre_process,
post_process=self.post_process,
drop_path_rate=self.drop_path_rate
)
def set_input_tensor(self, input_tensor):
......
......@@ -23,13 +23,14 @@ import torch.distributed as dist
from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vision.contrastive import DINOPretrainModel
from megatron.model.vision.dino import DINOPretrainModel
from megatron.model.vision.knn_monitor import knn_predict
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group, unwrap_model
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.model import ModelType
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
......@@ -116,6 +117,7 @@ if __name__ == "__main__":
pretrain(
train_valid_test_datasets_provider,
model_provider,
ModelType.encoder_or_decoder,
forward_step,
args_defaults={'dataloader_type': 'cyclic'}
)
......
......@@ -25,7 +25,7 @@ from megatron.model.vision.inpainting import MitInpaintingModel
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
from tasks.vision.metrics import SSIM, PSNR
from megatron.model import ModelType
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
......@@ -143,6 +143,7 @@ if __name__ == "__main__":
pretrain(
train_valid_test_datasets_provider,
model_provider,
ModelType.encoder_or_decoder,
forward_step,
process_non_loss_data,
args_defaults={'dataloader_type': 'cyclic'}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment