Commit 0f0f60aa authored by Neel Kant's avatar Neel Kant
Browse files

Able to run REALM with terrible index sync

parent 2f7d666c
......@@ -93,8 +93,6 @@ def salient_span_mask(tokens, mask_id):
Note: Tokens here are vocab ids and not text tokens."""
tokenizer = get_tokenizer()
tokens_str = join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(tokens))
print("-" * 100)
print("TOKEN STR\n", tokens_str)
# need to get all named entities
entities = SPACY_NER(tokens_str).ents
......@@ -103,7 +101,6 @@ def salient_span_mask(tokens, mask_id):
return None
entity_idx = np.random.randint(0, len(entities))
selected_entity = entities[entity_idx]
print("SELECTED ENTITY\n", selected_entity.text)
token_pos_map = id_to_str_pos_map(tokens, tokenizer)
mask_start = mask_end = 0
......@@ -114,14 +111,17 @@ def salient_span_mask(tokens, mask_id):
if not set_mask_start:
mask_start += 1
mask_end += 1
masked_positions = list(range(mask_start, mask_end + 1))
masked_positions = list(range(mask_start - 1, mask_end))
labels = []
output_tokens = tokens.copy()
for id_idx in masked_positions:
labels.append(tokens[id_idx])
output_tokens[id_idx] = mask_id
print("OUTPUT\n", join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(output_tokens)))
#print("-" * 100 + '\n',
# "TOKEN STR\n", tokens_str + '\n',
# "SELECTED ENTITY\n", selected_entity.text + '\n',
# "OUTPUT\n", join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(output_tokens)), flush=True)
return output_tokens, masked_positions, labels
......
......@@ -91,7 +91,7 @@ class FaissMIPSIndex(object):
self._set_block_index()
def _set_block_index(self):
INDEX_TYPES = ['flat_l2', 'flat_ip']
INDEX_TYPES = ['flat_ip']
if self.index_type not in INDEX_TYPES:
raise ValueError("Invalid index type specified")
......@@ -123,14 +123,17 @@ class FaissMIPSIndex(object):
"""
if self.index_type == 'flat_l2':
query_embeds = self.alsh_query_preprocess_fn(query_embeds)
query_embeds = np.float32(query_embeds)
if reconstruct:
top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds.astype('float32'), top_k)
top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k)
return top_k_block_embeds
else:
distances, block_indices = self.block_mips_index.search(query_embeds.astype('float32'), top_k)
distances, block_indices = self.block_mips_index.search(query_embeds, top_k)
return distances, block_indices
# functions below are for ALSH, which currently isn't being used
def get_norm_powers_and_halves_array(self, embeds):
norm = np.linalg.norm(embeds, axis=1)
norm_powers = [np.multiply(norm, norm)] # squared L2 norms of all
......
......@@ -89,7 +89,7 @@ def forward_step(data_iterator, model):
# Forward model.
lm_logits, block_probs = model(tokens, pad_mask, query_block_indices)
with torch.no_grad():
retrieval_utility = get_retrieval_utility(lm_logits, labels, loss_mask)
retrieval_utility = get_retrieval_utility(lm_logits, block_probs, labels, loss_mask)
# P(y|x) = sum_z(P(y|z, x) * P(z|x))
block_probs = block_probs.unsqueeze(2).unsqueeze(3).expand_as(lm_logits)
......@@ -105,9 +105,13 @@ def forward_step(data_iterator, model):
return lm_loss, {'lm_loss': reduced_loss[0], 'retrieval_utility': reduced_loss[1]}
def get_retrieval_utility(lm_logits, labels, loss_mask):
def get_retrieval_utility(lm_logits, block_probs, labels, loss_mask):
"""log P(y | z, x) - log P(y | null, x)"""
# [batch x seq_len x vocab_size]
lm_logits = lm_logits[:, :, :labels.shape[1], :]
#non_null_block_probs = block_probs[:, :-1]
#non_null_block_probs /= torch.sum(non_null_block_probs, axis=1, keepdim=True)
# non_null_block_probs = non_null_block_probsexpand_as(lm_logits[:, :-1, :, :])
null_block_lm_logits = lm_logits[:, -1, :, :]
null_block_loss_ = mpu.vocab_parallel_cross_entropy(null_block_lm_logits.contiguous().float(),
labels.contiguous())
......@@ -119,10 +123,11 @@ def get_retrieval_utility(lm_logits, labels, loss_mask):
retrieved_block_lm_logits = lm_logits[:, block_num, :, :]
retrieved_block_loss_ = mpu.vocab_parallel_cross_entropy(retrieved_block_lm_logits.contiguous().float(),
labels.contiguous())
#retrieved_block_loss_ *= non_null_block_probs[:, block_num].reshape(-1, 1)
retrieved_block_loss = torch.sum(
retrieved_block_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
retrieved_block_losses.append(retrieved_block_loss)
avg_retrieved_block_loss = torch.sum(retrieved_block_losses) / (lm_logits.shape[1] - 1)
avg_retrieved_block_loss = torch.sum(torch.cuda.FloatTensor(retrieved_block_losses)) / (lm_logits.shape[1] - 1)
retrieval_utility = null_block_loss - avg_retrieved_block_loss
return retrieval_utility
......@@ -171,6 +176,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment