Commit 88637044 authored by Neel Kant's avatar Neel Kant
Browse files

Debug hashed_index.main

parent 9a617f6c
......@@ -28,7 +28,7 @@ class HashedIndex(object):
np.random.seed(seed)
self.block_data = defaultdict(list)
self.hash_data = defaultdict(list)
self.hash_matrix = np.random.rand(embed_size, num_buckets / 2)
self.hash_matrix = np.random.rand(embed_size, int(num_buckets / 2))
def state(self):
state = {
......@@ -72,19 +72,21 @@ class HashedIndex(object):
with open('{}/{}.pkl'.format(dir_name, rank), 'wb') as data_file:
pickle.dump(self.state(), data_file)
def consolidate_shards_and_save(self):
def consolidate_shards_and_save(self, ignore_shard=0):
"""Combine all the shards made using self.save_shard()"""
dir_name = 'block_hash_data'
fnames = os.listdir(dir_name)
for fname in fnames:
if str(ignore_shard) in fname:
continue
with open('{}/{}'.format(dir_name, fname), 'rb') as f:
data = pickle.load(f)
assert data['hash_matrix'] == self.hash_matrix
assert np.array_equal(data['hash_matrix'], self.hash_matrix)
old_size = len(self.block_data)
shard_size = len(data['block_data'])
self.block_data.update(data['block_data'])
assert len(self.block_data) == old_size + shard_size
assert len(self.block_data) == old_size + shard_size, (old_size, shard_size, len(self.block_data))
for bucket, items in data['hash_data'].items():
self.hash_data[bucket].extend(items)
......@@ -137,7 +139,7 @@ def main():
block_logits = actual_model.embed_block(block_tokens, block_pad_mask)
hashed_index.hash_embeds(block_logits, block_indices)
hashed_index.assign_block_embeds(block_indices, detach(block_logits))
hashed_index.assign_block_embeds(block_indices[:,3], detach(block_logits))
if i % 100 == 0:
print(i, flush=True)
......
......@@ -329,7 +329,7 @@ class ICTBertModel(MegatronModule):
ict_head_size=ict_head_size,
parallel_output=parallel_output
)
assert not only_block_model and only_query_model
assert not (only_block_model and only_query_model)
self.use_block_model = not only_query_model
self.use_query_model = not only_block_model
......@@ -355,7 +355,7 @@ class ICTBertModel(MegatronModule):
def embed_query(self, query_tokens, query_attention_mask):
"""Embed a batch of tokens using the query model"""
if self.use_query_model:
query_types = torch.zeros(query_tokens.shape).type(torch.float16).cuda()
query_types = torch.zeros(query_tokens.shape).type(torch.int64).cuda()
query_ict_logits, _ = self.query_model.forward(query_tokens, query_attention_mask, query_types)
return query_ict_logits
else:
......@@ -364,7 +364,7 @@ class ICTBertModel(MegatronModule):
def embed_block(self, block_tokens, block_attention_mask):
"""Embed a batch of tokens using the block model"""
if self.use_block_model:
block_types = torch.zeros(block_tokens.shape).type(torch.float16).cuda()
block_types = torch.zeros(block_tokens.shape).type(torch.int64).cuda()
block_ict_logits, _ = self.block_model.forward(block_tokens, block_attention_mask, block_types)
return block_ict_logits
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment