Commit 9d225b44 authored by Neel Kant's avatar Neel Kant
Browse files

Whitening code

parent 5e56e563
......@@ -25,18 +25,23 @@ def detach(tensor):
class HashedIndex(object):
"""Class for holding hashed data"""
def __init__(self, embed_size, num_buckets, seed=0):
def __init__(self, embed_size, num_buckets, whiten=False, seed=0):
np.random.seed(seed)
self.block_data = defaultdict(list)
self.hash_data = defaultdict(list)
hash_matrix = np.random.rand(embed_size, int(num_buckets / 2))
self.hash_matrix = hash_matrix / np.linalg.norm(hash_matrix, axis=0).reshape(1, -1)
self.embed_mean = None
self.embed_whitener = None
self.whiten = whiten
def state(self):
state = {
'block_data': self.block_data,
'hash_data': self.hash_data,
'hash_matrix': self.hash_matrix
'hash_matrix': self.hash_matrix,
'embed_mean': self.embed_mean,
'embed_whitener': self.embed_whitener,
}
return state
......@@ -79,8 +84,6 @@ class HashedIndex(object):
dir_name = 'block_hash_data'
fnames = os.listdir(dir_name)
for fname in fnames:
if str(ignore_shard) in fname:
continue
with open('{}/{}'.format(dir_name, fname), 'rb') as f:
data = pickle.load(f)
assert np.array_equal(data['hash_matrix'], self.hash_matrix)
......@@ -88,10 +91,14 @@ class HashedIndex(object):
old_size = len(self.block_data)
shard_size = len(data['block_data'])
self.block_data.update(data['block_data'])
assert len(self.block_data) == old_size + shard_size, (old_size, shard_size, len(self.block_data))
assert (len(self.block_data) == old_size + shard_size) or (str(ignore_shard) in fname)
for bucket, items in data['hash_data'].items():
self.hash_data[bucket].extend(items)
if not self.whiten:
for bucket, items in data['hash_data'].items():
self.hash_data[bucket].extend(items)
if self.whiten:
self.whiten_block_embeds()
args = get_args()
with open(args.hash_data_path, 'wb') as final_file:
......@@ -100,8 +107,43 @@ class HashedIndex(object):
def clear(self):
"""Clear the data structures to save memory"""
self.block_data = defaultdict(list)
self.block_data = dict()
self.hash_data = defaultdict(list)
def whiten_block_embeds(self):
"""Transform all block embeds to have zero mean and unit covariance
when treated as samples from a distribution"""
block_idx, all_embeds = zip(*self.block_data.items())
arr_embeds = np.transpose(np.array(all_embeds))
mean = np.mean(arr_embeds, axis=1).reshape(-1, 1)
centered = arr_embeds - mean
inv_cov = np.linalg.inv(np.cov(arr_embeds))
whitener = np.transpose(np.linalg.cholesky(inv_cov))
whitened = np.transpose(whitener.dot(centered))
self.embed_mean = mean.reshape(-1)
self.embed_whitener = whitener
self.block_data = dict(zip(block_idx, list(whitened)))
self.hash_data = defaultdict(list)
batch_size = 16384
i = 0
with torch.no_grad():
hashing_tensor = torch.cuda.HalfTensor(self.hash_matrix)
while True:
batch_slice = slice(i * batch_size, (i + 1) * batch_size)
batch_embed = torch.cuda.HalfTensor(whitened[batch_slice])
batch_block_idx = block_idx[batch_slice]
if batch_embed.size == 0:
break
hash_scores_pos = torch.matmul(batch_embed, hashing_tensor)
embed_scores = torch.cat((hash_scores_pos, -hash_scores_pos), axis=1)
embed_hashes = detach(torch.argmax(embed_scores, axis=1))
for hash, embed in zip(list(embed_hashes), list(detach(batch_embed))):
# [int] instead of [array<int>] since this is just for analysis rn
self.hash_data[hash].append(batch_block_idx)
@classmethod
def load_from_file(cls, fname):
......@@ -114,8 +156,26 @@ class HashedIndex(object):
new_index.block_data = state_dict['block_data']
new_index.hash_data = state_dict['hash_data']
new_index.hash_matrix = hash_matrix
return new_index
@classmethod
def whiten_and_rehash(cls, fname):
"""Load up a HashedIndex, whiten it and rehash"""
index = cls.load_from_file(fname)
all_vectors = []
for block_embed in index.block_data.values():
all_vectors.append(block_embed)
arr_vectors = np.transpose(np.array(all_vectors))
mean = np.mean(arr_vectors, axis=1)
cov = np.cov(arr_vectors)
inv_cov = np.linalg.inv(cov)
def test_retriever():
initialize_megatron(extra_args_provider=None,
......@@ -163,10 +223,12 @@ def main():
model = load_ict_checkpoint(only_block_model=True, no_grad=True)
model.eval()
dataset = get_ict_dataset()
data_iter = iter(get_dataloader(dataset))
hashed_index = HashedIndex(embed_size=128, num_buckets=4096)
data_iter = iter(get_one_epoch_dataloader(dataset))
hashed_index = HashedIndex(embed_size=128, num_buckets=4096, whiten=True)
i = 0
i = 1
total = 0
whiten = False
while True:
try:
query_tokens, query_pad_mask, \
......@@ -176,18 +238,25 @@ def main():
block_indices = detach(block_indices)
block_logits = model(None, None, block_tokens, block_pad_mask, only_block=True)
hashed_index.hash_embeds(block_logits, block_indices)
hashed_index.assign_block_embeds(block_indices[:,3], detach(block_logits))
if i % 100 == 0:
print(i, flush=True)
# If whiten, then hashing needs to be done after whitening the block embeds
# which is done in consolidate_shards_and_save()
if not whiten:
hashed_index.hash_embeds(block_logits, block_indices)
hashed_index.assign_block_embeds(block_indices[:, 3], detach(block_logits))
total += block_indices.size
i += 1
if i % 20 == 0:
print('Batch {:10d} | Total {:10d}'.format(i, total), flush=True)
if args.debug:
break
hashed_index.save_shard(args.rank)
torch.distributed.barrier()
del model
if mpu.get_data_parallel_rank() == 0:
if args.rank == 0:
hashed_index.consolidate_shards_and_save()
else:
hashed_index.clear()
......@@ -247,7 +316,7 @@ def get_ict_dataset():
return dataset
def get_dataloader(dataset):
def get_one_epoch_dataloader(dataset):
args = get_args()
world_size = mpu.get_data_parallel_world_size()
......
......@@ -184,6 +184,8 @@ def _add_training_args(parser):
def _add_initialization_args(parser):
group = parser.add_argument_group(title='initialization')
group.add_argument('--debug', action='store_true',
help='Run things in debug mode')
group.add_argument('--seed', type=int, default=1234,
help='Random seed used for python, numpy, '
'pytorch, and cuda.')
......
......@@ -46,9 +46,6 @@ def build_simple_training_sample(sample, target_seq_length, max_seq_length,
= pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
masked_labels, pad_id, max_seq_length)
# REALM true sequence length is twice as long but none of that is to be predicted with LM
# loss_mask_np = np.concatenate((loss_mask_np, np.ones(loss_mask_np.shape)), -1).astype(np.int64)
train_sample = {
'tokens': tokens_np,
'labels': labels_np,
......
......@@ -29,6 +29,7 @@ from megatron.utils import reduce_losses
num_batches = 0
def model_provider(only_query_model=False, only_block_model=False):
"""Build the model."""
args = get_args()
......@@ -103,7 +104,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid and test datasets."""
args = get_args()
print_rank_0('> building train, validation, and test datasets '
'for BERT ...')
'for BERT ICT...')
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
......
COMMAND="/home/scratch.gcf/adlr-utils/release/cluster-interface/latest/mp_launch python hashed_index.py \
--num-layers 12 \
--hidden-size 768 \
--num-attention-heads 12 \
--batch-size 8 \
--checkpoint-activations \
--seq-length 288 \
--max-position-embeddings 288 \
--train-iters 100000 \
--load /home/dcg-adlr-nkant-output.cosmos1203/chkpts/realm_debug \
--ict-load /home/dcg-adlr-nkant-output.cosmos1203/chkpts/ict_best \
--save /home/dcg-adlr-nkant-output.cosmos1203/chkpts/realm_debug \
--data-path /home/universal-lm-data.cosmos549/datasets/wiki-indexed/wikipedia_lines \
--titles-data-path /home/universal-lm-data.cosmos549/datasets/wiki-indexed/wikipedia_lines-titles \
--hash-data-path /home/dcg-adlr-nkant-data.cosmos1202/hash_data/ict_best.pkl \
--vocab-file /home/universal-lm-data.cosmos549/scratch/mshoeybi/data/albert/vocab.txt \
--split 58,1,1 \
--distributed-backend nccl \
--lr 0.0001 \
--num-workers 2 \
--lr-decay-style linear \
--warmup .01 \
--save-interval 3000 \
--fp16 \
--adlr-autoresume \
--adlr-autoresume-interval 100"
submit_job --image 'http://gitlab-master.nvidia.com/adlr/megatron-lm/megatron:20.03' --mounts /home/universal-lm-data.cosmos549,/home/dcg-adlr-nkant-source.cosmos1204,/home/dcg-adlr-nkant-data.cosmos1202,/home/dcg-adlr-nkant-output.cosmos1203,/home/nkant --name test_retriever --partition interactive --gpu 1 --nodes 1 --autoresume_timer 300 -c "${COMMAND}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment