Commit 0058b1a5 authored by Jared Casper's avatar Jared Casper
Browse files

Plugging new dataset stuff together.

Some cleanup of types, etc. in the helpers cpp code.
parent b1714c14
CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color
CPPFLAGS += $(shell python3 -m pybind11 --includes)
LIBNAME = helpers
LIBEXT = $(shell python3-config --extension-suffix)
default: $(LIBNAME)$(LIBEXT)
%$(LIBEXT): %.cpp
$(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@
from . import indexed_dataset from . import indexed_dataset
from .bert_tokenization import FullTokenizer as FullBertTokenizer from .bert_tokenization import FullTokenizer as FullBertTokenizer
from .dataset import AlbertDataset
...@@ -7,27 +7,36 @@ import numpy as np ...@@ -7,27 +7,36 @@ import numpy as np
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from dataset_utils import build_training_sample from .dataset_utils import build_training_sample
#from data.mapping import build_training_samples_mapping #from data.mapping import build_training_samples_mapping
class AlbertDataSet(Dataset): from . import helpers
from megatron.data import FullBertTokenizer, indexed_dataset
def __init__(self, indexed_dataset, tokenizer, num_epochs,
class AlbertDataset(Dataset):
def __init__(self, indexed_dataset, tokenizer, num_epochs, max_num_samples,
masked_lm_prob, max_seq_length, short_seq_prob, seed): masked_lm_prob, max_seq_length, short_seq_prob, seed):
# Params to store. # Params to store.
self.seed = seed self.seed = seed
self.masked_lm_prob = masked_lm_prob self.masked_lm_prob = masked_lm_prob
self.max_seq_length = max_seq_length self.max_seq_length = max_seq_length
self.tokenizer = tokenizer
# Indexed dataset. # Indexed dataset.
self.indexed_dataset = indexed_dataset self.indexed_dataset = indexed_dataset
# Build the samples mapping. # Build the samples mapping.
self.samples_mapping = build_training_samples_mapping( if not max_num_samples:
indexed_dataset, max_num_samples = len(indexed_dataset) * num_epochs
self.samples_mapping = helpers.build_mapping(
indexed_dataset.doc_idx,
indexed_dataset.sizes,
num_epochs, num_epochs,
self.max_seq_length, max_num_samples,
self.max_seq_length-3, # account for added tokens
short_seq_prob, short_seq_prob,
self.seed) self.seed)
...@@ -40,8 +49,17 @@ class AlbertDataSet(Dataset): ...@@ -40,8 +49,17 @@ class AlbertDataSet(Dataset):
self.pad_id = tokenizer.vocab['[PAD]'] self.pad_id = tokenizer.vocab['[PAD]']
@classmethod
def from_paths(cls, vocab, data_prefix, data_impl,
num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed):
tokenizer = FullBertTokenizer(vocab, do_lower_case=True)
idx_ds = indexed_dataset.make_dataset(data_prefix, data_impl)
return cls(idx_ds, tokenizer, num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed)
def __len__(self): def __len__(self):
return self.samples.shape[0] return self.samples_mapping.shape[0]
def __getitem__(self, idx): def __getitem__(self, idx):
rng = random.Random(self.seed + idx) rng = random.Random(self.seed + idx)
...@@ -49,6 +67,9 @@ class AlbertDataSet(Dataset): ...@@ -49,6 +67,9 @@ class AlbertDataSet(Dataset):
sample = [] sample = []
for index in range(start_index, end_index): for index in range(start_index, end_index):
sample.append(self.indexed_dataset[index]) sample.append(self.indexed_dataset[index])
for s in sample:
if len(s) > 1000:
print(self.tokenizer.convert_ids_to_tokens(s))
return build_training_sample(sample, seq_length, return build_training_sample(sample, seq_length,
self.max_seq_length, self.max_seq_length,
self.vocab_id_list, self.vocab_id_list,
...@@ -186,7 +207,6 @@ class JaredDataset(object): ...@@ -186,7 +207,6 @@ class JaredDataset(object):
if __name__ == '__main__': if __name__ == '__main__':
print('dataset ...') print('dataset ...')
from bert_tokenization import FullTokenizer from bert_tokenization import FullTokenizer
...@@ -207,8 +227,8 @@ if __name__ == '__main__': ...@@ -207,8 +227,8 @@ if __name__ == '__main__':
sentences.extend(sent) sentences.extend(sent)
yield sentences yield sentences
input_file = '/raid/mshoeybi/data/albert/sample/samples_11.json' input_file = 'test/samples_10000.json'
vocab_file = '/raid/mshoeybi/data/albert/bert_vocab/vocab.txt' vocab_file = 'test/vocab.txt'
tokenizer = FullTokenizer(vocab_file, do_lower_case=True) tokenizer = FullTokenizer(vocab_file, do_lower_case=True)
document_generator = document_generator_provider(input_file) document_generator = document_generator_provider(input_file)
......
...@@ -35,10 +35,9 @@ def build_training_sample(sample, ...@@ -35,10 +35,9 @@ def build_training_sample(sample,
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, rng) tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, rng)
# Truncate to `target_sequence_length`. # Truncate to `target_sequence_length`.
# Note that we have account for [CLS] A [SEP] B [SEP] max_num_tokens = target_seq_length
max_num_tokens = target_seq_length - 3 truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a), len(tokens_b),
truncate_segments(tokens_a, tokens_b, len(tokens_a), len(tokens_b), max_num_tokens, rng)
max_num_tokens, rng)
# Build tokens and toketypes. # Build tokens and toketypes.
tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b, tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b,
...@@ -48,7 +47,7 @@ def build_training_sample(sample, ...@@ -48,7 +47,7 @@ def build_training_sample(sample,
max_predictions_per_seq = masked_lm_prob * max_num_tokens max_predictions_per_seq = masked_lm_prob * max_num_tokens
(tokens, masked_positions, masked_labels, _) = create_masked_lm_predictions( (tokens, masked_positions, masked_labels, _) = create_masked_lm_predictions(
tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob, tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
cls_id, sep_id, mask_id, max_predictions_per_seq) cls_id, sep_id, mask_id, max_predictions_per_seq, rng)
# Padding. # Padding.
tokens_np, tokentypes_np, labels, padding_mask, loss_mask \ tokens_np, tokentypes_np, labels, padding_mask, loss_mask \
...@@ -61,7 +60,8 @@ def build_training_sample(sample, ...@@ -61,7 +60,8 @@ def build_training_sample(sample,
'labels': labels, 'labels': labels,
'is_random': int(is_next_random), 'is_random': int(is_next_random),
'loss_mask': loss_mask, 'loss_mask': loss_mask,
'padding_mask': padding_mask} 'padding_mask': padding_mask,
'truncated': int(truncated)}
return train_sample return train_sample
...@@ -99,11 +99,12 @@ def get_a_and_b_segments(sample, rng): ...@@ -99,11 +99,12 @@ def get_a_and_b_segments(sample, rng):
def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, rng): def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, rng):
"""Truncates a pair of sequences to a maximum sequence length.""" """Truncates a pair of sequences to a maximum sequence length."""
#print(len_a, len_b, max_num_tokens)
assert len_a > 0 assert len_a > 0
assert len_b > 0 assert len_b > 0
if (len_a + len_b) <= max_num_tokens: if len_a + len_b <= max_num_tokens:
return return False
else: while len_a + len_b > max_num_tokens:
if len_a > len_b: if len_a > len_b:
len_a -= 1 len_a -= 1
tokens = tokens_a tokens = tokens_a
...@@ -114,8 +115,7 @@ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, rng): ...@@ -114,8 +115,7 @@ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, rng):
del tokens[0] del tokens[0]
else: else:
tokens.pop() tokens.pop()
truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, rng) return True
def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
"""Merge segments A and B, add [CLS] and [SEP] and build tokentypes.""" """Merge segments A and B, add [CLS] and [SEP] and build tokentypes."""
...@@ -161,6 +161,7 @@ def create_masked_lm_predictions(tokens, ...@@ -161,6 +161,7 @@ def create_masked_lm_predictions(tokens,
masked_lm_prob, masked_lm_prob,
cls_id, sep_id, mask_id, cls_id, sep_id, mask_id,
max_predictions_per_seq, max_predictions_per_seq,
rng,
max_ngrams=3, max_ngrams=3,
do_whole_word_mask=True, do_whole_word_mask=True,
favor_longer_ngram=False, favor_longer_ngram=False,
...@@ -468,4 +469,3 @@ if __name__ == '__main__': ...@@ -468,4 +469,3 @@ if __name__ == '__main__':
string += '{:5d}'.format(tokentype) string += '{:5d}'.format(tokentype)
string += '{:5d}'.format(padding_mask) string += '{:5d}'.format(padding_mask)
print(string) print(string)
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <iostream> #include <iostream>
#include <limits> #include <limits>
#include <math.h> #include <math.h>
#include <stdexcept>
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/numpy.h> #include <pybind11/numpy.h>
...@@ -11,192 +12,204 @@ using namespace std; ...@@ -11,192 +12,204 @@ using namespace std;
inline uint32_t get_sample_len(const int short_seq_ratio, inline uint32_t get_sample_len(const int short_seq_ratio,
const uint32_t max_length) { const uint32_t max_length) {
/* Training sample length. */ /* Training sample length. */
const auto random_number = rand(); const auto random_number = rand();
if ((random_number % short_seq_ratio) == 0) { if ((random_number % short_seq_ratio) == 0) {
return 2 + random_number % (max_length - 1); return 2 + random_number % (max_length - 1);
} }
return max_length; return max_length;
} }
template<typename DocIdx>
py::array build_mapping_impl(const py::array_t<uint32_t>& docs_,
const py::array_t<uint16_t>& sizes_,
const int num_epochs,
const uint64_t max_num_samples,
const int max_seq_length,
const double short_seq_prob,
const int seed) {
cout << "> building dataset mapping for " << docs_.shape(0) - 1 <<
" documents with " << sizes_.shape(0) << " sentences ..." << endl;
// For efficiency, convert probability to ratio.
const auto short_seq_ratio = static_cast<int>(round(1.0 / short_seq_prob));
// Remove bound checks.
auto docs = docs_.unchecked<1>();
auto sizes = sizes_.unchecked<1>();
// Check for consistency.
if (docs[docs.shape(0) - 1] != sizes.shape(0)) {
cout << "document values is not consistent with length of sizes: " <<
docs[docs.shape(0) - 1] << " != " << sizes.shape(0) << endl;
throw(-1);
}
py::array_t<uint32_t> build_mapping(const py::array_t<uint32_t>& docs_, // Mapping and it's length (1D).
const py::array_t<uint16_t>& sizes_, int64_t num_samples = -1;
const int num_epochs, DocIdx* maps = NULL;
const int max_num_samples,
const int max_seq_length, // Perform two iterations, in the first iteration get the size
const double short_seq_prob, // and allocate memory and in the second iteration populate the map.
const int seed) { bool second = false;
for (int iteration=0; iteration < 2; ++iteration) {
cout << "> building dataset mapping for " << docs_.shape(0) - 1 <<
" documents with " << sizes_.shape(0) << " sentences ..." << endl; // Set the seed so both iterations produce the same results.
srand(seed);
// For efficiency, convert probability to ratio.
const int short_seq_ratio = int(round(1.0 / short_seq_prob)); // Set the flag on second iteration.
second = iteration == 1;
// Remove bound checks.
auto docs = docs_.unchecked<1>(); // Counters:
auto sizes = sizes_.unchecked<1>(); uint32_t empty_docs = 0;
uint32_t one_sent_docs = 0;
// Check for consistency.
if (docs[docs.shape(0) - 1] != sizes.shape(0)) { // Current map index.
cout << "document values is not consistent with length of sizes: " << uint64_t map_index = 0;
docs[docs.shape(0) - 1] << " != " << sizes.shape(0) << endl;
throw(-1); // For each epoch:
for (int epoch=0; epoch < num_epochs; ++epoch) {
if (map_index >= max_num_samples && !second) {
cout << " > reached " << max_num_samples << " samples after " <<
epoch << " epochs ..." << endl;
break;
}
// For each document:
for (int doc=0; doc < (docs.shape(0) - 1); ++doc) {
// Document sentences are in [sent_index_first, sent_index_last).
const auto sent_index_first = docs[doc];
const auto sent_index_last = docs[doc + 1];
// At the begining of the document previous index is the start index.
auto prev_start_index = sent_index_first;
// Remaining documents.
auto num_remain_sent = sent_index_last - sent_index_first;
// Some bookkeeping
if ((epoch == 0) && (!second)) {
if (num_remain_sent == 0) {
cout << "***WARNING*** document " << doc << " is empty" << endl;
empty_docs += 1;
}
if (num_remain_sent == 1) {
cout << "***WARNING*** document " << doc <<
" has one sentence" << endl;
one_sent_docs += 1;
}
}
// If we have more than two sentences.
if (num_remain_sent > 1) {
// Set values.
auto size = uint32_t{0};
auto num_sent = uint32_t{0};
auto seq_len = get_sample_len(short_seq_ratio, max_seq_length);
// Loop through sentences.
for (auto sent_index=sent_index_first;
sent_index < sent_index_last; ++sent_index) {
// Add the size and number of sentences.
size += sizes[sent_index];
num_sent += 1;
num_remain_sent -= 1;
// If we have reached the target length.
// and if not only one sentence is left in the document.
// and if we have at least two sentneces.
// and if we have reached end of the document.
if (((size >= seq_len) && (num_remain_sent > 1) &&
(num_sent > 1) ) || (num_remain_sent == 0)) {
// Populate the map.
if (second) {
const auto map_index_0 = 3 * map_index;
maps[map_index_0] = prev_start_index;
maps[map_index_0 + 1] = sent_index + 1;
maps[map_index_0 + 2] = seq_len;
}
// Update indices / counters.
// check for overflow
if (map_index == std::numeric_limits<DocIdx>::max()) {
cout << "number of samples exceeded maximum allowed by type: "
<< std::numeric_limits<DocIdx>::max() << endl;
throw std::overflow_error("Number of samples");
}
map_index += 1;
prev_start_index = sent_index + 1;
seq_len = get_sample_len(short_seq_ratio, max_seq_length);
size = 0;
num_sent = 0;
}
}
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
if (!second) {
cout << " number of samples: " <<
map_index << endl;
cout << " number of empty documents: " <<
empty_docs << endl;
cout << " number of documents with one sentence: " <<
one_sent_docs << endl;
maps = new DocIdx[3*map_index];
num_samples = map_index;
}
} // for (int iteration=0; iteration < 2; ++iteration) {
// Shuffle.
for (auto i=(num_samples - 1); i > 0; --i) {
const auto j = rand() % (i + 1);
const auto i0 = 3 * i;
const auto j0 = 3 * j;
// Swap values.
swap(maps[i0], maps[j0]);
swap(maps[i0 + 1], maps[j0 + 1]);
swap(maps[i0 + 2], maps[j0 + 2]);
} }
// Mapping and it's length (1D). cout << " > done building the mapping." << endl;
int num_samples = -1;
uint32_t* maps = NULL;
// Perform two iterations, in the first iteration get the size // Method to deallocate memory.
// and allocate memory and in the second iteration populate the map. py::capsule free_when_done(maps, [](void *mem_) {
bool second = false; DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
for (int iteration=0; iteration < 2; ++iteration) { cout << "freeing memory for the dataset mapping" << endl;
delete[] mem;
});
// Set the seed so both iterations produce the same results. // Return the numpy array.
srand(seed); return py::array(std::vector<int64_t>{num_samples, 3}, // shape
{3*4, 4}, // C-style contiguous strides
maps, // the data pointer
free_when_done); // numpy array references
// Set the flag on second iteration. }
if (iteration == 1) {
second = true;
}
// Counters: py::array build_mapping(const py::array& docs_,
uint32_t empty_docs = 0; const py::array& sizes_,
uint32_t one_sent_docs = 0; const int num_epochs,
const uint64_t max_num_samples,
// Current map index. const int max_seq_length,
uint64_t map_index = 0; const double short_seq_prob,
const int seed) {
// For each epoch: if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
for (int epoch=0; epoch < num_epochs; ++epoch) { return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs, max_num_samples,
if (map_index >= max_num_samples) { max_seq_length, short_seq_prob, seed);
cout << " > reached " << max_num_samples << " samples after " << } else {
epoch << " epochs ..." << endl; return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs, max_num_samples,
break; max_seq_length, short_seq_prob, seed);
}
// For each document:
for (int doc=0; doc < (docs.shape(0) - 1); ++doc) {
// Document sentences are in [sent_index_first, sent_index_last).
const uint32_t sent_index_first = docs[doc];
const uint32_t sent_index_last = docs[doc + 1];
// At the begining of the document previous index is the start index.
uint32_t prev_start_index = sent_index_first;
// Remaining documents.
uint32_t num_remain_sent = sent_index_last - sent_index_first;
// Some bookkeeping
if ((epoch == 0) && (!second)) {
if (num_remain_sent == 0) {
cout << "***WARNING*** document " << doc << " is empty" << endl;
empty_docs += 1;
}
if (num_remain_sent == 1) {
cout << "***WARNING*** document " << doc <<
" has one sentence" << endl;
one_sent_docs += 1;
}
}
// If we have more than two sentences.
if (num_remain_sent > 1) {
// Set values.
uint32_t size = 0;
uint32_t num_sent = 0;
uint32_t seq_len = get_sample_len(short_seq_ratio, max_seq_length);
// Loop through sentences.
for (uint32_t sent_index=sent_index_first;
sent_index < sent_index_last; ++sent_index) {
// Add the size and number of sentences.
size += sizes[sent_index];
num_sent += 1;
num_remain_sent -= 1;
// If we have reached the target length.
// and if not only one sentence is left in the document.
// and if we have at least two sentneces.
// and if we have reached end of the document.
if (((size >= seq_len) && (num_remain_sent > 1) &&
(num_sent > 1) ) || (num_remain_sent == 0)) {
// Populate the map.
if (second) {
const uint64_t map_index_0 = 3 * map_index;
maps[map_index_0] = prev_start_index;
maps[map_index_0 + 1] = sent_index + 1;
maps[map_index_0 + 2] = seq_len;
}
// Update indices / counters.
map_index += 1;
prev_start_index = sent_index + 1;
seq_len = get_sample_len(short_seq_ratio, max_seq_length);
size = 0;
num_sent = 0;
}
}
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
// For now only support mappings up to MAX_INT.
if (map_index > std::numeric_limits<int>::max()) {
cout << "number of samples ("<< map_index <<") exceeded MAX_INT" << endl;
throw(-1);
}
else if (!second) {
cout << " number of samples: " <<
map_index << endl;
cout << " number of empty documents: " <<
empty_docs << endl;
cout << " number of documents with one sentence: " <<
one_sent_docs << endl;
maps = new uint32_t[3*map_index];
num_samples = int(map_index);
} }
} // for (int iteration=0; iteration < 2; ++iteration) {
// Shuffle.
for (int i=(num_samples - 1); i > 0; --i) {
const int j = rand() % (i + 1);
uint64_t i0 = 3 * i;
uint64_t j0 = 3 * j;
// Swap values.
swap(maps[i0], maps[j0]);
swap(maps[i0 + 1], maps[j0 + 1]);
swap(maps[i0 + 2], maps[j0 + 2]);
}
cout << " > done building the mapping." << endl;
// Method to deallocate memory.
py::capsule free_when_done(maps, [](void *mem_) {
uint32_t *mem = reinterpret_cast<uint32_t *>(mem_);
cout << "freeing memory for the dataset mapping" << endl;
delete[] mem;
});
// Return the numpy array.
return py::array_t<uint32_t>({num_samples, 3}, // shape
{3*4, 4}, // C-style contiguous strides
maps, // the data pointer
free_when_done); // numpy array references
} }
PYBIND11_MODULE(helpers, m) { PYBIND11_MODULE(helpers, m) {
m.def("build_mapping", &build_mapping); m.def("build_mapping", &build_mapping);
} }
...@@ -458,7 +458,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset): ...@@ -458,7 +458,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
if self._index.dtype != np.int64: if self._index.dtype != np.int64:
np_array = np_array.astype(np.int64) np_array = np_array.astype(np.int64)
return torch.from_numpy(np_array) return np_array
elif isinstance(idx, slice): elif isinstance(idx, slice):
start, stop, step = idx.indices(len(self)) start, stop, step = idx.indices(len(self))
if step != 1: if step != 1:
......
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
script_dir = os.path.dirname(os.path.realpath(__file__)) script_dir = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(script_dir, "../../../")) sys.path.append(os.path.join(script_dir, "../../../"))
from megatron.data import indexed_dataset, FullBertTokenizer from megatron.data import indexed_dataset, FullBertTokenizer, AlbertDataset
def test_indexed_dataset(args): def test_indexed_dataset(args):
ds = indexed_dataset.make_dataset(args.data, args.dataset_impl) ds = indexed_dataset.make_dataset(args.data, args.dataset_impl)
...@@ -31,18 +31,47 @@ def test_indexed_dataset(args): ...@@ -31,18 +31,47 @@ def test_indexed_dataset(args):
print("Newline in string!") print("Newline in string!")
print(i) print(i)
def test_albert_dataset(args):
# tokenizer = FullBertTokenizer(args.vocab, do_lower_case=True)
# idataset = indexed_dataset.make_dataset(args.data, args.dataset_impl)
# ds = AlbertDataset(idataset, tokenizer)
ds = AlbertDataset.from_paths(args.vocab, args.data, args.dataset_impl,
args.epochs, args.max_num_samples,
args.masked_lm_prob, args.seq_length,
args.short_seq_prob, args.seed)
truncated = 0
total = 0
for s in ds:
ids = s['text']
tokens = ds.tokenizer.convert_ids_to_tokens(ids)
print(tokens)
exit()
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, help='prefix to data files') parser.add_argument('--data', type=str, help='prefix to data files')
parser.add_argument('--vocab', type=str, help='Path to vocab.txt') parser.add_argument('--vocab', type=str, help='Path to vocab.txt')
parser.add_argument('--dataset-impl', type=str, default='infer', parser.add_argument('--dataset-impl', type=str, default='infer',
choices=['lazy', 'cached', 'mmap', 'infer']) choices=['lazy', 'cached', 'mmap', 'infer'])
parser.add_argument('--epochs', type=int, default=5,
help='Number of epochs to plan for')
parser.add_argument('--max-num-samples', type=int, default=None,
help='Maximum number of samples to plan for')
parser.add_argument('--masked-lm-prob', type=float, default=0.15,
help='probability of masking tokens')
parser.add_argument('--seq-length', type=int, default=512,
help='maximum sequence length')
parser.add_argument('--short-seq-prob', type=float, default=0.1,
help='probability of creating a short sequence')
parser.add_argument('--seed', type=int, default=1234,
help='random seed')
args = parser.parse_args() args = parser.parse_args()
if args.dataset_impl == "infer": if args.dataset_impl == "infer":
args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data) args.dataset_impl = indexed_dataset.infer_dataset_impl(args.data)
test_indexed_dataset(args) test_albert_dataset(args)
# test_indexed_dataset(args)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment