"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "c0e7a3b90ec64101ee2999f200f71f019d447f77"
Commit f3d2426e authored by Neel Kant's avatar Neel Kant
Browse files

Embedding and hashing docs script works

parent 662dc982
...@@ -33,6 +33,7 @@ def main(): ...@@ -33,6 +33,7 @@ def main():
all_block_tokens = [] all_block_tokens = []
all_block_logits = [] all_block_logits = []
my_rank = args.rank
i = 0 i = 0
while True: while True:
try: try:
...@@ -40,6 +41,8 @@ def main(): ...@@ -40,6 +41,8 @@ def main():
block_tokens, block_token_types, block_pad_mask, block_indices = get_batch(data_iter) block_tokens, block_token_types, block_pad_mask, block_indices = get_batch(data_iter)
except StopIteration: except StopIteration:
break break
# TODO: make sure input is still in block
input_logits, block_logits, _ = model.module.module.forward( input_logits, block_logits, _ = model.module.module.forward(
input_tokens, input_types, input_pad_mask, block_tokens, block_pad_mask, block_token_types, return_logits=True) input_tokens, input_types, input_pad_mask, block_tokens, block_pad_mask, block_token_types, return_logits=True)
...@@ -54,29 +57,29 @@ def main(): ...@@ -54,29 +57,29 @@ def main():
all_block_tokens.append(block_tokens.detach().cpu().numpy()) all_block_tokens.append(block_tokens.detach().cpu().numpy())
all_block_logits.append(block_logits.detach().cpu().numpy()) all_block_logits.append(block_logits.detach().cpu().numpy())
if i % 100 == 0: if i % 10 == 0:
print(i, flush=True) print(i, flush=True)
print(len(all_block_tokens), flush=True) print(block_tokens[0])
print(block_tokens.shape, flush=True)
i += 1
if i == 10: if i == 100:
break break
i += 1
all_input_tokens = np.array(all_input_tokens).reshape(-1, args.seq_length) all_input_tokens = np.array(all_input_tokens).reshape(-1, args.seq_length)
all_input_logits = np.array(all_input_logits).reshape(-1, 128) all_input_logits = np.array(all_input_logits).reshape(-1, 128)
all_block_tokens = np.array(all_block_tokens).reshape(-1, args.seq_length) all_block_tokens = np.array(all_block_tokens).reshape(-1, args.seq_length)
all_block_logits = np.array(all_block_logits).reshape(-1, 128) all_block_logits = np.array(all_block_logits).reshape(-1, 128)
np.save('input_tokens.npy', all_input_tokens) np.save(f'input_tokens{my_rank}.npy', all_input_tokens)
np.save('input_logits.npy', all_input_logits) np.save(f'input_logits{my_rank}.npy', all_input_logits)
np.save('block_tokens.npy', all_block_tokens) np.save(f'block_tokens{my_rank}.npy', all_block_tokens)
np.save('block_logits.npy', all_block_logits) np.save(f'block_logits{my_rank}.npy', all_block_logits)
for hash, block_indices in hash_data.items(): for hash, block_indices in hash_data.items():
hash_data[hash] = np.array(block_indices) hash_data[hash] = np.array(block_indices)
hash_data['matrix'] = hash_matrix hash_data['matrix'] = hash_matrix
with open('hash_data.pkl', 'wb') as hash_file: with open(f'hash_data{my_rank}.pkl', 'wb') as hash_file:
pickle.dump(hash_data, hash_file) pickle.dump(hash_data, hash_file)
......
...@@ -59,7 +59,7 @@ class InverseClozeDataset(Dataset): ...@@ -59,7 +59,7 @@ class InverseClozeDataset(Dataset):
rand_sent_idx = self.rng.randint(1, len(context) - 2) rand_sent_idx = self.rng.randint(1, len(context) - 2)
# keep the query in the context 10% of the time. # keep the query in the context 10% of the time.
if self.rng.random() < 0.1: if self.rng.random() < 1:
input = context[rand_sent_idx].copy() input = context[rand_sent_idx].copy()
else: else:
input = context.pop(rand_sent_idx) input = context.pop(rand_sent_idx)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment