Commit a7169297 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

Addressing review comments

parent 58edb19a
...@@ -362,7 +362,9 @@ def _add_training_args(parser): ...@@ -362,7 +362,9 @@ def _add_training_args(parser):
group.add_argument('--optimizer', type=str, default='adam', group.add_argument('--optimizer', type=str, default='adam',
choices=['adam', 'sgd'], choices=['adam', 'sgd'],
help='Optimizer function') help='Optimizer function')
group.add_argument('--dataloader_type', type=str, default='single',
choices=['single', 'cyclic'],
help='Single pass vs multiple pass data loader')
return parser return parser
......
...@@ -17,12 +17,12 @@ ...@@ -17,12 +17,12 @@
import torch import torch
import random
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
def build_pretraining_data_loader(dataset, consumed_samples, random_sample=False): def build_pretraining_data_loader(dataset, consumed_samples):
"""Buld dataloader given an input dataset.""" """Buld dataloader given an input dataset."""
if dataset is None: if dataset is None:
...@@ -30,13 +30,23 @@ def build_pretraining_data_loader(dataset, consumed_samples, random_sample=False ...@@ -30,13 +30,23 @@ def build_pretraining_data_loader(dataset, consumed_samples, random_sample=False
args = get_args() args = get_args()
# Megatron sampler # Megatron sampler
batch_sampler = MegatronPretrainingSampler( if args.dataloader_type == 'single':
total_samples=len(dataset), batch_sampler = MegatronPretrainingSampler(
consumed_samples=consumed_samples, total_samples=len(dataset),
micro_batch_size=args.micro_batch_size, consumed_samples=consumed_samples,
data_parallel_rank=mpu.get_data_parallel_rank(), micro_batch_size=args.micro_batch_size,
data_parallel_size=mpu.get_data_parallel_world_size(), data_parallel_rank=mpu.get_data_parallel_rank(),
random_sample=random_sample) data_parallel_size=mpu.get_data_parallel_world_size())
elif args.dataloader_type == 'cyclic':
batch_sampler = MegatronPretrainingRandomSampler(
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size())
else:
raise Exception('{} dataloader type is not supported.'.format(
args.dataloader_type))
# Torch dataloader. # Torch dataloader.
return torch.utils.data.DataLoader(dataset, return torch.utils.data.DataLoader(dataset,
...@@ -44,11 +54,10 @@ def build_pretraining_data_loader(dataset, consumed_samples, random_sample=False ...@@ -44,11 +54,10 @@ def build_pretraining_data_loader(dataset, consumed_samples, random_sample=False
num_workers=args.num_workers, num_workers=args.num_workers,
pin_memory=True) pin_memory=True)
class MegatronPretrainingSampler: class MegatronPretrainingSampler:
def __init__(self, total_samples, consumed_samples, micro_batch_size, def __init__(self, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size, random_sample=False): data_parallel_rank, data_parallel_size):
# Keep a copy of input params for later use. # Keep a copy of input params for later use.
self.total_samples = total_samples self.total_samples = total_samples
self.consumed_samples = consumed_samples self.consumed_samples = consumed_samples
...@@ -56,14 +65,13 @@ class MegatronPretrainingSampler: ...@@ -56,14 +65,13 @@ class MegatronPretrainingSampler:
self.data_parallel_rank = data_parallel_rank self.data_parallel_rank = data_parallel_rank
self.micro_batch_times_data_parallel_size = \ self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size self.micro_batch_size * data_parallel_size
self.random_sample = random_sample
# Sanity checks. # Sanity checks.
assert self.total_samples > 0, \ assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples) 'no sample to consume: {}'.format(self.total_samples)
#assert self.consumed_samples < self.total_samples, \ assert self.consumed_samples < self.total_samples, \
# 'no samples left to consume: {}, {}'.format(self.consumed_samples, 'no samples left to consume: {}, {}'.format(self.consumed_samples,
# self.total_samples) self.total_samples)
assert self.micro_batch_size > 0 assert self.micro_batch_size > 0
assert data_parallel_size > 0 assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, \ assert self.data_parallel_rank < data_parallel_size, \
...@@ -74,25 +82,64 @@ class MegatronPretrainingSampler: ...@@ -74,25 +82,64 @@ class MegatronPretrainingSampler:
return self.total_samples return self.total_samples
def __iter__(self): def __iter__(self):
self.epoch = self.consumed_samples // self.total_samples
current_epoch_samples = self.consumed_samples % self.total_samples
if self.random_sample:
g = torch.Generator()
g.manual_seed(self.epoch)
idx_range_total = \
torch.randperm(self.total_samples, generator=g).tolist()
idx_range = idx_range_total[current_epoch_samples:]
else:
idx_range = range(current_epoch_samples, self.total_samples)
batch = [] batch = []
# Last batch if not complete will be dropped. # Last batch if not complete will be dropped.
for idx in idx_range: for idx in range(self.consumed_samples, self.total_samples):
batch.append(idx) batch.append(idx)
if len(batch) == self.micro_batch_times_data_parallel_size: if len(batch) == self.micro_batch_times_data_parallel_size:
self.consumed_samples += len(batch)
start_idx = self.data_parallel_rank * self.micro_batch_size start_idx = self.data_parallel_rank * self.micro_batch_size
end_idx = start_idx + self.micro_batch_size end_idx = start_idx + self.micro_batch_size
yield batch[start_idx:end_idx] yield batch[start_idx:end_idx]
batch = [] batch = []
self.consumed_samples += len(batch)
class MegatronPretrainingRandomSampler:
def __init__(self, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size):
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size
# Sanity checks.
assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)
def __len__(self):
return self.total_samples
def __iter__(self):
self.epoch = self.consumed_samples // self.total_samples
current_epoch_samples = self.consumed_samples % self.total_samples
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
# data sharding and random sampling
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
* self.micro_batch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size
g = torch.Generator()
g.manual_seed(self.epoch)
random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
batch = []
# Last batch if not complete will be dropped.
for idx in idx_range:
batch.append(idx)
if len(batch) == self.micro_batch_size:
self.consumed_samples += self.micro_batch_times_data_parallel_size
yield batch
batch = []
self.consumed_samples += self.total_samples % self.micro_batch_times_data_parallel_size
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import torch
from torchvision import datasets, transforms from torchvision import datasets, transforms
from megatron.data.autoaugment import ImageNetPolicy from megatron.data.autoaugment import ImageNetPolicy
...@@ -32,7 +33,8 @@ def build_train_valid_datasets(data_path, crop_size=224, color_jitter=True): ...@@ -32,7 +33,8 @@ def build_train_valid_datasets(data_path, crop_size=224, color_jitter=True):
brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1 brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1
) )
] ]
process += [ImageNetPolicy(), transforms.ToTensor(), normalize] fp16_t = transforms.ConvertImageDtype(torch.half)
process += [ImageNetPolicy(), transforms.ToTensor(), normalize, fp16_t]
transform_train = transforms.Compose(process) transform_train = transforms.Compose(process)
train_data = datasets.ImageFolder( train_data = datasets.ImageFolder(
root=train_data_path, transform=transform_train root=train_data_path, transform=transform_train
...@@ -46,6 +48,7 @@ def build_train_valid_datasets(data_path, crop_size=224, color_jitter=True): ...@@ -46,6 +48,7 @@ def build_train_valid_datasets(data_path, crop_size=224, color_jitter=True):
transforms.CenterCrop(crop_size), transforms.CenterCrop(crop_size),
transforms.ToTensor(), transforms.ToTensor(),
normalize, normalize,
fp16_t
] ]
) )
val_data = datasets.ImageFolder( val_data = datasets.ImageFolder(
......
...@@ -122,7 +122,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -122,7 +122,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
assert input.dim() == 4 assert input.dim() == 4
# invoke custom kernel # invoke custom kernel
if self.input_in_fp16 and key_seq_len <= 2048 and \ if self.input_in_fp16 and key_seq_len <= 2048 and mask is not None and \
query_seq_len % 4 == 0 and self.scaled_masked_softmax_fusion: query_seq_len % 4 == 0 and self.scaled_masked_softmax_fusion:
scale = self.scale if self.scale is not None else 1.0 scale = self.scale if self.scale is not None else 1.0
...@@ -142,7 +142,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -142,7 +142,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
if self.scale is not None: if self.scale is not None:
input = input * self.scale input = input * self.scale
mask_output = self.mask_func(input, mask) if mask else input mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output) probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_fp16 and self.softmax_in_fp32: if self.input_in_fp16 and self.softmax_in_fp32:
......
...@@ -120,7 +120,7 @@ def twod_interpolate_position_embeddings_hook( ...@@ -120,7 +120,7 @@ def twod_interpolate_position_embeddings_hook(
class VitModel(MegatronModule): class VitModel(MegatronModule):
"""Bert Language model.""" """Vision Transformer Model."""
def __init__(self, num_classes, finetune=False): def __init__(self, num_classes, finetune=False):
super(VitModel, self).__init__() super(VitModel, self).__init__()
......
...@@ -59,12 +59,14 @@ def get_megatron_optimizer(model): ...@@ -59,12 +59,14 @@ def get_megatron_optimizer(model):
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2), betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps) eps=args.adam_eps)
else: elif args.optimizer == 'sgd':
assert args.optimizer == 'sgd'
optimizer = SGD(param_groups, optimizer = SGD(param_groups,
lr=args.lr, lr=args.lr,
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
momentum=args.sgd_momentum) momentum=args.sgd_momentum)
else:
raise Exception('{} optimizer is not supported.'.format(
args.optimizer))
if args.fp16: if args.fp16:
# Constant loss scale. # Constant loss scale.
......
...@@ -46,7 +46,7 @@ from megatron.learning_rates import AnnealingLR ...@@ -46,7 +46,7 @@ from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model.realm_model import ICTBertModel from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.data.data_loaders import build_pretraining_data_loader from megatron.data.data_samplers import build_pretraining_data_loader
from megatron.utils import report_memory from megatron.utils import report_memory
...@@ -61,8 +61,7 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -61,8 +61,7 @@ def pretrain(train_valid_test_dataset_provider,
model_provider, model_provider,
forward_step_func, forward_step_func,
extra_args_provider=None, extra_args_provider=None,
args_defaults={}, args_defaults={}):
random_sample = False):
"""Main training program. """Main training program.
This function will run the followings in the order provided: This function will run the followings in the order provided:
...@@ -117,8 +116,7 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -117,8 +116,7 @@ def pretrain(train_valid_test_dataset_provider,
timers('train/valid/test data iterators').start() timers('train/valid/test data iterators').start()
train_data_iterator, valid_data_iterator, test_data_iterator \ train_data_iterator, valid_data_iterator, test_data_iterator \
= build_train_valid_test_data_iterators( = build_train_valid_test_data_iterators(
train_valid_test_dataset_provider, train_valid_test_dataset_provider)
random_sample)
timers('train/valid/test data iterators').stop() timers('train/valid/test data iterators').stop()
print_datetime('after dataloaders are built') print_datetime('after dataloaders are built')
...@@ -955,13 +953,13 @@ def evaluate_and_print_results(prefix, forward_step_func, ...@@ -955,13 +953,13 @@ def evaluate_and_print_results(prefix, forward_step_func,
print_rank_last('-' * length) print_rank_last('-' * length)
def cyclic_iterable(iterable): def cyclic_iter(iter):
while True: while True:
for x in iterable: for x in iter:
yield x yield x
def build_train_valid_test_data_iterators( def build_train_valid_test_data_iterators(
build_train_valid_test_datasets_provider, random_sample=False): build_train_valid_test_datasets_provider):
"""XXX""" """XXX"""
args = get_args() args = get_args()
...@@ -1005,10 +1003,10 @@ def build_train_valid_test_data_iterators( ...@@ -1005,10 +1003,10 @@ def build_train_valid_test_data_iterators(
# Build dataloders. # Build dataloders.
train_dataloader = build_pretraining_data_loader( train_dataloader = build_pretraining_data_loader(
train_ds, args.consumed_train_samples, random_sample) train_ds, args.consumed_train_samples)
valid_dataloader = build_pretraining_data_loader( valid_dataloader = build_pretraining_data_loader(
valid_ds, args.consumed_valid_samples, random_sample) valid_ds, args.consumed_valid_samples)
test_dataloader = build_pretraining_data_loader(test_ds, 0, random_sample) test_dataloader = build_pretraining_data_loader(test_ds, 0)
# Flags to know if we need to do training/validation/testing. # Flags to know if we need to do training/validation/testing.
do_train = train_dataloader is not None and args.train_iters > 0 do_train = train_dataloader is not None and args.train_iters > 0
...@@ -1028,19 +1026,26 @@ def build_train_valid_test_data_iterators( ...@@ -1028,19 +1026,26 @@ def build_train_valid_test_data_iterators(
args.do_valid = flags[1].item() args.do_valid = flags[1].item()
args.do_test = flags[2].item() args.do_test = flags[2].item()
# Build iterators. # Build iterators.
dl_type = args.dataloader_type
assert dl_type in ['single', 'cyclic']
if train_dataloader is not None: if train_dataloader is not None:
train_data_iterator = iter(cyclic_iterable(train_dataloader)) train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
else iter(cyclic_iter(train_dataloader))
else: else:
train_data_iterator = None train_data_iterator = None
if valid_dataloader is not None: if valid_dataloader is not None:
valid_data_iterator = iter(cyclic_iterable(valid_dataloader)) valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \
else iter(cyclic_iter(valid_dataloader))
else: else:
valid_data_iterator = None valid_data_iterator = None
if test_dataloader is not None: if test_dataloader is not None:
test_data_iterator = iter(cyclic_iterable(test_dataloader)) test_data_iterator = iter(test_dataloader) if dl_type == 'single' \
else iter(cyclic_iter(test_dataloader))
else: else:
test_data_iterator = None test_data_iterator = None
......
...@@ -23,7 +23,6 @@ from megatron.model import VitModel ...@@ -23,7 +23,6 @@ from megatron.model import VitModel
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
def model_provider(): def model_provider():
"""Build the model.""" """Build the model."""
...@@ -33,43 +32,28 @@ def model_provider(): ...@@ -33,43 +32,28 @@ def model_provider():
model = VitModel(num_classes=args.num_classes) model = VitModel(num_classes=args.num_classes)
return model return model
def get_batch(data_iterator): def get_batch(data_iterator):
"""Build the batch.""" """Build the batch."""
data = next(data_iterator)
# Items and their type. # only data parallelism; no need for broadcast
keys = ["image", "label"] images = data[0].cuda()
datatype = torch.half labels = data[1].cuda()
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
dict_data = {}
dict_data["image"] = data[0].half()
dict_data["label"] = data[1].half()
data_b = mpu.broadcast_data(keys, dict_data, datatype)
# Unpack.
images = data_b["image"]
labels = data_b["label"].long()
return images, labels return images, labels
def forward_step(data_iterator, model, input_tensor): def forward_step(data_iterator, model, input_tensor):
"""Forward step.""" """Forward step."""
timers = get_timers() timers = get_timers()
assert input_tensor is None assert input_tensor is None
# Get the batch. # Get the batch.
timers("batch generator").start() timers("batch-generator").start()
( (
images, images,
labels, labels,
) = get_batch(data_iterator) ) = get_batch(data_iterator)
timers("batch generator").stop() timers("batch-generator").stop()
# Forward model. lm_labels # Forward model. lm_labels
logits = model(images).contiguous().float() logits = model(images).contiguous().float()
...@@ -103,5 +87,5 @@ if __name__ == "__main__": ...@@ -103,5 +87,5 @@ if __name__ == "__main__":
train_valid_test_datasets_provider, train_valid_test_datasets_provider,
model_provider, model_provider,
forward_step, forward_step,
random_sample=True args_defaults={'dataloader_type': 'cyclic'}
) )
...@@ -33,8 +33,8 @@ from megatron.utils import average_losses_across_data_parallel_group ...@@ -33,8 +33,8 @@ from megatron.utils import average_losses_across_data_parallel_group
def process_batch(batch): def process_batch(batch):
"""Process batch and produce inputs for the model.""" """Process batch and produce inputs for the model."""
images = batch[0].half().cuda().contiguous() images = batch[0].cuda().contiguous()
labels = batch[1].long().cuda().contiguous() labels = batch[1].cuda().contiguous()
return images, labels return images, labels
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment