"src/vscode:/vscode.git/clone" did not exist on "22e8b2ca2b38ff924ca1cbf82ae7e34b51d1d61a"
Commit 0f5e2809 authored by Neel Kant's avatar Neel Kant
Browse files

Resolve internal merge conflict

parent f2094783
...@@ -365,6 +365,8 @@ def _add_data_args(parser): ...@@ -365,6 +365,8 @@ def _add_data_args(parser):
'end-of-document token.') 'end-of-document token.')
group.add_argument('--eod-mask-loss', action='store_true', group.add_argument('--eod-mask-loss', action='store_true',
help='Mask loss for the end of document tokens.') help='Mask loss for the end of document tokens.')
group.add_argument('--query-in-block-prob', type=float, default=0.1,
help='Probability of keeping query in block for ICT dataset')
return parser return parser
......
...@@ -22,13 +22,12 @@ import numpy as np ...@@ -22,13 +22,12 @@ import numpy as np
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from megatron import get_tokenizer from megatron import get_tokenizer, get_args
from megatron import mpu from megatron import mpu
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron import print_rank_0 from megatron import print_rank_0
class BertDataset(Dataset): class BertDataset(Dataset):
def __init__(self, name, indexed_dataset, data_prefix, def __init__(self, name, indexed_dataset, data_prefix,
......
...@@ -22,7 +22,7 @@ import collections ...@@ -22,7 +22,7 @@ import collections
import itertools import itertools
import numpy as np import numpy as np
from megatron import print_rank_0 from megatron import print_rank_0, get_args
from megatron.data.bert_dataset import get_indexed_dataset_, get_train_valid_test_split_, BertDataset from megatron.data.bert_dataset import get_indexed_dataset_, get_train_valid_test_split_, BertDataset
DATASET_TYPES = ['standard_bert', 'ict', 'realm'] DATASET_TYPES = ['standard_bert', 'ict', 'realm']
...@@ -478,9 +478,11 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -478,9 +478,11 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
) )
if dataset_type == 'ict': if dataset_type == 'ict':
args = get_args()
dataset = ICTDataset( dataset = ICTDataset(
block_dataset=indexed_dataset, block_dataset=indexed_dataset,
title_dataset=title_dataset, title_dataset=title_dataset,
query_in_block_prob=args.query_in_block_prob,
**kwargs **kwargs
) )
elif dataset_type == 'realm': elif dataset_type == 'realm':
......
import itertools
import random
import os
import time
import numpy as np
import torch
from torch.utils.data import Dataset
from megatron import get_tokenizer
from megatron import print_rank_0
from megatron import mpu
from megatron.data import helpers
class InverseClozeDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, block_dataset, title_dataset, data_prefix,
num_epochs, max_num_samples, max_seq_length,
query_in_block_prob, short_seq_prob, seed):
self.name = name
self.seed = seed
self.max_seq_length = max_seq_length
self.query_in_block_prob = query_in_block_prob
self.block_dataset = block_dataset
self.title_dataset = title_dataset
self.short_seq_prob = short_seq_prob
self.rng = random.Random(self.seed)
self.samples_mapping = self.get_samples_mapping(
data_prefix, num_epochs, max_num_samples)
self.tokenizer = get_tokenizer()
self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
self.vocab_id_to_token_list = self.tokenizer.inv_vocab
self.cls_id = self.tokenizer.cls
self.sep_id = self.tokenizer.sep
self.mask_id = self.tokenizer.mask
self.pad_id = self.tokenizer.pad
def __len__(self):
return self.samples_mapping.shape[0]
def __getitem__(self, idx):
start_idx, end_idx, doc_idx, block_idx = self.samples_mapping[idx]
title = list(self.title_dataset[int(doc_idx)])
block = [list(self.block_dataset[i]) for i in range(start_idx, end_idx)]
assert len(block) > 1
# avoid selecting the first or last sentence to be the query.
if len(block) == 2:
rand_sent_idx = int(self.rng.random() > 0.5)
else:
rand_sent_idx = self.rng.randint(1, len(block) - 2)
# keep the query in the context 10% of the time.
if self.rng.random() < self.query_in_block_prob:
query = block[rand_sent_idx].copy()
else:
query = block.pop(rand_sent_idx)
# still need to truncate because blocks are concluded when
# the sentence lengths have exceeded max_seq_length.
query = query[:self.max_seq_length - 2]
block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
sample = {
'query_tokens': np.array(query_tokens),
'query_pad_mask': np.array(query_pad_mask),
'block_tokens': np.array(block_tokens),
'block_pad_mask': np.array(block_pad_mask),
'block_data': np.array([start_idx, end_idx, doc_idx, block_idx]).astype(np.int64)
}
return sample
def encode_text(self, text):
return self.tokenizer.tokenize(text)
def decode_tokens(self, token_ids):
tokens = self.tokenizer.tokenizer.convert_ids_to_tokens(token_ids)
return ' '.join(token for token in tokens if token != '[PAD]')
def get_block(self, start_idx, end_idx, doc_idx):
"""Get the IDs for an evidence block plus the title of the corresponding document"""
block = [list(self.block_dataset[i]) for i in range(start_idx, end_idx)]
title = list(self.title_dataset[int(doc_idx)])
block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
return (block_tokens, block_pad_mask)
def concat_and_pad_tokens(self, tokens, title=None):
"""concat with special tokens and pad sequence to self.max_seq_length"""
tokens = [self.cls_id] + tokens + [self.sep_id]
if title is not None:
tokens += title + [self.sep_id]
assert len(tokens) <= self.max_seq_length, len(tokens)
num_pad = self.max_seq_length - len(tokens)
pad_mask = [1] * len(tokens) + [0] * num_pad
tokens += [self.pad_id] * num_pad
return tokens, pad_mask
def get_samples_mapping(self, data_prefix, num_epochs, max_num_samples):
if not num_epochs:
if not max_num_samples:
raise ValueError("Need to specify either max_num_samples "
"or num_epochs")
num_epochs = np.iinfo(np.int32).max - 1
if not max_num_samples:
max_num_samples = np.iinfo(np.int64).max - 1
# Filename of the index mapping
indexmap_filename = data_prefix
indexmap_filename += '_{}_indexmap'.format(self.name)
if num_epochs != (np.iinfo(np.int32).max - 1):
indexmap_filename += '_{}ep'.format(num_epochs)
if max_num_samples != (np.iinfo(np.int64).max - 1):
indexmap_filename += '_{}mns'.format(max_num_samples)
indexmap_filename += '_{}msl'.format(self.max_seq_length)
indexmap_filename += '_{}s'.format(self.seed)
indexmap_filename += '.npy'
# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0 and \
not os.path.isfile(indexmap_filename):
print(' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'.format(indexmap_filename))
# Make sure the types match the helpers input types.
assert self.block_dataset.doc_idx.dtype == np.int64
assert self.block_dataset.sizes.dtype == np.int32
# Build samples mapping
verbose = torch.distributed.get_rank() == 0
start_time = time.time()
print_rank_0(' > building samples index mapping for {} ...'.format(
self.name))
samples_mapping = helpers.build_blocks_mapping(
self.block_dataset.doc_idx,
self.block_dataset.sizes,
self.title_dataset.sizes,
num_epochs,
max_num_samples,
self.max_seq_length-3, # account for added tokens
self.seed,
verbose)
print_rank_0(' > done building samples index mapping')
np.save(indexmap_filename, samples_mapping, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format(
indexmap_filename))
# Make sure all the ranks have built the mapping
print_rank_0(' > elapsed time to build and save samples mapping '
'(seconds): {:4f}'.format(
time.time() - start_time))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
indexmap_filename))
start_time = time.time()
samples_mapping = np.load(indexmap_filename, allow_pickle=True)
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
print_rank_0(' total number of samples: {}'.format(
samples_mapping.shape[0]))
return samples_mapping
...@@ -4,7 +4,7 @@ import random ...@@ -4,7 +4,7 @@ import random
import time import time
import numpy as np import numpy as np
import spacy # import spacy
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
...@@ -38,7 +38,7 @@ def build_simple_training_sample(sample, target_seq_length, max_seq_length, ...@@ -38,7 +38,7 @@ def build_simple_training_sample(sample, target_seq_length, max_seq_length,
return train_sample return train_sample
qa_nlp = spacy.load('en_core_web_lg') # qa_nlp = spacy.load('en_core_web_lg')
def salient_span_mask(tokens, vocab_id_list, vocab_id_to_token_dict, def salient_span_mask(tokens, vocab_id_list, vocab_id_to_token_dict,
...@@ -357,10 +357,11 @@ class ICTDataset(Dataset): ...@@ -357,10 +357,11 @@ class ICTDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task.""" """Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, block_dataset, title_dataset, data_prefix, def __init__(self, name, block_dataset, title_dataset, data_prefix,
num_epochs, max_num_samples, max_seq_length, num_epochs, max_num_samples, max_seq_length,
short_seq_prob, seed, use_titles=True): query_in_block_prob, short_seq_prob, seed, use_titles=True):
self.name = name self.name = name
self.seed = seed self.seed = seed
self.max_seq_length = max_seq_length self.max_seq_length = max_seq_length
self.query_in_block_prob = query_in_block_prob
self.block_dataset = block_dataset self.block_dataset = block_dataset
self.title_dataset = title_dataset self.title_dataset = title_dataset
self.short_seq_prob = short_seq_prob self.short_seq_prob = short_seq_prob
...@@ -394,7 +395,7 @@ class ICTDataset(Dataset): ...@@ -394,7 +395,7 @@ class ICTDataset(Dataset):
rand_sent_idx = self.rng.randint(0, len(block) - 1) rand_sent_idx = self.rng.randint(0, len(block) - 1)
# keep the query in the context 10% of the time. # keep the query in the context 10% of the time.
if self.rng.random() < 1: if self.rng.random() < self.query_in_block_prob:
query = block[rand_sent_idx].copy() query = block[rand_sent_idx].copy()
else: else:
query = block.pop(rand_sent_idx) query = block.pop(rand_sent_idx)
......
...@@ -3,7 +3,6 @@ import os ...@@ -3,7 +3,6 @@ import os
import pickle import pickle
import shutil import shutil
import faiss
import numpy as np import numpy as np
import torch import torch
...@@ -103,6 +102,7 @@ class FaissMIPSIndex(object): ...@@ -103,6 +102,7 @@ class FaissMIPSIndex(object):
return new_index return new_index
def get_block_index(self): def get_block_index(self):
import faiss
INDEX_TYPES = ['flat_l2', 'flat_ip'] INDEX_TYPES = ['flat_l2', 'flat_ip']
if self.index_type not in INDEX_TYPES: if self.index_type not in INDEX_TYPES:
raise ValueError("Invalid index type specified") raise ValueError("Invalid index type specified")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment