Commit 51204a4d authored by Neel Kant's avatar Neel Kant
Browse files

Misc changes

parent dfb907fe
......@@ -167,10 +167,6 @@ class AsyncIndexBuilder(IndexBuilder):
print("Starting (again!)", flush=True)
self.build_and_save_index()
self.send_index_ready_signal()
while INDEX_READY == 1:
print("Waiting for new model checkpoint.", flush=True)
time.sleep(5)
self.load_attributes()
def load_attributes(self):
......@@ -195,7 +191,6 @@ class AsyncIndexBuilder(IndexBuilder):
# recv handle
dist.broadcast(INDEX_READY, 0, group=get_gloo_comm_group())
torch.distributed.barrier(get_data_parallel_group())
def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=False, from_realm_chkpt=False):
......
......@@ -96,7 +96,8 @@ def salient_span_mask(tokens, mask_id):
# need to get all named entities
entities = SPACY_NER(tokens_str).ents
entities = [e for e in entities if e.text != "CLS"]
undesired_types = ['CARDINAL', 'TIME', 'PERCENT', 'MONEY', 'QUANTITY', 'ORDINAL']
entities = [e for e in entities if e.text != "CLS" and e.label_ not in undesired_types]
if len(entities) == 0:
return None
entity_idx = np.random.randint(0, len(entities))
......
......@@ -29,7 +29,7 @@ class BlockData(object):
def clear(self):
"""Clear the data structures to save memory"""
self.embed_data = dict()
self.meta_data = dict()
# self.meta_data = dict()
@classmethod
def load_from_file(cls, fname):
......@@ -100,7 +100,7 @@ class FaissMIPSIndex(object):
self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT)
if not self.use_gpu:
self.block_mips_index = faiss.IndexIDMap(self.block_mips_index)
print(">> Finished building index", flush=True)
print(">> Finished building index\n", flush=True)
if self.use_gpu:
res = faiss.StandardGpuResources()
......@@ -109,9 +109,10 @@ class FaissMIPSIndex(object):
config.device = torch.cuda.current_device()
config.useFloat16 = True
self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config)
print(">>> Loaded Faiss index on GPU {}\n".format(self.block_mips_index.getDevice()), flush=True)
print(">>> Finished building index on GPU {}\n".format(self.block_mips_index.getDevice()), flush=True)
def reset_index(self):
del self.block_mips_index
self._set_block_index()
def add_block_embed_data(self, all_block_data, clear_block_data=False):
......@@ -120,7 +121,7 @@ class FaissMIPSIndex(object):
if self.use_gpu:
for i, idx in enumerate(block_indices):
self.id_map[i] = idx
if clear_block_data:
if True:
all_block_data.clear()
if self.use_gpu:
......@@ -134,8 +135,6 @@ class FaissMIPSIndex(object):
:param reconstruct: if True: return a [num_queries x k x embed_dim] array of blocks
if False: return [num_queries x k] array of distances, and another for indices
"""
if self.index_type == 'flat_l2':
query_embeds = self.alsh_query_preprocess_fn(query_embeds)
query_embeds = np.float32(detach(query_embeds))
# query_embeds = query_embeds.float()
......
......@@ -164,14 +164,14 @@ class _Timer:
def start(self):
"""Start the timer."""
assert not self.started_, 'timer has already been started'
# torch.cuda.synchronize()
torch.cuda.synchronize()
self.start_time = time.time()
self.started_ = True
def stop(self):
"""Stop the timer."""
assert self.started_, 'timer is not started'
# torch.cuda.synchronize()
torch.cuda.synchronize()
self.elapsed_ += (time.time() - self.start_time)
self.started_ = False
......
......@@ -56,7 +56,7 @@ class DistributedDataParallel(MegatronModule):
if not no_scale and not reduce_after:
coalesced /= dist.get_world_size(group=self.data_parallel_group)
dist.all_reduce(coalesced, group=self.data_parallel_group)
# torch.cuda.synchronize()
torch.cuda.synchronize()
if not no_scale and reduce_after:
coalesced /= dist.get_world_size(group=self.data_parallel_group)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
......
......@@ -103,11 +103,11 @@ class REALMBertModel(MegatronModule):
# print("\nAttention: ", det_attention, '\n', flush=True)
# print("pad id: ", dset.pad_id, flush=True)
assert bool(0 in det_attention) == bool(dset.pad_id in det_tokens)
if 0 in det_attention:
idx_padid = det_tokens.index(dset.pad_id)
idx_attn = det_attention.index(0)
assert idx_padid == idx_attn, (idx_padid, idx_attn)
# assert bool(0 in det_attention) == bool(dset.pad_id in det_tokens)
# if 0 in det_attention:
# idx_padid = det_tokens.index(dset.pad_id)
# idx_attn = det_attention.index(0)
# assert idx_padid == idx_attn, (idx_padid, idx_attn)
# text = dset.decode_tokens(det_tokens)
# print(text, flush=True)
......@@ -135,12 +135,12 @@ class REALMBertModel(MegatronModule):
fresh_block_logits = fresh_block_logits.reshape(batch_size, self.top_k, -1)
# print('Fresh block logits shape: ', fresh_block_logits.shape, flush=True)
# [batch_size x embed_size x 1]
query_logits = mpu.checkpoint(true_model.embed_query, tokens, attention_mask).unsqueeze(2)
# [batch_size x 1 x embed_size]
query_logits = mpu.checkpoint(true_model.embed_query, tokens, attention_mask).unsqueeze(1)
# print('Query logits shape: ', query_logits.shape, flush=True)
# [batch_size x k]
fresh_block_scores = torch.matmul(fresh_block_logits, query_logits).squeeze()
fresh_block_scores = torch.matmul(query_logits, torch.transpose(fresh_block_logits, 1, 2)).squeeze()
# print('Block score shape: ', fresh_block_scores.shape, flush=True)
block_probs = F.softmax(fresh_block_scores, dim=1)
......@@ -175,11 +175,11 @@ class REALMBertModel(MegatronModule):
b_start = block_starts[row_num]
b_end = block_ends[row_num]
# new tokens = CLS + query + SEP + block + SEP
new_tokens_length = q_len + b_end - b_start
# new_tokens_length = q_len + b_end - b_start
new_tokens_length = q_len
# splice query and block tokens accordingly
all_tokens[row_num, :q_len] = tokens[row_num, :q_len]
all_tokens[row_num, q_len:new_tokens_length] = topk_block_tokens[row_num, b_start:b_end]
# all_tokens[row_num, q_len:new_tokens_length] = topk_block_tokens[row_num, b_start:b_end]
all_tokens[row_num, new_tokens_length:] = self.retriever.ict_dataset.pad_id
# print(dset.decode_tokens(detach(all_tokens[row_num]).tolist()), '\n', flush=True)
......@@ -226,9 +226,8 @@ class REALMRetriever(MegatronModule):
def reload_index(self):
args = get_args()
self.block_data = BlockData.load_from_file(args.block_data_path)
print("resetting index", flush=True)
self.hashed_index.reset_index()
self.block_data = BlockData.load_from_file(args.block_data_path)
self.hashed_index.add_block_embed_data(self.block_data)
def prep_query_text_for_retrieval(self, query_text):
......
......@@ -244,7 +244,7 @@ def backward_step(optimizer, model, loss):
"""Backward step."""
args = get_args()
timers = get_timers()
# torch.cuda.synchronize()
torch.cuda.synchronize()
# Backward pass.
# optimizer.zero_grad(set_grads_to_None=True)
......@@ -392,39 +392,36 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, group=get_gloo_comm_group(), async_op=True)
last_reload_iteration = iteration
while iteration < args.train_iters:
if args.max_training_rank is not None and iteration >= last_reload_iteration + 500 and not recv_handle.is_completed():
time.sleep(5)
continue
# this only applies for realm right here
if args.max_training_rank is not None and recv_handle.is_completed():
# should add check that INDEX_READY == 1 but what else could be happening
true_model = model
if hasattr(true_model, 'module'):
true_model = true_model.module
if args.max_training_rank is not None and iteration >= last_reload_iteration + 500:
if recv_handle.is_completed():
# should add check that INDEX_READY == 1 but what else could be happening
true_model = model
if hasattr(true_model, 'module'):
true_model = true_model.module
if hasattr(true_model, 'module'):
true_model = true_model.module
print("> Saving model and reloading index", flush=True)
if args.rank == 0:
print("> Saving model and reloading index", flush=True)
save_checkpoint(iteration, model, optimizer, lr_scheduler)
true_model.retriever.reload_index()
if args.rank == 0:
INDEX_READY = 1 - INDEX_READY
# send handle
torch.distributed.broadcast(INDEX_READY, 0, group=get_gloo_comm_group())
true_model.retriever.reload_index()
if args.rank == 0:
INDEX_READY = 1 - INDEX_READY
torch.cuda.synchronize()
torch.cuda.synchronize()
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, group=get_gloo_comm_group(), async_op=True)
last_reload_iteration = iteration
else:
time.sleep(5)
continue
# send handle
torch.distributed.broadcast(INDEX_READY, 0, group=get_gloo_comm_group())
torch.distributed.barrier(get_data_parallel_group())
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, group=get_gloo_comm_group(), async_op=True)
last_reload_iteration = iteration
elif iteration < 20:
print("moving right along", flush=True)
# report_memory("iteration {}".format(iteration))
report_memory("iteration {}".format(iteration))
loss_dict, skipped_iter = train_step(forward_step_func,
train_data_iterator,
model,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment