Commit 9b599754 authored by Neel Kant's avatar Neel Kant
Browse files

Debug and run hashing code

parent 256eb6ed
from collections import defaultdict from collections import defaultdict
import os import os
import pickle import pickle
import shutil
import numpy as np import numpy as np
import torch import torch
...@@ -43,9 +44,8 @@ def embed_docs(): ...@@ -43,9 +44,8 @@ def embed_docs():
except: except:
break break
# TODO: make sure input is still in block input_logits, block_logits = model.module.module.forward(
input_logits, block_logits, _ = model.module.module.forward( input_tokens, input_types, input_pad_mask, block_tokens, block_pad_mask, block_token_types)
input_tokens, input_types, input_pad_mask, block_tokens, block_pad_mask, block_token_types, return_logits=True)
block_hash_pos = torch.matmul(block_logits, hash_matrix) block_hash_pos = torch.matmul(block_logits, hash_matrix)
block_hash_full = torch.cat((block_hash_pos, -block_hash_pos), axis=1) block_hash_full = torch.cat((block_hash_pos, -block_hash_pos), axis=1)
...@@ -67,6 +67,7 @@ def embed_docs(): ...@@ -67,6 +67,7 @@ def embed_docs():
if not os.path.isdir(dir_name): if not os.path.isdir(dir_name):
os.mkdir(dir_name) os.mkdir(dir_name)
# save the data for each shard
with open('{}/{}.pkl'.format(dir_name, args.rank), 'wb') as data_file: with open('{}/{}.pkl'.format(dir_name, args.rank), 'wb') as data_file:
all_data = {'block_data': block_data, 'hash_data': hash_data} all_data = {'block_data': block_data, 'hash_data': hash_data}
pickle.dump(all_data, data_file) pickle.dump(all_data, data_file)
...@@ -77,19 +78,20 @@ def embed_docs(): ...@@ -77,19 +78,20 @@ def embed_docs():
del all_data del all_data
del model del model
# rank 0 process consolidates shards and saves into final file
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
all_block_data = defaultdict(dict) all_block_data = defaultdict(dict)
dir_name = 'block_hash_data' dir_name = 'block_hash_data'
fnames = os.listdir(dir_name) fnames = os.listdir(dir_name)
for fname in fnames: for fname in fnames:
with open(fname, 'rb') as f: with open('{}/{}'.format(dir_name, fname), 'rb') as f:
data = pickle.load(f) data = pickle.load(f)
all_block_data['hash_data'].update(data['hash_data']) all_block_data['hash_data'].update(data['hash_data'])
all_block_data['block_data'].update(data['block_data']) all_block_data['block_data'].update(data['block_data'])
with open('block_hash_data.pkl', 'wb') as final_file: with open('block_hash_data.pkl', 'wb') as final_file:
pickle.dump(all_block_data, final_file) pickle.dump(all_block_data, final_file)
os.rmdir(dir_name) shutil.rmtree(dir_name, ignore_errors=True)
def load_checkpoint(): def load_checkpoint():
...@@ -125,8 +127,8 @@ def get_dataset(): ...@@ -125,8 +127,8 @@ def get_dataset():
kwargs = dict( kwargs = dict(
name='full', name='full',
context_dataset=block_dataset, block_dataset=block_dataset,
titles_dataset=titles_dataset, title_dataset=titles_dataset,
data_prefix=args.data_path, data_prefix=args.data_path,
num_epochs=1, num_epochs=1,
max_num_samples=None, max_num_samples=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment