Commit e1f9c3a5 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

minor fixes

parent 01a82723
......@@ -873,10 +873,9 @@ def _add_vision_args(parser):
group.add_argument('--head-lr-mult', type=float, default=1.0,
help='learning rate multiplier for head during finetuning')
# pretraining type and backbone selection`
group.add_argument('--vision-pretraining-type', type=str, default='classify',
choices=['classify', 'inpaint', 'contrast'],
choices=['classify', 'inpaint', 'dino'],
help='pretraining objectives')
group.add_argument('--vision-backbone-type', type=str, default='vit',
choices=['vit', 'mit', 'swin'],
......@@ -891,7 +890,6 @@ def _add_vision_args(parser):
help='mask types')
group.add_argument('--mask-factor', type=float, default=1.0,
help='mask size scaling parameter')
# dino arguments
group.add_argument('--iter-per-epoch', type=int, default=1250,
......
......@@ -251,7 +251,7 @@ def build_train_valid_datasets(data_path, image_size=224):
val_transform = ClassificationTransform(image_size, train=False)
# training dataset
train_data_path = data_path[0] if len(data_path) <= 2 else data_path[2] #TODO VIJAY
train_data_path = data_path[0] if len(data_path) <= 2 else data_path[2]
train_data = ImageFolder(
root=train_data_path,
transform=train_transform,
......
......@@ -68,7 +68,7 @@ class VitClassificationModel(MegatronModule):
class MitClassificationModel(MegatronModule):
"""Mix vision Transformer Model."""
def __init__(self, num_classes
def __init__(self, num_classes,
pre_process=True, post_process=True):
super(MitClassificationModel, self).__init__()
args = get_args()
......
......@@ -8,7 +8,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from megatron.model.vision.utils import DropPath, trunc_normal_
from megatron.model.vision.utils import trunc_normal_
from megatron.model.transformer import DropPath
from megatron.model import LayerNorm
......
......@@ -714,7 +714,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
report_memory_flag = True
while iteration < args.train_iters:
update_num_microbatches(args.consumed_train_samples)
args.curr_iteration = iteration
args.curr_iteration = iteration
loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
train_step(forward_step_func,
train_data_iterator,
......@@ -804,7 +804,7 @@ def evaluate(forward_step_func,
"""Evaluation."""
args = get_args()
if args.vision_pretraining_type == "contrast":
if args.vision_pretraining_type == "dino":
args.knn_features = compute_feature_bank(model)
# Turn on evaluation mode which disables dropout.
......
......@@ -30,14 +30,15 @@ from megatron.utils import average_losses_across_data_parallel_group
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0("building VIT model ...")
args = get_args()
if args.vision_backbone_type == 'vit':
print_rank_0("building VIT model ...")
model = VitClassificationModel(num_classes=args.num_classes,
pre_process=pre_process,
post_process=post_process)
elif args.vision_backbone_type == 'mit':
print_rank_0("building MIT model ...")
model = MitClassificationModel(num_classes=args.num_classes,
pre_process=pre_process,
post_process=post_process)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment