Commit 5511c258 authored by Neel Kant's avatar Neel Kant
Browse files

Update indexer.py

parent 15d0d55b
...@@ -17,6 +17,23 @@ from megatron.training import get_model ...@@ -17,6 +17,23 @@ from megatron.training import get_model
from pretrain_bert_ict import get_batch, model_provider from pretrain_bert_ict import get_batch, model_provider
# TODO re: main()
# consider broadcasting/all-reducing all in memory rather than using the filesystem
# create a different process group in the same nccl world - don't have to use chkpts on disc or transfer things on disc
# torch distributed new group, constains a list of rank, gives back a group which I can hand to the collective operations
# create a training process group, indexing process group
# pass the training group to the distributed DDP, instead of the large world process group
# use indexing process group for the shard-combining
# communication group between process "8" and process "0" which tells training group that there's a new index
# also, process 0 sends process 8 the new model
# if i want to launch a separate process for indexing, may have to work with environment variables to
# allocate the resources well. Have to subsequently assign the correct gpus to the indexing job
# consider initializing everything in a single group and break off processes based on the ranks
# for debugging purposes, make it so that the training process group checks every some number of intervals
# and if it isn't ready, then wait so that it's consistent. Start with using the filesystem
def test_retriever(): def test_retriever():
# TODO: Update this because it's outdated and definitely won't run. # TODO: Update this because it's outdated and definitely won't run.
initialize_megatron(extra_args_provider=None, initialize_megatron(extra_args_provider=None,
...@@ -41,22 +58,6 @@ def test_retriever(): ...@@ -41,22 +58,6 @@ def test_retriever():
def main(): def main():
# TODO
# consider broadcasting/all-reducing all in memory rather than using the filesystem
# create a different process group in the same nccl world - don't have to use chkpts on disc or transfer things on disc
# torch distributed new group, constains a list of rank, gives back a group which I can hand to the collective operations
# create a training process group, indexing process group
# pass the training group to the distributed DDP, instead of the large world process group
# use indexing process group for the shard-combining
# communication group between process "8" and process "0" which tells training group that there's a new index
# also, process 0 sends process 8 the new model
# if i want to launch a separate process for indexing, may have to work with environment variables to
# allocate the resources well. Have to subsequently assign the correct gpus to the indexing job
# consider initializing everything in a single group and break off processes based on the ranks
# for debugging purposes, make it so that the training process group checks every some number of intervals
# and if it isn't ready, then wait so that it's consistent. Start with using the filesystem
initialize_megatron(extra_args_provider=None, initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
...@@ -118,10 +119,6 @@ INDEX_COM_FILE = 'ready.index' ...@@ -118,10 +119,6 @@ INDEX_COM_FILE = 'ready.index'
MODEL_COM_FILE = 'ready.model' MODEL_COM_FILE = 'ready.model'
def setup_index_com_file():
set_index_com_file_not_ready()
def set_index_com_file_not_ready(): def set_index_com_file_not_ready():
with open(INDEX_COM_FILE, 'w') as com_file: with open(INDEX_COM_FILE, 'w') as com_file:
com_file.write('0') com_file.write('0')
...@@ -133,15 +130,11 @@ def set_index_com_file_ready(): ...@@ -133,15 +130,11 @@ def set_index_com_file_ready():
def check_index_com_file_ready(): def check_index_com_file_ready():
if os.path.exists(INDEX_COM_FILE): if not os.path.exists(INDEX_COM_FILE):
with open(INDEX_COM_FILE, 'r') as com_file: set_index_com_file_not_ready()
return bool(com_file.readline())
return False
def setup_model_com_file(): with open(INDEX_COM_FILE, 'r') as com_file:
set_model_com_file_not_ready() return bool(com_file.readline())
def set_model_com_file_not_ready(): def set_model_com_file_not_ready():
...@@ -155,11 +148,11 @@ def set_model_com_file_ready(): ...@@ -155,11 +148,11 @@ def set_model_com_file_ready():
def check_model_com_file_ready(): def check_model_com_file_ready():
if os.path.exists(MODEL_COM_FILE): if not os.path.exists(MODEL_COM_FILE):
with open(MODEL_COM_FILE, 'r') as com_file: set_index_com_file_not_ready()
return bool(com_file.readline())
return False with open(MODEL_COM_FILE, 'r') as com_file:
return bool(com_file.readline())
def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=False, from_realm_chkpt=False): def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=False, from_realm_chkpt=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment