"examples/sampling/vscode:/vscode.git/clone" did not exist on "2f47c2413eb0764066651c56a9727f456b7794bf"
Commit 4170fc69 authored by Neel Kant's avatar Neel Kant
Browse files

Merge hashing into realm-mlm

parents ee2490d5 43d5d84b
from collections import defaultdict
import pickle
import numpy as np
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args
from megatron import mpu
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.data.bert_dataset import get_indexed_dataset_
from megatron.data.ict_dataset import InverseClozeDataset
from megatron.data.samplers import DistributedBatchSampler
from megatron.initialize import initialize_megatron
from megatron.training import get_model
from pretrain_bert_ict import get_batch, model_provider
def main():
initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
args = get_args()
model = load_checkpoint()
model.eval()
dataset = get_dataset()
data_iter = iter(get_dataloader(dataset))
hash_data = defaultdict(list)
hash_matrix = torch.cuda.HalfTensor(np.random.rand(128, 1024))
#all_input_tokens = []
#all_input_logits = []
#all_block_tokens = []
block_data = defaultdict(list)
all_block_logits = []
all_block_indices = []
my_rank = args.rank
block_file = open(f'block_data{my_rank}.pkl', 'wb')
i = 0
while True:
try:
input_tokens, input_types, input_pad_mask, \
block_tokens, block_token_types, block_pad_mask, block_indices = get_batch(data_iter)
except:
break
# TODO: make sure input is still in block
input_logits, block_logits, _ = model.module.module.forward(
input_tokens, input_types, input_pad_mask, block_tokens, block_pad_mask, block_token_types, return_logits=True)
block_hash_pos = torch.matmul(block_logits, hash_matrix)
block_hash_full = torch.cat((block_hash_pos, -block_hash_pos), axis=1)
block_hashes = torch.argmax(block_hash_full, axis=1).detach().cpu().numpy()
for hash, indices_array in zip(block_hashes, block_indices):
hash_data[int(hash)].append(indices_array)
#all_input_tokens.append(input_tokens.detach().cpu().numpy())
#all_input_logits.append(input_logits.detach().cpu().numpy())
#all_block_tokens.append(block_tokens.detach().cpu().numpy())
#all_block_logits.append(block_logits.detach().cpu().numpy())
#all_block_indices.append(block_indices.detach().cpu().numpy()[:, 3])
block_logits = block_logits.detach().cpu().numpy()
block_indices = block_indices.detach().cpu().numpy()[:, 3]
for logits, idx in zip(block_logits, block_indices):
pickle.dump({idx: logits}, block_file)
if i == 100:
print(i)
i += 1
block_file.close()
#all_input_tokens = np.array(all_input_tokens).reshape(-1, args.seq_length)
#all_input_logits = np.array(all_input_logits).reshape(-1, 128)
#all_block_tokens = np.array(all_block_tokens).reshape(-1, args.seq_length)
#all_block_logits = np.array(all_block_logits).reshape(-1, 128)
#all_block_indices = np.array(all_block_indices).reshape(all_block_logits.shape[0])
#for logits, idx in zip(all_block_logits, all_block_indices):
# block_data[idx] = logits
#with as block_file:
# pickle.dump(block_data, block_file)
#np.save(f'input_tokens{my_rank}.npy', all_input_tokens)
#np.save(f'input_logits{my_rank}.npy', all_input_logits)
#np.save(f'block_tokens{my_rank}.npy', all_block_tokens)
#np.save(f'block_logits{my_rank}.npy', all_block_logits)
for hash, block_indices in hash_data.items():
hash_data[hash] = np.array(block_indices)
hash_data['matrix'] = hash_matrix
with open(f'hash_data{my_rank}.pkl', 'wb') as hash_file:
pickle.dump(hash_data, hash_file)
def load_checkpoint():
args = get_args()
model = get_model(model_provider)
if isinstance(model, torchDDP):
model = model.module
tracker_filename = get_checkpoint_tracker_filename(args.load)
with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip())
assert iteration > 0
checkpoint_name = get_checkpoint_name(args.load, iteration, False)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
state_dict = torch.load(checkpoint_name, map_location='cpu')
model.load_state_dict(state_dict['model'])
torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
return model
def get_dataset():
args = get_args()
block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
titles_dataset = get_indexed_dataset_(args.data_path + '-titles', 'mmap', True)
kwargs = dict(
name='full',
context_dataset=block_dataset,
titles_dataset=titles_dataset,
data_prefix=args.data_path,
num_epochs=1,
max_num_samples=None,
max_seq_length=288, # doesn't matter
short_seq_prob=0.0001, # doesn't matter
seed=1
)
dataset = InverseClozeDataset(**kwargs)
return dataset
def get_dataloader(dataset):
args = get_args()
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
global_batch_size = args.batch_size * world_size
num_workers = args.num_workers
sampler = torch.utils.data.SequentialSampler(dataset)
batch_sampler = DistributedBatchSampler(sampler,
batch_size=global_batch_size,
drop_last=True,
rank=rank,
world_size=world_size)
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
if __name__ == "__main__":
main()
...@@ -41,6 +41,7 @@ class InverseClozeDataset(Dataset): ...@@ -41,6 +41,7 @@ class InverseClozeDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
start_idx, end_idx, doc_idx, block_idx = self.samples_mapping[idx] start_idx, end_idx, doc_idx, block_idx = self.samples_mapping[idx]
title = list(self.title_dataset[int(doc_idx)])
block = [list(self.block_dataset[i]) for i in range(start_idx, end_idx)] block = [list(self.block_dataset[i]) for i in range(start_idx, end_idx)]
assert len(block) > 1 assert len(block) > 1
...@@ -50,8 +51,8 @@ class InverseClozeDataset(Dataset): ...@@ -50,8 +51,8 @@ class InverseClozeDataset(Dataset):
else: else:
rand_sent_idx = self.rng.randint(1, len(block) - 2) rand_sent_idx = self.rng.randint(1, len(block) - 2)
# keep the query in the block 10% of the time. # keep the query in the context 10% of the time.
if self.rng.random() < 0.1: if self.rng.random() < 1:
query = block[rand_sent_idx].copy() query = block[rand_sent_idx].copy()
else: else:
query = block.pop(rand_sent_idx) query = block.pop(rand_sent_idx)
...@@ -71,7 +72,7 @@ class InverseClozeDataset(Dataset): ...@@ -71,7 +72,7 @@ class InverseClozeDataset(Dataset):
'block_tokens': np.array(block_tokens), 'block_tokens': np.array(block_tokens),
'block_types': np.array(block_token_types), 'block_types': np.array(block_token_types),
'block_pad_mask': np.array(block_pad_mask), 'block_pad_mask': np.array(block_pad_mask),
'block_indices': np.array([start_idx, end_idx, doc_idx, block_idx]) 'block_indices': np.array([start_idx, end_idx, doc_idx, block_idx]).astype(np.int64)
} }
return sample return sample
......
...@@ -46,7 +46,7 @@ def get_batch(data_iterator): ...@@ -46,7 +46,7 @@ def get_batch(data_iterator):
# Items and their type. # Items and their type.
keys = ['query_tokens', 'query_types', 'query_pad_mask', keys = ['query_tokens', 'query_types', 'query_pad_mask',
'block_tokens', 'block_types', 'block_pad_mask'] 'block_tokens', 'block_types', 'block_pad_mask', 'block_indices']
datatype = torch.int64 datatype = torch.int64
# Broadcast data. # Broadcast data.
...@@ -63,9 +63,10 @@ def get_batch(data_iterator): ...@@ -63,9 +63,10 @@ def get_batch(data_iterator):
block_tokens = data_b['block_tokens'].long() block_tokens = data_b['block_tokens'].long()
block_types = data_b['block_types'].long() block_types = data_b['block_types'].long()
block_pad_mask = data_b['block_pad_mask'].long() block_pad_mask = data_b['block_pad_mask'].long()
block_indices = data_b['block_indices'].long()
return query_tokens, query_types, query_pad_mask,\ return query_tokens, query_types, query_pad_mask,\
block_tokens, block_types, block_pad_mask block_tokens, block_types, block_pad_mask, block_indices
def forward_step(data_iterator, model): def forward_step(data_iterator, model):
...@@ -75,7 +76,7 @@ def forward_step(data_iterator, model): ...@@ -75,7 +76,7 @@ def forward_step(data_iterator, model):
# Get the batch. # Get the batch.
timers('batch generator').start() timers('batch generator').start()
query_tokens, query_types, query_pad_mask,\ query_tokens, query_types, query_pad_mask,\
block_tokens, block_types, block_pad_mask = get_batch(data_iterator) block_tokens, block_types, block_pad_mask, block_indices = get_batch(data_iterator)
timers('batch generator').stop() timers('batch generator').stop()
# Forward model. # Forward model.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment