Commit 51204a4d authored by Neel Kant's avatar Neel Kant
Browse files

Misc changes

parent dfb907fe
...@@ -167,10 +167,6 @@ class AsyncIndexBuilder(IndexBuilder): ...@@ -167,10 +167,6 @@ class AsyncIndexBuilder(IndexBuilder):
print("Starting (again!)", flush=True) print("Starting (again!)", flush=True)
self.build_and_save_index() self.build_and_save_index()
self.send_index_ready_signal() self.send_index_ready_signal()
while INDEX_READY == 1:
print("Waiting for new model checkpoint.", flush=True)
time.sleep(5)
self.load_attributes() self.load_attributes()
def load_attributes(self): def load_attributes(self):
...@@ -195,7 +191,6 @@ class AsyncIndexBuilder(IndexBuilder): ...@@ -195,7 +191,6 @@ class AsyncIndexBuilder(IndexBuilder):
# recv handle # recv handle
dist.broadcast(INDEX_READY, 0, group=get_gloo_comm_group()) dist.broadcast(INDEX_READY, 0, group=get_gloo_comm_group())
torch.distributed.barrier(get_data_parallel_group())
def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=False, from_realm_chkpt=False): def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=False, from_realm_chkpt=False):
......
...@@ -96,7 +96,8 @@ def salient_span_mask(tokens, mask_id): ...@@ -96,7 +96,8 @@ def salient_span_mask(tokens, mask_id):
# need to get all named entities # need to get all named entities
entities = SPACY_NER(tokens_str).ents entities = SPACY_NER(tokens_str).ents
entities = [e for e in entities if e.text != "CLS"] undesired_types = ['CARDINAL', 'TIME', 'PERCENT', 'MONEY', 'QUANTITY', 'ORDINAL']
entities = [e for e in entities if e.text != "CLS" and e.label_ not in undesired_types]
if len(entities) == 0: if len(entities) == 0:
return None return None
entity_idx = np.random.randint(0, len(entities)) entity_idx = np.random.randint(0, len(entities))
......
...@@ -29,7 +29,7 @@ class BlockData(object): ...@@ -29,7 +29,7 @@ class BlockData(object):
def clear(self): def clear(self):
"""Clear the data structures to save memory""" """Clear the data structures to save memory"""
self.embed_data = dict() self.embed_data = dict()
self.meta_data = dict() # self.meta_data = dict()
@classmethod @classmethod
def load_from_file(cls, fname): def load_from_file(cls, fname):
...@@ -100,7 +100,7 @@ class FaissMIPSIndex(object): ...@@ -100,7 +100,7 @@ class FaissMIPSIndex(object):
self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT) self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT)
if not self.use_gpu: if not self.use_gpu:
self.block_mips_index = faiss.IndexIDMap(self.block_mips_index) self.block_mips_index = faiss.IndexIDMap(self.block_mips_index)
print(">> Finished building index", flush=True) print(">> Finished building index\n", flush=True)
if self.use_gpu: if self.use_gpu:
res = faiss.StandardGpuResources() res = faiss.StandardGpuResources()
...@@ -109,9 +109,10 @@ class FaissMIPSIndex(object): ...@@ -109,9 +109,10 @@ class FaissMIPSIndex(object):
config.device = torch.cuda.current_device() config.device = torch.cuda.current_device()
config.useFloat16 = True config.useFloat16 = True
self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config) self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config)
print(">>> Loaded Faiss index on GPU {}\n".format(self.block_mips_index.getDevice()), flush=True) print(">>> Finished building index on GPU {}\n".format(self.block_mips_index.getDevice()), flush=True)
def reset_index(self): def reset_index(self):
del self.block_mips_index
self._set_block_index() self._set_block_index()
def add_block_embed_data(self, all_block_data, clear_block_data=False): def add_block_embed_data(self, all_block_data, clear_block_data=False):
...@@ -120,7 +121,7 @@ class FaissMIPSIndex(object): ...@@ -120,7 +121,7 @@ class FaissMIPSIndex(object):
if self.use_gpu: if self.use_gpu:
for i, idx in enumerate(block_indices): for i, idx in enumerate(block_indices):
self.id_map[i] = idx self.id_map[i] = idx
if clear_block_data: if True:
all_block_data.clear() all_block_data.clear()
if self.use_gpu: if self.use_gpu:
...@@ -134,8 +135,6 @@ class FaissMIPSIndex(object): ...@@ -134,8 +135,6 @@ class FaissMIPSIndex(object):
:param reconstruct: if True: return a [num_queries x k x embed_dim] array of blocks :param reconstruct: if True: return a [num_queries x k x embed_dim] array of blocks
if False: return [num_queries x k] array of distances, and another for indices if False: return [num_queries x k] array of distances, and another for indices
""" """
if self.index_type == 'flat_l2':
query_embeds = self.alsh_query_preprocess_fn(query_embeds)
query_embeds = np.float32(detach(query_embeds)) query_embeds = np.float32(detach(query_embeds))
# query_embeds = query_embeds.float() # query_embeds = query_embeds.float()
......
...@@ -164,14 +164,14 @@ class _Timer: ...@@ -164,14 +164,14 @@ class _Timer:
def start(self): def start(self):
"""Start the timer.""" """Start the timer."""
assert not self.started_, 'timer has already been started' assert not self.started_, 'timer has already been started'
# torch.cuda.synchronize() torch.cuda.synchronize()
self.start_time = time.time() self.start_time = time.time()
self.started_ = True self.started_ = True
def stop(self): def stop(self):
"""Stop the timer.""" """Stop the timer."""
assert self.started_, 'timer is not started' assert self.started_, 'timer is not started'
# torch.cuda.synchronize() torch.cuda.synchronize()
self.elapsed_ += (time.time() - self.start_time) self.elapsed_ += (time.time() - self.start_time)
self.started_ = False self.started_ = False
......
...@@ -56,7 +56,7 @@ class DistributedDataParallel(MegatronModule): ...@@ -56,7 +56,7 @@ class DistributedDataParallel(MegatronModule):
if not no_scale and not reduce_after: if not no_scale and not reduce_after:
coalesced /= dist.get_world_size(group=self.data_parallel_group) coalesced /= dist.get_world_size(group=self.data_parallel_group)
dist.all_reduce(coalesced, group=self.data_parallel_group) dist.all_reduce(coalesced, group=self.data_parallel_group)
# torch.cuda.synchronize() torch.cuda.synchronize()
if not no_scale and reduce_after: if not no_scale and reduce_after:
coalesced /= dist.get_world_size(group=self.data_parallel_group) coalesced /= dist.get_world_size(group=self.data_parallel_group)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
......
...@@ -103,11 +103,11 @@ class REALMBertModel(MegatronModule): ...@@ -103,11 +103,11 @@ class REALMBertModel(MegatronModule):
# print("\nAttention: ", det_attention, '\n', flush=True) # print("\nAttention: ", det_attention, '\n', flush=True)
# print("pad id: ", dset.pad_id, flush=True) # print("pad id: ", dset.pad_id, flush=True)
assert bool(0 in det_attention) == bool(dset.pad_id in det_tokens) # assert bool(0 in det_attention) == bool(dset.pad_id in det_tokens)
if 0 in det_attention: # if 0 in det_attention:
idx_padid = det_tokens.index(dset.pad_id) # idx_padid = det_tokens.index(dset.pad_id)
idx_attn = det_attention.index(0) # idx_attn = det_attention.index(0)
assert idx_padid == idx_attn, (idx_padid, idx_attn) # assert idx_padid == idx_attn, (idx_padid, idx_attn)
# text = dset.decode_tokens(det_tokens) # text = dset.decode_tokens(det_tokens)
# print(text, flush=True) # print(text, flush=True)
...@@ -135,12 +135,12 @@ class REALMBertModel(MegatronModule): ...@@ -135,12 +135,12 @@ class REALMBertModel(MegatronModule):
fresh_block_logits = fresh_block_logits.reshape(batch_size, self.top_k, -1) fresh_block_logits = fresh_block_logits.reshape(batch_size, self.top_k, -1)
# print('Fresh block logits shape: ', fresh_block_logits.shape, flush=True) # print('Fresh block logits shape: ', fresh_block_logits.shape, flush=True)
# [batch_size x embed_size x 1] # [batch_size x 1 x embed_size]
query_logits = mpu.checkpoint(true_model.embed_query, tokens, attention_mask).unsqueeze(2) query_logits = mpu.checkpoint(true_model.embed_query, tokens, attention_mask).unsqueeze(1)
# print('Query logits shape: ', query_logits.shape, flush=True) # print('Query logits shape: ', query_logits.shape, flush=True)
# [batch_size x k] # [batch_size x k]
fresh_block_scores = torch.matmul(fresh_block_logits, query_logits).squeeze() fresh_block_scores = torch.matmul(query_logits, torch.transpose(fresh_block_logits, 1, 2)).squeeze()
# print('Block score shape: ', fresh_block_scores.shape, flush=True) # print('Block score shape: ', fresh_block_scores.shape, flush=True)
block_probs = F.softmax(fresh_block_scores, dim=1) block_probs = F.softmax(fresh_block_scores, dim=1)
...@@ -175,11 +175,11 @@ class REALMBertModel(MegatronModule): ...@@ -175,11 +175,11 @@ class REALMBertModel(MegatronModule):
b_start = block_starts[row_num] b_start = block_starts[row_num]
b_end = block_ends[row_num] b_end = block_ends[row_num]
# new tokens = CLS + query + SEP + block + SEP # new tokens = CLS + query + SEP + block + SEP
new_tokens_length = q_len + b_end - b_start # new_tokens_length = q_len + b_end - b_start
new_tokens_length = q_len
# splice query and block tokens accordingly # splice query and block tokens accordingly
all_tokens[row_num, :q_len] = tokens[row_num, :q_len] all_tokens[row_num, :q_len] = tokens[row_num, :q_len]
all_tokens[row_num, q_len:new_tokens_length] = topk_block_tokens[row_num, b_start:b_end] # all_tokens[row_num, q_len:new_tokens_length] = topk_block_tokens[row_num, b_start:b_end]
all_tokens[row_num, new_tokens_length:] = self.retriever.ict_dataset.pad_id all_tokens[row_num, new_tokens_length:] = self.retriever.ict_dataset.pad_id
# print(dset.decode_tokens(detach(all_tokens[row_num]).tolist()), '\n', flush=True) # print(dset.decode_tokens(detach(all_tokens[row_num]).tolist()), '\n', flush=True)
...@@ -226,9 +226,8 @@ class REALMRetriever(MegatronModule): ...@@ -226,9 +226,8 @@ class REALMRetriever(MegatronModule):
def reload_index(self): def reload_index(self):
args = get_args() args = get_args()
self.block_data = BlockData.load_from_file(args.block_data_path)
print("resetting index", flush=True)
self.hashed_index.reset_index() self.hashed_index.reset_index()
self.block_data = BlockData.load_from_file(args.block_data_path)
self.hashed_index.add_block_embed_data(self.block_data) self.hashed_index.add_block_embed_data(self.block_data)
def prep_query_text_for_retrieval(self, query_text): def prep_query_text_for_retrieval(self, query_text):
......
...@@ -244,7 +244,7 @@ def backward_step(optimizer, model, loss): ...@@ -244,7 +244,7 @@ def backward_step(optimizer, model, loss):
"""Backward step.""" """Backward step."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
# torch.cuda.synchronize() torch.cuda.synchronize()
# Backward pass. # Backward pass.
# optimizer.zero_grad(set_grads_to_None=True) # optimizer.zero_grad(set_grads_to_None=True)
...@@ -392,39 +392,36 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -392,39 +392,36 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, group=get_gloo_comm_group(), async_op=True) recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, group=get_gloo_comm_group(), async_op=True)
last_reload_iteration = iteration last_reload_iteration = iteration
while iteration < args.train_iters: while iteration < args.train_iters:
if args.max_training_rank is not None and iteration >= last_reload_iteration + 500 and not recv_handle.is_completed(): if args.max_training_rank is not None and iteration >= last_reload_iteration + 500:
time.sleep(5) if recv_handle.is_completed():
continue # should add check that INDEX_READY == 1 but what else could be happening
true_model = model
# this only applies for realm right here
if args.max_training_rank is not None and recv_handle.is_completed():
# should add check that INDEX_READY == 1 but what else could be happening
true_model = model
if hasattr(true_model, 'module'):
true_model = true_model.module
if hasattr(true_model, 'module'): if hasattr(true_model, 'module'):
true_model = true_model.module true_model = true_model.module
if hasattr(true_model, 'module'):
true_model = true_model.module
print("> Saving model and reloading index", flush=True) print("> Saving model and reloading index", flush=True)
if args.rank == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, lr_scheduler)
true_model.retriever.reload_index() if args.rank == 0:
INDEX_READY = 1 - INDEX_READY
# send handle
torch.distributed.broadcast(INDEX_READY, 0, group=get_gloo_comm_group())
true_model.retriever.reload_index()
if args.rank == 0: torch.cuda.synchronize()
INDEX_READY = 1 - INDEX_READY
torch.cuda.synchronize() recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, group=get_gloo_comm_group(), async_op=True)
last_reload_iteration = iteration
else:
time.sleep(5)
continue
# send handle
torch.distributed.broadcast(INDEX_READY, 0, group=get_gloo_comm_group())
torch.distributed.barrier(get_data_parallel_group())
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, group=get_gloo_comm_group(), async_op=True)
last_reload_iteration = iteration
elif iteration < 20: elif iteration < 20:
print("moving right along", flush=True) print("moving right along", flush=True)
# report_memory("iteration {}".format(iteration)) report_memory("iteration {}".format(iteration))
loss_dict, skipped_iter = train_step(forward_step_func, loss_dict, skipped_iter = train_step(forward_step_func,
train_data_iterator, train_data_iterator,
model, model,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment