Commit 2fd4ea6c authored by Neel Kant's avatar Neel Kant
Browse files

Corrected realm example building, misc improvements for async concurrency

parent 8e22824e
......@@ -95,6 +95,7 @@ def setup_realm_groups_and_vars():
class IndexBuilder(object):
def __init__(self):
args = get_args()
self.debug = args.debug
self.rank = args.rank
self.model = None
self.dataloader = None
......@@ -287,6 +288,6 @@ def get_one_epoch_dataloader(dataset, batch_size=None):
if __name__ == "__main__":
initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
index_builder = BasicIndexBuilder()
index_builder = IndexBuilder()
index_builder.build_and_save_index()
......@@ -5,7 +5,7 @@ import numpy as np
from torch.utils.data import Dataset
from megatron import get_tokenizer
from megatron.data.realm_dataset_utils import build_realm_training_sample, get_block_samples_mapping
from megatron.data.realm_dataset_utils import build_realm_training_sample, get_block_samples_mapping, join_str_list
class REALMDataset(Dataset):
......@@ -136,7 +136,8 @@ class ICTDataset(Dataset):
def decode_tokens(self, token_ids):
tokens = self.tokenizer.tokenizer.convert_ids_to_tokens(token_ids)
return ' '.join(token for token in tokens if token != '[PAD]')
non_pads = [t for t in tokens if t != '[PAD]']
return join_str_list(non_pads)
def get_block(self, start_idx, end_idx, doc_idx):
"""Get the IDs for an evidence block plus the title of the corresponding document"""
......
......@@ -94,51 +94,96 @@ class REALMBertModel(MegatronModule):
self._retriever_key = 'retriever'
def forward(self, tokens, attention_mask, query_block_indices, return_topk_block_tokens=False):
# print("\nNEW FORWARD", '-' * 100, flush=True)
dset = self.retriever.ict_dataset
det_tokens = detach(tokens)[0].tolist()
det_attention = detach(attention_mask)[0].tolist()
# print("\nTokens: ", det_tokens, '\n', flush=True)
# print("\nAttention: ", det_attention, '\n', flush=True)
# print("pad id: ", dset.pad_id, flush=True)
assert bool(0 in det_attention) == bool(dset.pad_id in det_tokens)
if 0 in det_attention:
idx_padid = det_tokens.index(dset.pad_id)
idx_attn = det_attention.index(0)
assert idx_padid == idx_attn, (idx_padid, idx_attn)
# text = dset.decode_tokens(det_tokens)
# print(text, flush=True)
# print("Token shape: ", tokens.shape, flush=True)
# [batch_size x k x seq_length]
topk_block_tokens, topk_block_attention_mask = self.retriever.retrieve_evidence_blocks(
tokens, attention_mask, query_block_indices=query_block_indices, include_null_doc=True)
# print("Top k block shape: ", topk_block_tokens.shape, flush=True)
batch_size = tokens.shape[0]
# create a copy in case it needs to be returned
ret_topk_block_tokens = np.array(topk_block_tokens)
seq_length = topk_block_tokens.shape[2]
topk_block_tokens = torch.cuda.LongTensor(topk_block_tokens).reshape(-1, seq_length)
topk_block_attention_mask = torch.cuda.LongTensor(topk_block_attention_mask).reshape(-1, seq_length)
long_tensor = torch.cuda.LongTensor
topk_block_tokens = long_tensor(topk_block_tokens).reshape(-1, seq_length)
topk_block_attention_mask = long_tensor(topk_block_attention_mask).reshape(-1, seq_length)
# print('Block token shape: ', topk_block_tokens.shape, flush=True)
# [batch_size x k x embed_size]
true_model = self.retriever.ict_model.module.module
fresh_block_logits = mpu.checkpoint(true_model.embed_block, topk_block_tokens, topk_block_attention_mask)
fresh_block_logits = fresh_block_logits.reshape(batch_size, self.top_k, -1)
# print('Fresh block logits shape: ', fresh_block_logits.shape, flush=True)
# [batch_size x embed_size x 1]
query_logits = mpu.checkpoint(true_model.embed_query, tokens, attention_mask).unsqueeze(2)
# print('Query logits shape: ', query_logits.shape, flush=True)
# [batch_size x k]
fresh_block_scores = torch.matmul(fresh_block_logits, query_logits).squeeze()
# print('Block score shape: ', fresh_block_scores.shape, flush=True)
block_probs = F.softmax(fresh_block_scores, dim=1)
# [batch_size * k x seq_length]
tokens = torch.stack([tokens.unsqueeze(1)] * self.top_k, dim=1).reshape(-1, seq_length)
#assert all(tokens[i] == tokens[0] for i in range(self.top_k))
#assert all(tokens[i] == tokens[self.top_k] for i in range(self.top_k, 2 * self.top_k))
#assert not any(tokens[i] == tokens[0] for i in range(self.top_k, batch_size * self.top_k))
attention_mask = torch.stack([attention_mask.unsqueeze(1)] * self.top_k, dim=1).reshape(-1, seq_length)
# [batch_size * k x 2 * seq_length]
all_tokens = torch.cat((tokens, topk_block_tokens), axis=1)
all_attention_mask = torch.cat((attention_mask, topk_block_attention_mask), axis=1)
all_token_types = torch.zeros(all_tokens.shape).type(torch.int64).cuda()
lm_input_batch_shape = (batch_size * self.top_k, 2 * seq_length)
all_tokens = torch.zeros(lm_input_batch_shape).long().cuda()
all_attention_mask = all_tokens.clone()
all_token_types = all_tokens.clone()
#all_tokens = torch.cat((tokens, topk_block_tokens), axis=1)
#all_attention_mask = torch.cat((attention_mask, topk_block_attention_mask), axis=1)
#all_token_types = torch.zeros(all_tokens.shape).type(torch.int64).cuda()
# re-align tokens to be contiguous
query_lengths = torch.sum(attention_mask, axis=1)
block_lengths = torch.sum(topk_block_attention_mask, axis=1)
for row_num in range(all_tokens.shape[0]):
qlen = query_lengths[row_num]
blen = block_lengths[row_num]
# disregard the CLS token from the block tokens
new_tokens_length = qlen + blen - 1
# all blocks (including null ones) will have two SEP tokens
block_sep_indices = (topk_block_tokens == dset.sep_id).nonzero().reshape(batch_size * self.top_k, 2, 2)
# block body starts after the first SEP
block_starts = block_sep_indices[:, 0, 1] + 1
# block body ends after the second SEP
block_ends = block_sep_indices[:, 1, 1] + 1
all_tokens[row_num, :qlen] = tokens[row_num, :qlen]
all_tokens[row_num, qlen:new_tokens_length] = tokens[row_num, 1:blen]
# block_lengths = torch.sum(topk_block_attention_mask, axis=1)
for row_num in range(all_tokens.shape[0]):
q_len = query_lengths[row_num]
b_start = block_starts[row_num]
b_end = block_ends[row_num]
# new tokens = CLS + query + SEP + block + SEP
new_tokens_length = q_len + b_end - b_start
# splice query and block tokens accordingly
all_tokens[row_num, :q_len] = tokens[row_num, :q_len]
all_tokens[row_num, q_len:new_tokens_length] = topk_block_tokens[row_num, b_start:b_end]
all_tokens[row_num, new_tokens_length:] = self.retriever.ict_dataset.pad_id
# print(dset.decode_tokens(detach(all_tokens[row_num]).tolist()), '\n', flush=True)
all_attention_mask[row_num, :new_tokens_length] = 1
all_attention_mask[row_num, new_tokens_length:] = 0
......
......@@ -120,7 +120,6 @@ def set_data_parallel_group(group):
global _DATA_PARALLEL_GROUP
assert _DATA_PARALLEL_GROUP is None, \
'data parallel group has already been initialized'
print(">>> setting data parallel group: ", group, flush=True)
_DATA_PARALLEL_GROUP = group
......
......@@ -18,6 +18,7 @@
from datetime import datetime
import math
import sys
import time
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
......@@ -381,8 +382,12 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, group=get_gloo_comm_group(), async_op=True)
last_reload_iteration = iteration
while iteration < args.train_iters:
if iteration >= last_reload_iteration + 500 and not recv_handle.is_completed():
time.sleep(5)
continue
# this only applies for realm right here
if args.max_training_rank is not None and recv_handle.is_completed() and iteration >= last_reload_iteration + 500:
if args.max_training_rank is not None and recv_handle.is_completed():
# should add check that INDEX_READY == 1 but what else could be happening
true_model = model
......@@ -393,7 +398,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
print("> Saving model and reloading index", flush=True)
save_checkpoint(iteration, model, optimizer, lr_scheduler)
if args.rank == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
true_model.retriever.reload_index()
if args.rank == 0:
......
......@@ -49,7 +49,7 @@ def model_provider():
hashed_index.add_block_embed_data(all_block_data)
# top_k + 1 because we may need to exclude trivial candidate
retriever = REALMRetriever(ict_model, ict_dataset, all_block_data, hashed_index, args.block_top_k + 1)
retriever = REALMRetriever(ict_model, ict_dataset, all_block_data, hashed_index, args.block_top_k)
model = REALMBertModel(retriever)
return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment