Commit e7045139 authored by Neel Kant's avatar Neel Kant
Browse files

Add stanza NER salient span masking

parent b8bb0b49
...@@ -394,6 +394,10 @@ def _add_data_args(parser): ...@@ -394,6 +394,10 @@ def _add_data_args(parser):
group.add_argument('--use-random-spans', action='store_true') group.add_argument('--use-random-spans', action='store_true')
group.add_argument('--allow-trivial-doc', action='store_true') group.add_argument('--allow-trivial-doc', action='store_true')
group.add_argument('--ner-data-path', type=str, default=None) group.add_argument('--ner-data-path', type=str, default=None)
group.add_argument('--cased-data-path', type=str, default=None,
help='path to cased data to use for NER salient span masking')
group.add_argument('--cased-vocab', type=str, default=None,
help='path to cased vocab file to use for NER salient span masking')
return parser return parser
......
...@@ -387,7 +387,7 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, ...@@ -387,7 +387,7 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
num_tokens = len(tokens) num_tokens = len(tokens)
padding_length = max_seq_length - num_tokens padding_length = max_seq_length - num_tokens
assert padding_length >= 0 assert padding_length >= 0
assert len(tokentypes) == num_tokens assert len(tokentypes) == num_tokens, (len(tokentypes), num_tokens)
assert len(masked_positions) == len(masked_labels), (len(masked_positions), len(masked_labels)) assert len(masked_positions) == len(masked_labels), (len(masked_positions), len(masked_labels))
# Tokens and token types. # Tokens and token types.
...@@ -491,6 +491,12 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -491,6 +491,12 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
data_impl, data_impl,
skip_warmup) skip_warmup)
kwargs.update({'ner_dataset': ner_dataset}) kwargs.update({'ner_dataset': ner_dataset})
elif args.cased_data_path is not None:
cased_dataset = get_indexed_dataset_(args.cased_data_path,
data_impl,
skip_warmup)
kwargs.update({'cased_block_dataset': cased_dataset,
'cased_vocab': args.cased_vocab})
dataset = REALMDataset( dataset = REALMDataset(
block_dataset=indexed_dataset, block_dataset=indexed_dataset,
title_dataset=title_dataset, title_dataset=title_dataset,
......
...@@ -20,7 +20,7 @@ class REALMDataset(Dataset): ...@@ -20,7 +20,7 @@ class REALMDataset(Dataset):
""" """
def __init__(self, name, block_dataset, title_dataset, def __init__(self, name, block_dataset, title_dataset,
data_prefix, num_epochs, max_num_samples, masked_lm_prob, data_prefix, num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed, ner_dataset=None): max_seq_length, short_seq_prob, seed, ner_dataset=None, cased_block_dataset=None, cased_vocab=None):
self.name = name self.name = name
self.seed = seed self.seed = seed
self.max_seq_length = max_seq_length self.max_seq_length = max_seq_length
...@@ -29,7 +29,13 @@ class REALMDataset(Dataset): ...@@ -29,7 +29,13 @@ class REALMDataset(Dataset):
self.title_dataset = title_dataset self.title_dataset = title_dataset
self.short_seq_prob = short_seq_prob self.short_seq_prob = short_seq_prob
self.rng = random.Random(self.seed) self.rng = random.Random(self.seed)
self.ner_dataset = ner_dataset self.ner_dataset = ner_dataset
self.cased_block_dataset = cased_block_dataset
self.cased_tokenizer = None
if self.cased_block_dataset is not None:
from megatron.tokenizer.tokenizer import BertWordPieceTokenizer
self.cased_tokenizer = BertWordPieceTokenizer(vocab_file=cased_vocab, lower_case=False)
self.samples_mapping = get_block_samples_mapping( self.samples_mapping = get_block_samples_mapping(
block_dataset, title_dataset, data_prefix, num_epochs, block_dataset, title_dataset, data_prefix, num_epochs,
...@@ -49,7 +55,6 @@ class REALMDataset(Dataset): ...@@ -49,7 +55,6 @@ class REALMDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
start_idx, end_idx, doc_idx, block_idx = self.samples_mapping[idx] start_idx, end_idx, doc_idx, block_idx = self.samples_mapping[idx]
block = [list(self.block_dataset[i]) for i in range(start_idx, end_idx)] block = [list(self.block_dataset[i]) for i in range(start_idx, end_idx)]
# print([len(list(self.block_dataset[i])) for i in range(start_idx, end_idx)], flush=True)
assert len(block) > 1 assert len(block) > 1
block_ner_mask = None block_ner_mask = None
...@@ -57,6 +62,10 @@ class REALMDataset(Dataset): ...@@ -57,6 +62,10 @@ class REALMDataset(Dataset):
block_ner_mask = [list(self.ner_dataset[i]) for i in range(start_idx, end_idx)] block_ner_mask = [list(self.ner_dataset[i]) for i in range(start_idx, end_idx)]
# print([len(list(self.ner_dataset[i])) for i in range(start_idx, end_idx)], flush=True) # print([len(list(self.ner_dataset[i])) for i in range(start_idx, end_idx)], flush=True)
cased_tokens = None
if self.cased_block_dataset is not None:
cased_tokens = [list(self.cased_block_dataset[i]) for i in range(start_idx, end_idx)]
np_rng = np.random.RandomState(seed=(self.seed + idx)) np_rng = np.random.RandomState(seed=(self.seed + idx))
sample = build_realm_training_sample(block, sample = build_realm_training_sample(block,
...@@ -69,6 +78,8 @@ class REALMDataset(Dataset): ...@@ -69,6 +78,8 @@ class REALMDataset(Dataset):
self.pad_id, self.pad_id,
self.masked_lm_prob, self.masked_lm_prob,
block_ner_mask, block_ner_mask,
cased_tokens,
self.cased_tokenizer,
np_rng) np_rng)
sample.update({'query_block_indices': np.array([block_idx]).astype(np.int64)}) sample.update({'query_block_indices': np.array([block_idx]).astype(np.int64)})
return sample return sample
......
...@@ -6,6 +6,12 @@ import time ...@@ -6,6 +6,12 @@ import time
import numpy as np import numpy as np
import spacy import spacy
import torch import torch
try:
import stanza
processors_dict = {'tokenize': 'default', 'mwt': 'default', 'ner': 'conll03'}
stanza_pipeline = stanza.Pipeline('en', processors=processors_dict, use_gpu=True)
except:
pass
from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy
from megatron import get_args, get_tokenizer, print_rank_0, mpu from megatron import get_args, get_tokenizer, print_rank_0, mpu
...@@ -16,7 +22,8 @@ SPACY_NER = spacy.load('en_core_web_lg') ...@@ -16,7 +22,8 @@ SPACY_NER = spacy.load('en_core_web_lg')
def build_realm_training_sample(sample, max_seq_length, def build_realm_training_sample(sample, max_seq_length,
vocab_id_list, vocab_id_to_token_dict, vocab_id_list, vocab_id_to_token_dict,
cls_id, sep_id, mask_id, pad_id, cls_id, sep_id, mask_id, pad_id,
masked_lm_prob, block_ner_mask, np_rng): masked_lm_prob, block_ner_mask, cased_tokens,
cased_tokenizer, np_rng):
tokens = list(itertools.chain(*sample))[:max_seq_length - 2] tokens = list(itertools.chain(*sample))[:max_seq_length - 2]
tokens, tokentypes = create_single_tokens_and_tokentypes(tokens, cls_id, sep_id) tokens, tokentypes = create_single_tokens_and_tokentypes(tokens, cls_id, sep_id)
...@@ -35,8 +42,20 @@ def build_realm_training_sample(sample, max_seq_length, ...@@ -35,8 +42,20 @@ def build_realm_training_sample(sample, max_seq_length,
masked_tokens, masked_positions, masked_labels = get_arrays_using_ner_mask(tokens, block_ner_mask, mask_id) masked_tokens, masked_positions, masked_labels = get_arrays_using_ner_mask(tokens, block_ner_mask, mask_id)
else: else:
try: try:
masked_tokens, masked_positions, masked_labels = salient_span_mask(tokens, mask_id) if args.cased_data_path is not None:
except TypeError: total_len = sum(len(l) for l in sample)
# truncate the last sentence to make it so that the whole thing has length max_seq_length - 2
if total_len > max_seq_length - 2:
offset = -(total_len - (max_seq_length - 2))
sample[-1] = sample[-1][:offset]
masked_tokens, masked_positions, masked_labels = get_stanza_ner_mask(sample, cased_tokens, cased_tokenizer,
cls_id, sep_id, mask_id)
else:
masked_tokens, masked_positions, masked_labels = salient_span_mask(tokens, mask_id)
except:
# print("+" * 100, flush=True)
# print('could not create salient span', flush=True)
# print("+" * 100, flush=True)
# this means the above returned None, and None isn't iterable. # this means the above returned None, and None isn't iterable.
# TODO: consider coding style. # TODO: consider coding style.
max_predictions_per_seq = masked_lm_prob * max_seq_length max_predictions_per_seq = masked_lm_prob * max_seq_length
...@@ -57,6 +76,67 @@ def build_realm_training_sample(sample, max_seq_length, ...@@ -57,6 +76,67 @@ def build_realm_training_sample(sample, max_seq_length,
return train_sample return train_sample
def get_stanza_ner_mask(tokens, cased_tokens, cased_tokenizer, cls_id, sep_id, mask_id):
"""Use stanza to generate NER salient span masks in the loop"""
# assuming that the default tokenizer is uncased.
uncased_tokenizer = get_tokenizer()
block_ner_mask = []
for cased_sent_ids, uncased_sent_ids in zip(cased_tokens, tokens):
# print('>')
token_pos_map = id_to_str_pos_map(uncased_sent_ids, uncased_tokenizer)
# get the cased string and do NER with both toolkits
cased_sent_str = join_str_list(cased_tokenizer.tokenizer.convert_ids_to_tokens(cased_sent_ids))
entities = stanza_pipeline(cased_sent_str).ents
spacy_entities = SPACY_NER(cased_sent_str).ents
# CoNLL doesn't do dates, so we scan with spacy to get the dates.
entities = [e for e in entities if e.text != 'CLS']
entities.extend([e for e in spacy_entities if (e.text != 'CLS' and e.label_ == 'DATE')])
# randomize which entities to look at, and set a target of 12% of tokens being masked
entity_indices = np.arange(len(entities))
np.random.shuffle(entity_indices)
target_num_masks = int(len(cased_sent_ids) * 0.12)
masked_positions = []
for entity_idx in entity_indices[:3]:
# if we have enough masks then break.
if len(masked_positions) > target_num_masks:
break
selected_entity = entities[entity_idx]
# print(">> selected entity: {}".format(selected_entity.text), flush=True)
mask_start = mask_end = 0
set_mask_start = False
# loop for checking where mask should start and end.
while mask_end < len(token_pos_map) and token_pos_map[mask_end] < selected_entity.end_char:
if token_pos_map[mask_start] > selected_entity.start_char:
set_mask_start = True
if not set_mask_start:
mask_start += 1
mask_end += 1
# add offset to indices since our input was list of sentences
masked_positions.extend(range(mask_start - 1, mask_end))
ner_mask = [0] * len(uncased_sent_ids)
for pos in masked_positions:
ner_mask[pos] = 1
block_ner_mask.extend(ner_mask)
# len_tokens = [len(l) for l in tokens]
# print(len_tokens, flush=True)
# print([sum(len_tokens[:i + 1]) for i in range(len(tokens))], flush=True)
tokens = list(itertools.chain(*tokens))
tokens = [cls_id] + tokens + [sep_id]
block_ner_mask = [0] + block_ner_mask + [0]
return get_arrays_using_ner_mask(tokens, block_ner_mask, mask_id)
def get_arrays_using_ner_mask(tokens, block_ner_mask, mask_id): def get_arrays_using_ner_mask(tokens, block_ner_mask, mask_id):
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
tokens_str = join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(tokens)) tokens_str = join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(tokens))
...@@ -65,16 +145,17 @@ def get_arrays_using_ner_mask(tokens, block_ner_mask, mask_id): ...@@ -65,16 +145,17 @@ def get_arrays_using_ner_mask(tokens, block_ner_mask, mask_id):
masked_positions = [] masked_positions = []
masked_labels = [] masked_labels = []
for i in range(len(tokens)): for i in range(len(tokens)):
if block_ner_mask[i] == 1: if block_ner_mask[i] == 1:
masked_positions.append(i) masked_positions.append(i)
masked_labels.append(tokens[i]) masked_labels.append(tokens[i])
masked_tokens[i] = mask_id masked_tokens[i] = mask_id
# print("-" * 100 + '\n', # print("\nTOKEN STR\n", tokens_str + '\n',
# "TOKEN STR\n", tokens_str + '\n', # "OUTPUT\n", join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(masked_tokens)) + '\n',
# "OUTPUT\n", join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(masked_tokens)), flush=True) # "FRAC_MASKED: {}\n".format(len(masked_labels) / len(tokens)),
# "-" * 100 + '\n',
# flush=True)
return masked_tokens, masked_positions, masked_labels return masked_tokens, masked_positions, masked_labels
......
...@@ -31,11 +31,11 @@ def build_tokenizer(args): ...@@ -31,11 +31,11 @@ def build_tokenizer(args):
# Select and instantiate the tokenizer. # Select and instantiate the tokenizer.
assert args.vocab_file is not None assert args.vocab_file is not None
if args.tokenizer_type == 'BertWordPieceLowerCase': if args.tokenizer_type == 'BertWordPieceLowerCase':
tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file, tokenizer = BertWordPieceTokenizer(vocab_file=args.vocab_file,
lower_case=True) lower_case=True)
elif args.tokenizer_type == 'BertWordPieceCase': elif args.tokenizer_type == 'BertWordPieceCase':
tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file, tokenizer = BertWordPieceTokenizer(vocab_file=args.vocab_file,
lower_case=False) lower_case=False)
elif args.tokenizer_type == 'GPT2BPETokenizer': elif args.tokenizer_type == 'GPT2BPETokenizer':
assert args.merge_file is not None assert args.merge_file is not None
tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file) tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)
...@@ -124,7 +124,7 @@ class AbstractTokenizer(ABC): ...@@ -124,7 +124,7 @@ class AbstractTokenizer(ABC):
'tokenizer'.format(self.name)) 'tokenizer'.format(self.name))
class _BertWordPieceTokenizer(AbstractTokenizer): class BertWordPieceTokenizer(AbstractTokenizer):
"""Original BERT wordpiece tokenizer.""" """Original BERT wordpiece tokenizer."""
def __init__(self, vocab_file, lower_case=True): def __init__(self, vocab_file, lower_case=True):
......
...@@ -87,6 +87,11 @@ def pretrain(train_valid_test_dataset_provider, model_provider, ...@@ -87,6 +87,11 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
if args.rank == 0 and args.cased_data_path is not None:
import stanza
stanza.download('en', processors={'ner': 'conll03'}, dir='stanza')
# Model, optimizer, and learning rate. # Model, optimizer, and learning rate.
timers('model and optimizer').start() timers('model and optimizer').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider) model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment