Commit e59496bf authored by Neel Kant's avatar Neel Kant
Browse files

Restructure Indexer classes

parent e0a1caba
......@@ -92,39 +92,18 @@ def setup_realm_groups_and_vars():
set_data_parallel_group(get_index_group())
class AsyncIndexBuilder(object):
def __init__(self, rank):
self.rank = rank
class IndexBuilder(object):
def __init__(self):
args = get_args()
self.is_main_builder = self.rank == args.max_training_rank
self.main_builder_idx = args.max_training_rank
self.debug = args.debug
self.rank = args.rank
self.model = None
self.dataloader = None
self.block_data = None
self.load_attributes()
global INDEX_READY
INDEX_READY = get_index_ready()
def run_async(self):
while True:
print("Starting (again!)", flush=True)
self.build_and_save_index()
self.send_index_ready_signal()
while INDEX_READY == 1:
print("Waiting for new model checkpoint.", flush=True)
time.sleep(5)
self.load_attributes()
self.is_main_builder = args.rank == 0
def load_attributes(self):
try:
self.model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=True)
except:
print(">>>>> No realm chkpt available", flush=True)
self.model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=False)
self.model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=False)
self.model.eval()
self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset()))
self.block_data = BlockData()
......@@ -149,7 +128,7 @@ class AsyncIndexBuilder(object):
total += block_indices.size
i += 1
if i % 500 == 0:
if i % 1000 == 0:
print('Batch {:10d} | Total {:10d}'.format(i, total), flush=True)
if self.debug:
break
......@@ -162,57 +141,60 @@ class AsyncIndexBuilder(object):
self.block_data.consolidate_shards_and_save(ignore_shard=self.rank)
self.block_data.clear()
def send_index_ready_signal(self):
class AsyncIndexBuilder(IndexBuilder):
def __init__(self, rank):
self.rank = rank
args = get_args()
self.is_main_builder = self.rank == args.max_training_rank
self.main_builder_idx = args.max_training_rank
self.debug = args.debug
self.model = None
self.dataloader = None
self.block_data = None
self.load_attributes()
global INDEX_READY
if self.is_main_builder:
INDEX_READY = 1 - INDEX_READY
print("Switched INDEX_READY", flush=True)
torch.cuda.synchronize()
send_handle = dist.broadcast(INDEX_READY, self.main_builder_idx, group=get_gloo_comm_group(), async_op=True)
INDEX_READY = get_index_ready()
torch.distributed.barrier(get_data_parallel_group())
def run_async(self):
global INDEX_READY
# synchronize for start
dist.broadcast(INDEX_READY, 0, group=get_gloo_comm_group())
while True:
print("Starting (again!)", flush=True)
self.build_and_save_index()
self.send_index_ready_signal()
while INDEX_READY == 1:
print("Waiting for new model checkpoint.", flush=True)
time.sleep(5)
self.load_attributes()
class BasicIndexBuilder(object):
def __init__(self):
args = get_args()
self.rank = args.rank
self.model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=False)
def load_attributes(self):
try:
self.model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=True)
except:
print(">>>>> No realm chkpt available", flush=True)
self.model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=False)
self.model.eval()
self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset()))
self.block_data = BlockData()
def build_and_save_index(self):
i = 1
total = 0
while True:
with torch.no_grad():
try:
query_tokens, query_pad_mask, \
block_tokens, block_pad_mask, block_index_data = get_batch(self.dataloader)
except:
break
block_index_data = detach(block_index_data)
block_indices = block_index_data[:, 3]
block_meta = block_index_data[:, :3]
block_logits = detach(self.model(None, None, block_tokens, block_pad_mask, only_block=True))
self.block_data.add_block_data(block_indices, block_logits, block_meta)
total += block_indices.size
i += 1
if i % 2000 == 0:
print('Batch {:10d} | Total {:10d}'.format(i, total), flush=True)
def send_index_ready_signal(self):
global INDEX_READY
if self.is_main_builder:
INDEX_READY = 1 - INDEX_READY
print("Switched INDEX_READY", flush=True)
torch.cuda.synchronize()
self.block_data.save_shard(self.rank)
torch.distributed.barrier()
del self.model
# send handle
dist.broadcast(INDEX_READY, self.main_builder_idx, group=get_gloo_comm_group(), async_op=True)
if self.rank == 0:
self.block_data.consolidate_shards_and_save(ignore_shard=self.rank)
self.block_data.clear()
# recv handle
dist.broadcast(INDEX_READY, 0, group=get_gloo_comm_group())
torch.distributed.barrier(get_data_parallel_group())
def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=False, from_realm_chkpt=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment