Commit aae93362 authored by Neel Kant's avatar Neel Kant
Browse files

Create HashedIndex class

parent 9b599754
...@@ -22,7 +22,99 @@ def detach(tensor): ...@@ -22,7 +22,99 @@ def detach(tensor):
return tensor.detach().cpu().numpy() return tensor.detach().cpu().numpy()
def embed_docs(): class HashedIndex(object):
"""Class for holding hashed data"""
def __init__(self, embed_size, num_buckets, seed=0):
np.random.seed(seed)
self.block_data = defaultdict(list)
self.hash_data = defaultdict(list)
self.hash_matrix = np.random.rand(embed_size, num_buckets / 2)
def state(self):
state = {
'block_data': self.block_data,
'hash_data': self.hash_data,
'hash_matrix': self.hash_matrix
}
return state
def get_block_bucket(self, hash):
return self.hash_data[hash]
def get_block_embed(self, block_idx):
return self.block_data[block_idx]
def hash_embeds(self, embeds, block_data=None):
"""Hash a tensor of embeddings using a random projection matrix"""
embed_scores_pos = torch.matmul(embeds, torch.cuda.HalfTensor(self.hash_matrix))
embed_scores = torch.cat((embed_scores_pos, -embed_scores_pos), axis=1)
embed_hashes = detach(torch.argmax(embed_scores, axis=1))
if block_data is not None:
for hash, indices in zip(embed_hashes, block_data):
self.hash_data[hash].append(indices)
return embed_hashes
def assign_block_embeds(self, block_indices, block_embeds, allow_overwrite=False):
"""Assign the embeddings for each block index into a hash map"""
for idx, embed in zip(block_indices, block_embeds):
if not allow_overwrite and int(idx) in self.block_data:
raise ValueError("Attempted to overwrite a read-only HashedIndex")
self.block_data[int(idx)] = embed
def save_shard(self, rank):
dir_name = 'block_hash_data'
if not os.path.isdir(dir_name):
os.mkdir(dir_name)
# save the data for each shard
with open('{}/{}.pkl'.format(dir_name, rank), 'wb') as data_file:
pickle.dump(self.state(), data_file)
def consolidate_shards_and_save(self):
"""Combine all the shards made using self.save_shard()"""
dir_name = 'block_hash_data'
fnames = os.listdir(dir_name)
for fname in fnames:
with open('{}/{}'.format(dir_name, fname), 'rb') as f:
data = pickle.load(f)
assert data['hash_matrix'] == self.hash_matrix
old_size = len(self.block_data)
shard_size = len(data['block_data'])
self.block_data.update(data['block_data'])
assert len(self.block_data) == old_size + shard_size
for bucket, items in data['hash_data'].items():
self.hash_data[bucket].extend(items)
with open('block_hash_data.pkl', 'wb') as final_file:
pickle.dump(self.state(), final_file)
shutil.rmtree(dir_name, ignore_errors=True)
def clear(self):
"""Clear the data structures to save memory"""
self.block_data = defaultdict(list)
self.hash_data = defaultdict(list)
def main():
# TODO
# consider broadcasting/all-reducing all in memory rather than using the filesystem
# create a different process group in the same nccl world - don't have to use chkpts on disc or transfer things on disc
# torch distributed new group, constains a list of rank, gives back a group which I can hand to the collective operations
# create a training process group, indexing process group
# pass the training group to the distributed DDP, instead of the large world process group
# use indexing process group for the shard-combining
# communication group between process "8" and process "0" which tells training group that there's a new index
# also, process 0 sends process 8 the new model
# if i want to launch a separate process for indexing, may have to work with environment variables to
# allocate the resources well. Have to subsequently assign the correct gpus to the indexing job
# consider initializing everything in a single group and break off processes based on the ranks
initialize_megatron(extra_args_provider=None, initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
args = get_args() args = get_args()
...@@ -30,68 +122,35 @@ def embed_docs(): ...@@ -30,68 +122,35 @@ def embed_docs():
model.eval() model.eval()
dataset = get_dataset() dataset = get_dataset()
data_iter = iter(get_dataloader(dataset)) data_iter = iter(get_dataloader(dataset))
hashed_index = HashedIndex(embed_size=128, num_buckets=2048)
hash_data = defaultdict(list)
hash_matrix = torch.cuda.HalfTensor(np.random.rand(128, 1024))
hash_data['matrix'] = hash_matrix
block_data = defaultdict(list)
i = 0 i = 0
while True: while True:
try: try:
input_tokens, input_types, input_pad_mask, \ query_tokens, query_pad_mask, \
block_tokens, block_token_types, block_pad_mask, block_indices = get_batch(data_iter) block_tokens, block_pad_mask, block_indices = get_batch(data_iter)
except: except:
break break
input_logits, block_logits = model.module.module.forward( actual_model = model.module.module
input_tokens, input_types, input_pad_mask, block_tokens, block_pad_mask, block_token_types) block_indices = detach(block_indices)
block_hash_pos = torch.matmul(block_logits, hash_matrix) block_logits = actual_model.embed_block(block_tokens, block_pad_mask)
block_hash_full = torch.cat((block_hash_pos, -block_hash_pos), axis=1) hashed_index.hash_embeds(block_logits, block_indices)
block_hashes = detach(torch.argmax(block_hash_full, axis=1)) hashed_index.assign_block_embeds(block_indices, detach(block_logits))
for hash, indices_array in zip(block_hashes, block_indices):
hash_data[int(hash)].append(detach(indices_array))
block_logits = detach(block_logits)
# originally this has [start_idx, end_idx, doc_idx, block_idx]
block_indices = detach(block_indices)[:, 3]
for logits, idx in zip(block_logits, block_indices):
block_data[int(idx)] = logits
if i % 100 == 0: if i % 100 == 0:
print(i, flush=True) print(i, flush=True)
i += 1 i += 1
dir_name = 'block_hash_data' hashed_index.save_shard(args.rank)
if not os.path.isdir(dir_name):
os.mkdir(dir_name)
# save the data for each shard
with open('{}/{}.pkl'.format(dir_name, args.rank), 'wb') as data_file:
all_data = {'block_data': block_data, 'hash_data': hash_data}
pickle.dump(all_data, data_file)
torch.distributed.barrier() torch.distributed.barrier()
all_data.clear()
del all_data
del model del model
# rank 0 process consolidates shards and saves into final file
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
all_block_data = defaultdict(dict) hashed_index.consolidate_shards_and_save()
dir_name = 'block_hash_data' else:
fnames = os.listdir(dir_name) hashed_index.clear()
for fname in fnames:
with open('{}/{}'.format(dir_name, fname), 'rb') as f:
data = pickle.load(f)
all_block_data['hash_data'].update(data['hash_data'])
all_block_data['block_data'].update(data['block_data'])
with open('block_hash_data.pkl', 'wb') as final_file:
pickle.dump(all_block_data, final_file)
shutil.rmtree(dir_name, ignore_errors=True)
def load_checkpoint(): def load_checkpoint():
...@@ -162,4 +221,4 @@ def get_dataloader(dataset): ...@@ -162,4 +221,4 @@ def get_dataloader(dataset):
if __name__ == "__main__": if __name__ == "__main__":
embed_docs() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment