Commit e6f2720d authored by Neel Kant's avatar Neel Kant
Browse files

Include comprehensive block info when hashing

parent f3d2426e
...@@ -39,7 +39,7 @@ def main(): ...@@ -39,7 +39,7 @@ def main():
try: try:
input_tokens, input_types, input_pad_mask, \ input_tokens, input_types, input_pad_mask, \
block_tokens, block_token_types, block_pad_mask, block_indices = get_batch(data_iter) block_tokens, block_token_types, block_pad_mask, block_indices = get_batch(data_iter)
except StopIteration: except:
break break
# TODO: make sure input is still in block # TODO: make sure input is still in block
...@@ -49,20 +49,16 @@ def main(): ...@@ -49,20 +49,16 @@ def main():
block_hash_pos = torch.matmul(block_logits, hash_matrix) block_hash_pos = torch.matmul(block_logits, hash_matrix)
block_hash_full = torch.cat((block_hash_pos, -block_hash_pos), axis=1) block_hash_full = torch.cat((block_hash_pos, -block_hash_pos), axis=1)
block_hashes = torch.argmax(block_hash_full, axis=1).detach().cpu().numpy() block_hashes = torch.argmax(block_hash_full, axis=1).detach().cpu().numpy()
for hash, idx in zip(block_hashes, block_indices): for hash, indices_array in zip(block_hashes, block_indices):
hash_data[int(hash)].append(int(idx)) hash_data[int(hash)].append(indicecs_array)
all_input_tokens.append(input_tokens.detach().cpu().numpy()) all_input_tokens.append(input_tokens.detach().cpu().numpy())
all_input_logits.append(input_logits.detach().cpu().numpy()) all_input_logits.append(input_logits.detach().cpu().numpy())
all_block_tokens.append(block_tokens.detach().cpu().numpy()) all_block_tokens.append(block_tokens.detach().cpu().numpy())
all_block_logits.append(block_logits.detach().cpu().numpy()) all_block_logits.append(block_logits.detach().cpu().numpy())
if i % 10 == 0: if i == 1000:
print(i, flush=True) print(i)
print(block_tokens[0])
if i == 100:
break
i += 1 i += 1
......
...@@ -79,7 +79,7 @@ class InverseClozeDataset(Dataset): ...@@ -79,7 +79,7 @@ class InverseClozeDataset(Dataset):
'context_text': np.array(context_tokens), 'context_text': np.array(context_tokens),
'context_types': np.array(context_token_types), 'context_types': np.array(context_token_types),
'context_pad_mask': np.array(context_pad_mask), 'context_pad_mask': np.array(context_pad_mask),
'context_indices': np.array([block_idx]).astype(np.int64) 'context_indices': np.array([start_idx, end_idx, doc_idx, block_idx]).astype(np.int64)
} }
return sample return sample
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment