Commit 662dc982 authored by Neel Kant's avatar Neel Kant
Browse files

Debug hash dump

parent 81c71789
...@@ -26,13 +26,14 @@ def main(): ...@@ -26,13 +26,14 @@ def main():
data_iter = iter(get_dataloader(dataset)) data_iter = iter(get_dataloader(dataset))
hash_data = defaultdict(list) hash_data = defaultdict(list)
hash_matrix = np.random.rand(128, 1024) hash_matrix = torch.cuda.HalfTensor(np.random.rand(128, 1024))
all_input_tokens = [] all_input_tokens = []
all_input_logits = [] all_input_logits = []
all_block_tokens = [] all_block_tokens = []
all_block_logits = [] all_block_logits = []
i = 0
while True: while True:
try: try:
input_tokens, input_types, input_pad_mask, \ input_tokens, input_types, input_pad_mask, \
...@@ -43,8 +44,8 @@ def main(): ...@@ -43,8 +44,8 @@ def main():
input_tokens, input_types, input_pad_mask, block_tokens, block_pad_mask, block_token_types, return_logits=True) input_tokens, input_types, input_pad_mask, block_tokens, block_pad_mask, block_token_types, return_logits=True)
block_hash_pos = torch.matmul(block_logits, hash_matrix) block_hash_pos = torch.matmul(block_logits, hash_matrix)
block_hash_full = torch.concat((block_hash_pos, -block_hash_pos), axis=1) block_hash_full = torch.cat((block_hash_pos, -block_hash_pos), axis=1)
block_hashes = torch.argmax(block_hash_full, axis=1) block_hashes = torch.argmax(block_hash_full, axis=1).detach().cpu().numpy()
for hash, idx in zip(block_hashes, block_indices): for hash, idx in zip(block_hashes, block_indices):
hash_data[int(hash)].append(int(idx)) hash_data[int(hash)].append(int(idx))
...@@ -53,6 +54,15 @@ def main(): ...@@ -53,6 +54,15 @@ def main():
all_block_tokens.append(block_tokens.detach().cpu().numpy()) all_block_tokens.append(block_tokens.detach().cpu().numpy())
all_block_logits.append(block_logits.detach().cpu().numpy()) all_block_logits.append(block_logits.detach().cpu().numpy())
if i % 100 == 0:
print(i, flush=True)
print(len(all_block_tokens), flush=True)
print(block_tokens.shape, flush=True)
i += 1
if i == 10:
break
all_input_tokens = np.array(all_input_tokens).reshape(-1, args.seq_length) all_input_tokens = np.array(all_input_tokens).reshape(-1, args.seq_length)
all_input_logits = np.array(all_input_logits).reshape(-1, 128) all_input_logits = np.array(all_input_logits).reshape(-1, 128)
all_block_tokens = np.array(all_block_tokens).reshape(-1, args.seq_length) all_block_tokens = np.array(all_block_tokens).reshape(-1, args.seq_length)
......
...@@ -79,7 +79,7 @@ class InverseClozeDataset(Dataset): ...@@ -79,7 +79,7 @@ class InverseClozeDataset(Dataset):
'context_text': np.array(context_tokens), 'context_text': np.array(context_tokens),
'context_types': np.array(context_token_types), 'context_types': np.array(context_token_types),
'context_pad_mask': np.array(context_pad_mask), 'context_pad_mask': np.array(context_pad_mask),
'context_indices': np.array([block_idx]) 'context_indices': np.array([block_idx]).astype(np.int64)
} }
return sample return sample
......
...@@ -47,7 +47,7 @@ def get_batch(data_iterator): ...@@ -47,7 +47,7 @@ def get_batch(data_iterator):
# Items and their type. # Items and their type.
keys = ['input_text', 'input_types', 'input_pad_mask', keys = ['input_text', 'input_types', 'input_pad_mask',
'context_text', 'context_types', 'context_pad_mask'] 'context_text', 'context_types', 'context_pad_mask', 'context_indices']
datatype = torch.int64 datatype = torch.int64
# Broadcast data. # Broadcast data.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment