Commit 85589322 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

dataloader_type argument fix + randomsampler fix

parent e6c7b05e
......@@ -124,6 +124,9 @@ def parse_args(extra_args_provider=None, defaults={},
print('using {} for parameters ...'.format(args.params_dtype),
flush=True)
if args.dataloader_type is None:
args.dataloader_type = 'single'
# Consumed tokens.
args.consumed_train_samples = 0
args.consumed_valid_samples = 0
......@@ -365,7 +368,7 @@ def _add_training_args(parser):
group.add_argument('--optimizer', type=str, default='adam',
choices=['adam', 'sgd'],
help='Optimizer function')
group.add_argument('--dataloader_type', type=str, default='single',
group.add_argument('--dataloader-type', type=str, default=None,
choices=['single', 'cyclic'],
help='Single pass vs multiple pass data loader')
return parser
......
......@@ -105,6 +105,8 @@ class MegatronPretrainingRandomSampler:
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size
self.last_batch_size = \
self.total_samples % self.micro_batch_times_data_parallel_size
# Sanity checks.
assert self.total_samples > 0, \
......@@ -119,8 +121,9 @@ class MegatronPretrainingRandomSampler:
return self.total_samples
def __iter__(self):
self.epoch = self.consumed_samples // self.total_samples
current_epoch_samples = self.consumed_samples % self.total_samples
active_total_samples = self.total_samples - self.last_batch_size
self.epoch = self.consumed_samples // active_total_samples
current_epoch_samples = self.consumed_samples % active_total_samples
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
# data sharding and random sampling
......@@ -132,7 +135,7 @@ class MegatronPretrainingRandomSampler:
g = torch.Generator()
g.manual_seed(self.epoch)
random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
batch = []
# Last batch if not complete will be dropped.
......@@ -142,4 +145,3 @@ class MegatronPretrainingRandomSampler:
self.consumed_samples += self.micro_batch_times_data_parallel_size
yield batch
batch = []
self.consumed_samples += self.total_samples % self.micro_batch_times_data_parallel_size
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment