Commit c150c48b authored by Neel Kant's avatar Neel Kant
Browse files

Debug and run salient span masking

parent 451d1325
...@@ -15,6 +15,7 @@ from megatron.initialize import initialize_megatron ...@@ -15,6 +15,7 @@ from megatron.initialize import initialize_megatron
from megatron.model import REALMRetriever from megatron.model import REALMRetriever
from megatron.training import get_model from megatron.training import get_model
from pretrain_bert_ict import get_batch, model_provider from pretrain_bert_ict import get_batch, model_provider
from indexer_utils import set_index_com_file_ready, set_model_com_file_not_ready, check_model_com_file_ready
# TODO re: main() # TODO re: main()
...@@ -115,45 +116,6 @@ def main(): ...@@ -115,45 +116,6 @@ def main():
set_model_com_file_not_ready() set_model_com_file_not_ready()
INDEX_COM_FILE = 'ready.index'
MODEL_COM_FILE = 'ready.model'
def set_index_com_file_not_ready():
with open(INDEX_COM_FILE, 'w') as com_file:
com_file.write('0')
def set_index_com_file_ready():
with open(INDEX_COM_FILE, 'w') as com_file:
com_file.write('1')
def check_index_com_file_ready():
if not os.path.exists(INDEX_COM_FILE):
set_index_com_file_not_ready()
with open(INDEX_COM_FILE, 'r') as com_file:
return bool(com_file.readline())
def set_model_com_file_not_ready():
with open(MODEL_COM_FILE, 'w') as com_file:
com_file.write('0')
def set_model_com_file_ready():
with open(MODEL_COM_FILE, 'w') as com_file:
com_file.write('1')
def check_model_com_file_ready():
if not os.path.exists(MODEL_COM_FILE):
set_index_com_file_not_ready()
with open(MODEL_COM_FILE, 'r') as com_file:
return bool(com_file.readline())
def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=False, from_realm_chkpt=False): def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=False, from_realm_chkpt=False):
args = get_args() args = get_args()
...@@ -210,6 +172,7 @@ def get_ict_dataset(use_titles=True): ...@@ -210,6 +172,7 @@ def get_ict_dataset(use_titles=True):
max_seq_length=288, # doesn't matter max_seq_length=288, # doesn't matter
short_seq_prob=0.0001, # doesn't matter short_seq_prob=0.0001, # doesn't matter
seed=1, seed=1,
query_in_block_prob=1,
use_titles=use_titles use_titles=use_titles
) )
dataset = ICTDataset(**kwargs) dataset = ICTDataset(**kwargs)
......
...@@ -375,6 +375,7 @@ def create_masked_lm_predictions(tokens, ...@@ -375,6 +375,7 @@ def create_masked_lm_predictions(tokens,
for p in masked_lms: for p in masked_lms:
masked_lm_positions.append(p.index) masked_lm_positions.append(p.index)
masked_lm_labels.append(p.label) masked_lm_labels.append(p.label)
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary) return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary)
...@@ -387,7 +388,7 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, ...@@ -387,7 +388,7 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
padding_length = max_seq_length - num_tokens padding_length = max_seq_length - num_tokens
assert padding_length >= 0 assert padding_length >= 0
assert len(tokentypes) == num_tokens assert len(tokentypes) == num_tokens
assert len(masked_positions) == len(masked_labels) assert len(masked_positions) == len(masked_labels), (len(masked_positions), len(masked_labels))
# Tokens and token types. # Tokens and token types.
filler = [pad_id] * padding_length filler = [pad_id] * padding_length
......
...@@ -25,14 +25,13 @@ def build_realm_training_sample(sample, max_seq_length, ...@@ -25,14 +25,13 @@ def build_realm_training_sample(sample, max_seq_length,
except TypeError: except TypeError:
# this means the above returned None, and None isn't iterable. # this means the above returned None, and None isn't iterable.
# TODO: consider coding style. # TODO: consider coding style.
print("No salient span found.", flush=True)
max_predictions_per_seq = masked_lm_prob * max_seq_length max_predictions_per_seq = masked_lm_prob * max_seq_length
masked_tokens, masked_positions, masked_labels, _ = create_masked_lm_predictions( masked_tokens, masked_positions, masked_labels, _ = create_masked_lm_predictions(
tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng) cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng)
tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \ tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \
= pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, = pad_and_convert_to_numpy(masked_tokens, tokentypes, masked_positions,
masked_labels, pad_id, max_seq_length) masked_labels, pad_id, max_seq_length)
train_sample = { train_sample = {
...@@ -84,7 +83,7 @@ def id_to_str_pos_map(token_ids, tokenizer): ...@@ -84,7 +83,7 @@ def id_to_str_pos_map(token_ids, tokenizer):
# make sure total size is correct # make sure total size is correct
offset = -2 if token_strs[-1].startswith("##") else 0 offset = -2 if token_strs[-1].startswith("##") else 0
total_len = pos_map[-1] + len(token_strs[-1]) + offset total_len = pos_map[-1] + len(token_strs[-1]) + offset
assert total_len == len(join_str_list(token_strs)) assert total_len == len(join_str_list(token_strs)) - 1, (total_len, len(join_str_list(token_strs)))
return pos_map return pos_map
...@@ -93,25 +92,34 @@ def salient_span_mask(tokens, mask_id): ...@@ -93,25 +92,34 @@ def salient_span_mask(tokens, mask_id):
"""Creates the predictions for the masked LM objective. """Creates the predictions for the masked LM objective.
Note: Tokens here are vocab ids and not text tokens.""" Note: Tokens here are vocab ids and not text tokens."""
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
tokens_str = join_str_list(tokenizer.tokenize(tokens)) tokens_str = join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(tokens))
# need to get all named entities # need to get all named entities
entities = SPACY_NER(tokens_str).ents entities = SPACY_NER(tokens_str).ents
entities = [e for e in entities if e.text != "CLS"]
if len(entities) == 0: if len(entities) == 0:
return None return None
entity_idx = np.random.randint(0, len(entities))
selected_entity = entities[entity_idx]
selected_entity = np.random.choice(entities)
token_pos_map = id_to_str_pos_map(tokens, tokenizer) token_pos_map = id_to_str_pos_map(tokens, tokenizer)
mask_start = mask_end = token_pos_map.index(selected_entity.start_char) mask_start = mask_end = 0
set_mask_start = False
while mask_end < len(token_pos_map) and token_pos_map[mask_end] < selected_entity.end_char: while mask_end < len(token_pos_map) and token_pos_map[mask_end] < selected_entity.end_char:
if token_pos_map[mask_start] > selected_entity.start_char:
set_mask_start = True
if not set_mask_start:
mask_start += 1
mask_end += 1 mask_end += 1
masked_positions = list(range(mask_start, mask_end + 1))
labels = tokens.copy() labels = []
output_tokens = tokens.copy() output_tokens = tokens.copy()
for id_idx in range(mask_start, mask_end): for id_idx in masked_positions:
labels.append(tokens[id_idx])
output_tokens[id_idx] = mask_id output_tokens[id_idx] = mask_id
return output_tokens, list(range(mask_start, mask_end)), labels return output_tokens, masked_positions, labels
def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs, def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs,
......
...@@ -108,12 +108,8 @@ class FaissMIPSIndex(object): ...@@ -108,12 +108,8 @@ class FaissMIPSIndex(object):
if self.index_type not in INDEX_TYPES: if self.index_type not in INDEX_TYPES:
raise ValueError("Invalid index type specified") raise ValueError("Invalid index type specified")
if self.index_type == 'flat_l2': index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT)
index = faiss.IndexFlatL2(self.embed_size + 2 * self.m) self.block_mips_index = faiss.IndexIDMap(index)
self.block_mips_index = faiss.IndexIDMap(index)
elif self.index_type == 'flat_ip':
index = faiss.IndexFlatIP(self.embed_size)
self.block_mips_index = faiss.IndexIDMap(index)
def reset_index(self): def reset_index(self):
self._set_block_index() self._set_block_index()
...@@ -126,7 +122,7 @@ class FaissMIPSIndex(object): ...@@ -126,7 +122,7 @@ class FaissMIPSIndex(object):
if self.index_type == 'flat_l2': if self.index_type == 'flat_l2':
block_embeds = self.alsh_block_preprocess_fn(block_embeds) block_embeds = self.alsh_block_preprocess_fn(block_embeds)
self.block_mips_index.add_with_ids(np.array(block_embeds), np.array(block_indices)) self.block_mips_index.add_with_ids(np.float32(np.array(block_embeds)), np.array(block_indices))
def search_mips_index(self, query_embeds, top_k, reconstruct=True): def search_mips_index(self, query_embeds, top_k, reconstruct=True):
"""Get the top-k blocks by the index distance metric. """Get the top-k blocks by the index distance metric.
...@@ -138,10 +134,10 @@ class FaissMIPSIndex(object): ...@@ -138,10 +134,10 @@ class FaissMIPSIndex(object):
query_embeds = self.alsh_query_preprocess_fn(query_embeds) query_embeds = self.alsh_query_preprocess_fn(query_embeds)
if reconstruct: if reconstruct:
top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k) top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds.astype('float32'), top_k)
return top_k_block_embeds return top_k_block_embeds
else: else:
distances, block_indices = self.block_mips_index.search(query_embeds, top_k) distances, block_indices = self.block_mips_index.search(query_embeds.astype('float32'), top_k)
return distances, block_indices return distances, block_indices
def get_norm_powers_and_halves_array(self, embeds): def get_norm_powers_and_halves_array(self, embeds):
...@@ -176,6 +172,8 @@ class FaissMIPSIndex(object): ...@@ -176,6 +172,8 @@ class FaissMIPSIndex(object):
return np.float32(np.concatenate((query_embeds, halves_array, norm_powers), axis=1)) return np.float32(np.concatenate((query_embeds, halves_array, norm_powers), axis=1))
# This was the original hashing scheme, not used anymore
class RandProjectionLSHIndex(object): class RandProjectionLSHIndex(object):
"""Class for holding hashed data""" """Class for holding hashed data"""
def __init__(self, embed_size, num_buckets, whiten=True, seed=0): def __init__(self, embed_size, num_buckets, whiten=True, seed=0):
......
...@@ -39,7 +39,7 @@ from megatron.model import get_params_for_weight_decay_optimization ...@@ -39,7 +39,7 @@ from megatron.model import get_params_for_weight_decay_optimization
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import make_data_loader from megatron.utils import make_data_loader
from megatron.utils import report_memory from megatron.utils import report_memory
from indexer import check_index_com_file_ready, set_index_com_file_not_ready, set_model_com_file_ready from indexer_utils import check_index_com_file_ready, set_index_com_file_not_ready, set_model_com_file_ready
def pretrain(train_valid_test_dataset_provider, model_provider, def pretrain(train_valid_test_dataset_provider, model_provider,
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
"""Pretrain BERT for Inverse Cloze Task""" """Pretrain BERT for Inverse Cloze Task"""
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment