Commit 05620ee4 authored by Neel Kant's avatar Neel Kant
Browse files

Merge branch 'ict-merge' into 'master'

ICT code

See merge request ADLR/megatron-lm!90
parents c20f4d48 5247f24c
......@@ -16,6 +16,7 @@ For BERT training, we swapped the position of the layer normalization and the re
- [BERT Pretraining](#bert-pretraining)
- [GPT-2 Pretraining](#gpt-2-pretraining)
- [Distributed BERT or GPT-2 Pretraining](#distributed-bert-or-gpt-2-pretraining)
- [REALM Pipeline](#realm)
- [Evaluation and Tasks](#evaluation-and-tasks)
- [GPT-2 Text Generation](#gpt-2-text-generation)
- [GPT-2 Evaluation](#gpt-2-evaluation)
......@@ -263,6 +264,61 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./pretrain_gpt2.py \
</pre>
<a id="realm"></a>
## REALM Pipeline
The following sections (will) reflect the three stages of training a REALM system. For now it's just the ICT code.
Loosely, they are pretraining the retriever modules, then jointly training the language model and the retriever, and then finetuning a question answering head on the language model with fixed retriever.
### Inverse Cloze Task (ICT) Pretraining
1. Have a corpus in loose JSON format with the intention of creating a collection of fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block but also multiple blocks per document.
Run `tools/preprocess_data.py` to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. For the original REALM system, we construct two datasets, one with the title of every document, and another with the body.
Refer to the following script
<pre>
python preprocess_data.py \
--input /path/to/corpus.json \
--json-keys text title \
--split-sentences \
--tokenizer-type BertWordPieceLowerCase \
--vocab-file /path/to/vocab.txt \
--output-prefix corpus_indexed \
--workers 5 # works well for 10 CPU cores. Scale up accordingly.
</pre>
2. Use a custom samples mapping function in place of `megatron/data/realm_dataset_utils.get_block_samples_mapping` if required. To do this, you will need to implement a new function in C++ inside of `megatron/data/helpers.cpp`. The samples mapping data structure is used to select the data that will constitute every training sample in advance of the training loop.
The samples mapping is responsible for holding all of the required metadata needed to construct the sample from one or more indexed datasets. In REALM, the samples mapping contains the start and end sentence indices, as well as the document index (to find the correct title for a body) and a unique ID for every block.
3. Pretrain a BERT language model using `pretrain_bert.py`, with the sequence length equal to the block size in token ids. This model should be trained on the same indexed dataset that is used to supply the blocks for the information retrieval task.
In REALM, this is an uncased bert base model trained with the standard hyperparameters.
4. Use `pretrain_ict.py` to train an `ICTBertModel` which uses two BERT-based encoders to encode queries and blocks to perform retrieval with.
The script below trains the ICT model from REALM. It refrences a pretrained BERT model (step 3) in the `--bert-load` argument. The batch size used in the paper is 4096, so this would need to be run with data parallel world size 32.
<pre>
python pretrain_ict.py \
--num-layers 12 \
--num-attention-heads 12 \
--hidden-size 768 \
--batch-size 128 \
--seq-length 256 \
--max-position-embeddings 256 \
--ict-head-size 128 \
--train-iters 100000 \
--checkpoint-activations \
--bert-load /path/to/pretrained_bert \
--load checkpoints \
--save checkpoints \
--data-path /path/to/indexed_dataset \
--titles-data-path /path/to/titles_indexed_dataset \
--vocab-file /path/to/vocab.txt \
--lr 0.0001 \
--num-workers 2 \
--lr-decay-style linear \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--warmup .01 \
--save-interval 3000 \
--query-in-block-prob 0.1 \
--fp16
</pre>
<a id="evaluation-and-tasks"></a>
# Evaluation and Tasks
......
......@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from .package_info import (
__description__,
......@@ -30,7 +31,6 @@ from .global_vars import get_tensorboard_writer
from .global_vars import get_adlr_autoresume
from .global_vars import get_timers
import torch
def print_rank_0(message):
"""If distributed is initialized print only on rank 0."""
......@@ -38,4 +38,4 @@ def print_rank_0(message):
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)
print(message, flush=True)
\ No newline at end of file
......@@ -37,6 +37,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_validation_args(parser)
parser = _add_data_args(parser)
parser = _add_autoresume_args(parser)
parser = _add_realm_args(parser)
# Custom arguments.
if extra_args_provider is not None:
......@@ -390,3 +391,32 @@ def _add_autoresume_args(parser):
'termination signal')
return parser
def _add_realm_args(parser):
group = parser.add_argument_group(title='realm')
# network size
group.add_argument('--ict-head-size', type=int, default=None,
help='Size of block embeddings to be used in ICT and REALM (paper default: 128)')
# checkpointing
group.add_argument('--ict-load', type=str, default=None,
help='Directory containing an ICTBertModel checkpoint')
group.add_argument('--bert-load', type=str, default=None,
help='Directory containing an BertModel checkpoint (needed to start ICT and REALM)')
# data
group.add_argument('--titles-data-path', type=str, default=None,
help='Path to titles dataset used for ICT')
group.add_argument('--query-in-block-prob', type=float, default=0.1,
help='Probability of keeping query in block for ICT dataset')
group.add_argument('--ict-one-sent', action='store_true',
help='Whether to use one sentence documents in ICT')
# training
group.add_argument('--report-topk-accuracies', nargs='+', default=[],
help="Which top-k accuracies to report (e.g. '1 5 20')")
return parser
......@@ -128,14 +128,15 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
torch.distributed.barrier()
def load_checkpoint(model, optimizer, lr_scheduler):
def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
"""Load a model checkpoint and return the iteration."""
args = get_args()
load_dir = getattr(args, load_arg)
if isinstance(model, torchDDP):
model = model.module
# Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(args.load)
tracker_filename = get_checkpoint_tracker_filename(load_dir)
# If no tracker file, return iretation zero.
if not os.path.isfile(tracker_filename):
......@@ -164,7 +165,7 @@ def load_checkpoint(model, optimizer, lr_scheduler):
tracker_filename)
# Checkpoint.
checkpoint_name = get_checkpoint_name(args.load, iteration, release)
checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
......
......@@ -22,81 +22,14 @@ import numpy as np
import torch
from torch.utils.data import Dataset
from megatron import get_tokenizer
from megatron import mpu
from megatron.data.dataset_utils import build_training_sample
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron import get_tokenizer, get_args
from megatron import print_rank_0
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup):
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
skip_warmup)
# Get start and end indices of train/valid/train into doc-idx
# Note that doc-idx is desinged to be num-docs + 1 so we can
# easily iterate over it.
total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
# Print stats about the splits.
print_rank_0(' > dataset split:')
def print_split_stats(name, index):
print_rank_0(' {}:'.format(name))
print_rank_0(' document indices in [{}, {}) total of {} '
'documents'.format(splits[index], splits[index + 1],
splits[index + 1] - splits[index]))
start_index = indexed_dataset.doc_idx[splits[index]]
end_index = indexed_dataset.doc_idx[splits[index + 1]]
print_rank_0(' sentence indices in [{}, {}) total of {} '
'sentences'.format(start_index, end_index,
end_index - start_index))
print_split_stats('train', 0)
print_split_stats('validation', 1)
print_split_stats('test', 2)
def build_dataset(index, name):
dataset = None
if splits[index + 1] > splits[index]:
# Get the pointer to the original doc-idx so we can set it later.
doc_idx_ptr = indexed_dataset.get_doc_idx()
# Slice the doc-idx
start_index = splits[index]
# Add +1 so we can index into the dataset to get the upper bound.
end_index = splits[index + 1] + 1
# New doc_idx view.
indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index])
# Build the dataset accordingly.
dataset = BertDataset(
name=name,
indexed_dataset=indexed_dataset,
data_prefix=data_prefix,
num_epochs=None,
max_num_samples=train_valid_test_num_samples[index],
masked_lm_prob=masked_lm_prob,
max_seq_length=max_seq_length,
short_seq_prob=short_seq_prob,
seed=seed)
# Set the original pointer so dataset remains the main dataset.
indexed_dataset.set_doc_idx(doc_idx_ptr)
# Checks.
assert indexed_dataset.doc_idx[0] == 0
assert indexed_dataset.doc_idx.shape[0] == \
(total_num_of_documents + 1)
return dataset
train_dataset = build_dataset(0, 'train')
valid_dataset = build_dataset(1, 'valid')
test_dataset = build_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset)
from megatron import mpu
from megatron.data.dataset_utils import get_a_and_b_segments
from megatron.data.dataset_utils import truncate_segments
from megatron.data.dataset_utils import create_tokens_and_tokentypes
from megatron.data.dataset_utils import pad_and_convert_to_numpy
from megatron.data.dataset_utils import create_masked_lm_predictions
class BertDataset(Dataset):
......@@ -137,11 +70,8 @@ class BertDataset(Dataset):
return self.samples_mapping.shape[0]
def __getitem__(self, idx):
start_index, end_index, seq_length = self.samples_mapping[idx]
sample = []
for index in range(start_index, end_index):
sample.append(self.indexed_dataset[index])
start_idx, end_idx, seq_length = self.samples_mapping[idx]
sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)]
# Note that this rng state should be numpy and not python since
# python randint is inclusive whereas the numpy one is exclusive.
np_rng = np.random.RandomState(seed=(self.seed + idx))
......@@ -154,55 +84,6 @@ class BertDataset(Dataset):
self.masked_lm_prob, np_rng)
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
print_rank_0(' > building dataset index ...')
start_time = time.time()
indexed_dataset = make_indexed_dataset(data_prefix,
data_impl,
skip_warmup)
assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1]
print_rank_0(' > finished creating indexed dataset in {:4f} '
'seconds'.format(time.time() - start_time))
print_rank_0(' > indexed dataset stats:')
print_rank_0(' number of documents: {}'.format(
indexed_dataset.doc_idx.shape[0] - 1))
print_rank_0(' number of sentences: {}'.format(
indexed_dataset.sizes.shape[0]))
return indexed_dataset
def get_train_valid_test_split_(splits_string, size):
""" Get dataset splits from comma or '/' separated string list."""
splits = []
if splits_string.find(',') != -1:
splits = [float(s) for s in splits_string.split(',')]
elif splits_string.find('/') != -1:
splits = [float(s) for s in splits_string.split('/')]
else:
splits = [float(splits_string)]
while len(splits) < 3:
splits.append(0.)
splits = splits[:3]
splits_sum = sum(splits)
assert splits_sum > 0.0
splits = [split / splits_sum for split in splits]
splits_index = [0]
for index, split in enumerate(splits):
splits_index.append(splits_index[index] +
int(round(split * float(size))))
diff = splits_index[-1] - size
for index in range(1, len(splits_index)):
splits_index[index] -= diff
assert len(splits_index) == 4
assert splits_index[-1] == size
return splits_index
def get_samples_mapping_(indexed_dataset,
data_prefix,
num_epochs,
......@@ -286,3 +167,66 @@ def get_samples_mapping_(indexed_dataset,
samples_mapping.shape[0]))
return samples_mapping
def build_training_sample(sample,
target_seq_length, max_seq_length,
vocab_id_list, vocab_id_to_token_dict,
cls_id, sep_id, mask_id, pad_id,
masked_lm_prob, np_rng):
"""Biuld training sample.
Arguments:
sample: A list of sentences in which each sentence is a list token ids.
target_seq_length: Desired sequence length.
max_seq_length: Maximum length of the sequence. All values are padded to
this length.
vocab_id_list: List of vocabulary ids. Used to pick a random id.
vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
cls_id: Start of example id.
sep_id: Separator id.
mask_id: Mask token id.
pad_id: Padding token id.
masked_lm_prob: Probability to mask tokens.
np_rng: Random number genenrator. Note that this rng state should be
numpy and not python since python randint is inclusive for
the opper bound whereas the numpy one is exclusive.
"""
# We assume that we have at least two sentences in the sample
assert len(sample) > 1
assert target_seq_length <= max_seq_length
# Divide sample into two segments (A and B).
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, np_rng)
# Truncate to `target_sequence_length`.
max_num_tokens = target_seq_length
truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a),
len(tokens_b), max_num_tokens, np_rng)
# Build tokens and toketypes.
tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b,
cls_id, sep_id)
# Masking.
max_predictions_per_seq = masked_lm_prob * max_num_tokens
(tokens, masked_positions, masked_labels, _) = create_masked_lm_predictions(
tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng)
# Padding.
tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \
= pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
masked_labels, pad_id, max_seq_length)
train_sample = {
'text': tokens_np,
'types': tokentypes_np,
'labels': labels_np,
'is_random': int(is_next_random),
'loss_mask': loss_mask_np,
'padding_mask': padding_mask_np,
'truncated': int(truncated)}
return train_sample
......@@ -18,8 +18,17 @@
# https://github.com/google-research/albert/blob/master/create_pretraining_data.py
# with some modifications.
import time
import collections
import numpy as np
from megatron import get_args, print_rank_0
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
DSET_TYPE_STD = 'standard_bert'
DSET_TYPE_ICT = 'ict'
DSET_TYPES = [DSET_TYPE_ICT, DSET_TYPE_STD]
def compile_helper():
......@@ -35,68 +44,6 @@ def compile_helper():
sys.exit(1)
def build_training_sample(sample,
target_seq_length, max_seq_length,
vocab_id_list, vocab_id_to_token_dict,
cls_id, sep_id, mask_id, pad_id,
masked_lm_prob, np_rng):
"""Biuld training sample.
Arguments:
sample: A list of sentences in which each sentence is a list token ids.
target_seq_length: Desired sequence length.
max_seq_length: Maximum length of the sequence. All values are padded to
this length.
vocab_id_list: List of vocabulary ids. Used to pick a random id.
vocab_id_to_token_dict: A dictionary from vocab ids to text tokens.
cls_id: Start of example id.
sep_id: Separator id.
mask_id: Mask token id.
pad_id: Padding token id.
masked_lm_prob: Probability to mask tokens.
np_rng: Random number genenrator. Note that this rng state should be
numpy and not python since python randint is inclusive for
the opper bound whereas the numpy one is exclusive.
"""
# We assume that we have at least two sentences in the sample
assert len(sample) > 1
assert target_seq_length <= max_seq_length
# Divide sample into two segments (A and B).
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, np_rng)
# Truncate to `target_sequence_length`.
max_num_tokens = target_seq_length
truncated = truncate_segments(tokens_a, tokens_b, len(tokens_a),
len(tokens_b), max_num_tokens, np_rng)
# Build tokens and toketypes.
tokens, tokentypes = create_tokens_and_tokentypes(tokens_a, tokens_b,
cls_id, sep_id)
# Masking.
max_predictions_per_seq = masked_lm_prob * max_num_tokens
(tokens, masked_positions, masked_labels, _) = create_masked_lm_predictions(
tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng)
# Padding.
tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \
= pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
masked_labels, pad_id, max_seq_length)
train_sample = {
'text': tokens_np,
'types': tokentypes_np,
'labels': labels_np,
'is_random': int(is_next_random),
'loss_mask': loss_mask_np,
'padding_mask': padding_mask_np,
'truncated': int(truncated)}
return train_sample
def get_a_and_b_segments(sample, np_rng):
"""Divide sample into a and b segments."""
......@@ -370,6 +317,7 @@ def create_masked_lm_predictions(tokens,
for p in masked_lms:
masked_lm_positions.append(p.index)
masked_lm_labels.append(p.label)
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary)
......@@ -404,3 +352,152 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
loss_mask_np = np.array(loss_mask, dtype=np.int64)
return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup,
dataset_type='standard_bert'):
if dataset_type not in DSET_TYPES:
raise ValueError("Invalid dataset_type: ", dataset_type)
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
skip_warmup)
if dataset_type == DSET_TYPE_ICT:
args = get_args()
title_dataset = get_indexed_dataset_(args.titles_data_path,
data_impl,
skip_warmup)
# Get start and end indices of train/valid/train into doc-idx
# Note that doc-idx is desinged to be num-docs + 1 so we can
# easily iterate over it.
total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
# Print stats about the splits.
print_rank_0(' > dataset split:')
def print_split_stats(name, index):
print_rank_0(' {}:'.format(name))
print_rank_0(' document indices in [{}, {}) total of {} '
'documents'.format(splits[index], splits[index + 1],
splits[index + 1] - splits[index]))
start_index = indexed_dataset.doc_idx[splits[index]]
end_index = indexed_dataset.doc_idx[splits[index + 1]]
print_rank_0(' sentence indices in [{}, {}) total of {} '
'sentences'.format(start_index, end_index,
end_index - start_index))
print_split_stats('train', 0)
print_split_stats('validation', 1)
print_split_stats('test', 2)
def build_dataset(index, name):
from megatron.data.bert_dataset import BertDataset
from megatron.data.ict_dataset import ICTDataset
dataset = None
if splits[index + 1] > splits[index]:
# Get the pointer to the original doc-idx so we can set it later.
doc_idx_ptr = indexed_dataset.get_doc_idx()
# Slice the doc-idx
start_index = splits[index]
# Add +1 so we can index into the dataset to get the upper bound.
end_index = splits[index + 1] + 1
# New doc_idx view.
indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index])
# Build the dataset accordingly.
kwargs = dict(
name=name,
data_prefix=data_prefix,
num_epochs=None,
max_num_samples=train_valid_test_num_samples[index],
max_seq_length=max_seq_length,
short_seq_prob=short_seq_prob,
seed=seed
)
if dataset_type == DSET_TYPE_ICT:
args = get_args()
dataset = ICTDataset(
block_dataset=indexed_dataset,
title_dataset=title_dataset,
query_in_block_prob=args.query_in_block_prob,
use_one_sent_docs=args.ict_one_sent,
**kwargs
)
else:
dataset = BertDataset(
indexed_dataset=indexed_dataset,
masked_lm_prob=masked_lm_prob,
**kwargs
)
# Set the original pointer so dataset remains the main dataset.
indexed_dataset.set_doc_idx(doc_idx_ptr)
# Checks.
assert indexed_dataset.doc_idx[0] == 0
assert indexed_dataset.doc_idx.shape[0] == \
(total_num_of_documents + 1)
return dataset
train_dataset = build_dataset(0, 'train')
valid_dataset = build_dataset(1, 'valid')
test_dataset = build_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset)
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
print_rank_0(' > building dataset index ...')
start_time = time.time()
indexed_dataset = make_indexed_dataset(data_prefix,
data_impl,
skip_warmup)
assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1]
print_rank_0(' > finished creating indexed dataset in {:4f} '
'seconds'.format(time.time() - start_time))
print_rank_0(' > indexed dataset stats:')
print_rank_0(' number of documents: {}'.format(
indexed_dataset.doc_idx.shape[0] - 1))
print_rank_0(' number of sentences: {}'.format(
indexed_dataset.sizes.shape[0]))
return indexed_dataset
def get_train_valid_test_split_(splits_string, size):
""" Get dataset splits from comma or '/' separated string list."""
splits = []
if splits_string.find(',') != -1:
splits = [float(s) for s in splits_string.split(',')]
elif splits_string.find('/') != -1:
splits = [float(s) for s in splits_string.split('/')]
else:
splits = [float(splits_string)]
while len(splits) < 3:
splits.append(0.)
splits = splits[:3]
splits_sum = sum(splits)
assert splits_sum > 0.0
splits = [split / splits_sum for split in splits]
splits_index = [0]
for index, split in enumerate(splits):
splits_index.append(splits_index[index] +
int(round(split * float(size))))
diff = splits_index[-1] - size
for index in range(1, len(splits_index)):
splits_index[index] -= diff
assert len(splits_index) == 4
assert splits_index[-1] == size
return splits_index
......@@ -21,9 +21,8 @@ import time
import numpy as np
import torch
from megatron import print_rank_0
from megatron import mpu
from megatron.data.bert_dataset import get_train_valid_test_split_
from megatron import mpu, print_rank_0
from megatron.data.dataset_utils import get_train_valid_test_split_
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
......
......@@ -393,8 +393,251 @@ py::array build_mapping(const py::array_t<int64_t>& docs_,
}
}
template<typename DocIdx>
py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
const py::array_t<int32_t>& sizes_,
const py::array_t<int32_t>& titles_sizes_,
const int32_t num_epochs,
const uint64_t max_num_samples,
const int32_t max_seq_length,
const int32_t seed,
const bool verbose,
const bool use_one_sent_blocks) {
/* Build a mapping of (start-index, end-index, sequence-length) where
start and end index are the indices of the sentences in the sample
and sequence-length is the target sequence length.
*/
// Consistency checks.
assert(num_epochs > 0);
assert(max_seq_length > 1);
assert(seed > 0);
// Remove bound checks.
auto docs = docs_.unchecked<1>();
auto sizes = sizes_.unchecked<1>();
auto titles_sizes = titles_sizes_.unchecked<1>();
if (verbose) {
const auto sent_start_index = docs[0];
const auto sent_end_index = docs[docs_.shape(0) - 1];
const auto num_sentences = sent_end_index - sent_start_index;
cout << " using:" << endl << std::flush;
cout << " number of documents: " << docs_.shape(0) - 1 <<
endl << std::flush;
cout << " sentences range: [" << sent_start_index <<
", " << sent_end_index << ")" << endl << std::flush;
cout << " total number of sentences: " << num_sentences <<
endl << std::flush;
cout << " number of epochs: " << num_epochs <<
endl << std::flush;
cout << " maximum number of samples: " << max_num_samples <<
endl << std::flush;
cout << " maximum sequence length: " << max_seq_length <<
endl << std::flush;
cout << " seed: " << seed << endl <<
std::flush;
}
// Mapping and its length (1D).
int64_t num_samples = -1;
DocIdx* maps = NULL;
// Acceptable number of sentences per block.
int min_num_sent = 2;
if (use_one_sent_blocks) {
min_num_sent = 1;
}
// Perform two iterations, in the first iteration get the size
// and allocate memory and in the second iteration populate the map.
bool second = false;
for (int32_t iteration=0; iteration<2; ++iteration) {
// Set the flag on second iteration.
second = (iteration == 1);
// Current map index.
uint64_t map_index = 0;
uint64_t empty_docs = 0;
uint64_t one_sent_docs = 0;
uint64_t long_sent_docs = 0;
// For each epoch:
for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
// assign every block a unique id
int32_t block_id = 0;
if (map_index >= max_num_samples) {
if (verbose && (!second)) {
cout << " reached " << max_num_samples << " samples after "
<< epoch << " epochs ..." << endl << std::flush;
}
break;
}
// For each document:
for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) {
// Document sentences are in [sent_index_first, sent_index_last)
const auto sent_index_first = docs[doc];
const auto sent_index_last = docs[doc + 1];
const auto target_seq_len = max_seq_length - titles_sizes[doc];
// At the begining of the document previous index is the
// start index.
auto prev_start_index = sent_index_first;
// Remaining documents.
auto num_remain_sent = sent_index_last - sent_index_first;
// Some bookkeeping
if ((epoch == 0) && (!second)) {
if (num_remain_sent == 0) {
++empty_docs;
}
if (num_remain_sent == 1) {
++one_sent_docs;
}
}
// Detect documents with long sentences.
bool contains_long_sentence = false;
if (num_remain_sent >= min_num_sent) {
for (auto sent_index=sent_index_first;
sent_index < sent_index_last; ++sent_index) {
if (sizes[sent_index] > LONG_SENTENCE_LEN){
if ((epoch == 0) && (!second)) {
++long_sent_docs;
}
contains_long_sentence = true;
break;
}
}
}
// If we have enough sentences and no long sentences.
if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
// Set values.
auto seq_len = int32_t{0};
auto num_sent = int32_t{0};
// Loop through sentences.
for (auto sent_index=sent_index_first;
sent_index < sent_index_last; ++sent_index) {
// Add the size and number of sentences.
seq_len += sizes[sent_index];
++num_sent;
--num_remain_sent;
// If we have reached the target length.
// and there are an acceptable number of sentences left
// and if we have at least the minimum number of sentences.
// or if we have reached end of the document.
if (((seq_len >= target_seq_len) &&
(num_remain_sent >= min_num_sent) &&
(num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
// Populate the map.
if (second) {
const auto map_index_0 = 4 * map_index;
// Each sample has 4 items: the starting sentence index, ending sentence index,
// the index of the document from which the block comes (used for fetching titles)
// and the unique id of the block (used for creating block indexes)
maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
maps[map_index_0 + 2] = static_cast<DocIdx>(doc);
maps[map_index_0 + 3] = static_cast<DocIdx>(block_id);
}
// Update indices / counters.
++map_index;
++block_id;
prev_start_index = sent_index + 1;
seq_len = 0;
num_sent = 0;
}
} // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
if (!second) {
if (verbose) {
cout << " number of empty documents: " << empty_docs <<
endl << std::flush;
cout << " number of documents with one sentence: " <<
one_sent_docs << endl << std::flush;
cout << " number of documents with long sentences: " <<
long_sent_docs << endl << std::flush;
cout << " will create mapping for " << map_index <<
" samples" << endl << std::flush;
}
assert(maps == NULL);
assert(num_samples < 0);
maps = new DocIdx[4*map_index];
num_samples = static_cast<int64_t>(map_index);
}
} // for (int iteration=0; iteration < 2; ++iteration) {
// Shuffle.
// We need a 64 bit random number generator as we might have more
// than 2 billion samples.
std::mt19937_64 rand64_gen(seed + 1);
for (auto i=(num_samples - 1); i > 0; --i) {
const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
const auto i0 = 4 * i;
const auto j0 = 4 * j;
// Swap values.
swap(maps[i0], maps[j0]);
swap(maps[i0 + 1], maps[j0 + 1]);
swap(maps[i0 + 2], maps[j0 + 2]);
swap(maps[i0 + 3], maps[j0 + 3]);
}
// Method to deallocate memory.
py::capsule free_when_done(maps, [](void *mem_) {
DocIdx *mem = reinterpret_cast<DocIdx*>(mem_);
delete[] mem;
});
// Return the numpy array.
const auto byte_size = sizeof(DocIdx);
return py::array(std::vector<int64_t>{num_samples, 4}, // shape
{4*byte_size, byte_size}, // C-style contiguous strides
maps, // the data pointer
free_when_done); // numpy array references
}
py::array build_blocks_mapping(const py::array_t<int64_t>& docs_,
const py::array_t<int>& sizes_,
const py::array_t<int>& titles_sizes_,
const int num_epochs,
const uint64_t max_num_samples,
const int max_seq_length,
const int seed,
const bool verbose,
const bool use_one_sent_blocks) {
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
if (verbose) {
cout << " using uint64 for data mapping..." << endl << std::flush;
}
return build_blocks_mapping_impl<uint64_t>(docs_, sizes_, titles_sizes_,
num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
} else {
if (verbose) {
cout << " using uint32 for data mapping..." << endl << std::flush;
}
return build_blocks_mapping_impl<uint32_t>(docs_, sizes_, titles_sizes_,
num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
}
}
PYBIND11_MODULE(helpers, m) {
m.def("build_mapping", &build_mapping);
m.def("build_blocks_mapping", &build_blocks_mapping);
m.def("build_sample_idx", &build_sample_idx);
}
import itertools
import random
import numpy as np
from torch.utils.data import Dataset
from megatron import get_tokenizer
from megatron.data.realm_dataset_utils import get_block_samples_mapping
class ICTDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, block_dataset, title_dataset, data_prefix,
num_epochs, max_num_samples, max_seq_length,
query_in_block_prob, short_seq_prob, seed, use_titles=True, use_one_sent_docs=False):
self.name = name
self.seed = seed
self.max_seq_length = max_seq_length
self.query_in_block_prob = query_in_block_prob
self.block_dataset = block_dataset
self.title_dataset = title_dataset
self.short_seq_prob = short_seq_prob
self.rng = random.Random(self.seed)
self.use_titles = use_titles
self.use_one_sent_docs = use_one_sent_docs
self.samples_mapping = get_block_samples_mapping(
block_dataset, title_dataset, data_prefix, num_epochs,
max_num_samples, max_seq_length, seed, name, use_one_sent_docs)
self.tokenizer = get_tokenizer()
self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
self.vocab_id_to_token_list = self.tokenizer.inv_vocab
self.cls_id = self.tokenizer.cls
self.sep_id = self.tokenizer.sep
self.mask_id = self.tokenizer.mask
self.pad_id = self.tokenizer.pad
def __len__(self):
return self.samples_mapping.shape[0]
def __getitem__(self, idx):
"""Get an ICT example of a pseudo-query and the block of text from which it was extracted"""
start_idx, end_idx, doc_idx, block_idx = self.samples_mapping[idx]
if self.use_titles:
title = self.title_dataset[int(doc_idx)]
title_pad_offset = 3 + len(title)
else:
title = None
title_pad_offset = 2
block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
assert len(block) > 1 or self.use_one_sent_docs
# randint() is inclusive for Python rng
rand_sent_idx = self.rng.randint(0, len(block) - 1)
# keep the query in the context query_in_block_prob fraction of the time.
if self.rng.random() < self.query_in_block_prob:
query = block[rand_sent_idx].copy()
else:
query = block.pop(rand_sent_idx)
# still need to truncate because blocks are concluded when
# the sentence lengths have exceeded max_seq_length.
query = query[:self.max_seq_length - 2]
block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset]
query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
block_data = np.array([start_idx, end_idx, doc_idx, block_idx]).astype(np.int64)
sample = {
'query_tokens': query_tokens,
'query_pad_mask': query_pad_mask,
'block_tokens': block_tokens,
'block_pad_mask': block_pad_mask,
'block_data': block_data,
}
return sample
def get_block(self, start_idx, end_idx, doc_idx):
"""Get the IDs for an evidence block plus the title of the corresponding document"""
block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
title = self.title_dataset[int(doc_idx)]
block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
return block_tokens, block_pad_mask
def get_null_block(self):
"""Get empty block and title - used in REALM pretraining"""
block, title = [], []
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
return block_tokens, block_pad_mask
def concat_and_pad_tokens(self, tokens, title=None):
"""Concat with special tokens and pad sequence to self.max_seq_length"""
tokens = list(tokens)
if title is None:
tokens = [self.cls_id] + tokens + [self.sep_id]
else:
title = list(title)
tokens = [self.cls_id] + title + [self.sep_id] + tokens + [self.sep_id]
assert len(tokens) <= self.max_seq_length
num_pad = self.max_seq_length - len(tokens)
pad_mask = [1] * len(tokens) + [0] * num_pad
tokens += [self.pad_id] * num_pad
return np.array(tokens), np.array(pad_mask)
import os
import time
import numpy as np
import torch
from megatron import mpu, print_rank_0
def join_str_list(str_list):
"""Join a list of strings, handling spaces appropriately"""
result = ""
for s in str_list:
if s.startswith("##"):
result += s[2:]
else:
result += " " + s
return result
def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs,
max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False):
"""Get samples mapping for a dataset over fixed size blocks. This function also requires
a dataset of the titles for the source documents since their lengths must be taken into account."""
if not num_epochs:
if not max_num_samples:
raise ValueError("Need to specify either max_num_samples "
"or num_epochs")
num_epochs = np.iinfo(np.int32).max - 1
if not max_num_samples:
max_num_samples = np.iinfo(np.int64).max - 1
# Filename of the index mapping
indexmap_filename = data_prefix
indexmap_filename += '_{}_indexmap'.format(name)
if num_epochs != (np.iinfo(np.int32).max - 1):
indexmap_filename += '_{}ep'.format(num_epochs)
if max_num_samples != (np.iinfo(np.int64).max - 1):
indexmap_filename += '_{}mns'.format(max_num_samples)
indexmap_filename += '_{}msl'.format(max_seq_length)
indexmap_filename += '_{}s'.format(seed)
if use_one_sent_docs:
indexmap_filename += '_1sentok'
indexmap_filename += '.npy'
# Build the indexed mapping if not exist.
if mpu.get_data_parallel_rank() == 0 and \
not os.path.isfile(indexmap_filename):
print(' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'.format(indexmap_filename))
# Make sure the types match the helpers input types.
assert block_dataset.doc_idx.dtype == np.int64
assert block_dataset.sizes.dtype == np.int32
# Build samples mapping
verbose = torch.distributed.get_rank() == 0
start_time = time.time()
print_rank_0(' > building samples index mapping for {} ...'.format(
name))
from megatron.data.dataset_utils import compile_helper
compile_helper()
from megatron.data import helpers
samples_mapping = helpers.build_blocks_mapping(
block_dataset.doc_idx,
block_dataset.sizes,
title_dataset.sizes,
num_epochs,
max_num_samples,
max_seq_length-3, # account for added tokens
seed,
verbose,
use_one_sent_docs)
print_rank_0(' > done building samples index mapping')
np.save(indexmap_filename, samples_mapping, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format(
indexmap_filename))
# Make sure all the ranks have built the mapping
print_rank_0(' > elapsed time to build and save samples mapping '
'(seconds): {:4f}'.format(
time.time() - start_time))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
indexmap_filename))
start_time = time.time()
samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
print_rank_0(' total number of samples: {}'.format(
samples_mapping.shape[0]))
return samples_mapping
......@@ -15,5 +15,6 @@
from .distributed import *
from .bert_model import BertModel
from megatron.model.realm_model import ICTBertModel
from .gpt2_model import GPT2Model
from .utils import get_params_for_weight_decay_optimization
......@@ -74,7 +74,7 @@ class BertLMHead(MegatronModule):
hidden_size: hidden size
init_method: init method for weight initialization
layernorm_epsilon: tolerance for layer norm divisions
parallel_output: wether output logits being distributed or not.
parallel_output: whether output logits being distributed or not.
"""
def __init__(self, mpu_vocab_size, hidden_size, init_method,
......@@ -83,7 +83,7 @@ class BertLMHead(MegatronModule):
super(BertLMHead, self).__init__()
args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
self.bias.model_parallel = True
self.bias.partition_dim = 0
......@@ -131,10 +131,8 @@ class BertModel(MegatronModule):
self.lm_head = BertLMHead(
self.language_model.embedding.word_embeddings.weight.size(0),
args.hidden_size, init_method, args.layernorm_epsilon,
parallel_output)
args.hidden_size, init_method, args.layernorm_epsilon, parallel_output)
self._lm_head_key = 'lm_head'
if self.add_binary_head:
self.binary_head = get_linear_layer(args.hidden_size, 2,
init_method)
......@@ -188,10 +186,10 @@ class BertModel(MegatronModule):
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
destination, prefix, keep_vars)
state_dict_[self._lm_head_key] \
= self.lm_head.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
destination, prefix, keep_vars)
if self.add_binary_head:
state_dict_[self._binary_head_key] \
= self.binary_head.state_dict(destination, prefix, keep_vars)
......@@ -202,8 +200,8 @@ class BertModel(MegatronModule):
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
self.lm_head.load_state_dict(state_dict[self._lm_head_key],
strict=strict)
self.lm_head.load_state_dict(
state_dict[self._lm_head_key], strict=strict)
if self.add_binary_head:
self.binary_head.load_state_dict(state_dict[self._binary_head_key],
strict=strict)
self.binary_head.load_state_dict(
state_dict[self._binary_head_key], strict=strict)
......@@ -17,16 +17,13 @@
import torch
from megatron import get_args
from megatron.model.bert_model import bert_attention_mask_func
from megatron.model.bert_model import bert_extended_attention_mask
from megatron.model.bert_model import bert_position_ids
from megatron import get_args, print_rank_0
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from megatron.module import MegatronModule
from megatron import print_rank_0
class Classification(MegatronModule):
......
......@@ -17,16 +17,13 @@
import torch
from megatron import get_args
from megatron.model.bert_model import bert_attention_mask_func
from megatron.model.bert_model import bert_extended_attention_mask
from megatron.model.bert_model import bert_position_ids
from megatron import get_args, print_rank_0
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from megatron.module import MegatronModule
from megatron import print_rank_0
class MultipleChoice(MegatronModule):
......
import os
import torch
from megatron import get_args
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.model import BertModel
from megatron.module import MegatronModule
from megatron import mpu
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.language_model import get_language_model
from megatron.model.utils import scaled_init_method_normal
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
class ICTBertModel(MegatronModule):
"""Bert-based module for Inverse Cloze task."""
def __init__(self,
ict_head_size,
num_tokentypes=1,
parallel_output=True,
only_query_model=False,
only_block_model=False):
super(ICTBertModel, self).__init__()
bert_kwargs = dict(
ict_head_size=ict_head_size,
num_tokentypes=num_tokentypes,
parallel_output=parallel_output
)
assert not (only_block_model and only_query_model)
self.use_block_model = not only_query_model
self.use_query_model = not only_block_model
if self.use_query_model:
# this model embeds (pseudo-)queries - Embed_input in the paper
self.query_model = IREncoderBertModel(**bert_kwargs)
self._query_key = 'question_model'
if self.use_block_model:
# this model embeds evidence blocks - Embed_doc in the paper
self.block_model = IREncoderBertModel(**bert_kwargs)
self._block_key = 'context_model'
def forward(self, query_tokens, query_attention_mask, block_tokens, block_attention_mask):
"""Run a forward pass for each of the models and return the respective embeddings."""
query_logits = self.embed_query(query_tokens, query_attention_mask)
block_logits = self.embed_block(block_tokens, block_attention_mask)
return query_logits, block_logits
def embed_query(self, query_tokens, query_attention_mask):
"""Embed a batch of tokens using the query model"""
if self.use_query_model:
query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
query_ict_logits, _ = self.query_model.forward(query_tokens, query_attention_mask, query_types)
return query_ict_logits
else:
raise ValueError("Cannot embed query without query model.")
def embed_block(self, block_tokens, block_attention_mask):
"""Embed a batch of tokens using the block model"""
if self.use_block_model:
block_types = torch.cuda.LongTensor(*block_tokens.shape).fill_(0)
block_ict_logits, _ = self.block_model.forward(block_tokens, block_attention_mask, block_types)
return block_ict_logits
else:
raise ValueError("Cannot embed block without block model.")
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
"""Save dict with state dicts of each of the models."""
state_dict_ = {}
if self.use_query_model:
state_dict_[self._query_key] \
= self.query_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.use_block_model:
state_dict_[self._block_key] \
= self.block_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Load the state dicts of each of the models"""
if self.use_query_model:
print("Loading ICT query model", flush=True)
self.query_model.load_state_dict(
state_dict[self._query_key], strict=strict)
if self.use_block_model:
print("Loading ICT block model", flush=True)
self.block_model.load_state_dict(
state_dict[self._block_key], strict=strict)
def init_state_dict_from_bert(self):
"""Initialize the state from a pretrained BERT model on iteration zero of ICT pretraining"""
args = get_args()
tracker_filename = get_checkpoint_tracker_filename(args.bert_load)
if not os.path.isfile(tracker_filename):
raise FileNotFoundError("Could not find BERT load for ICT")
with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip())
assert iteration > 0
checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
try:
state_dict = torch.load(checkpoint_name, map_location='cpu')
except BaseException:
raise ValueError("Could not load checkpoint")
# load the LM state dict into each model
model_dict = state_dict['model']['language_model']
self.query_model.language_model.load_state_dict(model_dict)
self.block_model.language_model.load_state_dict(model_dict)
# give each model the same ict_head to begin with as well
query_ict_head_state_dict = self.state_dict_for_save_checkpoint()[self._query_key]['ict_head']
self.block_model.ict_head.load_state_dict(query_ict_head_state_dict)
class IREncoderBertModel(MegatronModule):
"""BERT-based encoder for queries or blocks used for learned information retrieval."""
def __init__(self, ict_head_size, num_tokentypes=2, parallel_output=True):
super(IREncoderBertModel, self).__init__()
args = get_args()
self.ict_head_size = ict_head_size
self.parallel_output = parallel_output
init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=True,
init_method=init_method,
scaled_init_method=scaled_init_method)
self.ict_head = get_linear_layer(args.hidden_size, ict_head_size, init_method)
self._ict_head_key = 'ict_head'
def forward(self, input_ids, attention_mask, tokentype_ids=None):
extended_attention_mask = bert_extended_attention_mask(
attention_mask, next(self.language_model.parameters()).dtype)
position_ids = bert_position_ids(input_ids)
lm_output, pooled_output = self.language_model(
input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids)
# Output.
ict_logits = self.ict_head(pooled_output)
return ict_logits, None
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._ict_head_key] \
= self.ict_head.state_dict(destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
self.ict_head.load_state_dict(
state_dict[self._ict_head_key], strict=strict)
......@@ -155,6 +155,20 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
text_tokens = self.tokenizer.tokenize(text)
return self.tokenizer.convert_tokens_to_ids(text_tokens)
def decode_token_ids(self, token_ids):
tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
exclude_list = ['[PAD]', '[CLS]']
non_pads = [t for t in tokens if t not in exclude_list]
result = ""
for s in non_pads:
if s.startswith("##"):
result += s[2:]
else:
result += " " + s
return result
@property
def cls(self):
return self.cls_id
......
......@@ -18,7 +18,6 @@
from datetime import datetime
import math
import sys
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from apex.optimizers import FusedAdam as Adam
......@@ -36,6 +35,7 @@ from megatron.initialize import initialize_megatron
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization
from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import make_data_loader
from megatron.utils import report_memory
......@@ -70,6 +70,7 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
# Initalize and get arguments, timers, and Tensorboard writer.
initialize_megatron(extra_args_provider=extra_args_provider,
args_defaults=args_defaults)
args = get_args()
timers = get_timers()
......@@ -92,10 +93,9 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
iteration = 0
if args.do_train and args.train_iters > 0:
if args.do_train:
iteration, _ = train(forward_step_func,
model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator)
iteration, _ = train(forward_step_func,
model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator)
if args.do_valid:
prefix = 'the end of training for val data'
......@@ -218,6 +218,15 @@ def setup_model_and_optimizer(model_provider_func):
else:
args.iteration = 0
# get model without FP16 and/or TorchDDP wrappers
unwrapped_model = model
while hasattr(unwrapped_model, 'module'):
unwrapped_model = unwrapped_model.module
if args.iteration == 0 and hasattr(unwrapped_model, 'init_state_dict_from_bert'):
print("Initializing ICT from pretrained BERT model", flush=True)
unwrapped_model.init_state_dict_from_bert()
return model, optimizer, lr_scheduler
......
......@@ -20,9 +20,9 @@ import sys
import torch
from megatron import get_args
from megatron import print_rank_0
from megatron import get_adlr_autoresume
from megatron import mpu
from megatron import print_rank_0
from megatron.checkpointing import save_checkpoint
from megatron.data.samplers import DistributedBatchSampler
from megatron.fp16 import FP16_Optimizer
......@@ -173,3 +173,5 @@ def get_ltor_masks_and_position_ids(data,
attention_mask = (attention_mask < 0.5)
return attention_mask, loss_mask, position_ids
......@@ -19,10 +19,10 @@ import torch
import torch.nn.functional as F
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import mpu
from megatron import print_rank_0
from megatron.data.bert_dataset import build_train_valid_test_datasets
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import BertModel
from megatron.training import pretrain
from megatron.utils import reduce_losses
......
......@@ -18,10 +18,10 @@
import torch
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import get_tokenizer
from megatron import mpu
from megatron import print_rank_0
from megatron.data.gpt2_dataset import build_train_valid_test_datasets
from megatron.model import GPT2Model
from megatron.training import pretrain
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment