Commit dcff1acd authored by Mohammad Shoeybi's avatar Mohammad Shoeybi Committed by Jared Casper
Browse files

Adding option to remove the binary head for BERT

parent 1aa2e08a
...@@ -262,6 +262,9 @@ def _add_network_size_args(parser): ...@@ -262,6 +262,9 @@ def _add_network_size_args(parser):
'reasons.') 'reasons.')
group.add_argument('--onnx-safe', type=bool, required=False, group.add_argument('--onnx-safe', type=bool, required=False,
help='Use workarounds for known problems with Torch ONNX exporter') help='Use workarounds for known problems with Torch ONNX exporter')
group.add_argument('--bert-no-binary-head', action='store_false',
help='Disable BERT binary head.',
dest='bert_binary_head')
return parser return parser
......
...@@ -36,13 +36,14 @@ class BertDataset(Dataset): ...@@ -36,13 +36,14 @@ class BertDataset(Dataset):
def __init__(self, name, indexed_dataset, data_prefix, def __init__(self, name, indexed_dataset, data_prefix,
num_epochs, max_num_samples, masked_lm_prob, num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed): max_seq_length, short_seq_prob, seed, binary_head):
# Params to store. # Params to store.
self.name = name self.name = name
self.seed = seed self.seed = seed
self.masked_lm_prob = masked_lm_prob self.masked_lm_prob = masked_lm_prob
self.max_seq_length = max_seq_length self.max_seq_length = max_seq_length
self.binary_head = binary_head
# Dataset. # Dataset.
self.indexed_dataset = indexed_dataset self.indexed_dataset = indexed_dataset
...@@ -55,7 +56,8 @@ class BertDataset(Dataset): ...@@ -55,7 +56,8 @@ class BertDataset(Dataset):
self.max_seq_length, self.max_seq_length,
short_seq_prob, short_seq_prob,
self.seed, self.seed,
self.name) self.name,
self.binary_head)
# Vocab stuff. # Vocab stuff.
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
...@@ -81,7 +83,8 @@ class BertDataset(Dataset): ...@@ -81,7 +83,8 @@ class BertDataset(Dataset):
self.vocab_id_to_token_dict, self.vocab_id_to_token_dict,
self.cls_id, self.sep_id, self.cls_id, self.sep_id,
self.mask_id, self.pad_id, self.mask_id, self.pad_id,
self.masked_lm_prob, np_rng) self.masked_lm_prob, np_rng,
self.binary_head)
def get_samples_mapping_(indexed_dataset, def get_samples_mapping_(indexed_dataset,
...@@ -91,7 +94,8 @@ def get_samples_mapping_(indexed_dataset, ...@@ -91,7 +94,8 @@ def get_samples_mapping_(indexed_dataset,
max_seq_length, max_seq_length,
short_seq_prob, short_seq_prob,
seed, seed,
name): name,
binary_head):
if not num_epochs: if not num_epochs:
if not max_num_samples: if not max_num_samples:
raise ValueError("Need to specify either max_num_samples " raise ValueError("Need to specify either max_num_samples "
...@@ -137,7 +141,8 @@ def get_samples_mapping_(indexed_dataset, ...@@ -137,7 +141,8 @@ def get_samples_mapping_(indexed_dataset,
max_seq_length - 3, # account for added tokens max_seq_length - 3, # account for added tokens
short_seq_prob, short_seq_prob,
seed, seed,
verbose) verbose,
2 if binary_head else 1)
print_rank_0(' > done building sapmles index maping') print_rank_0(' > done building sapmles index maping')
np.save(indexmap_filename, samples_mapping, allow_pickle=True) np.save(indexmap_filename, samples_mapping, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format( print_rank_0(' > saved the index mapping in {}'.format(
...@@ -173,7 +178,7 @@ def build_training_sample(sample, ...@@ -173,7 +178,7 @@ def build_training_sample(sample,
target_seq_length, max_seq_length, target_seq_length, max_seq_length,
vocab_id_list, vocab_id_to_token_dict, vocab_id_list, vocab_id_to_token_dict,
cls_id, sep_id, mask_id, pad_id, cls_id, sep_id, mask_id, pad_id,
masked_lm_prob, np_rng): masked_lm_prob, np_rng, binary_head):
"""Biuld training sample. """Biuld training sample.
Arguments: Arguments:
...@@ -193,12 +198,21 @@ def build_training_sample(sample, ...@@ -193,12 +198,21 @@ def build_training_sample(sample,
the opper bound whereas the numpy one is exclusive. the opper bound whereas the numpy one is exclusive.
""" """
# We assume that we have at least two sentences in the sample if binary_head:
assert len(sample) > 1 # We assume that we have at least two sentences in the sample
assert len(sample) > 1
assert target_seq_length <= max_seq_length assert target_seq_length <= max_seq_length
# Divide sample into two segments (A and B). # Divide sample into two segments (A and B).
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, np_rng) if binary_head:
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample,
np_rng)
else:
tokens_a = []
for j in range(len(sample)):
tokens_a.extend(sample[j])
tokens_b = []
is_next_random = False
# Truncate to `target_sequence_length`. # Truncate to `target_sequence_length`.
max_num_tokens = target_seq_length max_num_tokens = target_seq_length
......
...@@ -114,7 +114,6 @@ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng): ...@@ -114,7 +114,6 @@ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
"""Truncates a pair of sequences to a maximum sequence length.""" """Truncates a pair of sequences to a maximum sequence length."""
#print(len_a, len_b, max_num_tokens) #print(len_a, len_b, max_num_tokens)
assert len_a > 0 assert len_a > 0
assert len_b > 0
if len_a + len_b <= max_num_tokens: if len_a + len_b <= max_num_tokens:
return False return False
while len_a + len_b > max_num_tokens: while len_a + len_b > max_num_tokens:
...@@ -150,10 +149,11 @@ def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): ...@@ -150,10 +149,11 @@ def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
for token in tokens_b: for token in tokens_b:
tokens.append(token) tokens.append(token)
tokentypes.append(1) tokentypes.append(1)
# [SEP]. if tokens_b:
tokens.append(sep_id) # [SEP].
tokentypes.append(1) tokens.append(sep_id)
tokentypes.append(1)
return tokens, tokentypes return tokens, tokentypes
...@@ -392,6 +392,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -392,6 +392,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples, train_valid_test_num_samples,
max_seq_length, masked_lm_prob, max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup, short_seq_prob, seed, skip_warmup,
binary_head,
dataset_type='standard_bert'): dataset_type='standard_bert'):
if len(data_prefix) == 1: if len(data_prefix) == 1:
...@@ -401,6 +402,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -401,6 +402,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
max_seq_length, masked_lm_prob, max_seq_length, masked_lm_prob,
short_seq_prob, seed, short_seq_prob, seed,
skip_warmup, skip_warmup,
binary_head,
dataset_type=dataset_type) dataset_type=dataset_type)
# Blending dataset. # Blending dataset.
# Parse the values. # Parse the values.
...@@ -417,7 +419,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -417,7 +419,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
prefixes[i], data_impl, splits_string, prefixes[i], data_impl, splits_string,
datasets_train_valid_test_num_samples[i], datasets_train_valid_test_num_samples[i],
max_seq_length, masked_lm_prob, short_seq_prob, max_seq_length, masked_lm_prob, short_seq_prob,
seed, skip_warmup, dataset_type=dataset_type) seed, skip_warmup, binary_head, dataset_type=dataset_type)
if train_ds: if train_ds:
train_datasets.append(train_ds) train_datasets.append(train_ds)
if valid_ds: if valid_ds:
...@@ -444,6 +446,7 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -444,6 +446,7 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples, train_valid_test_num_samples,
max_seq_length, masked_lm_prob, max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup, short_seq_prob, seed, skip_warmup,
binary_head,
dataset_type='standard_bert'): dataset_type='standard_bert'):
if dataset_type not in DSET_TYPES: if dataset_type not in DSET_TYPES:
...@@ -503,7 +506,8 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -503,7 +506,8 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
num_epochs=None, num_epochs=None,
max_num_samples=train_valid_test_num_samples[index], max_num_samples=train_valid_test_num_samples[index],
max_seq_length=max_seq_length, max_seq_length=max_seq_length,
seed=seed seed=seed,
binary_head=binary_head
) )
if dataset_type == DSET_TYPE_ICT: if dataset_type == DSET_TYPE_ICT:
......
...@@ -189,6 +189,9 @@ inline int32_t get_target_sample_len(const int32_t short_seq_ratio, ...@@ -189,6 +189,9 @@ inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
const int32_t max_length, const int32_t max_length,
std::mt19937& rand32_gen) { std::mt19937& rand32_gen) {
/* Training sample length. */ /* Training sample length. */
if (short_seq_ratio == 0) {
return max_length;
}
const auto random_number = rand32_gen(); const auto random_number = rand32_gen();
if ((random_number % short_seq_ratio) == 0) { if ((random_number % short_seq_ratio) == 0) {
return 2 + random_number % (max_length - 1); return 2 + random_number % (max_length - 1);
...@@ -205,7 +208,8 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -205,7 +208,8 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
const int32_t max_seq_length, const int32_t max_seq_length,
const double short_seq_prob, const double short_seq_prob,
const int32_t seed, const int32_t seed,
const bool verbose) { const bool verbose,
const int32_t min_num_sent) {
/* Build a mapping of (start-index, end-index, sequence-length) where /* Build a mapping of (start-index, end-index, sequence-length) where
start and end index are the indices of the sentences in the sample start and end index are the indices of the sentences in the sample
and sequence-length is the target sequence length. and sequence-length is the target sequence length.
...@@ -214,7 +218,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -214,7 +218,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
// Consistency checks. // Consistency checks.
assert(num_epochs > 0); assert(num_epochs > 0);
assert(max_seq_length > 1); assert(max_seq_length > 1);
assert(short_seq_prob > 0.0); assert(short_seq_prob >= 0.0);
assert(short_seq_prob <= 1.0); assert(short_seq_prob <= 1.0);
assert(seed > 0); assert(seed > 0);
...@@ -223,7 +227,10 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -223,7 +227,10 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
auto sizes = sizes_.unchecked<1>(); auto sizes = sizes_.unchecked<1>();
// For efficiency, convert probability to ratio. Note: rand() generates int. // For efficiency, convert probability to ratio. Note: rand() generates int.
const auto short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob)); int32_t short_seq_ratio = 0;
if (short_seq_prob > 0) {
short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));
}
if (verbose) { if (verbose) {
const auto sent_start_index = docs[0]; const auto sent_start_index = docs[0];
...@@ -322,7 +329,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -322,7 +329,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
} }
// If we have more than two sentences. // If we have more than two sentences.
if ((num_remain_sent > 1) && (!contains_long_sentence)) { if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
// Set values. // Set values.
auto seq_len = int32_t{0}; auto seq_len = int32_t{0};
...@@ -346,7 +353,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -346,7 +353,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
// and if we have reached end of the document. // and if we have reached end of the document.
if (((seq_len >= target_seq_len) && if (((seq_len >= target_seq_len) &&
(num_remain_sent > 1) && (num_remain_sent > 1) &&
(num_sent > 1) ) || (num_remain_sent == 0)) { (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
// Check for overflow. // Check for overflow.
if ((3 * map_index + 2) > if ((3 * map_index + 2) >
...@@ -437,7 +444,8 @@ py::array build_mapping(const py::array_t<int64_t>& docs_, ...@@ -437,7 +444,8 @@ py::array build_mapping(const py::array_t<int64_t>& docs_,
const int max_seq_length, const int max_seq_length,
const double short_seq_prob, const double short_seq_prob,
const int seed, const int seed,
const bool verbose) { const bool verbose,
const int32_t min_num_sent) {
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) { if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
if (verbose) { if (verbose) {
...@@ -445,14 +453,16 @@ py::array build_mapping(const py::array_t<int64_t>& docs_, ...@@ -445,14 +453,16 @@ py::array build_mapping(const py::array_t<int64_t>& docs_,
} }
return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs, return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs,
max_num_samples, max_seq_length, max_num_samples, max_seq_length,
short_seq_prob, seed, verbose); short_seq_prob, seed, verbose,
min_num_sent);
} else { } else {
if (verbose) { if (verbose) {
cout << " using uint32 for data mapping..." << endl << std::flush; cout << " using uint32 for data mapping..." << endl << std::flush;
} }
return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs, return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs,
max_num_samples, max_seq_length, max_num_samples, max_seq_length,
short_seq_prob, seed, verbose); short_seq_prob, seed, verbose,
min_num_sent);
} }
} }
......
...@@ -23,7 +23,10 @@ from megatron import print_rank_0 ...@@ -23,7 +23,10 @@ from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import BertModel, BertModelFirstStage, BertModelIntermediateStage, BertModelLastStage from megatron.model import (BertModel,
BertModelFirstStage,
BertModelIntermediateStage,
BertModelLastStage)
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
...@@ -34,23 +37,24 @@ def model_provider(): ...@@ -34,23 +37,24 @@ def model_provider():
print_rank_0('building BERT model ...') print_rank_0('building BERT model ...')
args = get_args() args = get_args()
num_tokentypes = 2 if args.bert_binary_head else 0
if mpu.get_pipeline_model_parallel_world_size() > 1: if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline. # Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
model = BertModelFirstStage( model = BertModelFirstStage(
num_tokentypes=2) num_tokentypes=num_tokentypes)
elif mpu.is_pipeline_last_stage(): elif mpu.is_pipeline_last_stage():
model = BertModelLastStage( model = BertModelLastStage(
num_tokentypes=2, num_tokentypes=num_tokentypes,
add_binary_head=True, add_binary_head=args.bert_binary_head,
parallel_output=True) parallel_output=True)
else: else:
model = BertModelIntermediateStage( model = BertModelIntermediateStage(
num_tokentypes=2) num_tokentypes=num_tokentypes)
else: else:
model = BertModel( model = BertModel(
num_tokentypes=2, num_tokentypes=num_tokentypes,
add_binary_head=True, add_binary_head=args.bert_binary_head,
parallel_output=True) parallel_output=True)
return model return model
...@@ -92,6 +96,9 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -92,6 +96,9 @@ def forward_step(data_iterator, model, input_tensor):
= get_batch(data_iterator) = get_batch(data_iterator)
timers('batch-generator').stop() timers('batch-generator').stop()
if not args.bert_binary_head:
types = None
# Forward pass through the model. # Forward pass through the model.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
assert input_tensor is None assert input_tensor is None
...@@ -109,22 +116,29 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -109,22 +116,29 @@ def forward_step(data_iterator, model, input_tensor):
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
lm_loss_, sop_logits = output_tensor lm_loss_, sop_logits = output_tensor
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1),
ignore_index=-1)
sop_loss = sop_loss.float()
lm_loss_ = lm_loss_.float() lm_loss_ = lm_loss_.float()
loss_mask = loss_mask.float() loss_mask = loss_mask.float()
lm_loss = torch.sum( lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
if sop_logits is not None:
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1),
ignore_index=-1)
sop_loss = sop_loss.float()
loss = lm_loss + sop_loss
averaged_losses = average_losses_across_data_parallel_group(
[lm_loss, sop_loss])
return loss, {'lm loss': averaged_losses[0],
'sop loss': averaged_losses[1]}
else:
loss = lm_loss
averaged_losses = average_losses_across_data_parallel_group(
[lm_loss])
return loss, {'lm loss': averaged_losses[0]}
loss = lm_loss + sop_loss
averaged_losses = average_losses_across_data_parallel_group([lm_loss, sop_loss])
return loss, {'lm loss': averaged_losses[0], 'sop loss': averaged_losses[1]}
return output_tensor return output_tensor
...@@ -143,7 +157,8 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -143,7 +157,8 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
masked_lm_prob=args.mask_prob, masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob, short_seq_prob=args.short_seq_prob,
seed=args.seed, seed=args.seed,
skip_warmup=(not args.mmap_warmup)) skip_warmup=(not args.mmap_warmup),
binary_head=args.bert_binary_head)
print_rank_0("> finished creating BERT datasets ...") print_rank_0("> finished creating BERT datasets ...")
return train_ds, valid_ds, test_ds return train_ds, valid_ds, test_ds
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment