Commit 81c71789 authored by Neel Kant's avatar Neel Kant
Browse files

Implement reformer hashing scheme

parent 8ba76558
from collections import defaultdict
import pickle
import numpy as np
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
......@@ -22,20 +25,33 @@ def main():
dataset = get_dataset()
data_iter = iter(get_dataloader(dataset))
hash_data = defaultdict(list)
hash_matrix = np.random.rand(128, 1024)
all_input_tokens = []
all_input_logits = []
all_block_tokens = []
all_block_logits = []
for i in range(100):
input_tokens, input_types, input_pad_mask, block_tokens, block_token_types, block_pad_mask = get_batch(data_iter)
input_logits, doc_logits, _ = model.module.module.forward(
while True:
try:
input_tokens, input_types, input_pad_mask, \
block_tokens, block_token_types, block_pad_mask, block_indices = get_batch(data_iter)
except StopIteration:
break
input_logits, block_logits, _ = model.module.module.forward(
input_tokens, input_types, input_pad_mask, block_tokens, block_pad_mask, block_token_types, return_logits=True)
block_hash_pos = torch.matmul(block_logits, hash_matrix)
block_hash_full = torch.concat((block_hash_pos, -block_hash_pos), axis=1)
block_hashes = torch.argmax(block_hash_full, axis=1)
for hash, idx in zip(block_hashes, block_indices):
hash_data[int(hash)].append(int(idx))
all_input_tokens.append(input_tokens.detach().cpu().numpy())
all_input_logits.append(input_logits.detach().cpu().numpy())
all_block_tokens.append(block_tokens.detach().cpu().numpy())
all_block_logits.append(doc_logits.detach().cpu().numpy())
all_block_logits.append(block_logits.detach().cpu().numpy())
all_input_tokens = np.array(all_input_tokens).reshape(-1, args.seq_length)
all_input_logits = np.array(all_input_logits).reshape(-1, 128)
......@@ -44,7 +60,14 @@ def main():
np.save('input_tokens.npy', all_input_tokens)
np.save('input_logits.npy', all_input_logits)
np.save('block_tokens.npy', all_block_tokens)
np.save('doc_logits.npy', all_block_logits)
np.save('block_logits.npy', all_block_logits)
for hash, block_indices in hash_data.items():
hash_data[hash] = np.array(block_indices)
hash_data['matrix'] = hash_matrix
with open('hash_data.pkl', 'wb') as hash_file:
pickle.dump(hash_data, hash_file)
def load_checkpoint():
......@@ -78,16 +101,13 @@ def get_dataset():
block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
titles_dataset = get_indexed_dataset_(args.data_path + '-titles', 'mmap', True)
doc_idx_ptr = block_dataset.get_doc_idx()
total_num_documents = block_dataset.doc_idx.shape[0] - 1
block_dataset.set_doc_idx(doc_idx_ptr[0:total_num_documents])
kwargs = dict(
name='full',
context_dataset=block_dataset,
titles_dataset=titles_dataset,
data_prefix=args.data_path,
num_epochs=None,
max_num_samples=total_num_documents * 3,
num_epochs=1,
max_num_samples=None,
max_seq_length=288, # doesn't matter
short_seq_prob=0.0001, # doesn't matter
seed=1
......
......@@ -363,6 +363,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
// Current map index.
uint64_t map_index = 0;
int32_t block_id = 0;
// For each epoch:
for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
......@@ -425,14 +426,16 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
// Populate the map.
if (second) {
const auto map_index_0 = 3 * map_index;
const auto map_index_0 = 4 * map_index;
maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
maps[map_index_0 + 2] = static_cast<DocIdx>(doc);
maps[map_index_0 + 3] = static_cast<DocIdx>(block_id);
}
// Update indices / counters.
++map_index;
++block_id;
prev_start_index = sent_index + 1;
seq_len = 0;
num_sent = 0;
......@@ -440,6 +443,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
} // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
block_id = 0;
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
if (!second) {
......@@ -449,7 +453,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
}
assert(maps == NULL);
assert(num_samples < 0);
maps = new DocIdx[3*map_index];
maps = new DocIdx[4*map_index];
num_samples = static_cast<int64_t>(map_index);
}
......@@ -461,12 +465,13 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
std::mt19937_64 rand64_gen(seed + 1);
for (auto i=(num_samples - 1); i > 0; --i) {
const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
const auto i0 = 3 * i;
const auto j0 = 3 * j;
const auto i0 = 4 * i;
const auto j0 = 4 * j;
// Swap values.
swap(maps[i0], maps[j0]);
swap(maps[i0 + 1], maps[j0 + 1]);
swap(maps[i0 + 2], maps[j0 + 2]);
swap(maps[i0 + 3], maps[j0 + 3]);
}
// Method to deallocate memory.
......@@ -477,8 +482,8 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
// Return the numpy array.
const auto byte_size = sizeof(DocIdx);
return py::array(std::vector<int64_t>{num_samples, 3}, // shape
{3*byte_size, byte_size}, // C-style contiguous strides
return py::array(std::vector<int64_t>{num_samples, 4}, // shape
{4*byte_size, byte_size}, // C-style contiguous strides
maps, // the data pointer
free_when_done); // numpy array references
......
......@@ -47,7 +47,7 @@ class InverseClozeDataset(Dataset):
return self.samples_mapping.shape[0]
def __getitem__(self, idx):
start_idx, end_idx, doc_idx = self.samples_mapping[idx]
start_idx, end_idx, doc_idx, block_idx = self.samples_mapping[idx]
title = list(self.titles_dataset[int(doc_idx)])
context = [list(self.context_dataset[i]) for i in range(start_idx, end_idx)]
assert len(context) > 1
......@@ -78,7 +78,8 @@ class InverseClozeDataset(Dataset):
'input_pad_mask': np.array(input_pad_mask),
'context_text': np.array(context_tokens),
'context_types': np.array(context_token_types),
'context_pad_mask': np.array(context_pad_mask)
'context_pad_mask': np.array(context_pad_mask),
'context_indices': np.array([block_idx])
}
return sample
......
......@@ -64,9 +64,10 @@ def get_batch(data_iterator):
context_tokens = data_b['context_text'].long()
context_types = data_b['context_types'].long()
context_pad_mask = data_b['context_pad_mask'].long()
context_indices = data_b['context_indices'].long()
return input_tokens, input_types, input_pad_mask,\
context_tokens, context_types, context_pad_mask
context_tokens, context_types, context_pad_mask, context_indices
def forward_step(data_iterator, model):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment