Commit 2fd4ea6c authored by Neel Kant's avatar Neel Kant
Browse files

Corrected realm example building, misc improvements for async concurrency

parent 8e22824e
...@@ -95,6 +95,7 @@ def setup_realm_groups_and_vars(): ...@@ -95,6 +95,7 @@ def setup_realm_groups_and_vars():
class IndexBuilder(object): class IndexBuilder(object):
def __init__(self): def __init__(self):
args = get_args() args = get_args()
self.debug = args.debug
self.rank = args.rank self.rank = args.rank
self.model = None self.model = None
self.dataloader = None self.dataloader = None
...@@ -287,6 +288,6 @@ def get_one_epoch_dataloader(dataset, batch_size=None): ...@@ -287,6 +288,6 @@ def get_one_epoch_dataloader(dataset, batch_size=None):
if __name__ == "__main__": if __name__ == "__main__":
initialize_megatron(extra_args_provider=None, initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
index_builder = BasicIndexBuilder() index_builder = IndexBuilder()
index_builder.build_and_save_index() index_builder.build_and_save_index()
...@@ -5,7 +5,7 @@ import numpy as np ...@@ -5,7 +5,7 @@ import numpy as np
from torch.utils.data import Dataset from torch.utils.data import Dataset
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron.data.realm_dataset_utils import build_realm_training_sample, get_block_samples_mapping from megatron.data.realm_dataset_utils import build_realm_training_sample, get_block_samples_mapping, join_str_list
class REALMDataset(Dataset): class REALMDataset(Dataset):
...@@ -136,7 +136,8 @@ class ICTDataset(Dataset): ...@@ -136,7 +136,8 @@ class ICTDataset(Dataset):
def decode_tokens(self, token_ids): def decode_tokens(self, token_ids):
tokens = self.tokenizer.tokenizer.convert_ids_to_tokens(token_ids) tokens = self.tokenizer.tokenizer.convert_ids_to_tokens(token_ids)
return ' '.join(token for token in tokens if token != '[PAD]') non_pads = [t for t in tokens if t != '[PAD]']
return join_str_list(non_pads)
def get_block(self, start_idx, end_idx, doc_idx): def get_block(self, start_idx, end_idx, doc_idx):
"""Get the IDs for an evidence block plus the title of the corresponding document""" """Get the IDs for an evidence block plus the title of the corresponding document"""
......
...@@ -94,51 +94,96 @@ class REALMBertModel(MegatronModule): ...@@ -94,51 +94,96 @@ class REALMBertModel(MegatronModule):
self._retriever_key = 'retriever' self._retriever_key = 'retriever'
def forward(self, tokens, attention_mask, query_block_indices, return_topk_block_tokens=False): def forward(self, tokens, attention_mask, query_block_indices, return_topk_block_tokens=False):
# print("\nNEW FORWARD", '-' * 100, flush=True)
dset = self.retriever.ict_dataset
det_tokens = detach(tokens)[0].tolist()
det_attention = detach(attention_mask)[0].tolist()
# print("\nTokens: ", det_tokens, '\n', flush=True)
# print("\nAttention: ", det_attention, '\n', flush=True)
# print("pad id: ", dset.pad_id, flush=True)
assert bool(0 in det_attention) == bool(dset.pad_id in det_tokens)
if 0 in det_attention:
idx_padid = det_tokens.index(dset.pad_id)
idx_attn = det_attention.index(0)
assert idx_padid == idx_attn, (idx_padid, idx_attn)
# text = dset.decode_tokens(det_tokens)
# print(text, flush=True)
# print("Token shape: ", tokens.shape, flush=True)
# [batch_size x k x seq_length] # [batch_size x k x seq_length]
topk_block_tokens, topk_block_attention_mask = self.retriever.retrieve_evidence_blocks( topk_block_tokens, topk_block_attention_mask = self.retriever.retrieve_evidence_blocks(
tokens, attention_mask, query_block_indices=query_block_indices, include_null_doc=True) tokens, attention_mask, query_block_indices=query_block_indices, include_null_doc=True)
# print("Top k block shape: ", topk_block_tokens.shape, flush=True)
batch_size = tokens.shape[0] batch_size = tokens.shape[0]
# create a copy in case it needs to be returned # create a copy in case it needs to be returned
ret_topk_block_tokens = np.array(topk_block_tokens) ret_topk_block_tokens = np.array(topk_block_tokens)
seq_length = topk_block_tokens.shape[2] seq_length = topk_block_tokens.shape[2]
topk_block_tokens = torch.cuda.LongTensor(topk_block_tokens).reshape(-1, seq_length) long_tensor = torch.cuda.LongTensor
topk_block_attention_mask = torch.cuda.LongTensor(topk_block_attention_mask).reshape(-1, seq_length) topk_block_tokens = long_tensor(topk_block_tokens).reshape(-1, seq_length)
topk_block_attention_mask = long_tensor(topk_block_attention_mask).reshape(-1, seq_length)
# print('Block token shape: ', topk_block_tokens.shape, flush=True)
# [batch_size x k x embed_size] # [batch_size x k x embed_size]
true_model = self.retriever.ict_model.module.module true_model = self.retriever.ict_model.module.module
fresh_block_logits = mpu.checkpoint(true_model.embed_block, topk_block_tokens, topk_block_attention_mask) fresh_block_logits = mpu.checkpoint(true_model.embed_block, topk_block_tokens, topk_block_attention_mask)
fresh_block_logits = fresh_block_logits.reshape(batch_size, self.top_k, -1) fresh_block_logits = fresh_block_logits.reshape(batch_size, self.top_k, -1)
# print('Fresh block logits shape: ', fresh_block_logits.shape, flush=True)
# [batch_size x embed_size x 1] # [batch_size x embed_size x 1]
query_logits = mpu.checkpoint(true_model.embed_query, tokens, attention_mask).unsqueeze(2) query_logits = mpu.checkpoint(true_model.embed_query, tokens, attention_mask).unsqueeze(2)
# print('Query logits shape: ', query_logits.shape, flush=True)
# [batch_size x k] # [batch_size x k]
fresh_block_scores = torch.matmul(fresh_block_logits, query_logits).squeeze() fresh_block_scores = torch.matmul(fresh_block_logits, query_logits).squeeze()
# print('Block score shape: ', fresh_block_scores.shape, flush=True)
block_probs = F.softmax(fresh_block_scores, dim=1) block_probs = F.softmax(fresh_block_scores, dim=1)
# [batch_size * k x seq_length] # [batch_size * k x seq_length]
tokens = torch.stack([tokens.unsqueeze(1)] * self.top_k, dim=1).reshape(-1, seq_length) tokens = torch.stack([tokens.unsqueeze(1)] * self.top_k, dim=1).reshape(-1, seq_length)
#assert all(tokens[i] == tokens[0] for i in range(self.top_k))
#assert all(tokens[i] == tokens[self.top_k] for i in range(self.top_k, 2 * self.top_k))
#assert not any(tokens[i] == tokens[0] for i in range(self.top_k, batch_size * self.top_k))
attention_mask = torch.stack([attention_mask.unsqueeze(1)] * self.top_k, dim=1).reshape(-1, seq_length) attention_mask = torch.stack([attention_mask.unsqueeze(1)] * self.top_k, dim=1).reshape(-1, seq_length)
# [batch_size * k x 2 * seq_length] # [batch_size * k x 2 * seq_length]
all_tokens = torch.cat((tokens, topk_block_tokens), axis=1) lm_input_batch_shape = (batch_size * self.top_k, 2 * seq_length)
all_attention_mask = torch.cat((attention_mask, topk_block_attention_mask), axis=1) all_tokens = torch.zeros(lm_input_batch_shape).long().cuda()
all_token_types = torch.zeros(all_tokens.shape).type(torch.int64).cuda() all_attention_mask = all_tokens.clone()
all_token_types = all_tokens.clone()
#all_tokens = torch.cat((tokens, topk_block_tokens), axis=1)
#all_attention_mask = torch.cat((attention_mask, topk_block_attention_mask), axis=1)
#all_token_types = torch.zeros(all_tokens.shape).type(torch.int64).cuda()
# re-align tokens to be contiguous
query_lengths = torch.sum(attention_mask, axis=1) query_lengths = torch.sum(attention_mask, axis=1)
block_lengths = torch.sum(topk_block_attention_mask, axis=1) # all blocks (including null ones) will have two SEP tokens
for row_num in range(all_tokens.shape[0]): block_sep_indices = (topk_block_tokens == dset.sep_id).nonzero().reshape(batch_size * self.top_k, 2, 2)
qlen = query_lengths[row_num]
blen = block_lengths[row_num] # block body starts after the first SEP
# disregard the CLS token from the block tokens block_starts = block_sep_indices[:, 0, 1] + 1
new_tokens_length = qlen + blen - 1 # block body ends after the second SEP
block_ends = block_sep_indices[:, 1, 1] + 1
all_tokens[row_num, :qlen] = tokens[row_num, :qlen] # block_lengths = torch.sum(topk_block_attention_mask, axis=1)
all_tokens[row_num, qlen:new_tokens_length] = tokens[row_num, 1:blen] for row_num in range(all_tokens.shape[0]):
q_len = query_lengths[row_num]
b_start = block_starts[row_num]
b_end = block_ends[row_num]
# new tokens = CLS + query + SEP + block + SEP
new_tokens_length = q_len + b_end - b_start
# splice query and block tokens accordingly
all_tokens[row_num, :q_len] = tokens[row_num, :q_len]
all_tokens[row_num, q_len:new_tokens_length] = topk_block_tokens[row_num, b_start:b_end]
all_tokens[row_num, new_tokens_length:] = self.retriever.ict_dataset.pad_id all_tokens[row_num, new_tokens_length:] = self.retriever.ict_dataset.pad_id
# print(dset.decode_tokens(detach(all_tokens[row_num]).tolist()), '\n', flush=True)
all_attention_mask[row_num, :new_tokens_length] = 1 all_attention_mask[row_num, :new_tokens_length] = 1
all_attention_mask[row_num, new_tokens_length:] = 0 all_attention_mask[row_num, new_tokens_length:] = 0
......
...@@ -120,7 +120,6 @@ def set_data_parallel_group(group): ...@@ -120,7 +120,6 @@ def set_data_parallel_group(group):
global _DATA_PARALLEL_GROUP global _DATA_PARALLEL_GROUP
assert _DATA_PARALLEL_GROUP is None, \ assert _DATA_PARALLEL_GROUP is None, \
'data parallel group has already been initialized' 'data parallel group has already been initialized'
print(">>> setting data parallel group: ", group, flush=True)
_DATA_PARALLEL_GROUP = group _DATA_PARALLEL_GROUP = group
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
from datetime import datetime from datetime import datetime
import math import math
import sys import sys
import time
import torch import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
...@@ -381,8 +382,12 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -381,8 +382,12 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, group=get_gloo_comm_group(), async_op=True) recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, group=get_gloo_comm_group(), async_op=True)
last_reload_iteration = iteration last_reload_iteration = iteration
while iteration < args.train_iters: while iteration < args.train_iters:
if iteration >= last_reload_iteration + 500 and not recv_handle.is_completed():
time.sleep(5)
continue
# this only applies for realm right here # this only applies for realm right here
if args.max_training_rank is not None and recv_handle.is_completed() and iteration >= last_reload_iteration + 500: if args.max_training_rank is not None and recv_handle.is_completed():
# should add check that INDEX_READY == 1 but what else could be happening # should add check that INDEX_READY == 1 but what else could be happening
true_model = model true_model = model
...@@ -393,7 +398,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -393,7 +398,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
print("> Saving model and reloading index", flush=True) print("> Saving model and reloading index", flush=True)
save_checkpoint(iteration, model, optimizer, lr_scheduler) if args.rank == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
true_model.retriever.reload_index() true_model.retriever.reload_index()
if args.rank == 0: if args.rank == 0:
......
...@@ -49,7 +49,7 @@ def model_provider(): ...@@ -49,7 +49,7 @@ def model_provider():
hashed_index.add_block_embed_data(all_block_data) hashed_index.add_block_embed_data(all_block_data)
# top_k + 1 because we may need to exclude trivial candidate # top_k + 1 because we may need to exclude trivial candidate
retriever = REALMRetriever(ict_model, ict_dataset, all_block_data, hashed_index, args.block_top_k + 1) retriever = REALMRetriever(ict_model, ict_dataset, all_block_data, hashed_index, args.block_top_k)
model = REALMBertModel(retriever) model = REALMBertModel(retriever)
return model return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment