Commit 1a3f5663 authored by Neel Kant's avatar Neel Kant
Browse files

Rename variables

parent 360885ee
...@@ -24,27 +24,27 @@ def main(): ...@@ -24,27 +24,27 @@ def main():
all_input_tokens = [] all_input_tokens = []
all_input_logits = [] all_input_logits = []
all_doc_tokens = [] all_block_tokens = []
all_doc_logits = [] all_block_logits = []
for i in range(100): for i in range(100):
input_tokens, input_types, input_pad_mask, doc_tokens, doc_token_types, doc_pad_mask = get_batch(data_iter) input_tokens, input_types, input_pad_mask, block_tokens, block_token_types, block_pad_mask = get_batch(data_iter)
input_logits, doc_logits, _ = model.module.module.forward( input_logits, doc_logits, _ = model.module.module.forward(
input_tokens, input_types, input_pad_mask, doc_tokens, doc_pad_mask, doc_token_types, return_logits=True) input_tokens, input_types, input_pad_mask, block_tokens, block_pad_mask, block_token_types, return_logits=True)
all_input_tokens.append(input_tokens.detach().cpu().numpy()) all_input_tokens.append(input_tokens.detach().cpu().numpy())
all_input_logits.append(input_logits.detach().cpu().numpy()) all_input_logits.append(input_logits.detach().cpu().numpy())
all_doc_tokens.append(doc_tokens.detach().cpu().numpy()) all_block_tokens.append(block_tokens.detach().cpu().numpy())
all_doc_logits.append(doc_logits.detach().cpu().numpy()) all_block_logits.append(doc_logits.detach().cpu().numpy())
all_inputs_tokens = np.array(all_input_tokens).reshape(-1, args.seq_length) all_input_tokens = np.array(all_input_tokens).reshape(-1, args.seq_length)
all_inputs_logits = np.array(all_input_logits).reshape(-1, 128) all_input_logits = np.array(all_input_logits).reshape(-1, 128)
all_doc_tokens = np.array(all_doc_tokens).reshape(-1, args.seq_length) all_block_tokens = np.array(all_block_tokens).reshape(-1, args.seq_length)
all_doc_logits = np.array(all_doc_logits).reshape(-1, 128) all_block_logits = np.array(all_block_logits).reshape(-1, 128)
np.save('input_tokens.npy', all_input_tokens) np.save('input_tokens.npy', all_input_tokens)
np.save('input_logits.npy', all_input_logits) np.save('input_logits.npy', all_input_logits)
np.save('doc_tokens.npy', all_doc_tokens) np.save('block_tokens.npy', all_block_tokens)
np.save('doc_logits.npy', all_doc_logits) np.save('doc_logits.npy', all_block_logits)
def load_checkpoint(): def load_checkpoint():
...@@ -75,17 +75,19 @@ def load_checkpoint(): ...@@ -75,17 +75,19 @@ def load_checkpoint():
def get_dataset(): def get_dataset():
args = get_args() args = get_args()
indexed_dataset = get_indexed_dataset_(args.data_path, 'mmap', True) block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
titles_dataset = get_indexed_dataset_(args.data_path + '-titles', 'mmap', True)
doc_idx_ptr = indexed_dataset.get_doc_idx() doc_idx_ptr = block_dataset.get_doc_idx()
total_num_documents = indexed_dataset.doc_idx.shape[0] - 1 total_num_documents = block_dataset.doc_idx.shape[0] - 1
indexed_dataset.set_doc_idx(doc_idx_ptr[0:total_num_documents]) block_dataset.set_doc_idx(doc_idx_ptr[0:total_num_documents])
kwargs = dict( kwargs = dict(
name='full', name='full',
indexed_dataset=indexed_dataset, context_dataset=block_dataset,
titles_dataset=titles_dataset,
data_prefix=args.data_path, data_prefix=args.data_path,
num_epochs=None, num_epochs=None,
max_num_samples=total_num_documents, max_num_samples=total_num_documents * 3,
max_seq_length=288, # doesn't matter max_seq_length=288, # doesn't matter
short_seq_prob=0.0001, # doesn't matter short_seq_prob=0.0001, # doesn't matter
seed=1 seed=1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment