"scripts/deprecated/test_httpserver_decode_stream.py" did not exist on "30db99b3d98cbc4886dc3e35dce0f1658a44939c"
Commit 360885ee authored by Neel Kant's avatar Neel Kant
Browse files

Qualitative test prep works

parent 7504ef44
import numpy as np import numpy as np
import torch import torch
import torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.data.bert_dataset import get_indexed_dataset_ from megatron.data.bert_dataset import get_indexed_dataset_
from megatron.data.ict_dataset import InverseClozeDataset from megatron.data.ict_dataset import InverseClozeDataset
from megatron.data.samplers import DistributedBatchSampler
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.training import get_model from megatron.training import get_model
from pretrain_bert_ict import model_provider from pretrain_bert_ict import get_batch, model_provider
def main(): def main():
...@@ -17,22 +18,33 @@ def main(): ...@@ -17,22 +18,33 @@ def main():
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
args = get_args() args = get_args()
model = load_checkpoint() model = load_checkpoint()
model.eval()
dataset = get_dataset() dataset = get_dataset()
data_iter = iter(get_dataloader(dataset))
num_docs = 100 all_input_tokens = []
all_doc_logits = np.zeros(num_docs, 128) all_input_logits = []
for i in range(num_docs): all_doc_tokens = []
doc_tokens = [] all_doc_logits = []
doc_token_lists = dataset.get_sentence_split_doc(i)
ptr = 0
while len(doc_tokens) < args.seq_length and ptr < len(doc_token_lists):
doc_tokens.extend(doc_token_lists[ptr])
doc_tokens, doc_token_types, doc_pad_mask = dataset.concat_and_pad_tokens(doc_tokens) for i in range(100):
doc_logits = model.embed_doc(np.array(doc_tokens), np.array(doc_pad_mask), np.array(doc_token_types)) input_tokens, input_types, input_pad_mask, doc_tokens, doc_token_types, doc_pad_mask = get_batch(data_iter)
all_doc_logits[i] = doc_logits input_logits, doc_logits, _ = model.module.module.forward(
input_tokens, input_types, input_pad_mask, doc_tokens, doc_pad_mask, doc_token_types, return_logits=True)
print(all_doc_logits, flush=True) all_input_tokens.append(input_tokens.detach().cpu().numpy())
all_input_logits.append(input_logits.detach().cpu().numpy())
all_doc_tokens.append(doc_tokens.detach().cpu().numpy())
all_doc_logits.append(doc_logits.detach().cpu().numpy())
all_inputs_tokens = np.array(all_input_tokens).reshape(-1, args.seq_length)
all_inputs_logits = np.array(all_input_logits).reshape(-1, 128)
all_doc_tokens = np.array(all_doc_tokens).reshape(-1, args.seq_length)
all_doc_logits = np.array(all_doc_logits).reshape(-1, 128)
np.save('input_tokens.npy', all_input_tokens)
np.save('input_logits.npy', all_input_logits)
np.save('doc_tokens.npy', all_doc_tokens)
np.save('doc_logits.npy', all_doc_logits)
def load_checkpoint(): def load_checkpoint():
...@@ -61,10 +73,6 @@ def load_checkpoint(): ...@@ -61,10 +73,6 @@ def load_checkpoint():
return model return model
def load_doc_embeds(path):
pass
def get_dataset(): def get_dataset():
args = get_args() args = get_args()
indexed_dataset = get_indexed_dataset_(args.data_path, 'mmap', True) indexed_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
...@@ -79,12 +87,33 @@ def get_dataset(): ...@@ -79,12 +87,33 @@ def get_dataset():
num_epochs=None, num_epochs=None,
max_num_samples=total_num_documents, max_num_samples=total_num_documents,
max_seq_length=288, # doesn't matter max_seq_length=288, # doesn't matter
short_seq_prob=0.01, # doesn't matter short_seq_prob=0.0001, # doesn't matter
seed=1 seed=1
) )
dataset = InverseClozeDataset(**kwargs) dataset = InverseClozeDataset(**kwargs)
return dataset return dataset
def get_dataloader(dataset):
args = get_args()
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
global_batch_size = args.batch_size * world_size
num_workers = args.num_workers
sampler = torch.utils.data.SequentialSampler(dataset)
batch_sampler = DistributedBatchSampler(sampler,
batch_size=global_batch_size,
drop_last=True,
rank=rank,
world_size=world_size)
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -29,6 +29,7 @@ class InverseClozeDataset(Dataset): ...@@ -29,6 +29,7 @@ class InverseClozeDataset(Dataset):
self.sep_id = tokenizer.sep self.sep_id = tokenizer.sep
self.mask_id = tokenizer.mask self.mask_id = tokenizer.mask
self.pad_id = tokenizer.pad self.pad_id = tokenizer.pad
self.offset = 0
def __len__(self): def __len__(self):
return self.indexed_dataset.doc_idx.shape[0] return self.indexed_dataset.doc_idx.shape[0]
...@@ -85,9 +86,10 @@ class InverseClozeDataset(Dataset): ...@@ -85,9 +86,10 @@ class InverseClozeDataset(Dataset):
num_tries += 1 num_tries += 1
doc = None doc = None
while doc is None: while doc is None:
doc = self.get_sentence_split_doc(idx) doc = self.get_sentence_split_doc(idx + self.offset)
if not doc: if not doc:
doc = None doc = None
self.offset += 1
num_sentences = len(doc) num_sentences = len(doc)
padless_max_len = self.max_seq_length - 2 padless_max_len = self.max_seq_length - 2
...@@ -97,6 +99,7 @@ class InverseClozeDataset(Dataset): ...@@ -97,6 +99,7 @@ class InverseClozeDataset(Dataset):
input_sentence_idx = rng.randint(0, num_sentences - 1) input_sentence_idx = rng.randint(0, num_sentences - 1)
input_tokens = doc[input_sentence_idx][:target_seq_length] input_tokens = doc[input_sentence_idx][:target_seq_length]
if not len(input_tokens) > 0: if not len(input_tokens) > 0:
self.offset += 1
continue continue
context_tokens = [] context_tokens = []
...@@ -127,6 +130,7 @@ class InverseClozeDataset(Dataset): ...@@ -127,6 +130,7 @@ class InverseClozeDataset(Dataset):
# assemble the tokens and token types of the context # assemble the tokens and token types of the context
context_tokens = context_tokens[:padless_max_len] context_tokens = context_tokens[:padless_max_len]
if not len(context_tokens) > 0: if not len(context_tokens) > 0:
self.offset += 1
continue continue
# concatenate 'CLS' and 'SEP' tokens and add extra token types # concatenate 'CLS' and 'SEP' tokens and add extra token types
......
...@@ -233,7 +233,7 @@ class ICTBertModel(MegatronModule): ...@@ -233,7 +233,7 @@ class ICTBertModel(MegatronModule):
self._context_key = 'context_model' self._context_key = 'context_model'
def forward(self, input_tokens, input_attention_mask, input_types, def forward(self, input_tokens, input_attention_mask, input_types,
context_tokens, context_attention_mask, context_types): context_tokens, context_attention_mask, context_types, return_logits=False):
question_ict_logits, _ = self.question_model.forward(input_tokens, 1 - input_attention_mask, input_types) question_ict_logits, _ = self.question_model.forward(input_tokens, 1 - input_attention_mask, input_types)
context_ict_logits, _ = self.context_model.forward(context_tokens, 1 - context_attention_mask, context_types) context_ict_logits, _ = self.context_model.forward(context_tokens, 1 - context_attention_mask, context_types)
...@@ -241,12 +241,11 @@ class ICTBertModel(MegatronModule): ...@@ -241,12 +241,11 @@ class ICTBertModel(MegatronModule):
# [batch x h] * [h x batch] # [batch x h] * [h x batch]
retrieval_scores = question_ict_logits.matmul(torch.transpose(context_ict_logits, 0, 1)) retrieval_scores = question_ict_logits.matmul(torch.transpose(context_ict_logits, 0, 1))
return retrieval_scores if return_logits:
return question_ict_logits, context_ict_logits, retrieval_scores
def embed_doc(self, doc_tokens, doc_attention_mask, doc_types): return retrieval_scores
doc_logits, _ = self.context_model.forward(doc_tokens, 1 - doc_attention_mask, doc_types)
return doc_logits
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment