Commit 3b29f0e7 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

minor fixes

parent 7a77abd9
......@@ -164,7 +164,7 @@ class MegatronPretrainingRandomSampler:
current_epoch_samples = self.consumed_samples % active_total_samples
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
if isinstance(dataset, RandomSeedDataset):
if isinstance(self.dataset, RandomSeedDataset):
self.dataset.set_epoch(self.epoch)
# data sharding and random sampling
......
......@@ -35,7 +35,7 @@ class ClassificationTransform():
T.ColorJitter(0.4, 0.4, 0.4, 0.1),
ImageNetPolicy(),
T.ToTensor(),
T.Normalize(*self.mean_std),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
T.ConvertImageDtype(self.data_type)
])
else:
......
......@@ -19,8 +19,6 @@ import torch
from megatron import get_args
from megatron.model.utils import get_linear_layer
from megatron.model.vision.vit_backbone import VitBackbone, VitMlpHead
from megatron.model.vision.mit_backbone import mit_b3_avg
from megatron.model.vision.utils import trunc_normal_
from megatron.model.module import MegatronModule
class VitClassificationModel(MegatronModule):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment