Commit 44860f8d authored by Neel Kant's avatar Neel Kant
Browse files

Changes after running on draco

parent a00b3c79
...@@ -349,6 +349,8 @@ def _add_data_args(parser): ...@@ -349,6 +349,8 @@ def _add_data_args(parser):
help='Path to combined dataset to split.') help='Path to combined dataset to split.')
group.add_argument('--titles-data-path', type=str, default=None, group.add_argument('--titles-data-path', type=str, default=None,
help='Path to titles dataset used for ICT') help='Path to titles dataset used for ICT')
group.add_argument('--block-data-path', type=str, default=None,
help='Path for loading and saving block data')
group.add_argument('--split', type=str, default='969, 30, 1', group.add_argument('--split', type=str, default='969, 30, 1',
help='Comma-separated list of proportions for training,' help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split ' ' validation, and test split. For example the split '
......
...@@ -6,14 +6,14 @@ import numpy as np ...@@ -6,14 +6,14 @@ import numpy as np
from torch.utils.data import Dataset from torch.utils.data import Dataset
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron.data.realm_dataset_utils import get_block_samples_mapping, join_str_list from megatron.data.realm_dataset_utils import BlockSampleData, get_block_samples_mapping, join_str_list
class ICTDataset(Dataset): class ICTDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task.""" """Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, block_dataset, title_dataset, data_prefix, def __init__(self, name, block_dataset, title_dataset, data_prefix,
num_epochs, max_num_samples, max_seq_length, num_epochs, max_num_samples, max_seq_length, query_in_block_prob,
query_in_block_prob, short_seq_prob, seed, use_titles=True): short_seq_prob, seed, use_titles=True, use_one_sent_docs=False):
self.name = name self.name = name
self.seed = seed self.seed = seed
self.max_seq_length = max_seq_length self.max_seq_length = max_seq_length
...@@ -26,7 +26,7 @@ class ICTDataset(Dataset): ...@@ -26,7 +26,7 @@ class ICTDataset(Dataset):
self.samples_mapping = get_block_samples_mapping( self.samples_mapping = get_block_samples_mapping(
block_dataset, title_dataset, data_prefix, num_epochs, block_dataset, title_dataset, data_prefix, num_epochs,
max_num_samples, max_seq_length, seed, name) max_num_samples, max_seq_length, seed, name, use_one_sent_docs)
self.tokenizer = get_tokenizer() self.tokenizer = get_tokenizer()
self.vocab_id_list = list(self.tokenizer.inv_vocab.keys()) self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
self.vocab_id_to_token_list = self.tokenizer.inv_vocab self.vocab_id_to_token_list = self.tokenizer.inv_vocab
...@@ -50,7 +50,7 @@ class ICTDataset(Dataset): ...@@ -50,7 +50,7 @@ class ICTDataset(Dataset):
title = None title = None
title_pad_offset = 2 title_pad_offset = 2
block = [self.block_dataset[i] for i in range(start_idx, end_idx)] block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
assert len(block) > 1 assert len(block) > 1 or self.query_in_block_prob == 1
# randint() is inclusive for Python rng # randint() is inclusive for Python rng
rand_sent_idx = self.rng.randint(0, len(block) - 1) rand_sent_idx = self.rng.randint(0, len(block) - 1)
......
...@@ -46,10 +46,11 @@ class BlockSamplesMapping(object): ...@@ -46,10 +46,11 @@ class BlockSamplesMapping(object):
# make sure that the array is compatible with BlockSampleData # make sure that the array is compatible with BlockSampleData
assert mapping_array.shape[1] == 4 assert mapping_array.shape[1] == 4
self.mapping_array = mapping_array self.mapping_array = mapping_array
self.shape = self.mapping_array.shape
def __getitem__(self, idx): def __getitem__(self, idx):
"""Get the data associated with a particular sample.""" """Get the data associated with a particular sample."""
sample_data = BlockSamplesData(*self.mapping_array[idx]) sample_data = BlockSampleData(*self.mapping_array[idx])
return sample_data return sample_data
...@@ -113,10 +114,10 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo ...@@ -113,10 +114,10 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
seed, seed,
verbose, verbose,
use_one_sent_docs) use_one_sent_docs)
samples_mapping = BlockSamplesMapping(mapping_array)
print_rank_0(' > done building samples index mapping') print_rank_0(' > done building samples index mapping')
np.save(indexmap_filename, samples_mapping, allow_pickle=True) np.save(indexmap_filename, mapping_array, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format( print_rank_0(' > saved the index mapping in {}'.format(
indexmap_filename)) indexmap_filename))
# Make sure all the ranks have built the mapping # Make sure all the ranks have built the mapping
...@@ -136,7 +137,10 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo ...@@ -136,7 +137,10 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
print_rank_0(' > loading indexed mapping from {}'.format( print_rank_0(' > loading indexed mapping from {}'.format(
indexmap_filename)) indexmap_filename))
start_time = time.time() start_time = time.time()
samples_mapping = np.load(indexmap_filename, allow_pickle=True)
mapping_array = np.load(indexmap_filename, allow_pickle=True)
samples_mapping = BlockSamplesMapping(mapping_array)
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time)) time.time() - start_time))
print_rank_0(' total number of samples: {}'.format( print_rank_0(' total number of samples: {}'.format(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment