"git@developer.sourcefind.cn:OpenDAS/torch-sparce.git" did not exist on "76bf1e8a4400b90ee3f3dbaa7cf984f1dfc780c9"
Commit 064bdc46 authored by Neel Kant's avatar Neel Kant
Browse files

Fix issue with validation dataloader

parent 72fb0d5c
...@@ -35,7 +35,7 @@ class InverseClozeDataset(Dataset): ...@@ -35,7 +35,7 @@ class InverseClozeDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
# get rng state corresponding to index (allows deterministic random pair) # get rng state corresponding to index (allows deterministic random pair)
rng = random.Random(idx + self.seed) rng = random.Random(idx + 20000 + self.seed)
np_rng = np.random.RandomState(seed=[rng.randint(0, 2**32-1) for _ in range(16)]) np_rng = np.random.RandomState(seed=[rng.randint(0, 2**32-1) for _ in range(16)])
# get seq length. Save 2 tokens for beginning and end # get seq length. Save 2 tokens for beginning and end
......
...@@ -98,10 +98,9 @@ def pretrain(train_val_test_data_provider, model_provider, forward_step_func, ...@@ -98,10 +98,9 @@ def pretrain(train_val_test_data_provider, model_provider, forward_step_func,
iteration = 0 iteration = 0
if args.do_train and args.train_iters > 0: if args.do_train and args.train_iters > 0:
if args.do_train: iteration, _ = train(forward_step_func,
iteration, _ = train(forward_step_func, model, optimizer, lr_scheduler,
model, optimizer, lr_scheduler, train_data_iterator, val_data_iterator)
train_data_iterator, val_data_iterator)
if args.do_valid: if args.do_valid:
...@@ -485,8 +484,8 @@ def get_train_val_test_data_iterators(train_data, val_data, test_data): ...@@ -485,8 +484,8 @@ def get_train_val_test_data_iterators(train_data, val_data, test_data):
if val_data is not None: if val_data is not None:
start_iter_val = (args.iteration // args.eval_interval) * \ start_iter_val = (args.iteration // args.eval_interval) * \
args.eval_iters args.eval_iters
val_data.batch_sampler.start_iter = start_iter_val % \ val_data.batch_sampler.start_iter = 0
len(val_data)
print_rank_0('setting validation data start iteration to {}'. print_rank_0('setting validation data start iteration to {}'.
format(val_data.batch_sampler.start_iter)) format(val_data.batch_sampler.start_iter))
......
...@@ -24,7 +24,7 @@ from megatron import get_adlr_autoresume ...@@ -24,7 +24,7 @@ from megatron import get_adlr_autoresume
from megatron import mpu from megatron import mpu
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.data.samplers import DistributedBatchSampler from megatron.data.samplers import DistributedBatchSampler, RandomSampler
from megatron.fp16 import FP16_Optimizer from megatron.fp16 import FP16_Optimizer
...@@ -102,12 +102,16 @@ def make_data_loader(dataset): ...@@ -102,12 +102,16 @@ def make_data_loader(dataset):
num_workers = args.num_workers num_workers = args.num_workers
# Use a simple sampler with distributed batch sampler. # Use a simple sampler with distributed batch sampler.
sampler = torch.utils.data.SequentialSampler(dataset) #sampler = torch.utils.data.SequentialSampler(dataset)
sampler = RandomSampler(dataset,
replacement=True,
num_samples=global_batch_size*args.train_iters)
batch_sampler = DistributedBatchSampler(sampler=sampler, batch_sampler = DistributedBatchSampler(sampler=sampler,
batch_size=global_batch_size, batch_size=global_batch_size,
drop_last=True, drop_last=True,
rank=rank, rank=rank,
world_size=world_size) world_size=world_size,
wrap_last=True)
# Torch dataloader. # Torch dataloader.
return torch.utils.data.DataLoader(dataset, return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
......
...@@ -115,7 +115,7 @@ def get_train_val_test_data(): ...@@ -115,7 +115,7 @@ def get_train_val_test_data():
# Number of train/valid/test samples. # Number of train/valid/test samples.
train_iters = args.train_iters train_iters = args.train_iters
eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters eval_iters = args.eval_iters
test_iters = args.eval_iters test_iters = args.eval_iters
train_val_test_num_samples = [train_iters * global_batch_size, train_val_test_num_samples = [train_iters * global_batch_size,
eval_iters * global_batch_size, eval_iters * global_batch_size,
...@@ -159,7 +159,7 @@ def get_train_val_test_data(): ...@@ -159,7 +159,7 @@ def get_train_val_test_data():
args.do_valid = flags[1].item() args.do_valid = flags[1].item()
args.do_test = flags[2].item() args.do_test = flags[2].item()
return train_data, val_data, test_data return train_data, valid_data, test_data
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment