"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "6a5f06488c0d88c1827c016835cd5f64abe4b52c"
Commit 0f0f60aa authored by Neel Kant's avatar Neel Kant
Browse files

Able to run REALM with terrible index sync

parent 2f7d666c
...@@ -93,8 +93,6 @@ def salient_span_mask(tokens, mask_id): ...@@ -93,8 +93,6 @@ def salient_span_mask(tokens, mask_id):
Note: Tokens here are vocab ids and not text tokens.""" Note: Tokens here are vocab ids and not text tokens."""
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
tokens_str = join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(tokens)) tokens_str = join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(tokens))
print("-" * 100)
print("TOKEN STR\n", tokens_str)
# need to get all named entities # need to get all named entities
entities = SPACY_NER(tokens_str).ents entities = SPACY_NER(tokens_str).ents
...@@ -103,7 +101,6 @@ def salient_span_mask(tokens, mask_id): ...@@ -103,7 +101,6 @@ def salient_span_mask(tokens, mask_id):
return None return None
entity_idx = np.random.randint(0, len(entities)) entity_idx = np.random.randint(0, len(entities))
selected_entity = entities[entity_idx] selected_entity = entities[entity_idx]
print("SELECTED ENTITY\n", selected_entity.text)
token_pos_map = id_to_str_pos_map(tokens, tokenizer) token_pos_map = id_to_str_pos_map(tokens, tokenizer)
mask_start = mask_end = 0 mask_start = mask_end = 0
...@@ -114,14 +111,17 @@ def salient_span_mask(tokens, mask_id): ...@@ -114,14 +111,17 @@ def salient_span_mask(tokens, mask_id):
if not set_mask_start: if not set_mask_start:
mask_start += 1 mask_start += 1
mask_end += 1 mask_end += 1
masked_positions = list(range(mask_start, mask_end + 1)) masked_positions = list(range(mask_start - 1, mask_end))
labels = [] labels = []
output_tokens = tokens.copy() output_tokens = tokens.copy()
for id_idx in masked_positions: for id_idx in masked_positions:
labels.append(tokens[id_idx]) labels.append(tokens[id_idx])
output_tokens[id_idx] = mask_id output_tokens[id_idx] = mask_id
print("OUTPUT\n", join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(output_tokens))) #print("-" * 100 + '\n',
# "TOKEN STR\n", tokens_str + '\n',
# "SELECTED ENTITY\n", selected_entity.text + '\n',
# "OUTPUT\n", join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(output_tokens)), flush=True)
return output_tokens, masked_positions, labels return output_tokens, masked_positions, labels
......
...@@ -91,7 +91,7 @@ class FaissMIPSIndex(object): ...@@ -91,7 +91,7 @@ class FaissMIPSIndex(object):
self._set_block_index() self._set_block_index()
def _set_block_index(self): def _set_block_index(self):
INDEX_TYPES = ['flat_l2', 'flat_ip'] INDEX_TYPES = ['flat_ip']
if self.index_type not in INDEX_TYPES: if self.index_type not in INDEX_TYPES:
raise ValueError("Invalid index type specified") raise ValueError("Invalid index type specified")
...@@ -123,14 +123,17 @@ class FaissMIPSIndex(object): ...@@ -123,14 +123,17 @@ class FaissMIPSIndex(object):
""" """
if self.index_type == 'flat_l2': if self.index_type == 'flat_l2':
query_embeds = self.alsh_query_preprocess_fn(query_embeds) query_embeds = self.alsh_query_preprocess_fn(query_embeds)
query_embeds = np.float32(query_embeds)
if reconstruct: if reconstruct:
top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds.astype('float32'), top_k) top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k)
return top_k_block_embeds return top_k_block_embeds
else: else:
distances, block_indices = self.block_mips_index.search(query_embeds.astype('float32'), top_k) distances, block_indices = self.block_mips_index.search(query_embeds, top_k)
return distances, block_indices return distances, block_indices
# functions below are for ALSH, which currently isn't being used
def get_norm_powers_and_halves_array(self, embeds): def get_norm_powers_and_halves_array(self, embeds):
norm = np.linalg.norm(embeds, axis=1) norm = np.linalg.norm(embeds, axis=1)
norm_powers = [np.multiply(norm, norm)] # squared L2 norms of all norm_powers = [np.multiply(norm, norm)] # squared L2 norms of all
......
...@@ -89,7 +89,7 @@ def forward_step(data_iterator, model): ...@@ -89,7 +89,7 @@ def forward_step(data_iterator, model):
# Forward model. # Forward model.
lm_logits, block_probs = model(tokens, pad_mask, query_block_indices) lm_logits, block_probs = model(tokens, pad_mask, query_block_indices)
with torch.no_grad(): with torch.no_grad():
retrieval_utility = get_retrieval_utility(lm_logits, labels, loss_mask) retrieval_utility = get_retrieval_utility(lm_logits, block_probs, labels, loss_mask)
# P(y|x) = sum_z(P(y|z, x) * P(z|x)) # P(y|x) = sum_z(P(y|z, x) * P(z|x))
block_probs = block_probs.unsqueeze(2).unsqueeze(3).expand_as(lm_logits) block_probs = block_probs.unsqueeze(2).unsqueeze(3).expand_as(lm_logits)
...@@ -105,9 +105,13 @@ def forward_step(data_iterator, model): ...@@ -105,9 +105,13 @@ def forward_step(data_iterator, model):
return lm_loss, {'lm_loss': reduced_loss[0], 'retrieval_utility': reduced_loss[1]} return lm_loss, {'lm_loss': reduced_loss[0], 'retrieval_utility': reduced_loss[1]}
def get_retrieval_utility(lm_logits, labels, loss_mask): def get_retrieval_utility(lm_logits, block_probs, labels, loss_mask):
"""log P(y | z, x) - log P(y | null, x)""" """log P(y | z, x) - log P(y | null, x)"""
# [batch x seq_len x vocab_size] # [batch x seq_len x vocab_size]
lm_logits = lm_logits[:, :, :labels.shape[1], :]
#non_null_block_probs = block_probs[:, :-1]
#non_null_block_probs /= torch.sum(non_null_block_probs, axis=1, keepdim=True)
# non_null_block_probs = non_null_block_probsexpand_as(lm_logits[:, :-1, :, :])
null_block_lm_logits = lm_logits[:, -1, :, :] null_block_lm_logits = lm_logits[:, -1, :, :]
null_block_loss_ = mpu.vocab_parallel_cross_entropy(null_block_lm_logits.contiguous().float(), null_block_loss_ = mpu.vocab_parallel_cross_entropy(null_block_lm_logits.contiguous().float(),
labels.contiguous()) labels.contiguous())
...@@ -119,10 +123,11 @@ def get_retrieval_utility(lm_logits, labels, loss_mask): ...@@ -119,10 +123,11 @@ def get_retrieval_utility(lm_logits, labels, loss_mask):
retrieved_block_lm_logits = lm_logits[:, block_num, :, :] retrieved_block_lm_logits = lm_logits[:, block_num, :, :]
retrieved_block_loss_ = mpu.vocab_parallel_cross_entropy(retrieved_block_lm_logits.contiguous().float(), retrieved_block_loss_ = mpu.vocab_parallel_cross_entropy(retrieved_block_lm_logits.contiguous().float(),
labels.contiguous()) labels.contiguous())
#retrieved_block_loss_ *= non_null_block_probs[:, block_num].reshape(-1, 1)
retrieved_block_loss = torch.sum( retrieved_block_loss = torch.sum(
retrieved_block_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() retrieved_block_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
retrieved_block_losses.append(retrieved_block_loss) retrieved_block_losses.append(retrieved_block_loss)
avg_retrieved_block_loss = torch.sum(retrieved_block_losses) / (lm_logits.shape[1] - 1) avg_retrieved_block_loss = torch.sum(torch.cuda.FloatTensor(retrieved_block_losses)) / (lm_logits.shape[1] - 1)
retrieval_utility = null_block_loss - avg_retrieved_block_loss retrieval_utility = null_block_loss - avg_retrieved_block_loss
return retrieval_utility return retrieval_utility
...@@ -171,6 +176,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -171,6 +176,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if __name__ == "__main__": if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, forward_step, pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment