Commit 22c0e300 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'staging' into 'master'

April 2020 Release

See merge request ADLR/megatron-lm!69
parents 70174ae3 acf8780b
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""ALBERT Style dataset.""" """BERT Style dataset."""
import os import os
import time import time
...@@ -22,24 +22,18 @@ import numpy as np ...@@ -22,24 +22,18 @@ import numpy as np
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron.data import helpers
from megatron.data import FullBertTokenizer
from megatron.data.dataset_utils import build_training_sample from megatron.data.dataset_utils import build_training_sample
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron.utils import print_rank_0 from megatron import print_rank_0
def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl, def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
splits_string, train_valid_test_num_samples, train_valid_test_num_samples,
max_seq_length, masked_lm_prob, max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup): short_seq_prob, seed, skip_warmup):
# Tokenizer is the same
tokenizer = FullBertTokenizer(vocab_file, do_lower_case=True)
print_rank_0(' > using full BERT tokenizer with vocabulary size: {}'.format(
tokenizer.vocab_size()))
# Indexed dataset. # Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix, indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl, data_impl,
...@@ -53,6 +47,7 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl, ...@@ -53,6 +47,7 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,
# Print stats about the splits. # Print stats about the splits.
print_rank_0(' > dataset split:') print_rank_0(' > dataset split:')
def print_split_stats(name, index): def print_split_stats(name, index):
print_rank_0(' {}:'.format(name)) print_rank_0(' {}:'.format(name))
print_rank_0(' document indices in [{}, {}) total of {} ' print_rank_0(' document indices in [{}, {}) total of {} '
...@@ -79,10 +74,9 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl, ...@@ -79,10 +74,9 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,
# New doc_idx view. # New doc_idx view.
indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index]) indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index])
# Build the dataset accordingly. # Build the dataset accordingly.
dataset = AlbertDataset( dataset = BertDataset(
name=name, name=name,
indexed_dataset=indexed_dataset, indexed_dataset=indexed_dataset,
tokenizer=tokenizer,
data_prefix=data_prefix, data_prefix=data_prefix,
num_epochs=None, num_epochs=None,
max_num_samples=train_valid_test_num_samples[index], max_num_samples=train_valid_test_num_samples[index],
...@@ -105,9 +99,9 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl, ...@@ -105,9 +99,9 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,
return (train_dataset, valid_dataset, test_dataset) return (train_dataset, valid_dataset, test_dataset)
class AlbertDataset(Dataset): class BertDataset(Dataset):
def __init__(self, name, indexed_dataset, tokenizer, data_prefix, def __init__(self, name, indexed_dataset, data_prefix,
num_epochs, max_num_samples, masked_lm_prob, num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed): max_seq_length, short_seq_prob, seed):
...@@ -117,11 +111,9 @@ class AlbertDataset(Dataset): ...@@ -117,11 +111,9 @@ class AlbertDataset(Dataset):
self.masked_lm_prob = masked_lm_prob self.masked_lm_prob = masked_lm_prob
self.max_seq_length = max_seq_length self.max_seq_length = max_seq_length
# Tokenizer and dataset. # Dataset.
self.tokenizer = tokenizer
self.indexed_dataset = indexed_dataset self.indexed_dataset = indexed_dataset
# Build the samples mapping. # Build the samples mapping.
self.samples_mapping = get_samples_mapping_(self.indexed_dataset, self.samples_mapping = get_samples_mapping_(self.indexed_dataset,
data_prefix, data_prefix,
...@@ -133,22 +125,17 @@ class AlbertDataset(Dataset): ...@@ -133,22 +125,17 @@ class AlbertDataset(Dataset):
self.name) self.name)
# Vocab stuff. # Vocab stuff.
self.vocab_id_list = list(self.tokenizer.inv_vocab.keys()) tokenizer = get_tokenizer()
self.vocab_id_to_token_dict = self.tokenizer.inv_vocab self.vocab_id_list = list(tokenizer.inv_vocab.keys())
self.cls_id = self.tokenizer.vocab['[CLS]'] self.vocab_id_to_token_dict = tokenizer.inv_vocab
self.sep_id = self.tokenizer.vocab['[SEP]'] self.cls_id = tokenizer.cls
self.mask_id = self.tokenizer.vocab['[MASK]'] self.sep_id = tokenizer.sep
self.pad_id = self.tokenizer.vocab['[PAD]'] self.mask_id = tokenizer.mask
self.pad_id = tokenizer.pad
def num_tokens(self):
return self.tokenizer.vocab_size()
def __len__(self): def __len__(self):
return self.samples_mapping.shape[0] return self.samples_mapping.shape[0]
def __getitem__(self, idx): def __getitem__(self, idx):
start_index, end_index, seq_length = self.samples_mapping[idx] start_index, end_index, seq_length = self.samples_mapping[idx]
...@@ -159,7 +146,7 @@ class AlbertDataset(Dataset): ...@@ -159,7 +146,7 @@ class AlbertDataset(Dataset):
# python randint is inclusive whereas the numpy one is exclusive. # python randint is inclusive whereas the numpy one is exclusive.
np_rng = np.random.RandomState(seed=(self.seed + idx)) np_rng = np.random.RandomState(seed=(self.seed + idx))
return build_training_sample(sample, seq_length, return build_training_sample(sample, seq_length,
self.max_seq_length, # needed for padding self.max_seq_length, # needed for padding
self.vocab_id_list, self.vocab_id_list,
self.vocab_id_to_token_dict, self.vocab_id_to_token_dict,
self.cls_id, self.sep_id, self.cls_id, self.sep_id,
...@@ -203,7 +190,7 @@ def get_train_valid_test_split_(splits_string, size): ...@@ -203,7 +190,7 @@ def get_train_valid_test_split_(splits_string, size):
splits = splits[:3] splits = splits[:3]
splits_sum = sum(splits) splits_sum = sum(splits)
assert splits_sum > 0.0 assert splits_sum > 0.0
splits = [split/splits_sum for split in splits] splits = [split / splits_sum for split in splits]
splits_index = [0] splits_index = [0]
for index, split in enumerate(splits): for index, split in enumerate(splits):
splits_index.append(splits_index[index] + splits_index.append(splits_index[index] +
...@@ -259,12 +246,16 @@ def get_samples_mapping_(indexed_dataset, ...@@ -259,12 +246,16 @@ def get_samples_mapping_(indexed_dataset,
start_time = time.time() start_time = time.time()
print_rank_0(' > building sapmles index mapping for {} ...'.format( print_rank_0(' > building sapmles index mapping for {} ...'.format(
name)) name))
# First compile and then import.
from megatron.data.dataset_utils import compile_helper
compile_helper()
from megatron.data import helpers
samples_mapping = helpers.build_mapping( samples_mapping = helpers.build_mapping(
indexed_dataset.doc_idx, indexed_dataset.doc_idx,
indexed_dataset.sizes, indexed_dataset.sizes,
num_epochs, num_epochs,
max_num_samples, max_num_samples,
max_seq_length-3, # account for added tokens max_seq_length - 3, # account for added tokens
short_seq_prob, short_seq_prob,
seed, seed,
verbose) verbose)
......
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
import unicodedata
import six
def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
"""Checks whether the casing config is consistent with the checkpoint name."""
# The casing has to be passed in by the user and there is no explicit check
# as to whether it matches the checkpoint. The casing information probably
# should have been stored in the bert_config.json file, but it's not, so
# we have to heuristically detect it to validate.
if not init_checkpoint:
return
m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
if m is None:
return
model_name = m.group(1)
lower_models = [
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
]
cased_models = [
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
"multi_cased_L-12_H-768_A-12"
]
is_bad_config = False
if model_name in lower_models and not do_lower_case:
is_bad_config = True
actual_flag = "False"
case_name = "lowercased"
opposite_flag = "True"
if model_name in cased_models and do_lower_case:
is_bad_config = True
actual_flag = "True"
case_name = "cased"
opposite_flag = "False"
if is_bad_config:
raise ValueError(
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
"However, `%s` seems to be a %s model, so you "
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
"how the model was pre-training. If this error is wrong, please "
"just comment out this check." % (actual_flag, init_checkpoint,
model_name, case_name, opposite_flag))
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def printable_text(text):
"""Returns text encoded in a way suitable for print or `tf.logging`."""
# These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string.
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text
elif isinstance(text, unicode):
return text.encode("utf-8")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
index = 0
with open(vocab_file, "r") as reader:
while True:
token = convert_to_unicode(reader.readline())
if not token:
break
token = token.strip()
vocab[token] = index
index += 1
return vocab
def convert_by_vocab(vocab, items):
"""Converts a sequence of [tokens|ids] using the vocab."""
output = []
for item in items:
output.append(vocab[item])
return output
def convert_tokens_to_ids(vocab, tokens):
return convert_by_vocab(vocab, tokens)
def convert_ids_to_tokens(inv_vocab, ids):
return convert_by_vocab(inv_vocab, ids)
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
class FullTokenizer(object):
"""Runs end-to-end tokenziation."""
def __init__(self, vocab_file, do_lower_case=True):
self.vocab = load_vocab(vocab_file)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
def tokenize(self, text):
split_tokens = []
for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
return split_tokens
def convert_tokens_to_ids(self, tokens):
return convert_by_vocab(self.vocab, tokens)
def convert_ids_to_tokens(self, ids):
return convert_by_vocab(self.inv_vocab, ids)
def vocab_size(self):
return len(self.vocab)
class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, do_lower_case=True):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
self.do_lower_case = do_lower_case
def tokenize(self, text):
"""Tokenizes a piece of text."""
text = convert_to_unicode(text)
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class WordpieceTokenizer(object):
"""Runs WordPiece tokenziation."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.
Returns:
A list of wordpiece tokens.
"""
text = convert_to_unicode(text)
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat in ("Cc", "Cf"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
...@@ -14,10 +14,27 @@ ...@@ -14,10 +14,27 @@
# limitations under the License. # limitations under the License.
# Most of the code here has been copied from:
# https://github.com/google-research/albert/blob/master/create_pretraining_data.py
# with some modifications.
import collections import collections
import numpy as np import numpy as np
def compile_helper():
"""Compile helper function ar runtime. Make sure this
is invoked on a single process."""
import os
import subprocess
path = os.path.abspath(os.path.dirname(__file__))
ret = subprocess.run(['make', '-C', path])
if ret.returncode != 0:
print("Making C++ dataset helpers module failed, exiting.")
import sys
sys.exit(1)
def build_training_sample(sample, def build_training_sample(sample,
target_seq_length, max_seq_length, target_seq_length, max_seq_length,
vocab_id_list, vocab_id_to_token_dict, vocab_id_list, vocab_id_to_token_dict,
...@@ -132,6 +149,7 @@ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng): ...@@ -132,6 +149,7 @@ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
tokens.pop() tokens.pop()
return True return True
def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
"""Merge segments A and B, add [CLS] and [SEP] and build tokentypes.""" """Merge segments A and B, add [CLS] and [SEP] and build tokentypes."""
...@@ -163,12 +181,12 @@ MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ...@@ -163,12 +181,12 @@ MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
def is_start_piece(piece): def is_start_piece(piece):
"""Check if the current word piece is the starting piece (BERT).""" """Check if the current word piece is the starting piece (BERT)."""
# When a word has been split into # When a word has been split into
# WordPieces, the first token does not have any marker and any subsequence # WordPieces, the first token does not have any marker and any subsequence
# tokens are prefixed with ##. So whenever we see the ## token, we # tokens are prefixed with ##. So whenever we see the ## token, we
# append it to the previous set of word indexes. # append it to the previous set of word indexes.
return not piece.startswith("##") return not piece.startswith("##")
def create_masked_lm_predictions(tokens, def create_masked_lm_predictions(tokens,
...@@ -181,178 +199,178 @@ def create_masked_lm_predictions(tokens, ...@@ -181,178 +199,178 @@ def create_masked_lm_predictions(tokens,
do_whole_word_mask=True, do_whole_word_mask=True,
favor_longer_ngram=False, favor_longer_ngram=False,
do_permutation=False): do_permutation=False):
"""Creates the predictions for the masked LM objective. """Creates the predictions for the masked LM objective.
Note: Tokens here are vocab ids and not text tokens.""" Note: Tokens here are vocab ids and not text tokens."""
cand_indexes = [] cand_indexes = []
# Note(mingdachen): We create a list for recording if the piece is # Note(mingdachen): We create a list for recording if the piece is
# the starting piece of current token, where 1 means true, so that # the starting piece of current token, where 1 means true, so that
# on-the-fly whole word masking is possible. # on-the-fly whole word masking is possible.
token_boundary = [0] * len(tokens) token_boundary = [0] * len(tokens)
for (i, token) in enumerate(tokens): for (i, token) in enumerate(tokens):
if token == cls_id or token == sep_id: if token == cls_id or token == sep_id:
token_boundary[i] = 1 token_boundary[i] = 1
continue continue
# Whole Word Masking means that if we mask all of the wordpieces # Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word. # corresponding to an original word.
# #
# Note that Whole Word Masking does *not* change the training code # Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed # at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary. # over the entire vocabulary.
if (do_whole_word_mask and len(cand_indexes) >= 1 and if (do_whole_word_mask and len(cand_indexes) >= 1 and
not is_start_piece(vocab_id_to_token_dict[token])): not is_start_piece(vocab_id_to_token_dict[token])):
cand_indexes[-1].append(i) cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
if is_start_piece(vocab_id_to_token_dict[token]):
token_boundary[i] = 1
output_tokens = list(tokens)
masked_lm_positions = []
masked_lm_labels = []
if masked_lm_prob == 0:
return (output_tokens, masked_lm_positions,
masked_lm_labels, token_boundary)
num_to_predict = min(max_predictions_per_seq,
max(1, int(round(len(tokens) * masked_lm_prob))))
# Note(mingdachen):
# By default, we set the probilities to favor shorter ngram sequences.
ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64)
pvals = 1. / np.arange(1, max_ngrams + 1)
pvals /= pvals.sum(keepdims=True)
if favor_longer_ngram:
pvals = pvals[::-1]
ngram_indexes = []
for idx in range(len(cand_indexes)):
ngram_index = []
for n in ngrams:
ngram_index.append(cand_indexes[idx:idx+n])
ngram_indexes.append(ngram_index)
np_rng.shuffle(ngram_indexes)
masked_lms = []
covered_indexes = set()
for cand_index_set in ngram_indexes:
if len(masked_lms) >= num_to_predict:
break
if not cand_index_set:
continue
# Note(mingdachen):
# Skip current piece if they are covered in lm masking or previous ngrams.
for index_set in cand_index_set[0]:
for index in index_set:
if index in covered_indexes:
continue
n = np_rng.choice(ngrams[:len(cand_index_set)],
p=pvals[:len(cand_index_set)] /
pvals[:len(cand_index_set)].sum(keepdims=True))
index_set = sum(cand_index_set[n - 1], [])
n -= 1
# Note(mingdachen):
# Repeatedly looking for a candidate that does not exceed the
# maximum number of predictions by trying shorter ngrams.
while len(masked_lms) + len(index_set) > num_to_predict:
if n == 0:
break
index_set = sum(cand_index_set[n - 1], [])
n -= 1
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_token = None
# 80% of the time, replace with [MASK]
if np_rng.random() < 0.8:
masked_token = mask_id
else:
# 10% of the time, keep original
if np_rng.random() < 0.5:
masked_token = tokens[index]
# 10% of the time, replace with random word
else: else:
masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))] cand_indexes.append([i])
if is_start_piece(vocab_id_to_token_dict[token]):
token_boundary[i] = 1
output_tokens[index] = masked_token output_tokens = list(tokens)
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) masked_lm_positions = []
assert len(masked_lms) <= num_to_predict masked_lm_labels = []
np_rng.shuffle(ngram_indexes) if masked_lm_prob == 0:
return (output_tokens, masked_lm_positions,
masked_lm_labels, token_boundary)
select_indexes = set() num_to_predict = min(max_predictions_per_seq,
if do_permutation: max(1, int(round(len(tokens) * masked_lm_prob))))
for cand_index_set in ngram_indexes:
if len(select_indexes) >= num_to_predict:
break
if not cand_index_set:
continue
# Note(mingdachen):
# Skip current piece if they are covered in lm masking or previous ngrams.
for index_set in cand_index_set[0]:
for index in index_set:
if index in covered_indexes or index in select_indexes:
continue
n = np.random.choice(ngrams[:len(cand_index_set)], # Note(mingdachen):
p=pvals[:len(cand_index_set)] / # By default, we set the probilities to favor shorter ngram sequences.
pvals[:len(cand_index_set)].sum(keepdims=True)) ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64)
index_set = sum(cand_index_set[n - 1], []) pvals = 1. / np.arange(1, max_ngrams + 1)
n -= 1 pvals /= pvals.sum(keepdims=True)
if favor_longer_ngram:
pvals = pvals[::-1]
ngram_indexes = []
for idx in range(len(cand_indexes)):
ngram_index = []
for n in ngrams:
ngram_index.append(cand_indexes[idx:idx + n])
ngram_indexes.append(ngram_index)
while len(select_indexes) + len(index_set) > num_to_predict: np_rng.shuffle(ngram_indexes)
if n == 0:
break masked_lms = []
covered_indexes = set()
for cand_index_set in ngram_indexes:
if len(masked_lms) >= num_to_predict:
break
if not cand_index_set:
continue
# Note(mingdachen):
# Skip current piece if they are covered in lm masking or previous ngrams.
for index_set in cand_index_set[0]:
for index in index_set:
if index in covered_indexes:
continue
n = np_rng.choice(ngrams[:len(cand_index_set)],
p=pvals[:len(cand_index_set)] /
pvals[:len(cand_index_set)].sum(keepdims=True))
index_set = sum(cand_index_set[n - 1], []) index_set = sum(cand_index_set[n - 1], [])
n -= 1 n -= 1
# If adding a whole-word mask would exceed the maximum number of # Note(mingdachen):
# predictions, then just skip this candidate. # Repeatedly looking for a candidate that does not exceed the
if len(select_indexes) + len(index_set) > num_to_predict: # maximum number of predictions by trying shorter ngrams.
continue while len(masked_lms) + len(index_set) > num_to_predict:
is_any_index_covered = False if n == 0:
for index in index_set: break
if index in covered_indexes or index in select_indexes: index_set = sum(cand_index_set[n - 1], [])
is_any_index_covered = True n -= 1
break # If adding a whole-word mask would exceed the maximum number of
if is_any_index_covered: # predictions, then just skip this candidate.
continue if len(masked_lms) + len(index_set) > num_to_predict:
for index in index_set: continue
select_indexes.add(index) is_any_index_covered = False
assert len(select_indexes) <= num_to_predict for index in index_set:
if index in covered_indexes:
select_indexes = sorted(select_indexes) is_any_index_covered = True
permute_indexes = list(select_indexes) break
np_rng.shuffle(permute_indexes) if is_any_index_covered:
orig_token = list(output_tokens) continue
for index in index_set:
for src_i, tgt_i in zip(select_indexes, permute_indexes): covered_indexes.add(index)
output_tokens[src_i] = orig_token[tgt_i]
masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i])) masked_token = None
# 80% of the time, replace with [MASK]
masked_lms = sorted(masked_lms, key=lambda x: x.index) if np_rng.random() < 0.8:
masked_token = mask_id
for p in masked_lms: else:
masked_lm_positions.append(p.index) # 10% of the time, keep original
masked_lm_labels.append(p.label) if np_rng.random() < 0.5:
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary) masked_token = tokens[index]
# 10% of the time, replace with random word
else:
masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))]
output_tokens[index] = masked_token
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
assert len(masked_lms) <= num_to_predict
np_rng.shuffle(ngram_indexes)
select_indexes = set()
if do_permutation:
for cand_index_set in ngram_indexes:
if len(select_indexes) >= num_to_predict:
break
if not cand_index_set:
continue
# Note(mingdachen):
# Skip current piece if they are covered in lm masking or previous ngrams.
for index_set in cand_index_set[0]:
for index in index_set:
if index in covered_indexes or index in select_indexes:
continue
n = np.random.choice(ngrams[:len(cand_index_set)],
p=pvals[:len(cand_index_set)] /
pvals[:len(cand_index_set)].sum(keepdims=True))
index_set = sum(cand_index_set[n - 1], [])
n -= 1
while len(select_indexes) + len(index_set) > num_to_predict:
if n == 0:
break
index_set = sum(cand_index_set[n - 1], [])
n -= 1
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(select_indexes) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes or index in select_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
select_indexes.add(index)
assert len(select_indexes) <= num_to_predict
select_indexes = sorted(select_indexes)
permute_indexes = list(select_indexes)
np_rng.shuffle(permute_indexes)
orig_token = list(output_tokens)
for src_i, tgt_i in zip(select_indexes, permute_indexes):
output_tokens[src_i] = orig_token[tgt_i]
masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i]))
masked_lms = sorted(masked_lms, key=lambda x: x.index)
for p in masked_lms:
masked_lm_positions.append(p.index)
masked_lm_labels.append(p.label)
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary)
def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
...@@ -367,12 +385,12 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, ...@@ -367,12 +385,12 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
assert len(masked_positions) == len(masked_labels) assert len(masked_positions) == len(masked_labels)
# Tokens and token types. # Tokens and token types.
filler = [pad_id]*padding_length filler = [pad_id] * padding_length
tokens_np = np.array(tokens + filler, dtype=np.int64) tokens_np = np.array(tokens + filler, dtype=np.int64)
tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) tokentypes_np = np.array(tokentypes + filler, dtype=np.int64)
# Padding mask. # Padding mask.
padding_mask_np = np.array([1]*num_tokens + [0]*padding_length, padding_mask_np = np.array([1] * num_tokens + [0] * padding_length,
dtype=np.int64) dtype=np.int64)
# Lables and loss mask. # Lables and loss mask.
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT2 style dataset."""
import os
import time
import numpy as np
import torch
from megatron import print_rank_0
from megatron import mpu
from megatron.data.bert_dataset import get_train_valid_test_split_
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
seq_length, seed, skip_warmup):
"""Build train, valid, and test datasets."""
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
skip_warmup)
total_num_of_documents = indexed_dataset.sizes.shape[0]
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
# Print stats about the splits.
print_rank_0(' > dataset split:')
def print_split_stats(name, index):
print_rank_0(' {}:'.format(name))
print_rank_0(' document indices in [{}, {}) total of {} '
'documents'.format(splits[index], splits[index + 1],
splits[index + 1] - splits[index]))
print_split_stats('train', 0)
print_split_stats('validation', 1)
print_split_stats('test', 2)
def build_dataset(index, name):
dataset = None
if splits[index + 1] > splits[index]:
documents = np.arange(start=splits[index], stop=splits[index + 1],
step=1, dtype=np.int32)
dataset = GPT2Dataset(name, data_prefix,
documents, indexed_dataset,
train_valid_test_num_samples[index],
seq_length, seed)
return dataset
train_dataset = build_dataset(0, 'train')
valid_dataset = build_dataset(1, 'valid')
test_dataset = build_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset)
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
"""Build indexed dataset."""
print_rank_0(' > building dataset index ...')
start_time = time.time()
indexed_dataset = make_indexed_dataset(data_prefix,
data_impl,
skip_warmup)
print_rank_0(' > finished creating indexed dataset in {:4f} '
'seconds'.format(time.time() - start_time))
print_rank_0(' number of documents: {}'.format(
indexed_dataset.sizes.shape[0]))
return indexed_dataset
class GPT2Dataset(torch.utils.data.Dataset):
def __init__(self, name, data_prefix, documents, indexed_dataset,
num_samples, seq_length, seed):
self.name = name
self.indexed_dataset = indexed_dataset
# Checks
assert np.min(documents) >= 0
assert np.max(documents) < indexed_dataset.sizes.shape[0]
# Build index mappings.
self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings(
self.name, data_prefix, documents, self.indexed_dataset.sizes,
num_samples, seq_length, seed)
def __len__(self):
# -1 is due to data structure used to retieve the index:
# sample i --> [sample_idx[i], sample_idx[i+1])
return self.sample_idx.shape[0] - 1
def __getitem__(self, idx):
# Get the shuffled index.
idx = self.shuffle_idx[idx]
# Start and end documents and offsets.
doc_index_f = self.sample_idx[idx][0]
doc_index_l = self.sample_idx[idx + 1][0]
offset_f = self.sample_idx[idx][1]
offset_l = self.sample_idx[idx + 1][1]
# If we are within the same document, just extract the chunk.
if doc_index_f == doc_index_l:
sample = self.indexed_dataset.get(self.doc_idx[doc_index_f],
offset=offset_f,
length=offset_l - offset_f + 1)
else:
# Otherwise, get the rest of the initial document.
sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f],
offset=offset_f)]
# Loop over all in between documents and add the entire document.
for i in range(doc_index_f + 1, doc_index_l):
sample_list.append(self.indexed_dataset.get(self.doc_idx[i]))
# And finally add the relevant portion of last document.
sample_list.append(self.indexed_dataset.get(
self.doc_idx[doc_index_l],
length=offset_l + 1))
sample = np.concatenate(sample_list)
return {'text': np.array(sample, dtype=np.int64)}
def _build_index_mappings(name, data_prefix, documents, sizes,
num_samples, seq_length, seed):
"""Build doc-idx, sample-idx, and shuffle-idx.
doc-idx: is an array (ordered) of documents to be used in training.
sample-idx: is the start document index and document offset for each
training sample.
shuffle-idx: maps the sample index into a random index into sample-idx.
"""
# Number of tokens in each epoch and number of required epochs.
tokens_per_epoch = _num_tokens(documents, sizes)
num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples)
# rng state
np_rng = np.random.RandomState(seed=seed)
# Filename of the index mappings.
_filename = data_prefix
_filename += '_{}_indexmap'.format(name)
_filename += '_{}ns'.format(num_samples)
_filename += '_{}sl'.format(seq_length)
_filename += '_{}s'.format(seed)
doc_idx_filename = _filename + '_doc_idx.npy'
sample_idx_filename = _filename + '_sample_idx.npy'
shuffle_idx_filename = _filename + '_shuffle_idx.npy'
# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0:
if (not os.path.isfile(doc_idx_filename)) or \
(not os.path.isfile(sample_idx_filename)) or \
(not os.path.isfile(shuffle_idx_filename)):
print_rank_0(' > WARNING: could not find index map files, building '
'the indices on rank 0 ...')
# doc-idx.
start_time = time.time()
doc_idx = _build_doc_idx(documents, num_epochs, np_rng)
np.save(doc_idx_filename, doc_idx, allow_pickle=True)
print_rank_0(' > elasped time to build and save doc-idx mapping '
'(seconds): {:4f}'.format(time.time() - start_time))
# sample-idx.
start_time = time.time()
# Use C++ implementation for speed.
# First compile and then import.
from megatron.data.dataset_utils import compile_helper
compile_helper()
from megatron.data import helpers
assert doc_idx.dtype == np.int32
assert sizes.dtype == np.int32
sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length,
num_epochs, tokens_per_epoch)
# sample_idx = _build_sample_idx(sizes, doc_idx, seq_length,
# num_epochs, tokens_per_epoch)
np.save(sample_idx_filename, sample_idx, allow_pickle=True)
print_rank_0(' > elasped time to build and save sample-idx mapping '
'(seconds): {:4f}'.format(time.time() - start_time))
# shuffle-idx.
start_time = time.time()
# -1 is due to data structure used to retieve the index:
# sample i --> [sample_idx[i], sample_idx[i+1])
shuffle_idx = _build_shuffle_idx(sample_idx.shape[0] - 1, np_rng)
np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
print_rank_0(' > elasped time to build and save shuffle-idx mapping'
' (seconds): {:4f}'.format(time.time() - start_time))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
# Load mappings.
start_time = time.time()
print_rank_0(' > loading doc-idx mapping from {}'.format(
doc_idx_filename))
doc_idx = np.load(doc_idx_filename, allow_pickle=True)
print_rank_0(' > loading sample-idx mapping from {}'.format(
sample_idx_filename))
sample_idx = np.load(sample_idx_filename, allow_pickle=True)
print_rank_0(' > loading shuffle-idx mapping from {}'.format(
shuffle_idx_filename))
shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True)
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
print_rank_0(' total number of samples: {}'.format(
sample_idx.shape[0]))
print_rank_0(' total number of epochs: {}'.format(num_epochs))
return doc_idx, sample_idx, shuffle_idx
def _num_tokens(documents, sizes):
"""Total number of tokens in the dataset."""
return np.sum(sizes[documents])
def _num_epochs(tokens_per_epoch, seq_length, num_samples):
"""Based on number of samples and sequence lenght, calculate how many
epochs will be needed."""
num_epochs = 0
total_tokens = 0
while True:
num_epochs += 1
total_tokens += tokens_per_epoch
# -1 is because we need to retrieve seq_length + 1 token each time
# but the last token will overlap with the first token of the next
# sample except for the last sample.
if ((total_tokens - 1) // seq_length) >= num_samples:
return num_epochs
def _build_doc_idx(documents, num_epochs, np_rng):
"""Build an array with length = number-of-epochs * number-of-dcuments.
Each index is mapped to a corresponding document."""
doc_idx = np.mgrid[0:num_epochs, 0:len(documents)][1]
doc_idx[:] = documents
doc_idx = doc_idx.reshape(-1)
doc_idx = doc_idx.astype(np.int32)
np_rng.shuffle(doc_idx)
return doc_idx
def _build_sample_idx(sizes, doc_idx, seq_length,
num_epochs, tokens_per_epoch):
"""Sample index mapping is a 2D array with sizes
[number-of-samples + 1, 2] where [..., 0] contains
the index into `doc_idx` and [..., 1] is the
starting offset in that document."""
# Total number of samples. For -1 see comments in `_num_epochs`.
num_samples = (num_epochs * tokens_per_epoch - 1) // seq_length
sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int32)
# Index into sample_idx.
sample_index = 0
# Index into doc_idx.
doc_idx_index = 0
# Begining offset for each document.
doc_offset = 0
# Start with first document and no offset.
sample_idx[sample_index][0] = doc_idx_index
sample_idx[sample_index][1] = doc_offset
sample_index += 1
while sample_index <= num_samples:
# Start with a fresh sequence.
remaining_seq_length = seq_length + 1
while remaining_seq_length != 0:
# Get the document length.
doc_id = doc_idx[doc_idx_index]
doc_length = sizes[doc_id] - doc_offset
# And add it to the current sequence.
remaining_seq_length -= doc_length
# If we have more than a full sequence, adjust offset and set
# remaining length to zero so we return from the while loop.
# Note that -1 here is for the same reason we have -1 in
# `_num_epochs` calculations.
if remaining_seq_length <= 0:
doc_offset += (remaining_seq_length + doc_length - 1)
remaining_seq_length = 0
else:
# Otherwise, start from the begining of the next document.
doc_idx_index += 1
doc_offset = 0
# Record the sequence.
sample_idx[sample_index][0] = doc_idx_index
sample_idx[sample_index][1] = doc_offset
sample_index += 1
return sample_idx
def _build_shuffle_idx(size, np_rng):
"""Build the range [0, size) and shuffle."""
dtype_ = np.uint32
if size >= (np.iinfo(np.uint32).max - 1):
dtype_ = np.int64
shuffle_idx = np.arange(start=0, stop=size, step=1, dtype=dtype_)
np_rng.shuffle(shuffle_idx)
return shuffle_idx
/* /*
coding=utf-8 coding=utf-8
Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -33,6 +33,95 @@ using namespace std; ...@@ -33,6 +33,95 @@ using namespace std;
const int32_t LONG_SENTENCE_LEN = 512; const int32_t LONG_SENTENCE_LEN = 512;
py::array build_sample_idx(const py::array_t<int32_t>& sizes_,
const py::array_t<int32_t>& doc_idx_,
const int32_t seq_length,
const int32_t num_epochs,
const int64_t tokens_per_epoch) {
/* Sample index (sample_idx) is used for gpt2 like dataset for which
the documents are flattened and the samples are built based on this
1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2]
where [..., 0] contains the index into `doc_idx` and [..., 1] is the
starting offset in that document.*/
// Consistency checks.
assert(seq_length > 1);
assert(num_epochs > 0);
assert(tokens_per_epoch > 1);
// Remove bound checks.
auto sizes = sizes_.unchecked<1>();
auto doc_idx = doc_idx_.unchecked<1>();
// Mapping and it's length (1D).
int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length;
int32_t* sample_idx = new int32_t[2*(num_samples+1)];
cout << " using:" << endl << std::flush;
cout << " number of documents: " <<
doc_idx_.shape(0) / num_epochs << endl << std::flush;
cout << " number of epochs: " << num_epochs <<
endl << std::flush;
cout << " sequence length: " << seq_length <<
endl << std::flush;
cout << " total number of samples: " << num_samples <<
endl << std::flush;
// Index into sample_idx.
int64_t sample_index = 0;
// Index into doc_idx.
int64_t doc_idx_index = 0;
// Begining offset for each document.
int32_t doc_offset = 0;
// Start with first document and no offset.
sample_idx[2 * sample_index] = doc_idx_index;
sample_idx[2 * sample_index + 1] = doc_offset;
++sample_index;
while (sample_index <= num_samples) {
// Start with a fresh sequence.
int32_t remaining_seq_length = seq_length + 1;
while (remaining_seq_length != 0) {
// Get the document length.
auto doc_id = doc_idx[doc_idx_index];
auto doc_length = sizes[doc_id] - doc_offset;
// And add it to the current sequence.
remaining_seq_length -= doc_length;
// If we have more than a full sequence, adjust offset and set
// remaining length to zero so we return from the while loop.
// Note that -1 here is for the same reason we have -1 in
// `_num_epochs` calculations.
if (remaining_seq_length <= 0) {
doc_offset += (remaining_seq_length + doc_length - 1);
remaining_seq_length = 0;
} else {
// Otherwise, start from the begining of the next document.
++doc_idx_index;
doc_offset = 0;
}
}
// Record the sequence.
sample_idx[2 * sample_index] = doc_idx_index;
sample_idx[2 * sample_index + 1] = doc_offset;
++sample_index;
}
// Method to deallocate memory.
py::capsule free_when_done(sample_idx, [](void *mem_) {
int32_t *mem = reinterpret_cast<int32_t*>(mem_);
delete[] mem;
});
// Return the numpy array.
const auto byte_size = sizeof(int32_t);
return py::array(std::vector<int64_t>{num_samples+1, 2}, // shape
{2*byte_size, byte_size}, // C-style contiguous strides
sample_idx, // the data pointer
free_when_done); // numpy array references
}
inline int32_t get_target_sample_len(const int32_t short_seq_ratio, inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
const int32_t max_length, const int32_t max_length,
std::mt19937& rand32_gen) { std::mt19937& rand32_gen) {
...@@ -307,4 +396,5 @@ py::array build_mapping(const py::array_t<int64_t>& docs_, ...@@ -307,4 +396,5 @@ py::array build_mapping(const py::array_t<int64_t>& docs_,
PYBIND11_MODULE(helpers, m) { PYBIND11_MODULE(helpers, m) {
m.def("build_mapping", &build_mapping); m.def("build_mapping", &build_mapping);
m.def("build_sample_idx", &build_sample_idx);
} }
...@@ -18,7 +18,8 @@ from itertools import accumulate ...@@ -18,7 +18,8 @@ from itertools import accumulate
import numpy as np import numpy as np
import torch import torch
from megatron.utils import print_rank_0 from megatron import print_rank_0
def __best_fitting_dtype(vocab_size=None): def __best_fitting_dtype(vocab_size=None):
if vocab_size is not None and vocab_size < 65500: if vocab_size is not None and vocab_size < 65500:
...@@ -42,6 +43,8 @@ def infer_dataset_impl(path): ...@@ -42,6 +43,8 @@ def infer_dataset_impl(path):
else: else:
return None return None
else: else:
print(f"Dataset does not exist: {path}")
print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
return None return None
...@@ -53,6 +56,10 @@ def make_builder(out_file, impl, vocab_size=None): ...@@ -53,6 +56,10 @@ def make_builder(out_file, impl, vocab_size=None):
def make_dataset(path, impl, skip_warmup=False): def make_dataset(path, impl, skip_warmup=False):
if not IndexedDataset.exists(path):
print(f"Dataset does not exist: {path}")
print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
return None
if impl == 'infer': if impl == 'infer':
impl = infer_dataset_impl(path) impl = infer_dataset_impl(path)
if impl == 'lazy' and IndexedDataset.exists(path): if impl == 'lazy' and IndexedDataset.exists(path):
...@@ -61,6 +68,7 @@ def make_dataset(path, impl, skip_warmup=False): ...@@ -61,6 +68,7 @@ def make_dataset(path, impl, skip_warmup=False):
return IndexedCachedDataset(path) return IndexedCachedDataset(path)
elif impl == 'mmap' and MMapIndexedDataset.exists(path): elif impl == 'mmap' and MMapIndexedDataset.exists(path):
return MMapIndexedDataset(path, skip_warmup) return MMapIndexedDataset(path, skip_warmup)
print(f"Unknown dataset implementation: {impl}")
return None return None
...@@ -107,13 +115,15 @@ def index_file_path(prefix_path): ...@@ -107,13 +115,15 @@ def index_file_path(prefix_path):
def data_file_path(prefix_path): def data_file_path(prefix_path):
return prefix_path + '.bin' return prefix_path + '.bin'
def create_doc_idx(sizes): def create_doc_idx(sizes):
doc_idx = [0] doc_idx = [0]
for i, s in enumerate(sizes): for i, s in enumerate(sizes):
if s == 0: if s == 0:
doc_idx.append(i+1) doc_idx.append(i + 1)
return doc_idx return doc_idx
class IndexedDataset(torch.utils.data.Dataset): class IndexedDataset(torch.utils.data.Dataset):
"""Loader for IndexedDataset""" """Loader for IndexedDataset"""
_HDR_MAGIC = b'TNTIDX\x00\x00' _HDR_MAGIC = b'TNTIDX\x00\x00'
...@@ -153,7 +163,7 @@ class IndexedDataset(torch.utils.data.Dataset): ...@@ -153,7 +163,7 @@ class IndexedDataset(torch.utils.data.Dataset):
if self.data_file: if self.data_file:
self.data_file.close() self.data_file.close()
#@lru_cache(maxsize=8) # @lru_cache(maxsize=8)
def __getitem__(self, idx): def __getitem__(self, idx):
if not self.data_file: if not self.data_file:
self.read_data(self.path) self.read_data(self.path)
...@@ -233,7 +243,7 @@ class IndexedCachedDataset(IndexedDataset): ...@@ -233,7 +243,7 @@ class IndexedCachedDataset(IndexedDataset):
self.data_file.close() self.data_file.close()
self.data_file = None self.data_file = None
#@lru_cache(maxsize=8) # @lru_cache(maxsize=8)
def __getitem__(self, idx): def __getitem__(self, idx):
if isinstance(idx, int): if isinstance(idx, int):
i = idx i = idx
...@@ -397,13 +407,18 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -397,13 +407,18 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
self._bin_buffer_mmap = np.memmap(path, mode='r', order='C') self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
self._bin_buffer = memoryview(self._bin_buffer_mmap) self._bin_buffer = memoryview(self._bin_buffer_mmap)
print_rank_0(" reading sizes...") print_rank_0(" reading sizes...")
self._sizes = np.frombuffer(self._bin_buffer, dtype=np.int32, count=self._len, offset=offset) self._sizes = np.frombuffer(
self._bin_buffer,
dtype=np.int32,
count=self._len,
offset=offset)
print_rank_0(" reading pointers...") print_rank_0(" reading pointers...")
self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len, self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
offset=offset + self._sizes.nbytes) offset=offset + self._sizes.nbytes)
print_rank_0(" reading document index...") print_rank_0(" reading document index...")
self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count, self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count,
offset=offset + self._sizes.nbytes + self._pointers.nbytes) offset=offset + self._sizes.nbytes + self._pointers.nbytes)
def __del__(self): def __del__(self):
self._bin_buffer_mmap._mmap.close() self._bin_buffer_mmap._mmap.close()
del self._bin_buffer_mmap del self._bin_buffer_mmap
...@@ -462,13 +477,12 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -462,13 +477,12 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
def __len__(self): def __len__(self):
return len(self._index) return len(self._index)
#@lru_cache(maxsize=8) # @lru_cache(maxsize=8)
def __getitem__(self, idx): def __getitem__(self, idx):
if isinstance(idx, int): if isinstance(idx, int):
ptr, size = self._index[idx] ptr, size = self._index[idx]
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr) np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
if self._index.dtype != np.int64: count=size, offset=ptr)
np_array = np_array.astype(np.int64)
return np_array return np_array
elif isinstance(idx, slice): elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self)) start, stop, step = idx.indices(len(self))
...@@ -478,10 +492,25 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -478,10 +492,25 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
sizes = self._index._sizes[idx] sizes = self._index._sizes[idx]
offsets = list(accumulate(sizes)) offsets = list(accumulate(sizes))
total_size = sum(sizes) total_size = sum(sizes)
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=total_size, offset=ptr) np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
count=total_size, offset=ptr)
sents = np.split(np_array, offsets[:-1]) sents = np.split(np_array, offsets[:-1])
return sents return sents
def get(self, idx, offset=0, length=None):
""" Retrieves a single item from the dataset with the option to only
return a portion of the item.
get(idx) is the same as [idx] but get() does not support slicing.
"""
ptr, size = self._index[idx]
if length is None:
length = size - offset
ptr += offset * np.dtype(self._index.dtype).itemsize
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
count=length, offset=ptr)
return np_array
@property @property
def sizes(self): def sizes(self):
return self._index.sizes return self._index.sizes
......
import argparse
import json
import multiprocessing
import nltk
import sys
import time
import torch
from bert_tokenization import FullTokenizer
import indexed_dataset
class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars):
_period_context_fmt = r"""
\S* # some word material
%(SentEndChars)s # a potential sentence ending
\s* # <-- THIS is what I changed
(?=(?P<after_tok>
%(NonWord)s # either other punctuation
|
(?P<next_tok>\S+) # <-- Normally you would have \s+ here
))"""
class Encoder(object):
def __init__(self, args):
self.args = args
def initializer(self):
# Use Encoder class as a container for global data
Encoder.tokenizer = FullTokenizer(self.args.vocab, do_lower_case=True)
spliter = nltk.load("tokenizers/punkt/english.pickle")
if self.args.keep_newlines:
# this prevents punkt from eating newlines after sentences
Encoder.spliter = nltk.tokenize.punkt.PunktSentenceTokenizer(
train_text = spliter._params,
lang_vars = CustomLanguageVars())
else:
Encoder.splitter = spliter
def encode(self, json_line):
text = json.loads(json_line)[self.args.json_key]
doc_ids = []
for sentence in Encoder.splitter.tokenize(text):
tokens = Encoder.tokenizer.tokenize(sentence)
ids = Encoder.tokenizer.convert_tokens_to_ids(tokens)
if len(ids) > 0:
doc_ids.append(ids)
return doc_ids, len(json_line)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--input', type=str, help='Path to input JSON')
parser.add_argument('--vocab', type=str, help='Path to vocab.txt')
parser.add_argument('--json-key', type=str, default='text',
help='Key to extract from json')
parser.add_argument('--output-prefix', type=str, help='Path to binary output file without suffix')
parser.add_argument('--workers', type=int, default=20,
help='Number of worker processes to launch')
parser.add_argument('--log-interval', type=int, default=100,
help='Interval between progress updates')
parser.add_argument('--keep-newlines', action='store_true',
help='Keep newlines between sentences.')
parser.add_argument('--dataset-impl', type=str, default='mmap',
choices=['lazy', 'cached', 'mmap'])
args = parser.parse_args()
args.keep_empty = False
startup_start = time.time()
print("Opening", args.input)
fin = open(args.input, 'r', encoding='utf-8')
nltk.download("punkt", quiet=True)
encoder = Encoder(args)
tokenizer = FullTokenizer(args.vocab, do_lower_case=True)
pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
encoded_docs = pool.imap(encoder.encode, fin, 25)
print(f"Vocab size: {tokenizer.vocab_size()}")
output_bin_file = "{}.bin".format(args.output_prefix)
output_idx_file = "{}.idx".format(args.output_prefix)
builder = indexed_dataset.make_builder(output_bin_file,
impl=args.dataset_impl,
vocab_size=tokenizer.vocab_size())
startup_end = time.time()
proc_start = time.time()
total_bytes_processed = 0
print("Time to startup:", startup_end - startup_start)
for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1):
total_bytes_processed += bytes_processed
for sentence in doc:
#print(sentence)
#print(tokenizer.convert_ids_to_tokens(sentence))
builder.add_item(torch.IntTensor(sentence))
builder.end_document()
if i % args.log_interval == 0:
current = time.time()
elapsed = current - proc_start
mbs = total_bytes_processed/elapsed/1024/1024
print(f"Processed {i} documents",
f"({i/elapsed} docs/s, {mbs} MB/s).",
file=sys.stderr)
builder.finalize(output_idx_file)
if __name__ == '__main__':
main()
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Batch samplers that work with either random or sequential data samplers."""
import torch
from torch.utils import data
class RandomSampler(data.sampler.Sampler):
"""Based off of pytorch RandomSampler and DistributedSampler. Essentially
a RandomSampler, but this class lets the user set an epoch like
DistributedSampler Samples elements randomly. If without replacement, then
sample from a shuffled dataset. If with replacement, then user can
specify ``num_samples`` to draw.
Arguments:
data_source (Dataset): dataset to sample from
num_samples (int): number of samples to draw, default=len(dataset)
replacement (bool): samples are drawn with replacement if ``True``,
default=False
"""
def __init__(self, data_source, replacement=False, num_samples=None):
self.data_source = data_source
self.replacement = replacement
self._num_samples = num_samples
self.epoch = -1
if self._num_samples is not None and replacement is False:
raise ValueError("With replacement=False, num_samples should not "
"be specified, since a random permute will be "
"performed.")
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(
self.num_samples))
if not isinstance(self.replacement, bool):
raise ValueError("replacement should be a boolean value, but got "
"replacement={}".format(self.replacement))
@property
def num_samples(self):
# dataset size might change at runtime
if self._num_samples is None:
return len(self.data_source)
return self._num_samples
def __iter__(self):
n = len(self.data_source)
g = torch.Generator()
if self.epoch >= 0:
g.manual_seed(self.epoch)
if self.replacement:
return iter(torch.randint(high=n, size=(self.num_samples,),
dtype=torch.int64, generator=g).tolist())
return iter(torch.randperm(n, generator=g).tolist())
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
class DistributedBatchSampler(data.sampler.BatchSampler):
"""Similar to normal implementation of distributed sampler, except
implementation is at the batch sampler level, instead of just the
sampler level. This allows wrapping of arbitrary data samplers
(sequential, random, WeightedRandomSampler, etc.) with this batch
sampler."""
def __init__(self, sampler, batch_size, drop_last, rank=-1,
world_size=2, wrap_last=False):
super(DistributedBatchSampler, self).__init__(sampler, batch_size,
drop_last)
if rank == -1:
assert False, 'should not be here'
rank = torch.distributed.get_rank()
self.rank = rank
self.world_size = world_size
self.sampler.wrap_around = 0
self.wrap_around = 0
self.wrap_last = wrap_last
self.start_iter = 0
def __iter__(self):
batch = []
i = 0
for idx in self.data_iterator(self.sampler, wrap_around=False):
batch.append(idx)
if len(batch) == self.batch_size:
tbatch = self._batch(batch)
if i >= self.start_iter:
yield tbatch
self.start_iter = 0
i += 1
batch = []
batch_len = len(batch)
if batch_len > 0 and not self.drop_last:
if self.wrap_last:
self.sampler.wrap_around -= (self.batch_size)
self.wrap_around += (len(batch))
self.wrap_around %= self.batch_size
yield self._batch(batch)
if self.wrap_last:
self.sampler.wrap_around += self.batch_size
def data_iterator(self, _iter, wrap_around=False):
"""iterates through data and handles wrap around"""
for i, idx in enumerate(_iter):
if i < self.wrap_around % self.batch_size:
continue
if wrap_around:
self.wrap_around += 1
self.wrap_around %= self.batch_size
yield idx
def _batch(self, batch):
"""extracts samples only pertaining to this worker's batch"""
start = self.rank * self.batch_size // self.world_size
end = (self.rank + 1) * self.batch_size // self.world_size
return batch[start:end]
# This file isn't really a formal automated test, it's just a place to
# put some code used during development and manual testing of
# indexed_dataset.
from megatron.data import indexed_dataset
from megatron.tokenizer import build_tokenizer
import argparse import argparse
import os import os
import sys import sys
...@@ -7,52 +13,90 @@ import torch ...@@ -7,52 +13,90 @@ import torch
script_dir = os.path.dirname(os.path.realpath(__file__)) script_dir = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(script_dir, "../../../")) sys.path.append(os.path.join(script_dir, "../../../"))
from megatron.data import indexed_dataset, FullBertTokenizer, AlbertDataset
def test_indexed_dataset(args): def test_indexed_dataset(args):
ds = indexed_dataset.make_dataset(args.data, args.dataset_impl) ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True) tokenizer = build_tokenizer(args)
print(len(ds.doc_idx)) print(len(ds.doc_idx))
print(len(ds)) print(len(ds))
print(ds.doc_idx[-1]) print(ds.doc_idx[-1])
if ds.supports_prefetch: if ds.supports_prefetch:
# just prefetch the whole thing in test (so assume it is small) # just prefetch the whole thing in test (so assume it is small)
ds.prefetch(range(len(ds))) ds.prefetch(range(len(ds)))
for i in range(len(ds.doc_idx)-1): if args.count > len(ds.doc_idx) - 1:
args.count = len(ds.doc_idx) - 1
for i in range(args.count):
start = ds.doc_idx[i] start = ds.doc_idx[i]
end = ds.doc_idx[i+1] end = ds.doc_idx[i + 1]
ids = ds[start:end] ids = ds[start:end]
print(f"Document {i}:")
print("--------------")
for s in ids: for s in ids:
assert len(s) > 0 assert len(s) > 0
l = s.data.tolist() l = s.data.tolist()
tokens = tokenizer.convert_ids_to_tokens(l) text = tokenizer.detokenize(l)
for t in tokens: print(text)
if '\n' in t: print("---")
print("Newline in string!")
print(i)
def test_indexed_dataset_get(args):
def test_albert_dataset(args): ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
# tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True) tokenizer = build_tokenizer(args)
# idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl) size = ds.sizes[0]
# ds = AlbertDataset(idataset, tokenizer) print(f"size: {size}")
ds = AlbertDataset.from_paths(args.vocab, args.data, args.dataset_impl, full = ds.get(0)
args.epochs, args.max_num_samples, print(full)
args.masked_lm_prob, args.seq_length, # print(tokenizer.detokenize(full.data.tolist()))
args.short_seq_prob, args.seed) print("---")
truncated = 0 end = ds.get(0, offset=size - 10)
total = 0 print(end)
for s in ds: # print(tokenizer.detokenize(end.data.tolist()))
ids = s['text']
tokens = ds.tokenizer.convert_ids_to_tokens(ids) start = ds.get(0, length=10)
print(tokens) print(start)
exit() # print(tokenizer.detokenize(start.data.tolist()))
part = ds.get(0, offset=2, length=8)
print(part)
# print(tokenizer.detokenize(part.data.tolist()))
# def test_albert_dataset(args):
# # tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True)
# # idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl)
# # ds = AlbertDataset(idataset, tokenizer)
# ds = AlbertDataset.from_paths(args.vocab, args.data, args.dataset_impl,
# args.epochs, args.max_num_samples,
# args.masked_lm_prob, args.seq_length,
# args.short_seq_prob, args.seed)
# truncated = 0
# total = 0
# for i, s in enumerate(ds):
# ids = s['text']
# tokens = ds.tokenizer.convert_ids_to_tokens(ids)
# print(tokens)
# if i >= args.count-1:
# exit()
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, help='prefix to data files') parser.add_argument('--data', type=str, help='prefix to data files')
parser.add_argument('--vocab', type=str, help='Path to vocab.txt')
parser.add_argument('--dataset-impl', type=str, default='infer', parser.add_argument('--dataset-impl', type=str, default='infer',
choices=['lazy', 'cached', 'mmap', 'infer']) choices=['lazy', 'cached', 'mmap', 'infer'])
parser.add_argument('--count', type=int, default=10,
help='Number of samples/documents to print')
group = parser.add_argument_group(title='tokenizer')
group.add_argument('--tokenizer-type', type=str, required=True,
choices=['BertWordPieceLowerCase',
'GPT2BPETokenizer'],
help='What type of tokenizer to use.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file')
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file (if necessary).')
parser.add_argument('--epochs', type=int, default=5, parser.add_argument('--epochs', type=int, default=5,
help='Number of epochs to plan for') help='Number of epochs to plan for')
parser.add_argument('--max-num-samples', type=int, default=None, parser.add_argument('--max-num-samples', type=int, default=None,
...@@ -66,12 +110,16 @@ def main(): ...@@ -66,12 +110,16 @@ def main():
parser.add_argument('--seed', type=int, default=1234, parser.add_argument('--seed', type=int, default=1234,
help='random seed') help='random seed')
args = parser.parse_args() args = parser.parse_args()
args.rank = 0
args.make_vocab_size_divisible_by = 128
args.model_parallel_size = 1
if args.dataset_impl == "infer": if args.dataset_impl == "infer":
args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data) args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data)
test_albert_dataset(args) # test_albert_dataset(args)
# test_indexed_dataset(args) test_indexed_dataset_get(args)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -28,21 +28,24 @@ TRAIN_DATA = 0 ...@@ -28,21 +28,24 @@ TRAIN_DATA = 0
VAL_DATA = 1 VAL_DATA = 1
TEST_DATA = 2 TEST_DATA = 2
def should_split(split): def should_split(split):
""" """
given split proportions checks if should split given split proportions checks if should split
Examples: Examples:
>>> should_split([10,0,0]) >>> should_split([10,0,0])
False False
>>> should_split([1,.1,.2]) >>> should_split([1,.1,.2])
True True
""" """
return max(split)/sum(split) != 1. return max(split) / sum(split) != 1.
def get_ext(path): def get_ext(path):
"""gets path extension""" """gets path extension"""
return os.path.splitext(path)[1] return os.path.splitext(path)[1]
def get_dataset(path, **kwargs): def get_dataset(path, **kwargs):
"""gets dataset object based on keyword args and file at `path`""" """gets dataset object based on keyword args and file at `path`"""
if supported_corpus(path): if supported_corpus(path):
...@@ -53,17 +56,19 @@ def get_dataset(path, **kwargs): ...@@ -53,17 +56,19 @@ def get_dataset(path, **kwargs):
elif ext in ['.csv', '.tsv']: elif ext in ['.csv', '.tsv']:
text = csv_dataset(path, **kwargs) text = csv_dataset(path, **kwargs)
else: else:
raise NotImplementedError('data file type %s is not supported'%(ext)) raise NotImplementedError('data file type %s is not supported' % (ext))
return text return text
def supported_corpus(corpus_name): def supported_corpus(corpus_name):
"""checks if corpus name is defined in `corpora.py`""" """checks if corpus name is defined in `corpora.py`"""
return corpus_name in corpora.NAMED_CORPORA return corpus_name in corpora.NAMED_CORPORA
def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=None, split=[1.], def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=None, split=[1.],
delim=',', loose=False, binarize_sent=False, drop_unlabeled=False, tokenizer=None, delim=',', loose=False, binarize_sent=False, drop_unlabeled=False, tokenizer=None,
tokenizer_type='CharacterLevelTokenizer', tokenizer_model_path=None, vocab_size=None, tokenizer_type='CharacterLevelTokenizer', tokenizer_model_path=None, vocab_size=None,
model_type='bpe', pad_token=0, character_converage=1.0, non_binary_cols=None, model_type='bpe', pad_token=0, character_converage=1.0, non_binary_cols=None,
parallel_group=None, **kwargs): parallel_group=None, **kwargs):
"""function to create datasets+tokenizers for common options""" """function to create datasets+tokenizers for common options"""
if isinstance(process_fn, str): if isinstance(process_fn, str):
...@@ -71,6 +76,7 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N ...@@ -71,6 +76,7 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
if non_binary_cols is not None: if non_binary_cols is not None:
# multilabel dataset support (only for csvs) # multilabel dataset support (only for csvs)
label_key = non_binary_cols label_key = non_binary_cols
def get_dataset_from_path(path_): def get_dataset_from_path(path_):
if lazy: if lazy:
# get lazily loaded dataset # get lazily loaded dataset
...@@ -82,7 +88,7 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N ...@@ -82,7 +88,7 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
if torch.distributed.get_rank() == 0 and not exists_lazy(path_, data_type='data'): if torch.distributed.get_rank() == 0 and not exists_lazy(path_, data_type='data'):
# create cached version of dataset for lazy loading if it doesn't exist # create cached version of dataset for lazy loading if it doesn't exist
text = get_dataset(name if named_corpora else path_, text_key=text_key, label_key=label_key, binarize_sent=binarize_sent, text = get_dataset(name if named_corpora else path_, text_key=text_key, label_key=label_key, binarize_sent=binarize_sent,
delim=delim, drop_unlabeled=drop_unlabeled, loose_json=loose) delim=delim, drop_unlabeled=drop_unlabeled, loose_json=loose)
make_lazy(path_, text.X, data_type='data') make_lazy(path_, text.X, data_type='data')
# This should be a barrier but nccl barrier assumes # This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model # device_index=rank which is not the case for model
...@@ -96,7 +102,7 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N ...@@ -96,7 +102,7 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
else: else:
# get dataset # get dataset
text = get_dataset(path_, text_key=text_key, label_key=label_key, binarize_sent=binarize_sent, text = get_dataset(path_, text_key=text_key, label_key=label_key, binarize_sent=binarize_sent,
delim=delim, drop_unlabeled=drop_unlabeled, loose_json=loose, preprocess_fn=process_fn) delim=delim, drop_unlabeled=drop_unlabeled, loose_json=loose, preprocess_fn=process_fn)
return text return text
# get one or multiple datasets and concatenate # get one or multiple datasets and concatenate
if isinstance(path, str): if isinstance(path, str):
...@@ -108,8 +114,8 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N ...@@ -108,8 +114,8 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
ds = ConcatDataset(datasets) ds = ConcatDataset(datasets)
# make tokenizer for dataset # make tokenizer for dataset
if tokenizer is None: if tokenizer is None:
tokenizer = make_tokenizer(tokenizer_type, ds, tokenizer_model_path, vocab_size, model_type, tokenizer = make_tokenizer(tokenizer_type, ds, tokenizer_model_path, vocab_size, model_type,
pad_token, character_converage, **kwargs) pad_token, character_converage, **kwargs)
ds_type = '' ds_type = ''
if 'ds_type' in kwargs: if 'ds_type' in kwargs:
...@@ -121,7 +127,8 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N ...@@ -121,7 +127,8 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
if 'bert' in ds_type.lower(): if 'bert' in ds_type.lower():
presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False
dstype = bert_sentencepair_dataset dstype = bert_sentencepair_dataset
ds = [dstype(d, max_seq_len=seq_length, presplit_sentences=presplit_sentences) if d is not None else None for d in ds] ds = [dstype(d, max_seq_len=seq_length, presplit_sentences=presplit_sentences)
if d is not None else None for d in ds]
elif ds_type.lower() == 'gpt2': elif ds_type.lower() == 'gpt2':
ds = [GPT2Dataset(d, max_seq_len=seq_length) if d is not None else None for d in ds] ds = [GPT2Dataset(d, max_seq_len=seq_length) if d is not None else None for d in ds]
else: else:
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -21,6 +21,7 @@ import torch ...@@ -21,6 +21,7 @@ import torch
from megatron import data_utils from megatron import data_utils
from megatron import mpu from megatron import mpu
class DataConfig: class DataConfig:
def __init__(self, defaults={}): def __init__(self, defaults={}):
...@@ -48,7 +49,8 @@ def make_data_loader(dataset, batch_size, args): ...@@ -48,7 +49,8 @@ def make_data_loader(dataset, batch_size, args):
shuffle = args.shuffle shuffle = args.shuffle
if shuffle: if shuffle:
sampler = data_utils.samplers.RandomSampler(dataset, replacement=True, num_samples=batch_size*args.train_iters) sampler = data_utils.samplers.RandomSampler(
dataset, replacement=True, num_samples=batch_size * args.train_iters)
else: else:
sampler = torch.utils.data.SequentialSampler(dataset) sampler = torch.utils.data.SequentialSampler(dataset)
world_size = torch.distributed.get_world_size( world_size = torch.distributed.get_world_size(
...@@ -204,6 +206,7 @@ def make_loaders(args): ...@@ -204,6 +206,7 @@ def make_loaders(args):
return (train, valid, test), tokenizer return (train, valid, test), tokenizer
def get_split(args): def get_split(args):
""" """
Get dataset splits from comma separated string list Get dataset splits from comma separated string list
...@@ -217,7 +220,7 @@ def get_split(args): ...@@ -217,7 +220,7 @@ def get_split(args):
splits = [float(args.split)] splits = [float(args.split)]
split_total = sum(splits) split_total = sum(splits)
if split_total < 1.: if split_total < 1.:
splits.append(1-split_total) splits.append(1 - split_total)
while len(splits) < 3: while len(splits) < 3:
splits.append(0.) splits.append(0.)
splits = splits[:3] splits = splits[:3]
...@@ -226,10 +229,10 @@ def get_split(args): ...@@ -226,10 +229,10 @@ def get_split(args):
if args.test_data is not None: if args.test_data is not None:
splits[2] = 0. splits[2] = 0.
final_sum = sum(splits) final_sum = sum(splits)
return [s/final_sum for s in splits] return [s / final_sum for s in splits]
def configure_data():
def configure_data():
"""add cmdline flags for configuring datasets""" """add cmdline flags for configuring datasets"""
# These are options that are used by data_utils, but are either # These are options that are used by data_utils, but are either
# deprecated or not meant to be exposed to the command line user. # deprecated or not meant to be exposed to the command line user.
......
...@@ -169,7 +169,7 @@ def http_get(url, temp_file): ...@@ -169,7 +169,7 @@ def http_get(url, temp_file):
total = int(content_length) if content_length is not None else None total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total) progress = tqdm(unit="B", total=total)
for chunk in req.iter_content(chunk_size=1024): for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks if chunk: # filter out keep-alive new chunks
progress.update(len(chunk)) progress.update(len(chunk))
temp_file.write(chunk) temp_file.write(chunk)
progress.close() progress.close()
......
...@@ -16,12 +16,12 @@ output_file = sys.argv[2] ...@@ -16,12 +16,12 @@ output_file = sys.argv[2]
line_seperator = "\n" line_seperator = "\n"
with open(input_file, 'r') as ifile: with open(input_file, 'r') as ifile:
with open(output_file, "w") as ofile: with open(output_file, "w") as ofile:
for doc in ifile.readlines(): for doc in ifile.readlines():
parsed = json.loads(doc) parsed = json.loads(doc)
sent_list = [] sent_list = []
for line in parsed['text'].split('\n'): for line in parsed['text'].split('\n'):
if line != '\n': if line != '\n':
sent_list.extend(nltk.tokenize.sent_tokenize(line)) sent_list.extend(nltk.tokenize.sent_tokenize(line))
parsed['text'] = line_seperator.join(sent_list) parsed['text'] = line_seperator.join(sent_list)
ofile.write(json.dumps(parsed)+'\n') ofile.write(json.dumps(parsed) + '\n')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment