Commit 8e22824e authored by Neel Kant's avatar Neel Kant
Browse files

Fix token alignment, add mpu checkpointing, misc training code

parent 8573ab35
...@@ -8,6 +8,8 @@ from megatron.data.realm_index import detach, BlockData, FaissMIPSIndex ...@@ -8,6 +8,8 @@ from megatron.data.realm_index import detach, BlockData, FaissMIPSIndex
from megatron.model import BertModel from megatron.model import BertModel
from megatron.model.utils import get_linear_layer, init_method_normal from megatron.model.utils import get_linear_layer, init_method_normal
from megatron.module import MegatronModule from megatron.module import MegatronModule
from megatron.utils import report_memory
from megatron import mpu
class REALMAnswerSpanModel(MegatronModule): class REALMAnswerSpanModel(MegatronModule):
...@@ -105,11 +107,11 @@ class REALMBertModel(MegatronModule): ...@@ -105,11 +107,11 @@ class REALMBertModel(MegatronModule):
# [batch_size x k x embed_size] # [batch_size x k x embed_size]
true_model = self.retriever.ict_model.module.module true_model = self.retriever.ict_model.module.module
fresh_block_logits = true_model.embed_block(topk_block_tokens, topk_block_attention_mask) fresh_block_logits = mpu.checkpoint(true_model.embed_block, topk_block_tokens, topk_block_attention_mask)
fresh_block_logits = fresh_block_logits.reshape(batch_size, self.top_k, -1) fresh_block_logits = fresh_block_logits.reshape(batch_size, self.top_k, -1)
# [batch_size x embed_size x 1] # [batch_size x embed_size x 1]
query_logits = true_model.embed_query(tokens, attention_mask).unsqueeze(2) query_logits = mpu.checkpoint(true_model.embed_query, tokens, attention_mask).unsqueeze(2)
# [batch_size x k] # [batch_size x k]
fresh_block_scores = torch.matmul(fresh_block_logits, query_logits).squeeze() fresh_block_scores = torch.matmul(fresh_block_logits, query_logits).squeeze()
...@@ -124,6 +126,22 @@ class REALMBertModel(MegatronModule): ...@@ -124,6 +126,22 @@ class REALMBertModel(MegatronModule):
all_attention_mask = torch.cat((attention_mask, topk_block_attention_mask), axis=1) all_attention_mask = torch.cat((attention_mask, topk_block_attention_mask), axis=1)
all_token_types = torch.zeros(all_tokens.shape).type(torch.int64).cuda() all_token_types = torch.zeros(all_tokens.shape).type(torch.int64).cuda()
# re-align tokens to be contiguous
query_lengths = torch.sum(attention_mask, axis=1)
block_lengths = torch.sum(topk_block_attention_mask, axis=1)
for row_num in range(all_tokens.shape[0]):
qlen = query_lengths[row_num]
blen = block_lengths[row_num]
# disregard the CLS token from the block tokens
new_tokens_length = qlen + blen - 1
all_tokens[row_num, :qlen] = tokens[row_num, :qlen]
all_tokens[row_num, qlen:new_tokens_length] = tokens[row_num, 1:blen]
all_tokens[row_num, new_tokens_length:] = self.retriever.ict_dataset.pad_id
all_attention_mask[row_num, :new_tokens_length] = 1
all_attention_mask[row_num, new_tokens_length:] = 0
# [batch_size x k x 2 * seq_length x vocab_size] # [batch_size x k x 2 * seq_length x vocab_size]
lm_logits, _ = self.lm_model.forward(all_tokens, all_attention_mask, all_token_types) lm_logits, _ = self.lm_model.forward(all_tokens, all_attention_mask, all_token_types)
lm_logits = lm_logits.reshape(batch_size, self.top_k, 2 * seq_length, -1) lm_logits = lm_logits.reshape(batch_size, self.top_k, 2 * seq_length, -1)
...@@ -163,11 +181,9 @@ class REALMRetriever(MegatronModule): ...@@ -163,11 +181,9 @@ class REALMRetriever(MegatronModule):
def reload_index(self): def reload_index(self):
args = get_args() args = get_args()
print("loading from file", flush=True)
self.block_data = BlockData.load_from_file(args.block_data_path) self.block_data = BlockData.load_from_file(args.block_data_path)
print("resetting index", flush=True) print("resetting index", flush=True)
self.hashed_index.reset_index() self.hashed_index.reset_index()
print("adding block data", flush=True)
self.hashed_index.add_block_embed_data(self.block_data) self.hashed_index.add_block_embed_data(self.block_data)
def prep_query_text_for_retrieval(self, query_text): def prep_query_text_for_retrieval(self, query_text):
...@@ -201,29 +217,29 @@ class REALMRetriever(MegatronModule): ...@@ -201,29 +217,29 @@ class REALMRetriever(MegatronModule):
true_model = self.ict_model true_model = self.ict_model
# print("true model: ", true_model, flush=True) # print("true model: ", true_model, flush=True)
query_embeds = detach(self.ict_model(query_tokens, query_pad_mask, None, None, only_query=True)) query_embeds = self.ict_model(query_tokens, query_pad_mask, None, None, only_query=True)
_, block_indices = self.hashed_index.search_mips_index(query_embeds, top_k=self.top_k, reconstruct=False) _, block_indices = self.hashed_index.search_mips_index(query_embeds, top_k=self.top_k, reconstruct=False)
all_topk_tokens, all_topk_pad_masks = [], [] all_topk_tokens, all_topk_pad_masks = [], []
# this will result in no candidate exclusion # this will result in no candidate exclusion
if query_block_indices is None: if query_block_indices is None:
query_block_indices = [-1] * len(block_indices) query_block_indices = [-1] * len(block_indices)
top_k_offset = int(include_null_doc) top_k_offset = int(include_null_doc)
for query_idx, indices in enumerate(block_indices): for query_idx, indices in enumerate(block_indices):
# [k x meta_dim] # [k x meta_dim]
# exclude trivial candidate if it appears, else just trim the weakest in the top-k # exclude trivial candidate if it appears, else just trim the weakest in the top-k
topk_metas = [self.block_data.meta_data[idx] for idx in indices if idx != query_block_indices[query_idx]] topk_metas = [self.block_data.meta_data[idx] for idx in indices if idx != query_block_indices[query_idx]]
topk_block_data = [self.ict_dataset.get_block(*block_meta) for block_meta in topk_metas[:self.top_k - top_k_offset]] topk_block_data = [self.ict_dataset.get_block(*block_meta) for block_meta in topk_metas[:self.top_k - top_k_offset]]
if include_null_doc: if include_null_doc:
topk_block_data.append(self.ict_dataset.get_null_block()) topk_block_data.append(self.ict_dataset.get_null_block())
topk_tokens, topk_pad_masks = zip(*topk_block_data) topk_tokens, topk_pad_masks = zip(*topk_block_data)
all_topk_tokens.append(np.array(topk_tokens)) all_topk_tokens.append(np.array(topk_tokens))
all_topk_pad_masks.append(np.array(topk_pad_masks)) all_topk_pad_masks.append(np.array(topk_pad_masks))
# [batch_size x k x seq_length] # [batch_size x k x seq_length]
return np.array(all_topk_tokens), np.array(all_topk_pad_masks) return np.array(all_topk_tokens), np.array(all_topk_pad_masks)
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
......
...@@ -374,12 +374,16 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -374,12 +374,16 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
timers('interval time').start() timers('interval time').start()
report_memory_flag = True report_memory_flag = True
global INDEX_READY global INDEX_READY
print('>>> Starting train()', flush=True)
# start off by posting a receive call which will be answered. # start off by posting a receive call which will be answered.
# synchronize for start
torch.distributed.broadcast(INDEX_READY, 0, group=get_gloo_comm_group())
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, group=get_gloo_comm_group(), async_op=True) recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, group=get_gloo_comm_group(), async_op=True)
last_reload_iteration = iteration
while iteration < args.train_iters: while iteration < args.train_iters:
# this only applies for realm right here # this only applies for realm right here
if args.max_training_rank is not None and recv_handle.is_completed(): if args.max_training_rank is not None and recv_handle.is_completed() and iteration >= last_reload_iteration + 500:
# should add check that INDEX_READY == 1 but what else could be happening # should add check that INDEX_READY == 1 but what else could be happening
true_model = model true_model = model
if hasattr(true_model, 'module'): if hasattr(true_model, 'module'):
...@@ -388,20 +392,23 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -388,20 +392,23 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
true_model = true_model.module true_model = true_model.module
print(">>>>>>> starting to reload index", flush=True) print("> Saving model and reloading index", flush=True)
true_model.retriever.reload_index()
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, lr_scheduler)
true_model.retriever.reload_index()
if args.rank == 0: if args.rank == 0:
INDEX_READY = 1 - INDEX_READY INDEX_READY = 1 - INDEX_READY
print(">>> Switched index ready", flush=True)
torch.cuda.synchronize() torch.cuda.synchronize()
send_handle = torch.distributed.broadcast(INDEX_READY, 0, group=get_gloo_comm_group())
# send handle
torch.distributed.broadcast(INDEX_READY, 0, group=get_gloo_comm_group())
torch.distributed.barrier(get_data_parallel_group()) torch.distributed.barrier(get_data_parallel_group())
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, group=get_gloo_comm_group(), async_op=True) recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, group=get_gloo_comm_group(), async_op=True)
elif iteration < 100: last_reload_iteration = iteration
elif iteration < 20:
print("moving right along", flush=True) print("moving right along", flush=True)
# report_memory("iteration {}".format(iteration))
loss_dict, skipped_iter = train_step(forward_step_func, loss_dict, skipped_iter = train_step(forward_step_func,
train_data_iterator, train_data_iterator,
model, model,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment