".github/vscode:/vscode.git/clone" did not exist on "c825c01920b37e45fe38e964632f6f8b1de2081c"
Commit 44860f8d authored by Neel Kant's avatar Neel Kant
Browse files

Changes after running on draco

parent a00b3c79
......@@ -349,6 +349,8 @@ def _add_data_args(parser):
help='Path to combined dataset to split.')
group.add_argument('--titles-data-path', type=str, default=None,
help='Path to titles dataset used for ICT')
group.add_argument('--block-data-path', type=str, default=None,
help='Path for loading and saving block data')
group.add_argument('--split', type=str, default='969, 30, 1',
help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
......
......@@ -6,14 +6,14 @@ import numpy as np
from torch.utils.data import Dataset
from megatron import get_tokenizer
from megatron.data.realm_dataset_utils import get_block_samples_mapping, join_str_list
from megatron.data.realm_dataset_utils import BlockSampleData, get_block_samples_mapping, join_str_list
class ICTDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, block_dataset, title_dataset, data_prefix,
num_epochs, max_num_samples, max_seq_length,
query_in_block_prob, short_seq_prob, seed, use_titles=True):
num_epochs, max_num_samples, max_seq_length, query_in_block_prob,
short_seq_prob, seed, use_titles=True, use_one_sent_docs=False):
self.name = name
self.seed = seed
self.max_seq_length = max_seq_length
......@@ -26,7 +26,7 @@ class ICTDataset(Dataset):
self.samples_mapping = get_block_samples_mapping(
block_dataset, title_dataset, data_prefix, num_epochs,
max_num_samples, max_seq_length, seed, name)
max_num_samples, max_seq_length, seed, name, use_one_sent_docs)
self.tokenizer = get_tokenizer()
self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
self.vocab_id_to_token_list = self.tokenizer.inv_vocab
......@@ -50,7 +50,7 @@ class ICTDataset(Dataset):
title = None
title_pad_offset = 2
block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
assert len(block) > 1
assert len(block) > 1 or self.query_in_block_prob == 1
# randint() is inclusive for Python rng
rand_sent_idx = self.rng.randint(0, len(block) - 1)
......
......@@ -46,10 +46,11 @@ class BlockSamplesMapping(object):
# make sure that the array is compatible with BlockSampleData
assert mapping_array.shape[1] == 4
self.mapping_array = mapping_array
self.shape = self.mapping_array.shape
def __getitem__(self, idx):
"""Get the data associated with a particular sample."""
sample_data = BlockSamplesData(*self.mapping_array[idx])
sample_data = BlockSampleData(*self.mapping_array[idx])
return sample_data
......@@ -113,10 +114,10 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
seed,
verbose,
use_one_sent_docs)
samples_mapping = BlockSamplesMapping(mapping_array)
print_rank_0(' > done building samples index mapping')
np.save(indexmap_filename, samples_mapping, allow_pickle=True)
np.save(indexmap_filename, mapping_array, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format(
indexmap_filename))
# Make sure all the ranks have built the mapping
......@@ -136,7 +137,10 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
print_rank_0(' > loading indexed mapping from {}'.format(
indexmap_filename))
start_time = time.time()
samples_mapping = np.load(indexmap_filename, allow_pickle=True)
mapping_array = np.load(indexmap_filename, allow_pickle=True)
samples_mapping = BlockSamplesMapping(mapping_array)
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
print_rank_0(' total number of samples: {}'.format(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment