Commit 1cd3650d authored by Vijay Korthikanti's avatar Vijay Korthikanti Committed by Sangkug Lym
Browse files

more minor fixes

parent 6f3bf9c0
...@@ -206,9 +206,9 @@ class DinoTransform(object): ...@@ -206,9 +206,9 @@ class DinoTransform(object):
normalize normalize
]) ])
# transformation for the local small crops # transformation for the local small crops
self.local_crops_number = args.local_crops_number self.local_crops_number = args.dino_local_crops_number
self.local_transform = T.Compose([ self.local_transform = T.Compose([
T.RandomResizedCrop(args.local_img_size, T.RandomResizedCrop(args.dino_local_img_size,
scale=(0.05, scale_const), scale=(0.05, scale_const),
interpolation=Image.BICUBIC), interpolation=Image.BICUBIC),
flip_and_color_jitter, flip_and_color_jitter,
...@@ -218,12 +218,6 @@ class DinoTransform(object): ...@@ -218,12 +218,6 @@ class DinoTransform(object):
def __call__(self, image): def __call__(self, image):
crops = [] crops = []
args = get_args()
if args.street_data:
crop_transform = T.RandomCrop(300)
image = crop_transform(image)
crops.append(self.global_transform1(image)) crops.append(self.global_transform1(image))
crops.append(self.global_transform2(image)) crops.append(self.global_transform2(image))
for _ in range(self.local_crops_number): for _ in range(self.local_crops_number):
...@@ -247,9 +241,6 @@ def build_train_valid_datasets(data_path, image_size=224): ...@@ -247,9 +241,6 @@ def build_train_valid_datasets(data_path, image_size=224):
raise Exception('{} vit pretraining type is not supported.'.format( raise Exception('{} vit pretraining type is not supported.'.format(
args.vit_pretraining_type)) args.vit_pretraining_type))
train_transform = ClassificationTransform(image_size)
val_transform = ClassificationTransform(image_size, train=False)
# training dataset # training dataset
train_data_path = data_path[0] if len(data_path) <= 2 else data_path[2] train_data_path = data_path[0] if len(data_path) <= 2 else data_path[2]
train_data = ImageFolder( train_data = ImageFolder(
......
...@@ -15,11 +15,9 @@ from megatron import get_args, print_rank_0 ...@@ -15,11 +15,9 @@ from megatron import get_args, print_rank_0
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.vision.vit_backbone import VitBackbone from megatron.model.vision.vit_backbone import VitBackbone
from megatron.model.module import MegatronModule from megatron.model.module import MegatronModule
from megatron.utils import print_tensor_min_max_norm as pt
from megatron.model.vision.utils import trunc_normal_ from megatron.model.vision.utils import trunc_normal_
from megatron.model.vision.mit_backbone import mit_b5_avg from megatron.model.vision.mit_backbone import mit_b5_avg
from megatron.model.vision.esvit_swin_backbone import get_swin from megatron.model.vision.esvit_swin_backbone import get_swin
from megatron.model.vision.av_cam_trunk import get_av_cam_trunk
class DINOLoss(torch.nn.Module): class DINOLoss(torch.nn.Module):
......
...@@ -14,7 +14,8 @@ import torch.nn as nn ...@@ -14,7 +14,8 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from functools import partial from functools import partial
import torch.distributed as dist import torch.distributed as dist
from megatron.model.vision.utils import DropPath, trunc_normal_ from megatron.model.vision.utils import trunc_normal_
from megatron.model.transformer import DropPath
from megatron import get_args from megatron import get_args
from megatron.model import LayerNorm from megatron.model import LayerNorm
import numpy as np import numpy as np
...@@ -809,12 +810,12 @@ class SwinTransformer(nn.Module): ...@@ -809,12 +810,12 @@ class SwinTransformer(nn.Module):
def get_swin(is_teacher=False): def get_swin(is_teacher=False):
args = get_args() args = get_args()
if args.swin_type == "tiny": if args.swin_backbone_type == "tiny":
embed_dim = 96 embed_dim = 96
depths = [2, 2, 6, 2] depths = [2, 2, 6, 2]
num_heads = [3, 6, 12, 24] num_heads = [3, 6, 12, 24]
drop_path_rate = 0.1 drop_path_rate = 0.1
elif args.swin_type == 'h3': elif args.swin_backbone_type == 'h3':
embed_dim = 384 embed_dim = 384
depths = [2, 2, 18, 2] depths = [2, 2, 18, 2]
num_heads = [6, 12, 24, 48] num_heads = [6, 12, 24, 48]
......
...@@ -147,7 +147,8 @@ class VitBackbone(MegatronModule): ...@@ -147,7 +147,8 @@ class VitBackbone(MegatronModule):
pre_process=True, pre_process=True,
post_process=True, post_process=True,
class_token=True, class_token=True,
single_token_output=False): single_token_output=False,
drop_path_rate=0.0):
super(VitBackbone, self).__init__(share_word_embeddings=False) super(VitBackbone, self).__init__(share_word_embeddings=False)
args = get_args() args = get_args()
...@@ -170,6 +171,7 @@ class VitBackbone(MegatronModule): ...@@ -170,6 +171,7 @@ class VitBackbone(MegatronModule):
self.img_w = args.img_w self.img_w = args.img_w
self.micro_batch_size = args.micro_batch_size self.micro_batch_size = args.micro_batch_size
self.single_token_output = single_token_output self.single_token_output = single_token_output
self.drop_path_rate = drop_path_rate
assert self.img_h % self.patch_dim == 0 assert self.img_h % self.patch_dim == 0
assert self.img_w % self.patch_dim == 0 assert self.img_w % self.patch_dim == 0
...@@ -216,6 +218,7 @@ class VitBackbone(MegatronModule): ...@@ -216,6 +218,7 @@ class VitBackbone(MegatronModule):
self.scaled_init_method, self.scaled_init_method,
pre_process=self.pre_process, pre_process=self.pre_process,
post_process=self.post_process, post_process=self.post_process,
drop_path_rate=self.drop_path_rate
) )
def set_input_tensor(self, input_tensor): def set_input_tensor(self, input_tensor):
......
...@@ -23,13 +23,14 @@ import torch.distributed as dist ...@@ -23,13 +23,14 @@ import torch.distributed as dist
from functools import partial from functools import partial
from megatron import get_args, get_timers, mpu, print_rank_0 from megatron import get_args, get_timers, mpu, print_rank_0
from megatron.data.vit_dataset import build_train_valid_datasets from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vision.contrastive import DINOPretrainModel from megatron.model.vision.dino import DINOPretrainModel
from megatron.model.vision.knn_monitor import knn_predict from megatron.model.vision.knn_monitor import knn_predict
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group, unwrap_model from megatron.utils import average_losses_across_data_parallel_group, unwrap_model
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module from megatron.model import Float16Module
from megatron.model import ModelType
def model_provider(pre_process=True, post_process=True): def model_provider(pre_process=True, post_process=True):
"""Build the model.""" """Build the model."""
...@@ -116,6 +117,7 @@ if __name__ == "__main__": ...@@ -116,6 +117,7 @@ if __name__ == "__main__":
pretrain( pretrain(
train_valid_test_datasets_provider, train_valid_test_datasets_provider,
model_provider, model_provider,
ModelType.encoder_or_decoder,
forward_step, forward_step,
args_defaults={'dataloader_type': 'cyclic'} args_defaults={'dataloader_type': 'cyclic'}
) )
......
...@@ -25,7 +25,7 @@ from megatron.model.vision.inpainting import MitInpaintingModel ...@@ -25,7 +25,7 @@ from megatron.model.vision.inpainting import MitInpaintingModel
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
from tasks.vision.metrics import SSIM, PSNR from tasks.vision.metrics import SSIM, PSNR
from megatron.model import ModelType
def model_provider(pre_process=True, post_process=True): def model_provider(pre_process=True, post_process=True):
"""Build the model.""" """Build the model."""
...@@ -143,6 +143,7 @@ if __name__ == "__main__": ...@@ -143,6 +143,7 @@ if __name__ == "__main__":
pretrain( pretrain(
train_valid_test_datasets_provider, train_valid_test_datasets_provider,
model_provider, model_provider,
ModelType.encoder_or_decoder,
forward_step, forward_step,
process_non_loss_data, process_non_loss_data,
args_defaults={'dataloader_type': 'cyclic'} args_defaults={'dataloader_type': 'cyclic'}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment