Commit 9b599754 authored by Neel Kant's avatar Neel Kant
Browse files

Debug and run hashing code

parent 256eb6ed
from collections import defaultdict
import os
import pickle
import shutil
import numpy as np
import torch
......@@ -43,9 +44,8 @@ def embed_docs():
except:
break
# TODO: make sure input is still in block
input_logits, block_logits, _ = model.module.module.forward(
input_tokens, input_types, input_pad_mask, block_tokens, block_pad_mask, block_token_types, return_logits=True)
input_logits, block_logits = model.module.module.forward(
input_tokens, input_types, input_pad_mask, block_tokens, block_pad_mask, block_token_types)
block_hash_pos = torch.matmul(block_logits, hash_matrix)
block_hash_full = torch.cat((block_hash_pos, -block_hash_pos), axis=1)
......@@ -67,6 +67,7 @@ def embed_docs():
if not os.path.isdir(dir_name):
os.mkdir(dir_name)
# save the data for each shard
with open('{}/{}.pkl'.format(dir_name, args.rank), 'wb') as data_file:
all_data = {'block_data': block_data, 'hash_data': hash_data}
pickle.dump(all_data, data_file)
......@@ -77,19 +78,20 @@ def embed_docs():
del all_data
del model
# rank 0 process consolidates shards and saves into final file
if mpu.get_data_parallel_rank() == 0:
all_block_data = defaultdict(dict)
dir_name = 'block_hash_data'
fnames = os.listdir(dir_name)
for fname in fnames:
with open(fname, 'rb') as f:
with open('{}/{}'.format(dir_name, fname), 'rb') as f:
data = pickle.load(f)
all_block_data['hash_data'].update(data['hash_data'])
all_block_data['block_data'].update(data['block_data'])
with open('block_hash_data.pkl', 'wb') as final_file:
pickle.dump(all_block_data, final_file)
os.rmdir(dir_name)
shutil.rmtree(dir_name, ignore_errors=True)
def load_checkpoint():
......@@ -125,8 +127,8 @@ def get_dataset():
kwargs = dict(
name='full',
context_dataset=block_dataset,
titles_dataset=titles_dataset,
block_dataset=block_dataset,
title_dataset=titles_dataset,
data_prefix=args.data_path,
num_epochs=1,
max_num_samples=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment