Commit 6afaffa5 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'bert_not_remove_single_sent_option' into 'main'

Adding option to remove the binary head for BERT

See merge request ADLR/megatron-lm!211
parents 1aa2e08a dcff1acd
......@@ -262,6 +262,9 @@ def _add_network_size_args(parser):
'reasons.')
group.add_argument('--onnx-safe', type=bool, required=False,
help='Use workarounds for known problems with Torch ONNX exporter')
group.add_argument('--bert-no-binary-head', action='store_false',
help='Disable BERT binary head.',
dest='bert_binary_head')
return parser
......
......@@ -36,13 +36,14 @@ class BertDataset(Dataset):
def __init__(self, name, indexed_dataset, data_prefix,
num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed):
max_seq_length, short_seq_prob, seed, binary_head):
# Params to store.
self.name = name
self.seed = seed
self.masked_lm_prob = masked_lm_prob
self.max_seq_length = max_seq_length
self.binary_head = binary_head
# Dataset.
self.indexed_dataset = indexed_dataset
......@@ -55,7 +56,8 @@ class BertDataset(Dataset):
self.max_seq_length,
short_seq_prob,
self.seed,
self.name)
self.name,
self.binary_head)
# Vocab stuff.
tokenizer = get_tokenizer()
......@@ -81,7 +83,8 @@ class BertDataset(Dataset):
self.vocab_id_to_token_dict,
self.cls_id, self.sep_id,
self.mask_id, self.pad_id,
self.masked_lm_prob, np_rng)
self.masked_lm_prob, np_rng,
self.binary_head)
def get_samples_mapping_(indexed_dataset,
......@@ -91,7 +94,8 @@ def get_samples_mapping_(indexed_dataset,
max_seq_length,
short_seq_prob,
seed,
name):
name,
binary_head):
if not num_epochs:
if not max_num_samples:
raise ValueError("Need to specify either max_num_samples "
......@@ -137,7 +141,8 @@ def get_samples_mapping_(indexed_dataset,
max_seq_length - 3, # account for added tokens
short_seq_prob,
seed,
verbose)
verbose,
2 if binary_head else 1)
print_rank_0(' > done building sapmles index maping')
np.save(indexmap_filename, samples_mapping, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format(
......@@ -173,7 +178,7 @@ def build_training_sample(sample,
target_seq_length, max_seq_length,
vocab_id_list, vocab_id_to_token_dict,
cls_id, sep_id, mask_id, pad_id,
masked_lm_prob, np_rng):
masked_lm_prob, np_rng, binary_head):
"""Biuld training sample.
Arguments:
......@@ -193,12 +198,21 @@ def build_training_sample(sample,
the opper bound whereas the numpy one is exclusive.
"""
# We assume that we have at least two sentences in the sample
assert len(sample) > 1
if binary_head:
# We assume that we have at least two sentences in the sample
assert len(sample) > 1
assert target_seq_length <= max_seq_length
# Divide sample into two segments (A and B).
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, np_rng)
if binary_head:
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample,
np_rng)
else:
tokens_a = []
for j in range(len(sample)):
tokens_a.extend(sample[j])
tokens_b = []
is_next_random = False
# Truncate to `target_sequence_length`.
max_num_tokens = target_seq_length
......
......@@ -114,7 +114,6 @@ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
"""Truncates a pair of sequences to a maximum sequence length."""
#print(len_a, len_b, max_num_tokens)
assert len_a > 0
assert len_b > 0
if len_a + len_b <= max_num_tokens:
return False
while len_a + len_b > max_num_tokens:
......@@ -150,10 +149,11 @@ def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
for token in tokens_b:
tokens.append(token)
tokentypes.append(1)
# [SEP].
tokens.append(sep_id)
tokentypes.append(1)
if tokens_b:
# [SEP].
tokens.append(sep_id)
tokentypes.append(1)
return tokens, tokentypes
......@@ -392,6 +392,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup,
binary_head,
dataset_type='standard_bert'):
if len(data_prefix) == 1:
......@@ -401,6 +402,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
max_seq_length, masked_lm_prob,
short_seq_prob, seed,
skip_warmup,
binary_head,
dataset_type=dataset_type)
# Blending dataset.
# Parse the values.
......@@ -417,7 +419,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
prefixes[i], data_impl, splits_string,
datasets_train_valid_test_num_samples[i],
max_seq_length, masked_lm_prob, short_seq_prob,
seed, skip_warmup, dataset_type=dataset_type)
seed, skip_warmup, binary_head, dataset_type=dataset_type)
if train_ds:
train_datasets.append(train_ds)
if valid_ds:
......@@ -444,6 +446,7 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup,
binary_head,
dataset_type='standard_bert'):
if dataset_type not in DSET_TYPES:
......@@ -503,7 +506,8 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
num_epochs=None,
max_num_samples=train_valid_test_num_samples[index],
max_seq_length=max_seq_length,
seed=seed
seed=seed,
binary_head=binary_head
)
if dataset_type == DSET_TYPE_ICT:
......
......@@ -189,6 +189,9 @@ inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
const int32_t max_length,
std::mt19937& rand32_gen) {
/* Training sample length. */
if (short_seq_ratio == 0) {
return max_length;
}
const auto random_number = rand32_gen();
if ((random_number % short_seq_ratio) == 0) {
return 2 + random_number % (max_length - 1);
......@@ -205,7 +208,8 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
const int32_t max_seq_length,
const double short_seq_prob,
const int32_t seed,
const bool verbose) {
const bool verbose,
const int32_t min_num_sent) {
/* Build a mapping of (start-index, end-index, sequence-length) where
start and end index are the indices of the sentences in the sample
and sequence-length is the target sequence length.
......@@ -214,7 +218,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
// Consistency checks.
assert(num_epochs > 0);
assert(max_seq_length > 1);
assert(short_seq_prob > 0.0);
assert(short_seq_prob >= 0.0);
assert(short_seq_prob <= 1.0);
assert(seed > 0);
......@@ -223,7 +227,10 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
auto sizes = sizes_.unchecked<1>();
// For efficiency, convert probability to ratio. Note: rand() generates int.
const auto short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));
int32_t short_seq_ratio = 0;
if (short_seq_prob > 0) {
short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));
}
if (verbose) {
const auto sent_start_index = docs[0];
......@@ -322,7 +329,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
}
// If we have more than two sentences.
if ((num_remain_sent > 1) && (!contains_long_sentence)) {
if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
// Set values.
auto seq_len = int32_t{0};
......@@ -346,7 +353,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
// and if we have reached end of the document.
if (((seq_len >= target_seq_len) &&
(num_remain_sent > 1) &&
(num_sent > 1) ) || (num_remain_sent == 0)) {
(num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
// Check for overflow.
if ((3 * map_index + 2) >
......@@ -437,7 +444,8 @@ py::array build_mapping(const py::array_t<int64_t>& docs_,
const int max_seq_length,
const double short_seq_prob,
const int seed,
const bool verbose) {
const bool verbose,
const int32_t min_num_sent) {
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
if (verbose) {
......@@ -445,14 +453,16 @@ py::array build_mapping(const py::array_t<int64_t>& docs_,
}
return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs,
max_num_samples, max_seq_length,
short_seq_prob, seed, verbose);
short_seq_prob, seed, verbose,
min_num_sent);
} else {
if (verbose) {
cout << " using uint32 for data mapping..." << endl << std::flush;
}
return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs,
max_num_samples, max_seq_length,
short_seq_prob, seed, verbose);
short_seq_prob, seed, verbose,
min_num_sent);
}
}
......
......@@ -23,7 +23,10 @@ from megatron import print_rank_0
from megatron import get_timers
from megatron import mpu
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import BertModel, BertModelFirstStage, BertModelIntermediateStage, BertModelLastStage
from megatron.model import (BertModel,
BertModelFirstStage,
BertModelIntermediateStage,
BertModelLastStage)
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
......@@ -34,23 +37,24 @@ def model_provider():
print_rank_0('building BERT model ...')
args = get_args()
num_tokentypes = 2 if args.bert_binary_head else 0
if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage():
model = BertModelFirstStage(
num_tokentypes=2)
num_tokentypes=num_tokentypes)
elif mpu.is_pipeline_last_stage():
model = BertModelLastStage(
num_tokentypes=2,
add_binary_head=True,
num_tokentypes=num_tokentypes,
add_binary_head=args.bert_binary_head,
parallel_output=True)
else:
model = BertModelIntermediateStage(
num_tokentypes=2)
num_tokentypes=num_tokentypes)
else:
model = BertModel(
num_tokentypes=2,
add_binary_head=True,
num_tokentypes=num_tokentypes,
add_binary_head=args.bert_binary_head,
parallel_output=True)
return model
......@@ -92,6 +96,9 @@ def forward_step(data_iterator, model, input_tensor):
= get_batch(data_iterator)
timers('batch-generator').stop()
if not args.bert_binary_head:
types = None
# Forward pass through the model.
if mpu.is_pipeline_first_stage():
assert input_tensor is None
......@@ -109,22 +116,29 @@ def forward_step(data_iterator, model, input_tensor):
if mpu.is_pipeline_last_stage():
lm_loss_, sop_logits = output_tensor
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1),
ignore_index=-1)
sop_loss = sop_loss.float()
lm_loss_ = lm_loss_.float()
loss_mask = loss_mask.float()
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
if sop_logits is not None:
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1),
ignore_index=-1)
sop_loss = sop_loss.float()
loss = lm_loss + sop_loss
averaged_losses = average_losses_across_data_parallel_group(
[lm_loss, sop_loss])
return loss, {'lm loss': averaged_losses[0],
'sop loss': averaged_losses[1]}
else:
loss = lm_loss
averaged_losses = average_losses_across_data_parallel_group(
[lm_loss])
return loss, {'lm loss': averaged_losses[0]}
loss = lm_loss + sop_loss
averaged_losses = average_losses_across_data_parallel_group([lm_loss, sop_loss])
return loss, {'lm loss': averaged_losses[0], 'sop loss': averaged_losses[1]}
return output_tensor
......@@ -143,7 +157,8 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=(not args.mmap_warmup))
skip_warmup=(not args.mmap_warmup),
binary_head=args.bert_binary_head)
print_rank_0("> finished creating BERT datasets ...")
return train_ds, valid_ds, test_ds
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment