Commit ca0cdfaa authored by Neel Kant's avatar Neel Kant
Browse files

Merge ict-merge into indexer-merge

parents d9d4ce70 a5bfc013
...@@ -20,7 +20,8 @@ import time ...@@ -20,7 +20,8 @@ import time
import torch import torch
from megatron import get_args, print_rank_0 from megatron import get_args
from megatron import print_rank_0
from megatron import mpu from megatron import mpu
from tasks.finetune_utils import build_data_loader from tasks.finetune_utils import build_data_loader
from tasks.finetune_utils import process_batch from tasks.finetune_utils import process_batch
......
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
import torch import torch
from megatron import get_args, print_rank_0 from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
"""GLUE finetuning/evaluation.""" """GLUE finetuning/evaluation."""
from megatron import get_args, print_rank_0 from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron.model.classification import Classification from megatron.model.classification import Classification
from tasks.eval_utils import accuracy_func_provider from tasks.eval_utils import accuracy_func_provider
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
"""Race.""" """Race."""
from megatron import get_args, print_rank_0 from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron.model.multiple_choice import MultipleChoice from megatron.model.multiple_choice import MultipleChoice
from tasks.eval_utils import accuracy_func_provider from tasks.eval_utils import accuracy_func_provider
......
...@@ -21,7 +21,8 @@ import math ...@@ -21,7 +21,8 @@ import math
import numpy as np import numpy as np
import torch import torch
from megatron import get_args, print_rank_0 from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer from megatron import get_tokenizer
from .detokenizer import get_detokenizer from .detokenizer import get_detokenizer
......
...@@ -19,7 +19,8 @@ import math ...@@ -19,7 +19,8 @@ import math
import torch import torch
from megatron import get_args, print_rank_0 from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
......
...@@ -20,7 +20,8 @@ import sys ...@@ -20,7 +20,8 @@ import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir))) os.path.pardir)))
from megatron import get_args, print_rank_0 from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
......
...@@ -24,7 +24,6 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ...@@ -24,7 +24,6 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir))) os.path.pardir)))
import time import time
import numpy as np
import torch import torch
try: try:
import nltk import nltk
...@@ -32,11 +31,8 @@ try: ...@@ -32,11 +31,8 @@ try:
except ImportError: except ImportError:
nltk_available = False nltk_available = False
from megatron.tokenizer import build_tokenizer from megatron.tokenizer import build_tokenizer
from megatron.data import indexed_dataset from megatron.data import indexed_dataset
from megatron.data.realm_dataset_utils import id_to_str_pos_map
# https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer # https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer
...@@ -79,14 +75,6 @@ class Encoder(object): ...@@ -79,14 +75,6 @@ class Encoder(object):
else: else:
Encoder.splitter = IdentitySplitter() Encoder.splitter = IdentitySplitter()
try:
import spacy
print("> Loading spacy")
Encoder.spacy = spacy.load('en_core_web_lg')
print(">> Finished loading spacy")
except:
Encoder.spacy = None
def encode(self, json_line): def encode(self, json_line):
data = json.loads(json_line) data = json.loads(json_line)
ids = {} ids = {}
...@@ -102,56 +90,6 @@ class Encoder(object): ...@@ -102,56 +90,6 @@ class Encoder(object):
ids[key] = doc_ids ids[key] = doc_ids
return ids, len(json_line) return ids, len(json_line)
def encode_with_ner(self, json_line):
if self.spacy is None:
raise ValueError('Cannot do NER without spacy')
data = json.loads(json_line)
ids = {}
ner_masks = {}
for key in self.args.json_keys:
text = data[key]
doc_ids = []
doc_ner_mask = []
for sentence in Encoder.splitter.tokenize(text):
sentence_ids = Encoder.tokenizer.tokenize(sentence)
if len(sentence_ids) > 0:
doc_ids.append(sentence_ids)
# sentence is cased?
# print(sentence)
entities = self.spacy(sentence).ents
undesired_types = ['CARDINAL', 'TIME', 'PERCENT', 'MONEY', 'QUANTITY', 'ORDINAL']
entities = [e for e in entities if e.text != "CLS" and e.label_ not in undesired_types]
# entities = []
masked_positions = []
if len(entities) > 0:
entity_idx = np.random.randint(0, len(entities))
selected_entity = entities[entity_idx]
token_pos_map = id_to_str_pos_map(sentence_ids, Encoder.tokenizer)
mask_start = mask_end = 0
set_mask_start = False
while mask_end < len(token_pos_map) and token_pos_map[mask_end] < selected_entity.end_char:
if token_pos_map[mask_start] > selected_entity.start_char:
set_mask_start = True
if not set_mask_start:
mask_start += 1
mask_end += 1
masked_positions = list(range(mask_start - 1, mask_end))
ner_mask = [0] * len(sentence_ids)
for pos in masked_positions:
ner_mask[pos] = 1
doc_ner_mask.append(ner_mask)
if self.args.append_eod:
doc_ids[-1].append(Encoder.tokenizer.eod)
doc_ner_mask[-1].append(0)
ids[key] = doc_ids
ner_masks[key + '-ner'] = doc_ner_mask
return ids, ner_masks, len(json_line)
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
group = parser.add_argument_group(title='input data') group = parser.add_argument_group(title='input data')
...@@ -188,8 +126,6 @@ def get_args(): ...@@ -188,8 +126,6 @@ def get_args():
help='Number of worker processes to launch') help='Number of worker processes to launch')
group.add_argument('--log-interval', type=int, default=100, group.add_argument('--log-interval', type=int, default=100,
help='Interval between progress updates') help='Interval between progress updates')
group.add_argument('--create-ner-masks', action='store_true',
help='Also create mask tensors for salient span masking')
args = parser.parse_args() args = parser.parse_args()
args.keep_empty = False args.keep_empty = False
...@@ -217,9 +153,6 @@ def main(): ...@@ -217,9 +153,6 @@ def main():
encoder = Encoder(args) encoder = Encoder(args)
tokenizer = build_tokenizer(args) tokenizer = build_tokenizer(args)
pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer) pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
if args.create_ner_masks:
encoded_docs = pool.imap(encoder.encode_with_ner, fin, 25)
else:
encoded_docs = pool.imap(encoder.encode, fin, 25) encoded_docs = pool.imap(encoder.encode, fin, 25)
#encoded_docs = map(encoder.encode, fin) #encoded_docs = map(encoder.encode, fin)
...@@ -232,10 +165,7 @@ def main(): ...@@ -232,10 +165,7 @@ def main():
output_bin_files = {} output_bin_files = {}
output_idx_files = {} output_idx_files = {}
builders = {} builders = {}
output_keys = args.json_keys.copy() for key in args.json_keys:
if args.create_ner_masks:
output_keys.extend([key + '-ner' for key in output_keys])
for key in output_keys:
output_bin_files[key] = "{}_{}_{}.bin".format(args.output_prefix, output_bin_files[key] = "{}_{}_{}.bin".format(args.output_prefix,
key, level) key, level)
output_idx_files[key] = "{}_{}_{}.idx".format(args.output_prefix, output_idx_files[key] = "{}_{}_{}.idx".format(args.output_prefix,
...@@ -249,24 +179,12 @@ def main(): ...@@ -249,24 +179,12 @@ def main():
total_bytes_processed = 0 total_bytes_processed = 0
print("Time to startup:", startup_end - startup_start) print("Time to startup:", startup_end - startup_start)
# for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1): for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1):
for i, doc_data in enumerate(encoded_docs, start=1):
if args.create_ner_masks:
doc, ner_masks, bytes_processed = doc_data
else:
doc, bytes_processed = doc_data
total_bytes_processed += bytes_processed total_bytes_processed += bytes_processed
for key, sentences in doc.items(): for key, sentences in doc.items():
for sentence in sentences: for sentence in sentences:
builders[key].add_item(torch.IntTensor(sentence)) builders[key].add_item(torch.IntTensor(sentence))
builders[key].end_document() builders[key].end_document()
if args.create_ner_masks:
for key, sentence_masks in ner_masks.items():
for mask in sentence_masks:
builders[key].add_item(torch.IntTensor(mask))
builders[key].end_document()
if i % args.log_interval == 0: if i % args.log_interval == 0:
current = time.time() current = time.time()
elapsed = current - proc_start elapsed = current - proc_start
...@@ -275,7 +193,7 @@ def main(): ...@@ -275,7 +193,7 @@ def main():
f"({i/elapsed} docs/s, {mbs} MB/s).", f"({i/elapsed} docs/s, {mbs} MB/s).",
file=sys.stderr) file=sys.stderr)
for key in output_keys: for key in args.json_keys:
builders[key].finalize(output_idx_files[key]) builders[key].finalize(output_idx_files[key])
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment