"vscode:/vscode.git/clone" did not exist on "1e651ca2c9f12bdcc5d63da8830847706e186f22"
Commit 81c71789 authored by Neel Kant's avatar Neel Kant
Browse files

Implement reformer hashing scheme

parent 8ba76558
from collections import defaultdict
import pickle
import numpy as np import numpy as np
import torch import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
...@@ -22,20 +25,33 @@ def main(): ...@@ -22,20 +25,33 @@ def main():
dataset = get_dataset() dataset = get_dataset()
data_iter = iter(get_dataloader(dataset)) data_iter = iter(get_dataloader(dataset))
hash_data = defaultdict(list)
hash_matrix = np.random.rand(128, 1024)
all_input_tokens = [] all_input_tokens = []
all_input_logits = [] all_input_logits = []
all_block_tokens = [] all_block_tokens = []
all_block_logits = [] all_block_logits = []
for i in range(100): while True:
input_tokens, input_types, input_pad_mask, block_tokens, block_token_types, block_pad_mask = get_batch(data_iter) try:
input_logits, doc_logits, _ = model.module.module.forward( input_tokens, input_types, input_pad_mask, \
block_tokens, block_token_types, block_pad_mask, block_indices = get_batch(data_iter)
except StopIteration:
break
input_logits, block_logits, _ = model.module.module.forward(
input_tokens, input_types, input_pad_mask, block_tokens, block_pad_mask, block_token_types, return_logits=True) input_tokens, input_types, input_pad_mask, block_tokens, block_pad_mask, block_token_types, return_logits=True)
block_hash_pos = torch.matmul(block_logits, hash_matrix)
block_hash_full = torch.concat((block_hash_pos, -block_hash_pos), axis=1)
block_hashes = torch.argmax(block_hash_full, axis=1)
for hash, idx in zip(block_hashes, block_indices):
hash_data[int(hash)].append(int(idx))
all_input_tokens.append(input_tokens.detach().cpu().numpy()) all_input_tokens.append(input_tokens.detach().cpu().numpy())
all_input_logits.append(input_logits.detach().cpu().numpy()) all_input_logits.append(input_logits.detach().cpu().numpy())
all_block_tokens.append(block_tokens.detach().cpu().numpy()) all_block_tokens.append(block_tokens.detach().cpu().numpy())
all_block_logits.append(doc_logits.detach().cpu().numpy()) all_block_logits.append(block_logits.detach().cpu().numpy())
all_input_tokens = np.array(all_input_tokens).reshape(-1, args.seq_length) all_input_tokens = np.array(all_input_tokens).reshape(-1, args.seq_length)
all_input_logits = np.array(all_input_logits).reshape(-1, 128) all_input_logits = np.array(all_input_logits).reshape(-1, 128)
...@@ -44,7 +60,14 @@ def main(): ...@@ -44,7 +60,14 @@ def main():
np.save('input_tokens.npy', all_input_tokens) np.save('input_tokens.npy', all_input_tokens)
np.save('input_logits.npy', all_input_logits) np.save('input_logits.npy', all_input_logits)
np.save('block_tokens.npy', all_block_tokens) np.save('block_tokens.npy', all_block_tokens)
np.save('doc_logits.npy', all_block_logits) np.save('block_logits.npy', all_block_logits)
for hash, block_indices in hash_data.items():
hash_data[hash] = np.array(block_indices)
hash_data['matrix'] = hash_matrix
with open('hash_data.pkl', 'wb') as hash_file:
pickle.dump(hash_data, hash_file)
def load_checkpoint(): def load_checkpoint():
...@@ -78,16 +101,13 @@ def get_dataset(): ...@@ -78,16 +101,13 @@ def get_dataset():
block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True) block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
titles_dataset = get_indexed_dataset_(args.data_path + '-titles', 'mmap', True) titles_dataset = get_indexed_dataset_(args.data_path + '-titles', 'mmap', True)
doc_idx_ptr = block_dataset.get_doc_idx()
total_num_documents = block_dataset.doc_idx.shape[0] - 1
block_dataset.set_doc_idx(doc_idx_ptr[0:total_num_documents])
kwargs = dict( kwargs = dict(
name='full', name='full',
context_dataset=block_dataset, context_dataset=block_dataset,
titles_dataset=titles_dataset, titles_dataset=titles_dataset,
data_prefix=args.data_path, data_prefix=args.data_path,
num_epochs=None, num_epochs=1,
max_num_samples=total_num_documents * 3, max_num_samples=None,
max_seq_length=288, # doesn't matter max_seq_length=288, # doesn't matter
short_seq_prob=0.0001, # doesn't matter short_seq_prob=0.0001, # doesn't matter
seed=1 seed=1
......
...@@ -363,6 +363,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -363,6 +363,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
// Current map index. // Current map index.
uint64_t map_index = 0; uint64_t map_index = 0;
int32_t block_id = 0;
// For each epoch: // For each epoch:
for (int32_t epoch=0; epoch<num_epochs; ++epoch) { for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
...@@ -425,14 +426,16 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -425,14 +426,16 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
// Populate the map. // Populate the map.
if (second) { if (second) {
const auto map_index_0 = 3 * map_index; const auto map_index_0 = 4 * map_index;
maps[map_index_0] = static_cast<DocIdx>(prev_start_index); maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1); maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
maps[map_index_0 + 2] = static_cast<DocIdx>(doc); maps[map_index_0 + 2] = static_cast<DocIdx>(doc);
maps[map_index_0 + 3] = static_cast<DocIdx>(block_id);
} }
// Update indices / counters. // Update indices / counters.
++map_index; ++map_index;
++block_id;
prev_start_index = sent_index + 1; prev_start_index = sent_index + 1;
seq_len = 0; seq_len = 0;
num_sent = 0; num_sent = 0;
...@@ -440,6 +443,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -440,6 +443,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
} // for (auto sent_index=sent_index_first; ... } // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) { } // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) { } // for (int doc=0; doc < num_docs; ++doc) {
block_id = 0;
} // for (int epoch=0; epoch < num_epochs; ++epoch) { } // for (int epoch=0; epoch < num_epochs; ++epoch) {
if (!second) { if (!second) {
...@@ -449,7 +453,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -449,7 +453,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
} }
assert(maps == NULL); assert(maps == NULL);
assert(num_samples < 0); assert(num_samples < 0);
maps = new DocIdx[3*map_index]; maps = new DocIdx[4*map_index];
num_samples = static_cast<int64_t>(map_index); num_samples = static_cast<int64_t>(map_index);
} }
...@@ -461,12 +465,13 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -461,12 +465,13 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
std::mt19937_64 rand64_gen(seed + 1); std::mt19937_64 rand64_gen(seed + 1);
for (auto i=(num_samples - 1); i > 0; --i) { for (auto i=(num_samples - 1); i > 0; --i) {
const auto j = static_cast<int64_t>(rand64_gen() % (i + 1)); const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
const auto i0 = 3 * i; const auto i0 = 4 * i;
const auto j0 = 3 * j; const auto j0 = 4 * j;
// Swap values. // Swap values.
swap(maps[i0], maps[j0]); swap(maps[i0], maps[j0]);
swap(maps[i0 + 1], maps[j0 + 1]); swap(maps[i0 + 1], maps[j0 + 1]);
swap(maps[i0 + 2], maps[j0 + 2]); swap(maps[i0 + 2], maps[j0 + 2]);
swap(maps[i0 + 3], maps[j0 + 3]);
} }
// Method to deallocate memory. // Method to deallocate memory.
...@@ -477,8 +482,8 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -477,8 +482,8 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
// Return the numpy array. // Return the numpy array.
const auto byte_size = sizeof(DocIdx); const auto byte_size = sizeof(DocIdx);
return py::array(std::vector<int64_t>{num_samples, 3}, // shape return py::array(std::vector<int64_t>{num_samples, 4}, // shape
{3*byte_size, byte_size}, // C-style contiguous strides {4*byte_size, byte_size}, // C-style contiguous strides
maps, // the data pointer maps, // the data pointer
free_when_done); // numpy array references free_when_done); // numpy array references
......
...@@ -47,7 +47,7 @@ class InverseClozeDataset(Dataset): ...@@ -47,7 +47,7 @@ class InverseClozeDataset(Dataset):
return self.samples_mapping.shape[0] return self.samples_mapping.shape[0]
def __getitem__(self, idx): def __getitem__(self, idx):
start_idx, end_idx, doc_idx = self.samples_mapping[idx] start_idx, end_idx, doc_idx, block_idx = self.samples_mapping[idx]
title = list(self.titles_dataset[int(doc_idx)]) title = list(self.titles_dataset[int(doc_idx)])
context = [list(self.context_dataset[i]) for i in range(start_idx, end_idx)] context = [list(self.context_dataset[i]) for i in range(start_idx, end_idx)]
assert len(context) > 1 assert len(context) > 1
...@@ -78,7 +78,8 @@ class InverseClozeDataset(Dataset): ...@@ -78,7 +78,8 @@ class InverseClozeDataset(Dataset):
'input_pad_mask': np.array(input_pad_mask), 'input_pad_mask': np.array(input_pad_mask),
'context_text': np.array(context_tokens), 'context_text': np.array(context_tokens),
'context_types': np.array(context_token_types), 'context_types': np.array(context_token_types),
'context_pad_mask': np.array(context_pad_mask) 'context_pad_mask': np.array(context_pad_mask),
'context_indices': np.array([block_idx])
} }
return sample return sample
......
...@@ -64,9 +64,10 @@ def get_batch(data_iterator): ...@@ -64,9 +64,10 @@ def get_batch(data_iterator):
context_tokens = data_b['context_text'].long() context_tokens = data_b['context_text'].long()
context_types = data_b['context_types'].long() context_types = data_b['context_types'].long()
context_pad_mask = data_b['context_pad_mask'].long() context_pad_mask = data_b['context_pad_mask'].long()
context_indices = data_b['context_indices'].long()
return input_tokens, input_types, input_pad_mask,\ return input_tokens, input_types, input_pad_mask,\
context_tokens, context_types, context_pad_mask context_tokens, context_types, context_pad_mask, context_indices
def forward_step(data_iterator, model): def forward_step(data_iterator, model):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment