Commit 9e95338b authored by Neel Kant's avatar Neel Kant
Browse files

Organize block embed data by block_idx

parent e6f2720d
...@@ -28,11 +28,12 @@ def main(): ...@@ -28,11 +28,12 @@ def main():
hash_data = defaultdict(list) hash_data = defaultdict(list)
hash_matrix = torch.cuda.HalfTensor(np.random.rand(128, 1024)) hash_matrix = torch.cuda.HalfTensor(np.random.rand(128, 1024))
all_input_tokens = [] #all_input_tokens = []
all_input_logits = [] #all_input_logits = []
all_block_tokens = [] #all_block_tokens = []
block_data = defaultdict(list)
all_block_logits = [] all_block_logits = []
all_block_indices = []
my_rank = args.rank my_rank = args.rank
i = 0 i = 0
while True: while True:
...@@ -52,24 +53,32 @@ def main(): ...@@ -52,24 +53,32 @@ def main():
for hash, indices_array in zip(block_hashes, block_indices): for hash, indices_array in zip(block_hashes, block_indices):
hash_data[int(hash)].append(indicecs_array) hash_data[int(hash)].append(indicecs_array)
all_input_tokens.append(input_tokens.detach().cpu().numpy()) #all_input_tokens.append(input_tokens.detach().cpu().numpy())
all_input_logits.append(input_logits.detach().cpu().numpy()) #all_input_logits.append(input_logits.detach().cpu().numpy())
all_block_tokens.append(block_tokens.detach().cpu().numpy()) #all_block_tokens.append(block_tokens.detach().cpu().numpy())
all_block_logits.append(block_logits.detach().cpu().numpy())
all_block_logits.append(block_logits.detach().cpu().numpy())
all_block_indices.append(block_indices.detach().cpu().numpy()[:, 3])
if i == 1000: if i == 1000:
print(i) print(i)
i += 1 i += 1
all_input_tokens = np.array(all_input_tokens).reshape(-1, args.seq_length) #all_input_tokens = np.array(all_input_tokens).reshape(-1, args.seq_length)
all_input_logits = np.array(all_input_logits).reshape(-1, 128) #all_input_logits = np.array(all_input_logits).reshape(-1, 128)
all_block_tokens = np.array(all_block_tokens).reshape(-1, args.seq_length) #all_block_tokens = np.array(all_block_tokens).reshape(-1, args.seq_length)
all_block_logits = np.array(all_block_logits).reshape(-1, 128) all_block_logits = np.array(all_block_logits).reshape(-1, 128)
np.save(f'input_tokens{my_rank}.npy', all_input_tokens) all_block_indices = np.array(all_block_indices).reshape(all_block_logits.shape[0])
np.save(f'input_logits{my_rank}.npy', all_input_logits) for logits, idx in zip(all_block_logits, all_block_indices):
np.save(f'block_tokens{my_rank}.npy', all_block_tokens) block_data[idx] = logits
np.save(f'block_logits{my_rank}.npy', all_block_logits)
with open(f'block_data{my_rank}.pkl', 'wb') as block_file:
pickle.dump(block_data, block_file)
#np.save(f'input_tokens{my_rank}.npy', all_input_tokens)
#np.save(f'input_logits{my_rank}.npy', all_input_logits)
#np.save(f'block_tokens{my_rank}.npy', all_block_tokens)
#np.save(f'block_logits{my_rank}.npy', all_block_logits)
for hash, block_indices in hash_data.items(): for hash, block_indices in hash_data.items():
hash_data[hash] = np.array(block_indices) hash_data[hash] = np.array(block_indices)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment