Commit 7504ef44 authored by Neel Kant's avatar Neel Kant
Browse files

Add ICT qualitative test script

parent 064bdc46
......@@ -36,14 +36,13 @@ class InverseClozeDataset(Dataset):
def __getitem__(self, idx):
# get rng state corresponding to index (allows deterministic random pair)
rng = random.Random(idx + 20000 + self.seed)
np_rng = np.random.RandomState(seed=[rng.randint(0, 2**32-1) for _ in range(16)])
# get seq length. Save 2 tokens for beginning and end
target_seq_length = self.max_seq_length - 2
if rng.random() < self.short_seq_prob:
target_seq_length = rng.randint(5, target_seq_length)
input_data, context_data = self.get_input_and_context(target_seq_length, rng, np_rng)
input_data, context_data = self.get_input_and_context(idx, target_seq_length, rng)
input_tokens, input_token_types, input_pad_mask = input_data
context_tokens, context_token_types, context_pad_mask = context_data
......@@ -79,16 +78,14 @@ class InverseClozeDataset(Dataset):
token_types = [0] * self.max_seq_length
return tokens, token_types, pad_mask
def get_input_and_context(self, target_seq_length, rng, np_rng):
def get_input_and_context(self, idx, target_seq_length, rng):
"""fetches a sentence and its surrounding context"""
num_tries = 0
while num_tries < 20:
num_tries += 1
doc = None
while doc is None:
doc_idx = np_rng.randint(len(self) - 1)
# doc is a list of sentences
doc = self.get_sentence_split_doc(doc_idx)
doc = self.get_sentence_split_doc(idx)
if not doc:
doc = None
......@@ -140,5 +137,3 @@ class InverseClozeDataset(Dataset):
(context_tokens, context_token_types, context_pad_mask)
else:
raise RuntimeError("Could not get a valid data point from InverseClozeDataset")
import numpy as np
import torch
import torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args
from megatron import mpu
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.data.bert_dataset import get_indexed_dataset_
from megatron.data.ict_dataset import InverseClozeDataset
from megatron.initialize import initialize_megatron
from megatron.training import get_model
from pretrain_bert_ict import model_provider
def main():
initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
args = get_args()
model = load_checkpoint()
dataset = get_dataset()
num_docs = 100
all_doc_logits = np.zeros(num_docs, 128)
for i in range(num_docs):
doc_tokens = []
doc_token_lists = dataset.get_sentence_split_doc(i)
ptr = 0
while len(doc_tokens) < args.seq_length and ptr < len(doc_token_lists):
doc_tokens.extend(doc_token_lists[ptr])
doc_tokens, doc_token_types, doc_pad_mask = dataset.concat_and_pad_tokens(doc_tokens)
doc_logits = model.embed_doc(np.array(doc_tokens), np.array(doc_pad_mask), np.array(doc_token_types))
all_doc_logits[i] = doc_logits
print(all_doc_logits, flush=True)
def load_checkpoint():
args = get_args()
model = get_model(model_provider)
if isinstance(model, torchDDP):
model = model.module
tracker_filename = get_checkpoint_tracker_filename(args.load)
with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip())
assert iteration > 0
checkpoint_name = get_checkpoint_name(args.load, iteration, False)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
state_dict = torch.load(checkpoint_name, map_location='cpu')
model.load_state_dict(state_dict['model'])
torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
return model
def load_doc_embeds(path):
pass
def get_dataset():
args = get_args()
indexed_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
doc_idx_ptr = indexed_dataset.get_doc_idx()
total_num_documents = indexed_dataset.doc_idx.shape[0] - 1
indexed_dataset.set_doc_idx(doc_idx_ptr[0:total_num_documents])
kwargs = dict(
name='full',
indexed_dataset=indexed_dataset,
data_prefix=args.data_path,
num_epochs=None,
max_num_samples=total_num_documents,
max_seq_length=288, # doesn't matter
short_seq_prob=0.01, # doesn't matter
seed=1
)
dataset = InverseClozeDataset(**kwargs)
return dataset
if __name__ == "__main__":
main()
......@@ -15,6 +15,7 @@
"""BERT model."""
import numpy as np
import torch
from megatron import get_args
......@@ -242,6 +243,11 @@ class ICTBertModel(MegatronModule):
return retrieval_scores
def embed_doc(self, doc_tokens, doc_attention_mask, doc_types):
doc_logits, _ = self.context_model.forward(doc_tokens, 1 - doc_attention_mask, doc_types)
return doc_logits
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
state_dict_ = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment