Commit f3d2426e authored by Neel Kant's avatar Neel Kant
Browse files

Embedding and hashing docs script works

parent 662dc982
......@@ -33,6 +33,7 @@ def main():
all_block_tokens = []
all_block_logits = []
my_rank = args.rank
i = 0
while True:
try:
......@@ -40,6 +41,8 @@ def main():
block_tokens, block_token_types, block_pad_mask, block_indices = get_batch(data_iter)
except StopIteration:
break
# TODO: make sure input is still in block
input_logits, block_logits, _ = model.module.module.forward(
input_tokens, input_types, input_pad_mask, block_tokens, block_pad_mask, block_token_types, return_logits=True)
......@@ -54,29 +57,29 @@ def main():
all_block_tokens.append(block_tokens.detach().cpu().numpy())
all_block_logits.append(block_logits.detach().cpu().numpy())
if i % 100 == 0:
if i % 10 == 0:
print(i, flush=True)
print(len(all_block_tokens), flush=True)
print(block_tokens.shape, flush=True)
i += 1
print(block_tokens[0])
if i == 10:
if i == 100:
break
i += 1
all_input_tokens = np.array(all_input_tokens).reshape(-1, args.seq_length)
all_input_logits = np.array(all_input_logits).reshape(-1, 128)
all_block_tokens = np.array(all_block_tokens).reshape(-1, args.seq_length)
all_block_logits = np.array(all_block_logits).reshape(-1, 128)
np.save('input_tokens.npy', all_input_tokens)
np.save('input_logits.npy', all_input_logits)
np.save('block_tokens.npy', all_block_tokens)
np.save('block_logits.npy', all_block_logits)
np.save(f'input_tokens{my_rank}.npy', all_input_tokens)
np.save(f'input_logits{my_rank}.npy', all_input_logits)
np.save(f'block_tokens{my_rank}.npy', all_block_tokens)
np.save(f'block_logits{my_rank}.npy', all_block_logits)
for hash, block_indices in hash_data.items():
hash_data[hash] = np.array(block_indices)
hash_data['matrix'] = hash_matrix
with open('hash_data.pkl', 'wb') as hash_file:
with open(f'hash_data{my_rank}.pkl', 'wb') as hash_file:
pickle.dump(hash_data, hash_file)
......
......@@ -59,7 +59,7 @@ class InverseClozeDataset(Dataset):
rand_sent_idx = self.rng.randint(1, len(context) - 2)
# keep the query in the context 10% of the time.
if self.rng.random() < 0.1:
if self.rng.random() < 1:
input = context[rand_sent_idx].copy()
else:
input = context.pop(rand_sent_idx)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment