Commit 360885ee authored by Neel Kant's avatar Neel Kant
Browse files

Qualitative test prep works

parent 7504ef44
import numpy as np
import torch
import torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args
from megatron import mpu
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.data.bert_dataset import get_indexed_dataset_
from megatron.data.ict_dataset import InverseClozeDataset
from megatron.data.samplers import DistributedBatchSampler
from megatron.initialize import initialize_megatron
from megatron.training import get_model
from pretrain_bert_ict import model_provider
from pretrain_bert_ict import get_batch, model_provider
def main():
......@@ -17,22 +18,33 @@ def main():
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
args = get_args()
model = load_checkpoint()
model.eval()
dataset = get_dataset()
data_iter = iter(get_dataloader(dataset))
num_docs = 100
all_doc_logits = np.zeros(num_docs, 128)
for i in range(num_docs):
doc_tokens = []
doc_token_lists = dataset.get_sentence_split_doc(i)
ptr = 0
while len(doc_tokens) < args.seq_length and ptr < len(doc_token_lists):
doc_tokens.extend(doc_token_lists[ptr])
all_input_tokens = []
all_input_logits = []
all_doc_tokens = []
all_doc_logits = []
doc_tokens, doc_token_types, doc_pad_mask = dataset.concat_and_pad_tokens(doc_tokens)
doc_logits = model.embed_doc(np.array(doc_tokens), np.array(doc_pad_mask), np.array(doc_token_types))
all_doc_logits[i] = doc_logits
for i in range(100):
input_tokens, input_types, input_pad_mask, doc_tokens, doc_token_types, doc_pad_mask = get_batch(data_iter)
input_logits, doc_logits, _ = model.module.module.forward(
input_tokens, input_types, input_pad_mask, doc_tokens, doc_pad_mask, doc_token_types, return_logits=True)
print(all_doc_logits, flush=True)
all_input_tokens.append(input_tokens.detach().cpu().numpy())
all_input_logits.append(input_logits.detach().cpu().numpy())
all_doc_tokens.append(doc_tokens.detach().cpu().numpy())
all_doc_logits.append(doc_logits.detach().cpu().numpy())
all_inputs_tokens = np.array(all_input_tokens).reshape(-1, args.seq_length)
all_inputs_logits = np.array(all_input_logits).reshape(-1, 128)
all_doc_tokens = np.array(all_doc_tokens).reshape(-1, args.seq_length)
all_doc_logits = np.array(all_doc_logits).reshape(-1, 128)
np.save('input_tokens.npy', all_input_tokens)
np.save('input_logits.npy', all_input_logits)
np.save('doc_tokens.npy', all_doc_tokens)
np.save('doc_logits.npy', all_doc_logits)
def load_checkpoint():
......@@ -61,10 +73,6 @@ def load_checkpoint():
return model
def load_doc_embeds(path):
pass
def get_dataset():
args = get_args()
indexed_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
......@@ -79,12 +87,33 @@ def get_dataset():
num_epochs=None,
max_num_samples=total_num_documents,
max_seq_length=288, # doesn't matter
short_seq_prob=0.01, # doesn't matter
short_seq_prob=0.0001, # doesn't matter
seed=1
)
dataset = InverseClozeDataset(**kwargs)
return dataset
def get_dataloader(dataset):
args = get_args()
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
global_batch_size = args.batch_size * world_size
num_workers = args.num_workers
sampler = torch.utils.data.SequentialSampler(dataset)
batch_sampler = DistributedBatchSampler(sampler,
batch_size=global_batch_size,
drop_last=True,
rank=rank,
world_size=world_size)
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
if __name__ == "__main__":
main()
......@@ -29,6 +29,7 @@ class InverseClozeDataset(Dataset):
self.sep_id = tokenizer.sep
self.mask_id = tokenizer.mask
self.pad_id = tokenizer.pad
self.offset = 0
def __len__(self):
return self.indexed_dataset.doc_idx.shape[0]
......@@ -85,9 +86,10 @@ class InverseClozeDataset(Dataset):
num_tries += 1
doc = None
while doc is None:
doc = self.get_sentence_split_doc(idx)
doc = self.get_sentence_split_doc(idx + self.offset)
if not doc:
doc = None
self.offset += 1
num_sentences = len(doc)
padless_max_len = self.max_seq_length - 2
......@@ -97,6 +99,7 @@ class InverseClozeDataset(Dataset):
input_sentence_idx = rng.randint(0, num_sentences - 1)
input_tokens = doc[input_sentence_idx][:target_seq_length]
if not len(input_tokens) > 0:
self.offset += 1
continue
context_tokens = []
......@@ -127,6 +130,7 @@ class InverseClozeDataset(Dataset):
# assemble the tokens and token types of the context
context_tokens = context_tokens[:padless_max_len]
if not len(context_tokens) > 0:
self.offset += 1
continue
# concatenate 'CLS' and 'SEP' tokens and add extra token types
......
......@@ -233,7 +233,7 @@ class ICTBertModel(MegatronModule):
self._context_key = 'context_model'
def forward(self, input_tokens, input_attention_mask, input_types,
context_tokens, context_attention_mask, context_types):
context_tokens, context_attention_mask, context_types, return_logits=False):
question_ict_logits, _ = self.question_model.forward(input_tokens, 1 - input_attention_mask, input_types)
context_ict_logits, _ = self.context_model.forward(context_tokens, 1 - context_attention_mask, context_types)
......@@ -241,12 +241,11 @@ class ICTBertModel(MegatronModule):
# [batch x h] * [h x batch]
retrieval_scores = question_ict_logits.matmul(torch.transpose(context_ict_logits, 0, 1))
return retrieval_scores
if return_logits:
return question_ict_logits, context_ict_logits, retrieval_scores
def embed_doc(self, doc_tokens, doc_attention_mask, doc_types):
doc_logits, _ = self.context_model.forward(doc_tokens, 1 - doc_attention_mask, doc_types)
return retrieval_scores
return doc_logits
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment