Commit c150c48b authored by Neel Kant's avatar Neel Kant
Browse files

Debug and run salient span masking

parent 451d1325
......@@ -15,6 +15,7 @@ from megatron.initialize import initialize_megatron
from megatron.model import REALMRetriever
from megatron.training import get_model
from pretrain_bert_ict import get_batch, model_provider
from indexer_utils import set_index_com_file_ready, set_model_com_file_not_ready, check_model_com_file_ready
# TODO re: main()
......@@ -115,45 +116,6 @@ def main():
set_model_com_file_not_ready()
INDEX_COM_FILE = 'ready.index'
MODEL_COM_FILE = 'ready.model'
def set_index_com_file_not_ready():
with open(INDEX_COM_FILE, 'w') as com_file:
com_file.write('0')
def set_index_com_file_ready():
with open(INDEX_COM_FILE, 'w') as com_file:
com_file.write('1')
def check_index_com_file_ready():
if not os.path.exists(INDEX_COM_FILE):
set_index_com_file_not_ready()
with open(INDEX_COM_FILE, 'r') as com_file:
return bool(com_file.readline())
def set_model_com_file_not_ready():
with open(MODEL_COM_FILE, 'w') as com_file:
com_file.write('0')
def set_model_com_file_ready():
with open(MODEL_COM_FILE, 'w') as com_file:
com_file.write('1')
def check_model_com_file_ready():
if not os.path.exists(MODEL_COM_FILE):
set_index_com_file_not_ready()
with open(MODEL_COM_FILE, 'r') as com_file:
return bool(com_file.readline())
def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=False, from_realm_chkpt=False):
args = get_args()
......@@ -210,6 +172,7 @@ def get_ict_dataset(use_titles=True):
max_seq_length=288, # doesn't matter
short_seq_prob=0.0001, # doesn't matter
seed=1,
query_in_block_prob=1,
use_titles=use_titles
)
dataset = ICTDataset(**kwargs)
......
......@@ -375,6 +375,7 @@ def create_masked_lm_predictions(tokens,
for p in masked_lms:
masked_lm_positions.append(p.index)
masked_lm_labels.append(p.label)
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary)
......@@ -387,7 +388,7 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
padding_length = max_seq_length - num_tokens
assert padding_length >= 0
assert len(tokentypes) == num_tokens
assert len(masked_positions) == len(masked_labels)
assert len(masked_positions) == len(masked_labels), (len(masked_positions), len(masked_labels))
# Tokens and token types.
filler = [pad_id] * padding_length
......
......@@ -25,14 +25,13 @@ def build_realm_training_sample(sample, max_seq_length,
except TypeError:
# this means the above returned None, and None isn't iterable.
# TODO: consider coding style.
print("No salient span found.", flush=True)
max_predictions_per_seq = masked_lm_prob * max_seq_length
masked_tokens, masked_positions, masked_labels, _ = create_masked_lm_predictions(
tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng)
tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \
= pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
= pad_and_convert_to_numpy(masked_tokens, tokentypes, masked_positions,
masked_labels, pad_id, max_seq_length)
train_sample = {
......@@ -84,7 +83,7 @@ def id_to_str_pos_map(token_ids, tokenizer):
# make sure total size is correct
offset = -2 if token_strs[-1].startswith("##") else 0
total_len = pos_map[-1] + len(token_strs[-1]) + offset
assert total_len == len(join_str_list(token_strs))
assert total_len == len(join_str_list(token_strs)) - 1, (total_len, len(join_str_list(token_strs)))
return pos_map
......@@ -93,25 +92,34 @@ def salient_span_mask(tokens, mask_id):
"""Creates the predictions for the masked LM objective.
Note: Tokens here are vocab ids and not text tokens."""
tokenizer = get_tokenizer()
tokens_str = join_str_list(tokenizer.tokenize(tokens))
tokens_str = join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(tokens))
# need to get all named entities
entities = SPACY_NER(tokens_str).ents
entities = [e for e in entities if e.text != "CLS"]
if len(entities) == 0:
return None
entity_idx = np.random.randint(0, len(entities))
selected_entity = entities[entity_idx]
selected_entity = np.random.choice(entities)
token_pos_map = id_to_str_pos_map(tokens, tokenizer)
mask_start = mask_end = token_pos_map.index(selected_entity.start_char)
mask_start = mask_end = 0
set_mask_start = False
while mask_end < len(token_pos_map) and token_pos_map[mask_end] < selected_entity.end_char:
if token_pos_map[mask_start] > selected_entity.start_char:
set_mask_start = True
if not set_mask_start:
mask_start += 1
mask_end += 1
masked_positions = list(range(mask_start, mask_end + 1))
labels = tokens.copy()
labels = []
output_tokens = tokens.copy()
for id_idx in range(mask_start, mask_end):
for id_idx in masked_positions:
labels.append(tokens[id_idx])
output_tokens[id_idx] = mask_id
return output_tokens, list(range(mask_start, mask_end)), labels
return output_tokens, masked_positions, labels
def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs,
......
......@@ -108,12 +108,8 @@ class FaissMIPSIndex(object):
if self.index_type not in INDEX_TYPES:
raise ValueError("Invalid index type specified")
if self.index_type == 'flat_l2':
index = faiss.IndexFlatL2(self.embed_size + 2 * self.m)
self.block_mips_index = faiss.IndexIDMap(index)
elif self.index_type == 'flat_ip':
index = faiss.IndexFlatIP(self.embed_size)
self.block_mips_index = faiss.IndexIDMap(index)
index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT)
self.block_mips_index = faiss.IndexIDMap(index)
def reset_index(self):
self._set_block_index()
......@@ -126,7 +122,7 @@ class FaissMIPSIndex(object):
if self.index_type == 'flat_l2':
block_embeds = self.alsh_block_preprocess_fn(block_embeds)
self.block_mips_index.add_with_ids(np.array(block_embeds), np.array(block_indices))
self.block_mips_index.add_with_ids(np.float32(np.array(block_embeds)), np.array(block_indices))
def search_mips_index(self, query_embeds, top_k, reconstruct=True):
"""Get the top-k blocks by the index distance metric.
......@@ -138,10 +134,10 @@ class FaissMIPSIndex(object):
query_embeds = self.alsh_query_preprocess_fn(query_embeds)
if reconstruct:
top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k)
top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds.astype('float32'), top_k)
return top_k_block_embeds
else:
distances, block_indices = self.block_mips_index.search(query_embeds, top_k)
distances, block_indices = self.block_mips_index.search(query_embeds.astype('float32'), top_k)
return distances, block_indices
def get_norm_powers_and_halves_array(self, embeds):
......@@ -176,6 +172,8 @@ class FaissMIPSIndex(object):
return np.float32(np.concatenate((query_embeds, halves_array, norm_powers), axis=1))
# This was the original hashing scheme, not used anymore
class RandProjectionLSHIndex(object):
"""Class for holding hashed data"""
def __init__(self, embed_size, num_buckets, whiten=True, seed=0):
......
......@@ -39,7 +39,7 @@ from megatron.model import get_params_for_weight_decay_optimization
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import make_data_loader
from megatron.utils import report_memory
from indexer import check_index_com_file_ready, set_index_com_file_not_ready, set_model_com_file_ready
from indexer_utils import check_index_com_file_ready, set_index_com_file_not_ready, set_model_com_file_ready
def pretrain(train_valid_test_dataset_provider, model_provider,
......
......@@ -14,7 +14,6 @@
# limitations under the License.
"""Pretrain BERT for Inverse Cloze Task"""
import torch
import torch.nn.functional as F
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment