Commit e7045139 authored by Neel Kant's avatar Neel Kant
Browse files

Add stanza NER salient span masking

parent b8bb0b49
......@@ -394,6 +394,10 @@ def _add_data_args(parser):
group.add_argument('--use-random-spans', action='store_true')
group.add_argument('--allow-trivial-doc', action='store_true')
group.add_argument('--ner-data-path', type=str, default=None)
group.add_argument('--cased-data-path', type=str, default=None,
help='path to cased data to use for NER salient span masking')
group.add_argument('--cased-vocab', type=str, default=None,
help='path to cased vocab file to use for NER salient span masking')
return parser
......
......@@ -387,7 +387,7 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
num_tokens = len(tokens)
padding_length = max_seq_length - num_tokens
assert padding_length >= 0
assert len(tokentypes) == num_tokens
assert len(tokentypes) == num_tokens, (len(tokentypes), num_tokens)
assert len(masked_positions) == len(masked_labels), (len(masked_positions), len(masked_labels))
# Tokens and token types.
......@@ -491,6 +491,12 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
data_impl,
skip_warmup)
kwargs.update({'ner_dataset': ner_dataset})
elif args.cased_data_path is not None:
cased_dataset = get_indexed_dataset_(args.cased_data_path,
data_impl,
skip_warmup)
kwargs.update({'cased_block_dataset': cased_dataset,
'cased_vocab': args.cased_vocab})
dataset = REALMDataset(
block_dataset=indexed_dataset,
title_dataset=title_dataset,
......
......@@ -20,7 +20,7 @@ class REALMDataset(Dataset):
"""
def __init__(self, name, block_dataset, title_dataset,
data_prefix, num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed, ner_dataset=None):
max_seq_length, short_seq_prob, seed, ner_dataset=None, cased_block_dataset=None, cased_vocab=None):
self.name = name
self.seed = seed
self.max_seq_length = max_seq_length
......@@ -29,7 +29,13 @@ class REALMDataset(Dataset):
self.title_dataset = title_dataset
self.short_seq_prob = short_seq_prob
self.rng = random.Random(self.seed)
self.ner_dataset = ner_dataset
self.cased_block_dataset = cased_block_dataset
self.cased_tokenizer = None
if self.cased_block_dataset is not None:
from megatron.tokenizer.tokenizer import BertWordPieceTokenizer
self.cased_tokenizer = BertWordPieceTokenizer(vocab_file=cased_vocab, lower_case=False)
self.samples_mapping = get_block_samples_mapping(
block_dataset, title_dataset, data_prefix, num_epochs,
......@@ -49,7 +55,6 @@ class REALMDataset(Dataset):
def __getitem__(self, idx):
start_idx, end_idx, doc_idx, block_idx = self.samples_mapping[idx]
block = [list(self.block_dataset[i]) for i in range(start_idx, end_idx)]
# print([len(list(self.block_dataset[i])) for i in range(start_idx, end_idx)], flush=True)
assert len(block) > 1
block_ner_mask = None
......@@ -57,6 +62,10 @@ class REALMDataset(Dataset):
block_ner_mask = [list(self.ner_dataset[i]) for i in range(start_idx, end_idx)]
# print([len(list(self.ner_dataset[i])) for i in range(start_idx, end_idx)], flush=True)
cased_tokens = None
if self.cased_block_dataset is not None:
cased_tokens = [list(self.cased_block_dataset[i]) for i in range(start_idx, end_idx)]
np_rng = np.random.RandomState(seed=(self.seed + idx))
sample = build_realm_training_sample(block,
......@@ -69,6 +78,8 @@ class REALMDataset(Dataset):
self.pad_id,
self.masked_lm_prob,
block_ner_mask,
cased_tokens,
self.cased_tokenizer,
np_rng)
sample.update({'query_block_indices': np.array([block_idx]).astype(np.int64)})
return sample
......
......@@ -6,6 +6,12 @@ import time
import numpy as np
import spacy
import torch
try:
import stanza
processors_dict = {'tokenize': 'default', 'mwt': 'default', 'ner': 'conll03'}
stanza_pipeline = stanza.Pipeline('en', processors=processors_dict, use_gpu=True)
except:
pass
from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy
from megatron import get_args, get_tokenizer, print_rank_0, mpu
......@@ -16,7 +22,8 @@ SPACY_NER = spacy.load('en_core_web_lg')
def build_realm_training_sample(sample, max_seq_length,
vocab_id_list, vocab_id_to_token_dict,
cls_id, sep_id, mask_id, pad_id,
masked_lm_prob, block_ner_mask, np_rng):
masked_lm_prob, block_ner_mask, cased_tokens,
cased_tokenizer, np_rng):
tokens = list(itertools.chain(*sample))[:max_seq_length - 2]
tokens, tokentypes = create_single_tokens_and_tokentypes(tokens, cls_id, sep_id)
......@@ -35,8 +42,20 @@ def build_realm_training_sample(sample, max_seq_length,
masked_tokens, masked_positions, masked_labels = get_arrays_using_ner_mask(tokens, block_ner_mask, mask_id)
else:
try:
masked_tokens, masked_positions, masked_labels = salient_span_mask(tokens, mask_id)
except TypeError:
if args.cased_data_path is not None:
total_len = sum(len(l) for l in sample)
# truncate the last sentence to make it so that the whole thing has length max_seq_length - 2
if total_len > max_seq_length - 2:
offset = -(total_len - (max_seq_length - 2))
sample[-1] = sample[-1][:offset]
masked_tokens, masked_positions, masked_labels = get_stanza_ner_mask(sample, cased_tokens, cased_tokenizer,
cls_id, sep_id, mask_id)
else:
masked_tokens, masked_positions, masked_labels = salient_span_mask(tokens, mask_id)
except:
# print("+" * 100, flush=True)
# print('could not create salient span', flush=True)
# print("+" * 100, flush=True)
# this means the above returned None, and None isn't iterable.
# TODO: consider coding style.
max_predictions_per_seq = masked_lm_prob * max_seq_length
......@@ -57,6 +76,67 @@ def build_realm_training_sample(sample, max_seq_length,
return train_sample
def get_stanza_ner_mask(tokens, cased_tokens, cased_tokenizer, cls_id, sep_id, mask_id):
"""Use stanza to generate NER salient span masks in the loop"""
# assuming that the default tokenizer is uncased.
uncased_tokenizer = get_tokenizer()
block_ner_mask = []
for cased_sent_ids, uncased_sent_ids in zip(cased_tokens, tokens):
# print('>')
token_pos_map = id_to_str_pos_map(uncased_sent_ids, uncased_tokenizer)
# get the cased string and do NER with both toolkits
cased_sent_str = join_str_list(cased_tokenizer.tokenizer.convert_ids_to_tokens(cased_sent_ids))
entities = stanza_pipeline(cased_sent_str).ents
spacy_entities = SPACY_NER(cased_sent_str).ents
# CoNLL doesn't do dates, so we scan with spacy to get the dates.
entities = [e for e in entities if e.text != 'CLS']
entities.extend([e for e in spacy_entities if (e.text != 'CLS' and e.label_ == 'DATE')])
# randomize which entities to look at, and set a target of 12% of tokens being masked
entity_indices = np.arange(len(entities))
np.random.shuffle(entity_indices)
target_num_masks = int(len(cased_sent_ids) * 0.12)
masked_positions = []
for entity_idx in entity_indices[:3]:
# if we have enough masks then break.
if len(masked_positions) > target_num_masks:
break
selected_entity = entities[entity_idx]
# print(">> selected entity: {}".format(selected_entity.text), flush=True)
mask_start = mask_end = 0
set_mask_start = False
# loop for checking where mask should start and end.
while mask_end < len(token_pos_map) and token_pos_map[mask_end] < selected_entity.end_char:
if token_pos_map[mask_start] > selected_entity.start_char:
set_mask_start = True
if not set_mask_start:
mask_start += 1
mask_end += 1
# add offset to indices since our input was list of sentences
masked_positions.extend(range(mask_start - 1, mask_end))
ner_mask = [0] * len(uncased_sent_ids)
for pos in masked_positions:
ner_mask[pos] = 1
block_ner_mask.extend(ner_mask)
# len_tokens = [len(l) for l in tokens]
# print(len_tokens, flush=True)
# print([sum(len_tokens[:i + 1]) for i in range(len(tokens))], flush=True)
tokens = list(itertools.chain(*tokens))
tokens = [cls_id] + tokens + [sep_id]
block_ner_mask = [0] + block_ner_mask + [0]
return get_arrays_using_ner_mask(tokens, block_ner_mask, mask_id)
def get_arrays_using_ner_mask(tokens, block_ner_mask, mask_id):
tokenizer = get_tokenizer()
tokens_str = join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(tokens))
......@@ -65,16 +145,17 @@ def get_arrays_using_ner_mask(tokens, block_ner_mask, mask_id):
masked_positions = []
masked_labels = []
for i in range(len(tokens)):
if block_ner_mask[i] == 1:
masked_positions.append(i)
masked_labels.append(tokens[i])
masked_tokens[i] = mask_id
# print("-" * 100 + '\n',
# "TOKEN STR\n", tokens_str + '\n',
# "OUTPUT\n", join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(masked_tokens)), flush=True)
# print("\nTOKEN STR\n", tokens_str + '\n',
# "OUTPUT\n", join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(masked_tokens)) + '\n',
# "FRAC_MASKED: {}\n".format(len(masked_labels) / len(tokens)),
# "-" * 100 + '\n',
# flush=True)
return masked_tokens, masked_positions, masked_labels
......
......@@ -31,11 +31,11 @@ def build_tokenizer(args):
# Select and instantiate the tokenizer.
assert args.vocab_file is not None
if args.tokenizer_type == 'BertWordPieceLowerCase':
tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,
lower_case=True)
tokenizer = BertWordPieceTokenizer(vocab_file=args.vocab_file,
lower_case=True)
elif args.tokenizer_type == 'BertWordPieceCase':
tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,
lower_case=False)
tokenizer = BertWordPieceTokenizer(vocab_file=args.vocab_file,
lower_case=False)
elif args.tokenizer_type == 'GPT2BPETokenizer':
assert args.merge_file is not None
tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)
......@@ -124,7 +124,7 @@ class AbstractTokenizer(ABC):
'tokenizer'.format(self.name))
class _BertWordPieceTokenizer(AbstractTokenizer):
class BertWordPieceTokenizer(AbstractTokenizer):
"""Original BERT wordpiece tokenizer."""
def __init__(self, vocab_file, lower_case=True):
......
......@@ -87,6 +87,11 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
args = get_args()
timers = get_timers()
if args.rank == 0 and args.cased_data_path is not None:
import stanza
stanza.download('en', processors={'ner': 'conll03'}, dir='stanza')
# Model, optimizer, and learning rate.
timers('model and optimizer').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment