Commit bf599e86 authored by Neel Kant's avatar Neel Kant
Browse files

Correct retrieval utility and add salient span preprocessing

parent 91158c9b
...@@ -389,6 +389,10 @@ def _add_data_args(parser): ...@@ -389,6 +389,10 @@ def _add_data_args(parser):
group.add_argument('--query-in-block-prob', type=float, default=0.1, group.add_argument('--query-in-block-prob', type=float, default=0.1,
help='Probability of keeping query in block for ICT dataset') help='Probability of keeping query in block for ICT dataset')
group.add_argument('--faiss-use-gpu', action='store_true') group.add_argument('--faiss-use-gpu', action='store_true')
group.add_argument('--index-reload-interval', type=int, default=500)
group.add_argument('--use-regular-masking', action='store_true')
group.add_argument('--allow-trivial-doc', action='store_true')
group.add_argument('--ner-data-path', type=str, default=None)
return parser return parser
......
...@@ -417,7 +417,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -417,7 +417,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
max_seq_length, masked_lm_prob, max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup, short_seq_prob, seed, skip_warmup,
dataset_type='standard_bert'): dataset_type='standard_bert'):
args = get_args()
if dataset_type not in DATASET_TYPES: if dataset_type not in DATASET_TYPES:
raise ValueError("Invalid dataset_type: ", dataset_type) raise ValueError("Invalid dataset_type: ", dataset_type)
...@@ -427,7 +427,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -427,7 +427,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
skip_warmup) skip_warmup)
if dataset_type in ['ict', 'realm']: if dataset_type in ['ict', 'realm']:
title_dataset = get_indexed_dataset_(data_prefix + '-titles', title_dataset = get_indexed_dataset_(args.titles_data_path,
data_impl, data_impl,
skip_warmup) skip_warmup)
...@@ -479,7 +479,6 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -479,7 +479,6 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
) )
if dataset_type == 'ict': if dataset_type == 'ict':
args = get_args()
dataset = ICTDataset( dataset = ICTDataset(
block_dataset=indexed_dataset, block_dataset=indexed_dataset,
title_dataset=title_dataset, title_dataset=title_dataset,
...@@ -487,6 +486,11 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -487,6 +486,11 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
**kwargs **kwargs
) )
elif dataset_type == 'realm': elif dataset_type == 'realm':
if args.ner_data_path is not None:
ner_dataset = get_indexed_dataset_(args.ner_data_path,
data_impl,
skip_warmup)
kwargs.update({'ner_dataset': ner_dataset})
dataset = REALMDataset( dataset = REALMDataset(
block_dataset=indexed_dataset, block_dataset=indexed_dataset,
title_dataset=title_dataset, title_dataset=title_dataset,
......
...@@ -18,9 +18,9 @@ class REALMDataset(Dataset): ...@@ -18,9 +18,9 @@ class REALMDataset(Dataset):
Presumably Presumably
""" """
def __init__(self, name, block_dataset, title_dataset, data_prefix, def __init__(self, name, block_dataset, title_dataset,
num_epochs, max_num_samples, masked_lm_prob, data_prefix, num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed): max_seq_length, short_seq_prob, seed, ner_dataset=None):
self.name = name self.name = name
self.seed = seed self.seed = seed
self.max_seq_length = max_seq_length self.max_seq_length = max_seq_length
...@@ -29,6 +29,7 @@ class REALMDataset(Dataset): ...@@ -29,6 +29,7 @@ class REALMDataset(Dataset):
self.title_dataset = title_dataset self.title_dataset = title_dataset
self.short_seq_prob = short_seq_prob self.short_seq_prob = short_seq_prob
self.rng = random.Random(self.seed) self.rng = random.Random(self.seed)
self.ner_dataset = ner_dataset
self.samples_mapping = get_block_samples_mapping( self.samples_mapping = get_block_samples_mapping(
block_dataset, title_dataset, data_prefix, num_epochs, block_dataset, title_dataset, data_prefix, num_epochs,
...@@ -48,7 +49,14 @@ class REALMDataset(Dataset): ...@@ -48,7 +49,14 @@ class REALMDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
start_idx, end_idx, doc_idx, block_idx = self.samples_mapping[idx] start_idx, end_idx, doc_idx, block_idx = self.samples_mapping[idx]
block = [list(self.block_dataset[i]) for i in range(start_idx, end_idx)] block = [list(self.block_dataset[i]) for i in range(start_idx, end_idx)]
# print([len(list(self.block_dataset[i])) for i in range(start_idx, end_idx)], flush=True)
assert len(block) > 1 assert len(block) > 1
block_ner_mask = None
if self.ner_dataset is not None:
block_ner_mask = [list(self.ner_dataset[i]) for i in range(start_idx, end_idx)]
# print([len(list(self.ner_dataset[i])) for i in range(start_idx, end_idx)], flush=True)
np_rng = np.random.RandomState(seed=(self.seed + idx)) np_rng = np.random.RandomState(seed=(self.seed + idx))
sample = build_realm_training_sample(block, sample = build_realm_training_sample(block,
...@@ -60,6 +68,7 @@ class REALMDataset(Dataset): ...@@ -60,6 +68,7 @@ class REALMDataset(Dataset):
self.mask_id, self.mask_id,
self.pad_id, self.pad_id,
self.masked_lm_prob, self.masked_lm_prob,
block_ner_mask,
np_rng) np_rng)
sample.update({'query_block_indices': np.array([block_idx]).astype(np.int64)}) sample.update({'query_block_indices': np.array([block_idx]).astype(np.int64)})
return sample return sample
......
...@@ -8,7 +8,7 @@ import spacy ...@@ -8,7 +8,7 @@ import spacy
import torch import torch
from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy
from megatron import get_tokenizer, print_rank_0, mpu from megatron import get_args, get_tokenizer, print_rank_0, mpu
SPACY_NER = spacy.load('en_core_web_lg') SPACY_NER = spacy.load('en_core_web_lg')
...@@ -16,19 +16,30 @@ SPACY_NER = spacy.load('en_core_web_lg') ...@@ -16,19 +16,30 @@ SPACY_NER = spacy.load('en_core_web_lg')
def build_realm_training_sample(sample, max_seq_length, def build_realm_training_sample(sample, max_seq_length,
vocab_id_list, vocab_id_to_token_dict, vocab_id_list, vocab_id_to_token_dict,
cls_id, sep_id, mask_id, pad_id, cls_id, sep_id, mask_id, pad_id,
masked_lm_prob, np_rng): masked_lm_prob, block_ner_mask, np_rng):
tokens = list(itertools.chain(*sample))[:max_seq_length - 2] tokens = list(itertools.chain(*sample))[:max_seq_length - 2]
tokens, tokentypes = create_single_tokens_and_tokentypes(tokens, cls_id, sep_id) tokens, tokentypes = create_single_tokens_and_tokentypes(tokens, cls_id, sep_id)
try: args = get_args()
masked_tokens, masked_positions, masked_labels = salient_span_mask(tokens, mask_id) if args.use_regular_masking:
except TypeError:
# this means the above returned None, and None isn't iterable.
# TODO: consider coding style.
max_predictions_per_seq = masked_lm_prob * max_seq_length max_predictions_per_seq = masked_lm_prob * max_seq_length
masked_tokens, masked_positions, masked_labels, _ = create_masked_lm_predictions( masked_tokens, masked_positions, masked_labels, _ = create_masked_lm_predictions(
tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng) cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng)
elif block_ner_mask is not None:
block_ner_mask = list(itertools.chain(*block_ner_mask))[:max_seq_length - 2]
block_ner_mask = [0] + block_ner_mask + [0]
masked_tokens, masked_positions, masked_labels = get_arrays_using_ner_mask(tokens, block_ner_mask, mask_id)
else:
try:
masked_tokens, masked_positions, masked_labels = salient_span_mask(tokens, mask_id)
except TypeError:
# this means the above returned None, and None isn't iterable.
# TODO: consider coding style.
max_predictions_per_seq = masked_lm_prob * max_seq_length
masked_tokens, masked_positions, masked_labels, _ = create_masked_lm_predictions(
tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng)
tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \ tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \
= pad_and_convert_to_numpy(masked_tokens, tokentypes, masked_positions, = pad_and_convert_to_numpy(masked_tokens, tokentypes, masked_positions,
...@@ -43,6 +54,28 @@ def build_realm_training_sample(sample, max_seq_length, ...@@ -43,6 +54,28 @@ def build_realm_training_sample(sample, max_seq_length,
return train_sample return train_sample
def get_arrays_using_ner_mask(tokens, block_ner_mask, mask_id):
tokenizer = get_tokenizer()
tokens_str = join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(tokens))
masked_tokens = tokens.copy()
masked_positions = []
masked_labels = []
for i in range(len(tokens)):
if block_ner_mask[i] == 1:
masked_positions.append(i)
masked_labels.append(tokens[i])
masked_tokens[i] = mask_id
# print("-" * 100 + '\n',
# "TOKEN STR\n", tokens_str + '\n',
# "OUTPUT\n", join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(masked_tokens)), flush=True)
return masked_tokens, masked_positions, masked_labels
def create_single_tokens_and_tokentypes(_tokens, cls_id, sep_id): def create_single_tokens_and_tokentypes(_tokens, cls_id, sep_id):
tokens = [] tokens = []
tokens.append(cls_id) tokens.append(cls_id)
...@@ -119,10 +152,10 @@ def salient_span_mask(tokens, mask_id): ...@@ -119,10 +152,10 @@ def salient_span_mask(tokens, mask_id):
for id_idx in masked_positions: for id_idx in masked_positions:
labels.append(tokens[id_idx]) labels.append(tokens[id_idx])
output_tokens[id_idx] = mask_id output_tokens[id_idx] = mask_id
#print("-" * 100 + '\n', # print("-" * 100 + '\n',
# "TOKEN STR\n", tokens_str + '\n', # "TOKEN STR\n", tokens_str + '\n',
# "SELECTED ENTITY\n", selected_entity.text + '\n', # "SELECTED ENTITY\n", selected_entity.text + '\n',
# "OUTPUT\n", join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(output_tokens)), flush=True) # "OUTPUT\n", join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(output_tokens)), flush=True)
return output_tokens, masked_positions, labels return output_tokens, masked_positions, labels
......
...@@ -16,9 +16,11 @@ def detach(tensor): ...@@ -16,9 +16,11 @@ def detach(tensor):
class BlockData(object): class BlockData(object):
def __init__(self): def __init__(self):
args = get_args()
self.embed_data = dict() self.embed_data = dict()
self.meta_data = dict() self.meta_data = dict()
self.temp_dir_name = 'temp_block_data' block_data_path = os.path.splitext(args.block_data_path)[0]
self.temp_dir_name = block_data_path + '_tmp'
def state(self): def state(self):
return { return {
...@@ -150,12 +152,12 @@ class FaissMIPSIndex(object): ...@@ -150,12 +152,12 @@ class FaissMIPSIndex(object):
for j in range(block_indices.shape[1]): for j in range(block_indices.shape[1]):
fresh_indices[i, j] = self.id_map[block_indices[i, j]] fresh_indices[i, j] = self.id_map[block_indices[i, j]]
block_indices = fresh_indices block_indices = fresh_indices
args = get_args() # args = get_args()
if args.rank == 0: # if args.rank == 0:
torch.save({'query_embeds': query_embeds, # torch.save({'query_embeds': query_embeds,
'id_map': self.id_map, # 'id_map': self.id_map,
'block_indices': block_indices, # 'block_indices': block_indices,
'distances': distances}, 'search.data') # 'distances': distances}, 'search.data')
return distances, block_indices return distances, block_indices
# functions below are for ALSH, which currently isn't being used # functions below are for ALSH, which currently isn't being used
......
...@@ -114,8 +114,15 @@ class REALMBertModel(MegatronModule): ...@@ -114,8 +114,15 @@ class REALMBertModel(MegatronModule):
# [batch_size x k x seq_length] # [batch_size x k x seq_length]
topk_block_tokens, topk_block_attention_mask = self.retriever.retrieve_evidence_blocks(
tokens, attention_mask, query_block_indices=query_block_indices, include_null_doc=True) args = get_args()
if args.allow_trivial_doc:
topk_block_tokens, topk_block_attention_mask = self.retriever.retrieve_evidence_blocks(
tokens, attention_mask, query_block_indices=None, include_null_doc=True)
else:
topk_block_tokens, topk_block_attention_mask = self.retriever.retrieve_evidence_blocks(
tokens, attention_mask, query_block_indices=query_block_indices, include_null_doc=True)
# print("Top k block shape: ", topk_block_tokens.shape, flush=True) # print("Top k block shape: ", topk_block_tokens.shape, flush=True)
batch_size = tokens.shape[0] batch_size = tokens.shape[0]
...@@ -130,15 +137,16 @@ class REALMBertModel(MegatronModule): ...@@ -130,15 +137,16 @@ class REALMBertModel(MegatronModule):
# [batch_size x k x embed_size] # [batch_size x k x embed_size]
true_model = self.retriever.ict_model.module.module true_model = self.retriever.ict_model.module.module
fresh_block_logits = mpu.checkpoint(true_model.embed_block, topk_block_tokens, topk_block_attention_mask) fresh_block_logits = true_model.embed_block(topk_block_tokens, topk_block_attention_mask)
fresh_block_logits = fresh_block_logits.reshape(batch_size, self.top_k, -1).float() fresh_block_logits = fresh_block_logits.reshape(batch_size, self.top_k, -1).float()
# print('Fresh block logits shape: ', fresh_block_logits.shape, flush=True) # print('Fresh block logits shape: ', fresh_block_logits.shape, flush=True)
# [batch_size x 1 x embed_size] # [batch_size x 1 x embed_size]
query_logits = mpu.checkpoint(true_model.embed_query, tokens, attention_mask).unsqueeze(1).float() query_logits = true_model.embed_query(tokens, attention_mask).unsqueeze(1).float()
# [batch_size x k] # [batch_size x k]
fresh_block_scores = torch.matmul(query_logits, torch.transpose(fresh_block_logits, 1, 2)).squeeze() fresh_block_scores = torch.matmul(query_logits, torch.transpose(fresh_block_logits, 1, 2)).squeeze()
# fresh_block_scores = fresh_block_scores / np.sqrt(query_logits.shape[2])
block_probs = F.softmax(fresh_block_scores, dim=1) block_probs = F.softmax(fresh_block_scores, dim=1)
# [batch_size * k x seq_length] # [batch_size * k x seq_length]
...@@ -163,7 +171,7 @@ class REALMBertModel(MegatronModule): ...@@ -163,7 +171,7 @@ class REALMBertModel(MegatronModule):
# block body ends after the second SEP # block body ends after the second SEP
block_ends = block_sep_indices[:, 1, 1] + 1 block_ends = block_sep_indices[:, 1, 1] + 1
print('-' * 100) # print('-' * 100)
for row_num in range(all_tokens.shape[0]): for row_num in range(all_tokens.shape[0]):
q_len = query_lengths[row_num] q_len = query_lengths[row_num]
b_start = block_starts[row_num] b_start = block_starts[row_num]
...@@ -176,24 +184,24 @@ class REALMBertModel(MegatronModule): ...@@ -176,24 +184,24 @@ class REALMBertModel(MegatronModule):
all_tokens[row_num, q_len:new_tokens_length] = topk_block_tokens[row_num, b_start:b_end] all_tokens[row_num, q_len:new_tokens_length] = topk_block_tokens[row_num, b_start:b_end]
all_tokens[row_num, new_tokens_length:] = self.retriever.ict_dataset.pad_id all_tokens[row_num, new_tokens_length:] = self.retriever.ict_dataset.pad_id
print(dset.decode_tokens(detach(all_tokens[row_num]).tolist()), '\n', flush=True) # print(dset.decode_tokens(detach(all_tokens[row_num]).tolist()), '\n', flush=True)
all_attention_mask[row_num, :new_tokens_length] = 1 all_attention_mask[row_num, :new_tokens_length] = 1
all_attention_mask[row_num, new_tokens_length:] = 0 all_attention_mask[row_num, new_tokens_length:] = 0
print('-' * 100) # print('-' * 100)
args = get_args() # args = get_args()
if args.rank == 0: # if args.rank == 0:
torch.save({'lm_tokens': all_tokens, # torch.save({'lm_tokens': all_tokens,
'lm_attn_mask': all_attention_mask, # 'lm_attn_mask': all_attention_mask,
'query_tokens': tokens, # 'query_tokens': tokens,
'query_attn_mask': attention_mask, # 'query_attn_mask': attention_mask,
'query_logits': query_logits, # 'query_logits': query_logits,
'block_tokens': topk_block_tokens, # 'block_tokens': topk_block_tokens,
'block_attn_mask': topk_block_attention_mask, # 'block_attn_mask': topk_block_attention_mask,
'block_logits': fresh_block_logits, # 'block_logits': fresh_block_logits,
'block_probs': block_probs, # 'block_probs': block_probs,
}, 'final_lm_inputs.data') # }, 'final_lm_inputs.data')
# assert all(torch.equal(all_tokens[i], all_tokens[0]) for i in range(self.top_k)) # assert all(torch.equal(all_tokens[i], all_tokens[0]) for i in range(self.top_k))
# assert all(torch.equal(all_attention_mask[i], all_attention_mask[0]) for i in range(self.top_k)) # assert all(torch.equal(all_attention_mask[i], all_attention_mask[0]) for i in range(self.top_k))
......
...@@ -394,7 +394,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -394,7 +394,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, group=get_gloo_comm_group(), async_op=True) recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, group=get_gloo_comm_group(), async_op=True)
last_reload_iteration = iteration last_reload_iteration = iteration
while iteration < args.train_iters: while iteration < args.train_iters:
if args.max_training_rank is not None and iteration >= last_reload_iteration + 100: if args.max_training_rank is not None and iteration >= last_reload_iteration + args.index_reload_interval:
if recv_handle.is_completed(): if recv_handle.is_completed():
# should add check that INDEX_READY == 1 but what else could be happening # should add check that INDEX_READY == 1 but what else could be happening
true_model = model true_model = model
......
...@@ -101,7 +101,7 @@ def forward_step(data_iterator, model): ...@@ -101,7 +101,7 @@ def forward_step(data_iterator, model):
# print('labels shape: ', labels.shape, flush=True) # print('labels shape: ', labels.shape, flush=True)
with torch.no_grad(): with torch.no_grad():
max_retrieval_utility, top_retrieval_utility, avg_retrieval_utility = mpu.checkpoint( max_retrieval_utility, top_retrieval_utility, avg_retrieval_utility, tokens_over_batch = mpu.checkpoint(
get_retrieval_utility, lm_logits, block_probs, labels, loss_mask) get_retrieval_utility, lm_logits, block_probs, labels, loss_mask)
# P(y|x) = sum_z(P(y|z, x) * P(z|x)) # P(y|x) = sum_z(P(y|z, x) * P(z|x))
...@@ -118,7 +118,7 @@ def forward_step(data_iterator, model): ...@@ -118,7 +118,7 @@ def forward_step(data_iterator, model):
# 'tokens': tokens.cpu(), # 'tokens': tokens.cpu(),
# 'pad_mask': pad_mask.cpu(), # 'pad_mask': pad_mask.cpu(),
# }, 'tensors.data') # }, 'tensors.data')
# torch.load('gagaga')
block_probs = block_probs.unsqueeze(2).unsqueeze(3).expand_as(relevant_logits) block_probs = block_probs.unsqueeze(2).unsqueeze(3).expand_as(relevant_logits)
# print(torch.sum(block_probs, dim=1), flush=True) # print(torch.sum(block_probs, dim=1), flush=True)
...@@ -131,58 +131,59 @@ def forward_step(data_iterator, model): ...@@ -131,58 +131,59 @@ def forward_step(data_iterator, model):
l_probs = torch.log(marginalized_probs) l_probs = torch.log(marginalized_probs)
return l_probs return l_probs
log_probs = mpu.checkpoint(get_log_probs, relevant_logits, block_probs)
def get_loss(l_probs, labs): def get_loss(l_probs, labs):
vocab_size = l_probs.shape[2] vocab_size = l_probs.shape[2]
loss = torch.nn.NLLLoss(ignore_index=-1)(l_probs.reshape(-1, vocab_size), labs.reshape(-1)) loss = torch.nn.NLLLoss(ignore_index=-1)(l_probs.reshape(-1, vocab_size), labs.reshape(-1))
# loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() # loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
return loss.float() return loss.float()
lm_loss = mpu.checkpoint(get_loss, log_probs, labels) lm_loss = get_loss(get_log_probs(relevant_logits, block_probs), labels)
reduced_loss = reduce_losses([lm_loss, max_retrieval_utility, top_retrieval_utility, avg_retrieval_utility, null_block_probs, tokens_over_batch])
# marginalized_logits = torch.sum(relevant_logits * block_probs, dim=1)
# vocab_size = marginalized_logits.shape[2]
# lm_loss_ = torch.nn.CrossEntropyLoss()(marginalized_logits.reshape(-1, vocab_size), labels.reshape(-1))
# lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
reduced_loss = reduce_losses([lm_loss, max_retrieval_utility, top_retrieval_utility, avg_retrieval_utility, null_block_probs])
# reduced_loss = reduce_losses([lm_loss]) # reduced_loss = reduce_losses([lm_loss])
# torch.cuda.synchronize() # torch.cuda.synchronize()
return lm_loss, {'lm_loss': reduced_loss[0], return lm_loss, {'lm_loss': reduced_loss[0],
'max_ru': reduced_loss[1], 'max_ru': reduced_loss[1],
'top_ru': reduced_loss[2], 'top_ru': reduced_loss[2],
'avg_ru': reduced_loss[3], 'avg_ru': reduced_loss[3],
'null_prob': reduced_loss[4]} 'null_prob': reduced_loss[4],
'mask/batch': reduced_loss[5]}
def get_retrieval_utility(lm_logits_, block_probs, labels, loss_mask): def get_retrieval_utility(lm_logits_, block_probs, labels, loss_mask):
"""log P(y | z, x) - log P(y | null, x)""" """log P(y | z, x) - log P(y | null, x)"""
# [batch x seq_len x vocab_size]
# [batch x top_k x seq_len x vocab_size]
lm_logits = lm_logits_[:, :, :labels.shape[1], :] lm_logits = lm_logits_[:, :, :labels.shape[1], :]
#non_null_block_probs = block_probs[:, :-1] batch_size, top_k = lm_logits.shape[0], lm_logits.shape[1]
#non_null_block_probs /= torch.sum(non_null_block_probs, axis=1, keepdim=True)
# non_null_block_probs = non_null_block_probsexpand_as(lm_logits[:, :-1, :, :]) # non_null_block_probs = block_probs[:, :-1]
# non_null_block_probs /= torch.sum(non_null_block_probs, axis=1, keepdim=True)
# non_null_block_probs = non_null_block_probs.expand_as(lm_logits[:, :-1, :, :])
null_block_lm_logits = lm_logits[:, -1, :, :] null_block_lm_logits = lm_logits[:, -1, :, :]
null_block_loss_ = mpu.vocab_parallel_cross_entropy(null_block_lm_logits.contiguous().float(), null_block_loss_ = mpu.vocab_parallel_cross_entropy(null_block_lm_logits.contiguous().float(),
labels.contiguous()) labels.contiguous())
null_block_loss = torch.sum( null_block_loss = torch.sum(null_block_loss_.view(-1) * loss_mask.reshape(-1)) / batch_size
null_block_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
retrieved_block_losses = [] retrieved_block_losses = []
for block_num in range(lm_logits.shape[1] - 1):
for block_num in range(top_k - 1):
retrieved_block_lm_logits = lm_logits[:, block_num, :, :] retrieved_block_lm_logits = lm_logits[:, block_num, :, :]
retrieved_block_loss_ = mpu.vocab_parallel_cross_entropy(retrieved_block_lm_logits.contiguous().float(), retrieved_block_loss_ = mpu.vocab_parallel_cross_entropy(retrieved_block_lm_logits.contiguous().float(),
labels.contiguous()) labels.contiguous())
#retrieved_block_loss_ *= non_null_block_probs[:, block_num].reshape(-1, 1)
retrieved_block_loss = torch.sum( # retrieved_block_loss_ *= non_null_block_probs[:, block_num].reshape(-1, 1)
retrieved_block_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() retrieved_block_loss = torch.sum(retrieved_block_loss_.view(-1) * loss_mask.reshape(-1)) / batch_size
retrieved_block_losses.append(retrieved_block_loss) retrieved_block_losses.append(retrieved_block_loss)
avg_retrieved_block_loss = torch.sum(torch.cuda.FloatTensor(retrieved_block_losses)) / (lm_logits.shape[1] - 1) avg_retrieved_block_loss = torch.sum(torch.cuda.FloatTensor(retrieved_block_losses)) / (top_k - 1)
max_retrieval_utility = null_block_loss - min(retrieved_block_losses) max_retrieval_utility = null_block_loss - min(retrieved_block_losses)
top_retrieval_utility = null_block_loss - retrieved_block_losses[0] top_retrieval_utility = null_block_loss - retrieved_block_losses[0]
avg_retrieval_utility = null_block_loss - avg_retrieved_block_loss avg_retrieval_utility = null_block_loss - avg_retrieved_block_loss
return max_retrieval_utility, top_retrieval_utility, avg_retrieval_utility
tokens_over_batch = loss_mask.sum().float() / batch_size
return max_retrieval_utility, top_retrieval_utility, avg_retrieval_utility, tokens_over_batch
def qa_forward_step(data_iterator, model): def qa_forward_step(data_iterator, model):
......
...@@ -24,6 +24,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ...@@ -24,6 +24,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir))) os.path.pardir)))
import time import time
import numpy as np
import torch import torch
try: try:
import nltk import nltk
...@@ -31,8 +32,11 @@ try: ...@@ -31,8 +32,11 @@ try:
except ImportError: except ImportError:
nltk_available = False nltk_available = False
from megatron.tokenizer import build_tokenizer from megatron.tokenizer import build_tokenizer
from megatron.data import indexed_dataset from megatron.data import indexed_dataset
from megatron.data.realm_dataset_utils import id_to_str_pos_map
# https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer # https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer
...@@ -75,6 +79,14 @@ class Encoder(object): ...@@ -75,6 +79,14 @@ class Encoder(object):
else: else:
Encoder.splitter = IdentitySplitter() Encoder.splitter = IdentitySplitter()
try:
import spacy
print("> Loading spacy")
Encoder.spacy = spacy.load('en_core_web_lg')
print(">> Finished loading spacy")
except:
Encoder.spacy = None
def encode(self, json_line): def encode(self, json_line):
data = json.loads(json_line) data = json.loads(json_line)
ids = {} ids = {}
...@@ -90,6 +102,56 @@ class Encoder(object): ...@@ -90,6 +102,56 @@ class Encoder(object):
ids[key] = doc_ids ids[key] = doc_ids
return ids, len(json_line) return ids, len(json_line)
def encode_with_ner(self, json_line):
if self.spacy is None:
raise ValueError('Cannot do NER without spacy')
data = json.loads(json_line)
ids = {}
ner_masks = {}
for key in self.args.json_keys:
text = data[key]
doc_ids = []
doc_ner_mask = []
for sentence in Encoder.splitter.tokenize(text):
sentence_ids = Encoder.tokenizer.tokenize(sentence)
if len(sentence_ids) > 0:
doc_ids.append(sentence_ids)
# sentence is cased?
# print(sentence)
entities = self.spacy(sentence).ents
undesired_types = ['CARDINAL', 'TIME', 'PERCENT', 'MONEY', 'QUANTITY', 'ORDINAL']
entities = [e for e in entities if e.text != "CLS" and e.label_ not in undesired_types]
# entities = []
masked_positions = []
if len(entities) > 0:
entity_idx = np.random.randint(0, len(entities))
selected_entity = entities[entity_idx]
token_pos_map = id_to_str_pos_map(sentence_ids, Encoder.tokenizer)
mask_start = mask_end = 0
set_mask_start = False
while mask_end < len(token_pos_map) and token_pos_map[mask_end] < selected_entity.end_char:
if token_pos_map[mask_start] > selected_entity.start_char:
set_mask_start = True
if not set_mask_start:
mask_start += 1
mask_end += 1
masked_positions = list(range(mask_start - 1, mask_end))
ner_mask = [0] * len(sentence_ids)
for pos in masked_positions:
ner_mask[pos] = 1
doc_ner_mask.append(ner_mask)
if self.args.append_eod:
doc_ids[-1].append(Encoder.tokenizer.eod)
doc_ner_mask[-1].append(0)
ids[key] = doc_ids
ner_masks[key + '-ner'] = doc_ner_mask
return ids, ner_masks, len(json_line)
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
group = parser.add_argument_group(title='input data') group = parser.add_argument_group(title='input data')
...@@ -126,6 +188,8 @@ def get_args(): ...@@ -126,6 +188,8 @@ def get_args():
help='Number of worker processes to launch') help='Number of worker processes to launch')
group.add_argument('--log-interval', type=int, default=100, group.add_argument('--log-interval', type=int, default=100,
help='Interval between progress updates') help='Interval between progress updates')
group.add_argument('--create-ner-masks', action='store_true',
help='Also create mask tensors for salient span masking')
args = parser.parse_args() args = parser.parse_args()
args.keep_empty = False args.keep_empty = False
...@@ -153,8 +217,11 @@ def main(): ...@@ -153,8 +217,11 @@ def main():
encoder = Encoder(args) encoder = Encoder(args)
tokenizer = build_tokenizer(args) tokenizer = build_tokenizer(args)
pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer) pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
encoded_docs = pool.imap(encoder.encode, fin, 25) if args.create_ner_masks:
#encoded_docs = map(encoder.encode, fin) encoded_docs = pool.imap(encoder.encode_with_ner, fin, 25)
else:
encoded_docs = pool.imap(encoder.encode, fin, 25)
#encoded_docs = map(encoder.encode, fin)
level = "document" level = "document"
if args.split_sentences: if args.split_sentences:
...@@ -165,7 +232,10 @@ def main(): ...@@ -165,7 +232,10 @@ def main():
output_bin_files = {} output_bin_files = {}
output_idx_files = {} output_idx_files = {}
builders = {} builders = {}
for key in args.json_keys: output_keys = args.json_keys.copy()
if args.create_ner_masks:
output_keys.extend([key + '-ner' for key in output_keys])
for key in output_keys:
output_bin_files[key] = "{}_{}_{}.bin".format(args.output_prefix, output_bin_files[key] = "{}_{}_{}.bin".format(args.output_prefix,
key, level) key, level)
output_idx_files[key] = "{}_{}_{}.idx".format(args.output_prefix, output_idx_files[key] = "{}_{}_{}.idx".format(args.output_prefix,
...@@ -179,12 +249,24 @@ def main(): ...@@ -179,12 +249,24 @@ def main():
total_bytes_processed = 0 total_bytes_processed = 0
print("Time to startup:", startup_end - startup_start) print("Time to startup:", startup_end - startup_start)
for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1): # for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1):
for i, doc_data in enumerate(encoded_docs, start=1):
if args.create_ner_masks:
doc, ner_masks, bytes_processed = doc_data
else:
doc, bytes_processed = doc_data
total_bytes_processed += bytes_processed total_bytes_processed += bytes_processed
for key, sentences in doc.items(): for key, sentences in doc.items():
for sentence in sentences: for sentence in sentences:
builders[key].add_item(torch.IntTensor(sentence)) builders[key].add_item(torch.IntTensor(sentence))
builders[key].end_document() builders[key].end_document()
if args.create_ner_masks:
for key, sentence_masks in ner_masks.items():
for mask in sentence_masks:
builders[key].add_item(torch.IntTensor(mask))
builders[key].end_document()
if i % args.log_interval == 0: if i % args.log_interval == 0:
current = time.time() current = time.time()
elapsed = current - proc_start elapsed = current - proc_start
...@@ -193,7 +275,7 @@ def main(): ...@@ -193,7 +275,7 @@ def main():
f"({i/elapsed} docs/s, {mbs} MB/s).", f"({i/elapsed} docs/s, {mbs} MB/s).",
file=sys.stderr) file=sys.stderr)
for key in args.json_keys: for key in output_keys:
builders[key].finalize(output_idx_files[key]) builders[key].finalize(output_idx_files[key])
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment