"vscode:/vscode.git/clone" did not exist on "af7b1c3bf233346e9bf100d0dc6fc5f752bc5cca"
Commit fcc500d6 authored by Neel Kant's avatar Neel Kant
Browse files

Ran and cleaned up

parent c044f59a
...@@ -25,7 +25,6 @@ from torch.utils.data import Dataset ...@@ -25,7 +25,6 @@ from torch.utils.data import Dataset
from megatron import get_tokenizer, get_args from megatron import get_tokenizer, get_args
from megatron import mpu from megatron import mpu
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron.data.dataset_utils import build_training_sample
from megatron import print_rank_0 from megatron import print_rank_0
...@@ -62,6 +61,8 @@ class BertDataset(Dataset): ...@@ -62,6 +61,8 @@ class BertDataset(Dataset):
self.sep_id = tokenizer.sep self.sep_id = tokenizer.sep
self.mask_id = tokenizer.mask self.mask_id = tokenizer.mask
self.pad_id = tokenizer.pad self.pad_id = tokenizer.pad
from megatron.data.dataset_utils import build_training_sample
self.build_sample_fn = build_training_sample
def __len__(self): def __len__(self):
return self.samples_mapping.shape[0] return self.samples_mapping.shape[0]
...@@ -72,7 +73,7 @@ class BertDataset(Dataset): ...@@ -72,7 +73,7 @@ class BertDataset(Dataset):
# Note that this rng state should be numpy and not python since # Note that this rng state should be numpy and not python since
# python randint is inclusive whereas the numpy one is exclusive. # python randint is inclusive whereas the numpy one is exclusive.
np_rng = np.random.RandomState(seed=(self.seed + idx)) np_rng = np.random.RandomState(seed=(self.seed + idx))
return build_training_sample(sample, seq_length, return self.build_sample_fn(sample, seq_length,
self.max_seq_length, # needed for padding self.max_seq_length, # needed for padding
self.vocab_id_list, self.vocab_id_list,
self.vocab_id_to_token_dict, self.vocab_id_to_token_dict,
......
...@@ -23,7 +23,7 @@ import itertools ...@@ -23,7 +23,7 @@ import itertools
import numpy as np import numpy as np
from megatron import print_rank_0, get_args from megatron import print_rank_0, get_args
from megatron.data.bert_dataset import get_indexed_dataset_, get_train_valid_test_split_, BertDataset from megatron.data.bert_dataset import get_indexed_dataset_, get_train_valid_test_split_
DATASET_TYPES = ['standard_bert', 'ict'] DATASET_TYPES = ['standard_bert', 'ict']
...@@ -426,8 +426,9 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -426,8 +426,9 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
data_impl, data_impl,
skip_warmup) skip_warmup)
if dataset_type in ['ict', 'realm']: if dataset_type in ['ict']:
title_dataset = get_indexed_dataset_(data_prefix + '-titles', args = get_args()
title_dataset = get_indexed_dataset_(args.titles_data_path,
data_impl, data_impl,
skip_warmup) skip_warmup)
...@@ -455,6 +456,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -455,6 +456,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
print_split_stats('test', 2) print_split_stats('test', 2)
def build_dataset(index, name): def build_dataset(index, name):
from megatron.data.bert_dataset import BertDataset
from megatron.data.realm_dataset import ICTDataset from megatron.data.realm_dataset import ICTDataset
dataset = None dataset = None
if splits[index + 1] > splits[index]: if splits[index + 1] > splits[index]:
......
...@@ -5,6 +5,7 @@ import numpy as np ...@@ -5,6 +5,7 @@ import numpy as np
from torch.utils.data import Dataset from torch.utils.data import Dataset
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron.data.realm_dataset_utils import get_block_samples_mapping, join_str_list
class ICTDataset(Dataset): class ICTDataset(Dataset):
......
...@@ -79,8 +79,8 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo ...@@ -79,8 +79,8 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
# parallel case # parallel case
counts = torch.cuda.LongTensor([1]) counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
#assert counts[0].item() == torch.distributed.get_world_size( assert counts[0].item() == torch.distributed.get_world_size(
# group=mpu.get_data_parallel_group()) group=mpu.get_data_parallel_group())
# Load indexed dataset. # Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format( print_rank_0(' > loading indexed mapping from {}'.format(
......
import os import os
import numpy as np
import torch import torch
import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron.checkpointing import load_checkpoint, get_checkpoint_tracker_filename, get_checkpoint_name from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.data.realm_index import detach, BlockData, FaissMIPSIndex
from megatron.model import BertModel from megatron.model import BertModel
from megatron.model.utils import get_linear_layer, init_method_normal
from megatron.module import MegatronModule from megatron.module import MegatronModule
from megatron.utils import report_memory
from megatron import mpu from megatron import mpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment