"src/vscode:/vscode.git/clone" did not exist on "fa3c86beaf04e297d4b0e824692e3bd4edfb5f22"
Commit 1a3f5663 authored by Neel Kant's avatar Neel Kant
Browse files

Rename variables

parent 360885ee
......@@ -24,27 +24,27 @@ def main():
all_input_tokens = []
all_input_logits = []
all_doc_tokens = []
all_doc_logits = []
all_block_tokens = []
all_block_logits = []
for i in range(100):
input_tokens, input_types, input_pad_mask, doc_tokens, doc_token_types, doc_pad_mask = get_batch(data_iter)
input_tokens, input_types, input_pad_mask, block_tokens, block_token_types, block_pad_mask = get_batch(data_iter)
input_logits, doc_logits, _ = model.module.module.forward(
input_tokens, input_types, input_pad_mask, doc_tokens, doc_pad_mask, doc_token_types, return_logits=True)
input_tokens, input_types, input_pad_mask, block_tokens, block_pad_mask, block_token_types, return_logits=True)
all_input_tokens.append(input_tokens.detach().cpu().numpy())
all_input_logits.append(input_logits.detach().cpu().numpy())
all_doc_tokens.append(doc_tokens.detach().cpu().numpy())
all_doc_logits.append(doc_logits.detach().cpu().numpy())
all_block_tokens.append(block_tokens.detach().cpu().numpy())
all_block_logits.append(doc_logits.detach().cpu().numpy())
all_inputs_tokens = np.array(all_input_tokens).reshape(-1, args.seq_length)
all_inputs_logits = np.array(all_input_logits).reshape(-1, 128)
all_doc_tokens = np.array(all_doc_tokens).reshape(-1, args.seq_length)
all_doc_logits = np.array(all_doc_logits).reshape(-1, 128)
all_input_tokens = np.array(all_input_tokens).reshape(-1, args.seq_length)
all_input_logits = np.array(all_input_logits).reshape(-1, 128)
all_block_tokens = np.array(all_block_tokens).reshape(-1, args.seq_length)
all_block_logits = np.array(all_block_logits).reshape(-1, 128)
np.save('input_tokens.npy', all_input_tokens)
np.save('input_logits.npy', all_input_logits)
np.save('doc_tokens.npy', all_doc_tokens)
np.save('doc_logits.npy', all_doc_logits)
np.save('block_tokens.npy', all_block_tokens)
np.save('doc_logits.npy', all_block_logits)
def load_checkpoint():
......@@ -75,17 +75,19 @@ def load_checkpoint():
def get_dataset():
args = get_args()
indexed_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
titles_dataset = get_indexed_dataset_(args.data_path + '-titles', 'mmap', True)
doc_idx_ptr = indexed_dataset.get_doc_idx()
total_num_documents = indexed_dataset.doc_idx.shape[0] - 1
indexed_dataset.set_doc_idx(doc_idx_ptr[0:total_num_documents])
doc_idx_ptr = block_dataset.get_doc_idx()
total_num_documents = block_dataset.doc_idx.shape[0] - 1
block_dataset.set_doc_idx(doc_idx_ptr[0:total_num_documents])
kwargs = dict(
name='full',
indexed_dataset=indexed_dataset,
context_dataset=block_dataset,
titles_dataset=titles_dataset,
data_prefix=args.data_path,
num_epochs=None,
max_num_samples=total_num_documents,
max_num_samples=total_num_documents * 3,
max_seq_length=288, # doesn't matter
short_seq_prob=0.0001, # doesn't matter
seed=1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment