"docs/source/vscode:/vscode.git/clone" did not exist on "aad12df6a63b7c2269bc8ed68b10b9099b5df46d"
Commit 064a6881 authored by Neel Kant's avatar Neel Kant
Browse files

Re-adjust dataloader properties

parent 46b2addd
...@@ -24,7 +24,7 @@ from megatron import get_adlr_autoresume ...@@ -24,7 +24,7 @@ from megatron import get_adlr_autoresume
from megatron import mpu from megatron import mpu
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.data.samplers import DistributedBatchSampler, RandomSampler from megatron.data.samplers import DistributedBatchSampler
from megatron.fp16 import FP16_Optimizer from megatron.fp16 import FP16_Optimizer
...@@ -102,16 +102,12 @@ def make_data_loader(dataset): ...@@ -102,16 +102,12 @@ def make_data_loader(dataset):
num_workers = args.num_workers num_workers = args.num_workers
# Use a simple sampler with distributed batch sampler. # Use a simple sampler with distributed batch sampler.
#sampler = torch.utils.data.SequentialSampler(dataset) sampler = torch.utils.data.SequentialSampler(dataset)
sampler = RandomSampler(dataset,
replacement=True,
num_samples=global_batch_size*args.train_iters)
batch_sampler = DistributedBatchSampler(sampler=sampler, batch_sampler = DistributedBatchSampler(sampler=sampler,
batch_size=global_batch_size, batch_size=global_batch_size,
drop_last=True, drop_last=True,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size)
wrap_last=True)
# Torch dataloader. # Torch dataloader.
return torch.utils.data.DataLoader(dataset, return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
......
...@@ -102,7 +102,7 @@ def get_train_val_test_data(): ...@@ -102,7 +102,7 @@ def get_train_val_test_data():
"""Load the data on rank zero and boradcast number of tokens to all GPUS.""" """Load the data on rank zero and boradcast number of tokens to all GPUS."""
args = get_args() args = get_args()
(train_data, val_data, test_data) = (None, None, None) (train_data, valid_data, test_data) = (None, None, None)
# Data loader only on rank 0 of each model parallel group. # Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0: if mpu.get_model_parallel_rank() == 0:
...@@ -115,7 +115,7 @@ def get_train_val_test_data(): ...@@ -115,7 +115,7 @@ def get_train_val_test_data():
# Number of train/valid/test samples. # Number of train/valid/test samples.
train_iters = args.train_iters train_iters = args.train_iters
eval_iters = args.eval_iters eval_iters = (train_iters // args.eval_iters + 1) * args.eval_iters
test_iters = args.eval_iters test_iters = args.eval_iters
train_val_test_num_samples = [train_iters * global_batch_size, train_val_test_num_samples = [train_iters * global_batch_size,
eval_iters * global_batch_size, eval_iters * global_batch_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment