"docs/git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "6c053a6cea38ad323047d92da99d38e740f19845"
Commit 03feecbc authored by Neel Kant's avatar Neel Kant
Browse files

One sentence options

parent 76928caa
...@@ -386,6 +386,8 @@ def _add_data_args(parser): ...@@ -386,6 +386,8 @@ def _add_data_args(parser):
help='Mask loss for the end of document tokens.') help='Mask loss for the end of document tokens.')
group.add_argument('--query-in-block-prob', type=float, default=0.1, group.add_argument('--query-in-block-prob', type=float, default=0.1,
help='Probability of keeping query in block for ICT dataset') help='Probability of keeping query in block for ICT dataset')
group.add_argument('--ict-one-sent', action='store_true',
help='Whether to use one sentence documents in ICT')
return parser return parser
......
...@@ -427,6 +427,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -427,6 +427,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
block_dataset=indexed_dataset, block_dataset=indexed_dataset,
title_dataset=title_dataset, title_dataset=title_dataset,
query_in_block_prob=args.query_in_block_prob, query_in_block_prob=args.query_in_block_prob,
use_one_sent_docs=args.ict_one_sent,
**kwargs **kwargs
) )
else: else:
......
...@@ -5,14 +5,14 @@ import numpy as np ...@@ -5,14 +5,14 @@ import numpy as np
from torch.utils.data import Dataset from torch.utils.data import Dataset
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron.data.realm_dataset_utils import get_block_samples_mapping, join_str_list from megatron.data.realm_dataset_utils import get_block_samples_mapping
class ICTDataset(Dataset): class ICTDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task.""" """Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, block_dataset, title_dataset, data_prefix, def __init__(self, name, block_dataset, title_dataset, data_prefix,
num_epochs, max_num_samples, max_seq_length, num_epochs, max_num_samples, max_seq_length,
query_in_block_prob, short_seq_prob, seed, use_titles=True): query_in_block_prob, short_seq_prob, seed, use_titles=True, use_one_sent_docs=False):
self.name = name self.name = name
self.seed = seed self.seed = seed
self.max_seq_length = max_seq_length self.max_seq_length = max_seq_length
...@@ -22,10 +22,11 @@ class ICTDataset(Dataset): ...@@ -22,10 +22,11 @@ class ICTDataset(Dataset):
self.short_seq_prob = short_seq_prob self.short_seq_prob = short_seq_prob
self.rng = random.Random(self.seed) self.rng = random.Random(self.seed)
self.use_titles = use_titles self.use_titles = use_titles
self.use_one_sent_docs = use_one_sent_docs
self.samples_mapping = get_block_samples_mapping( self.samples_mapping = get_block_samples_mapping(
block_dataset, title_dataset, data_prefix, num_epochs, block_dataset, title_dataset, data_prefix, num_epochs,
max_num_samples, max_seq_length, seed, name) max_num_samples, max_seq_length, seed, name, use_one_sent_docs)
self.tokenizer = get_tokenizer() self.tokenizer = get_tokenizer()
self.vocab_id_list = list(self.tokenizer.inv_vocab.keys()) self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
self.vocab_id_to_token_list = self.tokenizer.inv_vocab self.vocab_id_to_token_list = self.tokenizer.inv_vocab
...@@ -47,7 +48,7 @@ class ICTDataset(Dataset): ...@@ -47,7 +48,7 @@ class ICTDataset(Dataset):
title = None title = None
title_pad_offset = 2 title_pad_offset = 2
block = [self.block_dataset[i] for i in range(start_idx, end_idx)] block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
assert len(block) > 1 assert len(block) > 1 or self.use_one_sent_docs
# randint() is inclusive for Python rng # randint() is inclusive for Python rng
rand_sent_idx = self.rng.randint(0, len(block) - 1) rand_sent_idx = self.rng.randint(0, len(block) - 1)
......
...@@ -91,7 +91,7 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo ...@@ -91,7 +91,7 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
print_rank_0(' > loading indexed mapping from {}'.format( print_rank_0(' > loading indexed mapping from {}'.format(
indexmap_filename)) indexmap_filename))
start_time = time.time() start_time = time.time()
samples_mapping = np.load(indexmap_filename, allow_pickle=True) samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time)) time.time() - start_time))
print_rank_0(' total number of samples: {}'.format( print_rank_0(' total number of samples: {}'.format(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment