Commit 03feecbc authored by Neel Kant's avatar Neel Kant
Browse files

One sentence options

parent 76928caa
......@@ -386,6 +386,8 @@ def _add_data_args(parser):
help='Mask loss for the end of document tokens.')
group.add_argument('--query-in-block-prob', type=float, default=0.1,
help='Probability of keeping query in block for ICT dataset')
group.add_argument('--ict-one-sent', action='store_true',
help='Whether to use one sentence documents in ICT')
return parser
......
......@@ -427,6 +427,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
block_dataset=indexed_dataset,
title_dataset=title_dataset,
query_in_block_prob=args.query_in_block_prob,
use_one_sent_docs=args.ict_one_sent,
**kwargs
)
else:
......
......@@ -5,14 +5,14 @@ import numpy as np
from torch.utils.data import Dataset
from megatron import get_tokenizer
from megatron.data.realm_dataset_utils import get_block_samples_mapping, join_str_list
from megatron.data.realm_dataset_utils import get_block_samples_mapping
class ICTDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, block_dataset, title_dataset, data_prefix,
num_epochs, max_num_samples, max_seq_length,
query_in_block_prob, short_seq_prob, seed, use_titles=True):
query_in_block_prob, short_seq_prob, seed, use_titles=True, use_one_sent_docs=False):
self.name = name
self.seed = seed
self.max_seq_length = max_seq_length
......@@ -22,10 +22,11 @@ class ICTDataset(Dataset):
self.short_seq_prob = short_seq_prob
self.rng = random.Random(self.seed)
self.use_titles = use_titles
self.use_one_sent_docs = use_one_sent_docs
self.samples_mapping = get_block_samples_mapping(
block_dataset, title_dataset, data_prefix, num_epochs,
max_num_samples, max_seq_length, seed, name)
max_num_samples, max_seq_length, seed, name, use_one_sent_docs)
self.tokenizer = get_tokenizer()
self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
self.vocab_id_to_token_list = self.tokenizer.inv_vocab
......@@ -47,7 +48,7 @@ class ICTDataset(Dataset):
title = None
title_pad_offset = 2
block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
assert len(block) > 1
assert len(block) > 1 or self.use_one_sent_docs
# randint() is inclusive for Python rng
rand_sent_idx = self.rng.randint(0, len(block) - 1)
......
......@@ -91,7 +91,7 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
print_rank_0(' > loading indexed mapping from {}'.format(
indexmap_filename))
start_time = time.time()
samples_mapping = np.load(indexmap_filename, allow_pickle=True)
samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
print_rank_0(' total number of samples: {}'.format(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment