Commit dfaf674d authored by Neel Kant's avatar Neel Kant
Browse files

Merge branch 'hashing' of https://gitlab-master.nvidia.com/ADLR/megatron-lm into hashing

parents 43d5d84b 017a943f
from collections import defaultdict from collections import defaultdict
import os
import pickle import pickle
import numpy as np import numpy as np
...@@ -16,7 +17,7 @@ from megatron.training import get_model ...@@ -16,7 +17,7 @@ from megatron.training import get_model
from pretrain_bert_ict import get_batch, model_provider from pretrain_bert_ict import get_batch, model_provider
def main(): def embed_docs():
initialize_megatron(extra_args_provider=None, initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
args = get_args() args = get_args()
...@@ -27,15 +28,9 @@ def main(): ...@@ -27,15 +28,9 @@ def main():
hash_data = defaultdict(list) hash_data = defaultdict(list)
hash_matrix = torch.cuda.HalfTensor(np.random.rand(128, 1024)) hash_matrix = torch.cuda.HalfTensor(np.random.rand(128, 1024))
hash_data['matrix'] = hash_matrix
#all_input_tokens = []
#all_input_logits = []
#all_block_tokens = []
block_data = defaultdict(list) block_data = defaultdict(list)
all_block_logits = []
all_block_indices = []
my_rank = args.rank
block_file = open(f'block_data{my_rank}.pkl', 'wb')
i = 0 i = 0
while True: while True:
try: try:
...@@ -52,47 +47,42 @@ def main(): ...@@ -52,47 +47,42 @@ def main():
block_hash_full = torch.cat((block_hash_pos, -block_hash_pos), axis=1) block_hash_full = torch.cat((block_hash_pos, -block_hash_pos), axis=1)
block_hashes = torch.argmax(block_hash_full, axis=1).detach().cpu().numpy() block_hashes = torch.argmax(block_hash_full, axis=1).detach().cpu().numpy()
for hash, indices_array in zip(block_hashes, block_indices): for hash, indices_array in zip(block_hashes, block_indices):
hash_data[int(hash)].append(indices_array) hash_data[int(hash)].append(indices_array.detach().cpu().numpy())
#all_input_tokens.append(input_tokens.detach().cpu().numpy())
#all_input_logits.append(input_logits.detach().cpu().numpy())
#all_block_tokens.append(block_tokens.detach().cpu().numpy())
#all_block_logits.append(block_logits.detach().cpu().numpy())
#all_block_indices.append(block_indices.detach().cpu().numpy()[:, 3])
block_logits = block_logits.detach().cpu().numpy() block_logits = block_logits.detach().cpu().numpy()
block_indices = block_indices.detach().cpu().numpy()[:, 3] block_indices = block_indices.detach().cpu().numpy()[:, 3]
for logits, idx in zip(block_logits, block_indices): for logits, idx in zip(block_logits, block_indices):
pickle.dump({idx: logits}, block_file) block_data[int(idx)] = logits
if i == 100:
print(i)
if i % 100 == 0:
print(i, flush=True)
i += 1 i += 1
block_file.close() dir_name = 'block_hash_data'
#all_input_tokens = np.array(all_input_tokens).reshape(-1, args.seq_length) if not os.path.isdir(dir_name):
#all_input_logits = np.array(all_input_logits).reshape(-1, 128) os.mkdir(dir_name)
#all_block_tokens = np.array(all_block_tokens).reshape(-1, args.seq_length)
#all_block_logits = np.array(all_block_logits).reshape(-1, 128)
#all_block_indices = np.array(all_block_indices).reshape(all_block_logits.shape[0])
#for logits, idx in zip(all_block_logits, all_block_indices):
# block_data[idx] = logits
#with as block_file: with open('{}/{}.pkl'.format(dir_name, args.rank), 'wb') as data_file:
# pickle.dump(block_data, block_file) all_data = {'block_data': block_data, 'hash_data': hash_data}
pickle.dump(all_data, data_file)
#np.save(f'input_tokens{my_rank}.npy', all_input_tokens) torch.distributed.barrier()
#np.save(f'input_logits{my_rank}.npy', all_input_logits)
#np.save(f'block_tokens{my_rank}.npy', all_block_tokens)
#np.save(f'block_logits{my_rank}.npy', all_block_logits)
for hash, block_indices in hash_data.items(): if mpu.get_data_parallel_rank() == 0:
hash_data[hash] = np.array(block_indices) all_block_data = defaultdict(dict)
dir_name = 'block_hash_data'
fnames = os.listdir(dir_name)
for fname in fnames:
with open(fname, 'rb') as f:
data = pickle.load(f)
all_block_data['hash_data'].update(data['hash_data'])
all_block_data['block_data'].update(data['block_data'])
hash_data['matrix'] = hash_matrix with open('block_hash_data.pkl', 'wb') as final_file:
with open(f'hash_data{my_rank}.pkl', 'wb') as hash_file: pickle.dump(all_block_data, final_file)
pickle.dump(hash_data, hash_file)
os.rmdir(dir_name)
return
def load_checkpoint(): def load_checkpoint():
...@@ -163,4 +153,4 @@ def get_dataloader(dataset): ...@@ -163,4 +153,4 @@ def get_dataloader(dataset):
if __name__ == "__main__": if __name__ == "__main__":
main() embed_docs()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment