Commit fcc500d6 authored by Neel Kant's avatar Neel Kant
Browse files

Ran and cleaned up

parent c044f59a
......@@ -25,7 +25,6 @@ from torch.utils.data import Dataset
from megatron import get_tokenizer, get_args
from megatron import mpu
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron.data.dataset_utils import build_training_sample
from megatron import print_rank_0
......@@ -62,6 +61,8 @@ class BertDataset(Dataset):
self.sep_id = tokenizer.sep
self.mask_id = tokenizer.mask
self.pad_id = tokenizer.pad
from megatron.data.dataset_utils import build_training_sample
self.build_sample_fn = build_training_sample
def __len__(self):
return self.samples_mapping.shape[0]
......@@ -72,13 +73,13 @@ class BertDataset(Dataset):
# Note that this rng state should be numpy and not python since
# python randint is inclusive whereas the numpy one is exclusive.
np_rng = np.random.RandomState(seed=(self.seed + idx))
return build_training_sample(sample, seq_length,
self.max_seq_length, # needed for padding
self.vocab_id_list,
self.vocab_id_to_token_dict,
self.cls_id, self.sep_id,
self.mask_id, self.pad_id,
self.masked_lm_prob, np_rng)
return self.build_sample_fn(sample, seq_length,
self.max_seq_length, # needed for padding
self.vocab_id_list,
self.vocab_id_to_token_dict,
self.cls_id, self.sep_id,
self.mask_id, self.pad_id,
self.masked_lm_prob, np_rng)
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
......
......@@ -23,7 +23,7 @@ import itertools
import numpy as np
from megatron import print_rank_0, get_args
from megatron.data.bert_dataset import get_indexed_dataset_, get_train_valid_test_split_, BertDataset
from megatron.data.bert_dataset import get_indexed_dataset_, get_train_valid_test_split_
DATASET_TYPES = ['standard_bert', 'ict']
......@@ -426,8 +426,9 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
data_impl,
skip_warmup)
if dataset_type in ['ict', 'realm']:
title_dataset = get_indexed_dataset_(data_prefix + '-titles',
if dataset_type in ['ict']:
args = get_args()
title_dataset = get_indexed_dataset_(args.titles_data_path,
data_impl,
skip_warmup)
......@@ -455,6 +456,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
print_split_stats('test', 2)
def build_dataset(index, name):
from megatron.data.bert_dataset import BertDataset
from megatron.data.realm_dataset import ICTDataset
dataset = None
if splits[index + 1] > splits[index]:
......
......@@ -5,6 +5,7 @@ import numpy as np
from torch.utils.data import Dataset
from megatron import get_tokenizer
from megatron.data.realm_dataset_utils import get_block_samples_mapping, join_str_list
class ICTDataset(Dataset):
......
......@@ -79,8 +79,8 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
#assert counts[0].item() == torch.distributed.get_world_size(
# group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
......
import os
import numpy as np
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron.checkpointing import load_checkpoint, get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.data.realm_index import detach, BlockData, FaissMIPSIndex
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.model import BertModel
from megatron.model.utils import get_linear_layer, init_method_normal
from megatron.module import MegatronModule
from megatron.utils import report_memory
from megatron import mpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment