Commit 6f3bf9c0 authored by Vijay Korthikanti's avatar Vijay Korthikanti Committed by Sangkug Lym
Browse files

minor fixes

parent 48c2a144
...@@ -880,10 +880,9 @@ def _add_vision_args(parser): ...@@ -880,10 +880,9 @@ def _add_vision_args(parser):
group.add_argument('--head-lr-mult', type=float, default=1.0, group.add_argument('--head-lr-mult', type=float, default=1.0,
help='learning rate multiplier for head during finetuning') help='learning rate multiplier for head during finetuning')
# pretraining type and backbone selection` # pretraining type and backbone selection`
group.add_argument('--vision-pretraining-type', type=str, default='classify', group.add_argument('--vision-pretraining-type', type=str, default='classify',
choices=['classify', 'inpaint', 'contrast'], choices=['classify', 'inpaint', 'dino'],
help='pretraining objectives') help='pretraining objectives')
group.add_argument('--vision-backbone-type', type=str, default='vit', group.add_argument('--vision-backbone-type', type=str, default='vit',
choices=['vit', 'mit', 'swin'], choices=['vit', 'mit', 'swin'],
...@@ -898,7 +897,6 @@ def _add_vision_args(parser): ...@@ -898,7 +897,6 @@ def _add_vision_args(parser):
help='mask types') help='mask types')
group.add_argument('--mask-factor', type=float, default=1.0, group.add_argument('--mask-factor', type=float, default=1.0,
help='mask size scaling parameter') help='mask size scaling parameter')
# dino arguments # dino arguments
group.add_argument('--iter-per-epoch', type=int, default=1250, group.add_argument('--iter-per-epoch', type=int, default=1250,
......
...@@ -251,7 +251,7 @@ def build_train_valid_datasets(data_path, image_size=224): ...@@ -251,7 +251,7 @@ def build_train_valid_datasets(data_path, image_size=224):
val_transform = ClassificationTransform(image_size, train=False) val_transform = ClassificationTransform(image_size, train=False)
# training dataset # training dataset
train_data_path = data_path[0] if len(data_path) <= 2 else data_path[2] #TODO VIJAY train_data_path = data_path[0] if len(data_path) <= 2 else data_path[2]
train_data = ImageFolder( train_data = ImageFolder(
root=train_data_path, root=train_data_path,
transform=train_transform, transform=train_transform,
......
...@@ -68,7 +68,7 @@ class VitClassificationModel(MegatronModule): ...@@ -68,7 +68,7 @@ class VitClassificationModel(MegatronModule):
class MitClassificationModel(MegatronModule): class MitClassificationModel(MegatronModule):
"""Mix vision Transformer Model.""" """Mix vision Transformer Model."""
def __init__(self, num_classes def __init__(self, num_classes,
pre_process=True, post_process=True): pre_process=True, post_process=True):
super(MitClassificationModel, self).__init__() super(MitClassificationModel, self).__init__()
args = get_args() args = get_args()
......
...@@ -8,7 +8,8 @@ import torch ...@@ -8,7 +8,8 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from functools import partial from functools import partial
from megatron.model.vision.utils import DropPath, trunc_normal_ from megatron.model.vision.utils import trunc_normal_
from megatron.model.transformer import DropPath
from megatron.model import LayerNorm from megatron.model import LayerNorm
......
...@@ -714,7 +714,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, ...@@ -714,7 +714,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
report_memory_flag = True report_memory_flag = True
while iteration < args.train_iters: while iteration < args.train_iters:
update_num_microbatches(args.consumed_train_samples) update_num_microbatches(args.consumed_train_samples)
args.curr_iteration = iteration args.curr_iteration = iteration
loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \ loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
train_step(forward_step_func, train_step(forward_step_func,
train_data_iterator, train_data_iterator,
...@@ -804,7 +804,7 @@ def evaluate(forward_step_func, ...@@ -804,7 +804,7 @@ def evaluate(forward_step_func,
"""Evaluation.""" """Evaluation."""
args = get_args() args = get_args()
if args.vision_pretraining_type == "contrast": if args.vision_pretraining_type == "dino":
args.knn_features = compute_feature_bank(model) args.knn_features = compute_feature_bank(model)
# Turn on evaluation mode which disables dropout. # Turn on evaluation mode which disables dropout.
......
...@@ -30,14 +30,15 @@ from megatron.utils import average_losses_across_data_parallel_group ...@@ -30,14 +30,15 @@ from megatron.utils import average_losses_across_data_parallel_group
def model_provider(pre_process=True, post_process=True): def model_provider(pre_process=True, post_process=True):
"""Build the model.""" """Build the model."""
print_rank_0("building VIT model ...")
args = get_args() args = get_args()
if args.vision_backbone_type == 'vit': if args.vision_backbone_type == 'vit':
print_rank_0("building VIT model ...")
model = VitClassificationModel(num_classes=args.num_classes, model = VitClassificationModel(num_classes=args.num_classes,
pre_process=pre_process, pre_process=pre_process,
post_process=post_process) post_process=post_process)
elif args.vision_backbone_type == 'mit': elif args.vision_backbone_type == 'mit':
print_rank_0("building MIT model ...")
model = MitClassificationModel(num_classes=args.num_classes, model = MitClassificationModel(num_classes=args.num_classes,
pre_process=pre_process, pre_process=pre_process,
post_process=post_process) post_process=post_process)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment