Commit d31d7e49 authored by mohammad's avatar mohammad
Browse files

Merge branch 'main' into log_grad_norm

parents 00764415 c0400306
...@@ -370,8 +370,7 @@ python tools/create_doc_index.py \ ...@@ -370,8 +370,7 @@ python tools/create_doc_index.py \
We provide several command line arguments, detailed in the scripts listed below, to handle various zero-shot and fine-tuned downstream tasks. However, you can also finetune your model from a pretrained checkpoint on other corpora as desired. To do so, simply add the `--finetune` flag and adjust the input files and training parameters within the original training script. The iteration count will be reset to zero, and the optimizer and internal state will be reinitialized. If the fine-tuning is interrupted for any reason, be sure to remove the `--finetune` flag before continuing, otherwise the training will start again from the beginning. We provide several command line arguments, detailed in the scripts listed below, to handle various zero-shot and fine-tuned downstream tasks. However, you can also finetune your model from a pretrained checkpoint on other corpora as desired. To do so, simply add the `--finetune` flag and adjust the input files and training parameters within the original training script. The iteration count will be reset to zero, and the optimizer and internal state will be reinitialized. If the fine-tuning is interrupted for any reason, be sure to remove the `--finetune` flag before continuing, otherwise the training will start again from the beginning.
<!-- Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on a single GPU in downstream tasks. The following script accomplishes this. Currently only tensor model parallelism is supported (not pipeline model parallelism).
Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on a single GPU in downstream tasks. The following script accomplishes this.
<pre> <pre>
TENSOR_MODEL_PARALLEL_SIZE=2 TENSOR_MODEL_PARALLEL_SIZE=2
...@@ -390,9 +389,10 @@ WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \ ...@@ -390,9 +389,10 @@ WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
--seq-length 512 \ --seq-length 512 \
--max-position-embeddings 512 \ --max-position-embeddings 512 \
--load $CHECKPOINT_PATH --load $CHECKPOINT_PATH
--save $CHECKPOINT_PATH/merged
</pre> </pre>
-->
Several downstream tasks are described for both GPT and BERT models below. They can be run in distributed and model parallel modes with the same changes used in the training scripts. Several downstream tasks are described for both GPT and BERT models below. They can be run in distributed and model parallel modes with the same changes used in the training scripts.
## GPT Text Generation ## GPT Text Generation
......
...@@ -91,6 +91,20 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -91,6 +91,20 @@ def parse_args(extra_args_provider=None, defaults={},
'longer valid, use --tensor-model-parallel-size instead' 'longer valid, use --tensor-model-parallel-size instead'
del args.model_parallel_size del args.model_parallel_size
# Set input defaults.
for key in defaults:
# For default to be valid, it should not be provided in the
# arguments that are passed to the program. We check this by
# ensuring the arg is set to None.
if getattr(args, key) is not None:
if args.rank == 0:
print('WARNING: overriding default arguments for {key}:{v} \
with {key}:{v2}'.format(key=key, v=defaults[key],
v2=getattr(args, key)),
flush=True)
else:
setattr(args, key, defaults[key])
# Batch size. # Batch size.
assert args.micro_batch_size is not None assert args.micro_batch_size is not None
assert args.micro_batch_size > 0 assert args.micro_batch_size > 0
...@@ -113,20 +127,6 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -113,20 +127,6 @@ def parse_args(extra_args_provider=None, defaults={},
args.consumed_train_samples = 0 args.consumed_train_samples = 0
args.consumed_valid_samples = 0 args.consumed_valid_samples = 0
# Set input defaults.
for key in defaults:
# For default to be valid, it should not be provided in the
# arguments that are passed to the program. We check this by
# ensuring the arg is set to None.
if getattr(args, key) is not None:
if args.rank == 0:
print('WARNING: overriding default arguments for {key}:{v} \
with {key}:{v2}'.format(key=key, v=defaults[key],
v2=getattr(args, key)),
flush=True)
else:
setattr(args, key, defaults[key])
# Iteration-based training. # Iteration-based training.
if args.train_iters: if args.train_iters:
# If we use iteration-based training, make sure the # If we use iteration-based training, make sure the
...@@ -262,6 +262,9 @@ def _add_network_size_args(parser): ...@@ -262,6 +262,9 @@ def _add_network_size_args(parser):
'reasons.') 'reasons.')
group.add_argument('--onnx-safe', type=bool, required=False, group.add_argument('--onnx-safe', type=bool, required=False,
help='Use workarounds for known problems with Torch ONNX exporter') help='Use workarounds for known problems with Torch ONNX exporter')
group.add_argument('--bert-no-binary-head', action='store_false',
help='Disable BERT binary head.',
dest='bert_binary_head')
return parser return parser
...@@ -432,9 +435,9 @@ def _add_checkpointing_args(parser): ...@@ -432,9 +435,9 @@ def _add_checkpointing_args(parser):
help='Do not save current rng state.') help='Do not save current rng state.')
group.add_argument('--load', type=str, default=None, group.add_argument('--load', type=str, default=None,
help='Directory containing a model checkpoint.') help='Directory containing a model checkpoint.')
group.add_argument('--no-load-optim', action='store_true', group.add_argument('--no-load-optim', action='store_true', default=None,
help='Do not load optimizer when loading checkpoint.') help='Do not load optimizer when loading checkpoint.')
group.add_argument('--no-load-rng', action='store_true', group.add_argument('--no-load-rng', action='store_true', default=None,
help='Do not load rng state when loading checkpoint.') help='Do not load rng state when loading checkpoint.')
group.add_argument('--finetune', action='store_true', group.add_argument('--finetune', action='store_true',
help='Load model for finetuning. Do not load optimizer ' help='Load model for finetuning. Do not load optimizer '
...@@ -503,7 +506,7 @@ def _add_distributed_args(parser): ...@@ -503,7 +506,7 @@ def _add_distributed_args(parser):
' and returns function to complete it instead.' ' and returns function to complete it instead.'
'Also turns on --use-cpu-initialization flag.' 'Also turns on --use-cpu-initialization flag.'
'This is for external DDP manager.' ) 'This is for external DDP manager.' )
group.add_argument('--use-cpu-initialization', action='store_true', group.add_argument('--use-cpu-initialization', action='store_true', default=None,
help='If set, affine parallel weights initialization uses CPU' ) help='If set, affine parallel weights initialization uses CPU' )
return parser return parser
......
...@@ -31,8 +31,9 @@ _CHECKPOINT_VERSION = None ...@@ -31,8 +31,9 @@ _CHECKPOINT_VERSION = None
def set_checkpoint_version(value): def set_checkpoint_version(value):
global _CHECKPOINT_VERSION global _CHECKPOINT_VERSION
assert _CHECKPOINT_VERSION is None, \ if _CHECKPOINT_VERSION is not None:
"checkpoint version already set" assert _CHECKPOINT_VERSION == value, \
"checkpoint versions do not match"
_CHECKPOINT_VERSION = value _CHECKPOINT_VERSION = value
def get_checkpoint_version(): def get_checkpoint_version():
...@@ -111,11 +112,10 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -111,11 +112,10 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
if isinstance(model, torchDDP): if isinstance(model, torchDDP):
model = model.module model = model.module
if torch.distributed.get_rank() == 0: print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
print('saving checkpoint at iteration {:7d} to {}'.format( iteration, args.save))
iteration, args.save), flush=True)
if mpu.get_data_parallel_rank() == 0: if not torch.distributed.is_initialized() or mpu.get_data_parallel_rank() == 0:
# Arguments, iteration, and model. # Arguments, iteration, and model.
state_dict = {} state_dict = {}
...@@ -146,16 +146,20 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -146,16 +146,20 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
torch.save(state_dict, checkpoint_name) torch.save(state_dict, checkpoint_name)
# Wait so everyone is done (necessary) # Wait so everyone is done (necessary)
if torch.distributed.is_initialized():
torch.distributed.barrier() torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' successfully saved checkpoint at iteration {:7d} to {}'.format( print_rank_0(' successfully saved checkpoint at iteration {:7d} to {}'.format(
iteration, args.save), flush=True) iteration, args.save))
# And update the latest iteration # And update the latest iteration
if torch.distributed.get_rank() == 0: if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(args.save) tracker_filename = get_checkpoint_tracker_filename(args.save)
with open(tracker_filename, 'w') as f: with open(tracker_filename, 'w') as f:
f.write(str(iteration)) f.write(str(iteration))
# Wait so everyone is done (not necessary) # Wait so everyone is done (not necessary)
if torch.distributed.is_initialized():
torch.distributed.barrier() torch.distributed.barrier()
...@@ -197,9 +201,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): ...@@ -197,9 +201,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
# Checkpoint. # Checkpoint.
checkpoint_name = get_checkpoint_name(load_dir, iteration, release) checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
if torch.distributed.get_rank() == 0: print_rank_0(f' loading checkpoint from {args.load} at iteration {iteration}')
print(' loading checkpoint from {} at iteration {}'.format(
args.load, iteration), flush=True)
# Load the checkpoint. # Load the checkpoint.
try: try:
...@@ -284,10 +286,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): ...@@ -284,10 +286,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
'exiting ...'.format(checkpoint_name)) 'exiting ...'.format(checkpoint_name))
sys.exit() sys.exit()
# Some utilities want to load a checkpoint without distributed being initialized
if torch.distributed.is_initialized():
torch.distributed.barrier() torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(' successfully loaded checkpoint from {} at iteration {}'.format( print_rank_0(f' successfully loaded checkpoint from {args.load} '
args.load, iteration), flush=True) f'at iteration {iteration}')
return iteration return iteration
......
...@@ -36,13 +36,14 @@ class BertDataset(Dataset): ...@@ -36,13 +36,14 @@ class BertDataset(Dataset):
def __init__(self, name, indexed_dataset, data_prefix, def __init__(self, name, indexed_dataset, data_prefix,
num_epochs, max_num_samples, masked_lm_prob, num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed): max_seq_length, short_seq_prob, seed, binary_head):
# Params to store. # Params to store.
self.name = name self.name = name
self.seed = seed self.seed = seed
self.masked_lm_prob = masked_lm_prob self.masked_lm_prob = masked_lm_prob
self.max_seq_length = max_seq_length self.max_seq_length = max_seq_length
self.binary_head = binary_head
# Dataset. # Dataset.
self.indexed_dataset = indexed_dataset self.indexed_dataset = indexed_dataset
...@@ -55,7 +56,8 @@ class BertDataset(Dataset): ...@@ -55,7 +56,8 @@ class BertDataset(Dataset):
self.max_seq_length, self.max_seq_length,
short_seq_prob, short_seq_prob,
self.seed, self.seed,
self.name) self.name,
self.binary_head)
# Vocab stuff. # Vocab stuff.
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
...@@ -81,7 +83,8 @@ class BertDataset(Dataset): ...@@ -81,7 +83,8 @@ class BertDataset(Dataset):
self.vocab_id_to_token_dict, self.vocab_id_to_token_dict,
self.cls_id, self.sep_id, self.cls_id, self.sep_id,
self.mask_id, self.pad_id, self.mask_id, self.pad_id,
self.masked_lm_prob, np_rng) self.masked_lm_prob, np_rng,
self.binary_head)
def get_samples_mapping_(indexed_dataset, def get_samples_mapping_(indexed_dataset,
...@@ -91,7 +94,8 @@ def get_samples_mapping_(indexed_dataset, ...@@ -91,7 +94,8 @@ def get_samples_mapping_(indexed_dataset,
max_seq_length, max_seq_length,
short_seq_prob, short_seq_prob,
seed, seed,
name): name,
binary_head):
if not num_epochs: if not num_epochs:
if not max_num_samples: if not max_num_samples:
raise ValueError("Need to specify either max_num_samples " raise ValueError("Need to specify either max_num_samples "
...@@ -137,7 +141,8 @@ def get_samples_mapping_(indexed_dataset, ...@@ -137,7 +141,8 @@ def get_samples_mapping_(indexed_dataset,
max_seq_length - 3, # account for added tokens max_seq_length - 3, # account for added tokens
short_seq_prob, short_seq_prob,
seed, seed,
verbose) verbose,
2 if binary_head else 1)
print_rank_0(' > done building sapmles index maping') print_rank_0(' > done building sapmles index maping')
np.save(indexmap_filename, samples_mapping, allow_pickle=True) np.save(indexmap_filename, samples_mapping, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format( print_rank_0(' > saved the index mapping in {}'.format(
...@@ -173,7 +178,7 @@ def build_training_sample(sample, ...@@ -173,7 +178,7 @@ def build_training_sample(sample,
target_seq_length, max_seq_length, target_seq_length, max_seq_length,
vocab_id_list, vocab_id_to_token_dict, vocab_id_list, vocab_id_to_token_dict,
cls_id, sep_id, mask_id, pad_id, cls_id, sep_id, mask_id, pad_id,
masked_lm_prob, np_rng): masked_lm_prob, np_rng, binary_head):
"""Biuld training sample. """Biuld training sample.
Arguments: Arguments:
...@@ -193,12 +198,21 @@ def build_training_sample(sample, ...@@ -193,12 +198,21 @@ def build_training_sample(sample,
the opper bound whereas the numpy one is exclusive. the opper bound whereas the numpy one is exclusive.
""" """
if binary_head:
# We assume that we have at least two sentences in the sample # We assume that we have at least two sentences in the sample
assert len(sample) > 1 assert len(sample) > 1
assert target_seq_length <= max_seq_length assert target_seq_length <= max_seq_length
# Divide sample into two segments (A and B). # Divide sample into two segments (A and B).
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, np_rng) if binary_head:
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample,
np_rng)
else:
tokens_a = []
for j in range(len(sample)):
tokens_a.extend(sample[j])
tokens_b = []
is_next_random = False
# Truncate to `target_sequence_length`. # Truncate to `target_sequence_length`.
max_num_tokens = target_seq_length max_num_tokens = target_seq_length
......
...@@ -114,7 +114,6 @@ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng): ...@@ -114,7 +114,6 @@ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
"""Truncates a pair of sequences to a maximum sequence length.""" """Truncates a pair of sequences to a maximum sequence length."""
#print(len_a, len_b, max_num_tokens) #print(len_a, len_b, max_num_tokens)
assert len_a > 0 assert len_a > 0
assert len_b > 0
if len_a + len_b <= max_num_tokens: if len_a + len_b <= max_num_tokens:
return False return False
while len_a + len_b > max_num_tokens: while len_a + len_b > max_num_tokens:
...@@ -150,6 +149,7 @@ def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): ...@@ -150,6 +149,7 @@ def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
for token in tokens_b: for token in tokens_b:
tokens.append(token) tokens.append(token)
tokentypes.append(1) tokentypes.append(1)
if tokens_b:
# [SEP]. # [SEP].
tokens.append(sep_id) tokens.append(sep_id)
tokentypes.append(1) tokentypes.append(1)
...@@ -392,6 +392,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -392,6 +392,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples, train_valid_test_num_samples,
max_seq_length, masked_lm_prob, max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup, short_seq_prob, seed, skip_warmup,
binary_head,
dataset_type='standard_bert'): dataset_type='standard_bert'):
if len(data_prefix) == 1: if len(data_prefix) == 1:
...@@ -401,6 +402,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -401,6 +402,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
max_seq_length, masked_lm_prob, max_seq_length, masked_lm_prob,
short_seq_prob, seed, short_seq_prob, seed,
skip_warmup, skip_warmup,
binary_head,
dataset_type=dataset_type) dataset_type=dataset_type)
# Blending dataset. # Blending dataset.
# Parse the values. # Parse the values.
...@@ -417,7 +419,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -417,7 +419,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
prefixes[i], data_impl, splits_string, prefixes[i], data_impl, splits_string,
datasets_train_valid_test_num_samples[i], datasets_train_valid_test_num_samples[i],
max_seq_length, masked_lm_prob, short_seq_prob, max_seq_length, masked_lm_prob, short_seq_prob,
seed, skip_warmup, dataset_type=dataset_type) seed, skip_warmup, binary_head, dataset_type=dataset_type)
if train_ds: if train_ds:
train_datasets.append(train_ds) train_datasets.append(train_ds)
if valid_ds: if valid_ds:
...@@ -444,6 +446,7 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -444,6 +446,7 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples, train_valid_test_num_samples,
max_seq_length, masked_lm_prob, max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup, short_seq_prob, seed, skip_warmup,
binary_head,
dataset_type='standard_bert'): dataset_type='standard_bert'):
if dataset_type not in DSET_TYPES: if dataset_type not in DSET_TYPES:
...@@ -503,7 +506,8 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -503,7 +506,8 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
num_epochs=None, num_epochs=None,
max_num_samples=train_valid_test_num_samples[index], max_num_samples=train_valid_test_num_samples[index],
max_seq_length=max_seq_length, max_seq_length=max_seq_length,
seed=seed seed=seed,
binary_head=binary_head
) )
if dataset_type == DSET_TYPE_ICT: if dataset_type == DSET_TYPE_ICT:
......
...@@ -189,6 +189,9 @@ inline int32_t get_target_sample_len(const int32_t short_seq_ratio, ...@@ -189,6 +189,9 @@ inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
const int32_t max_length, const int32_t max_length,
std::mt19937& rand32_gen) { std::mt19937& rand32_gen) {
/* Training sample length. */ /* Training sample length. */
if (short_seq_ratio == 0) {
return max_length;
}
const auto random_number = rand32_gen(); const auto random_number = rand32_gen();
if ((random_number % short_seq_ratio) == 0) { if ((random_number % short_seq_ratio) == 0) {
return 2 + random_number % (max_length - 1); return 2 + random_number % (max_length - 1);
...@@ -205,7 +208,8 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -205,7 +208,8 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
const int32_t max_seq_length, const int32_t max_seq_length,
const double short_seq_prob, const double short_seq_prob,
const int32_t seed, const int32_t seed,
const bool verbose) { const bool verbose,
const int32_t min_num_sent) {
/* Build a mapping of (start-index, end-index, sequence-length) where /* Build a mapping of (start-index, end-index, sequence-length) where
start and end index are the indices of the sentences in the sample start and end index are the indices of the sentences in the sample
and sequence-length is the target sequence length. and sequence-length is the target sequence length.
...@@ -214,7 +218,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -214,7 +218,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
// Consistency checks. // Consistency checks.
assert(num_epochs > 0); assert(num_epochs > 0);
assert(max_seq_length > 1); assert(max_seq_length > 1);
assert(short_seq_prob > 0.0); assert(short_seq_prob >= 0.0);
assert(short_seq_prob <= 1.0); assert(short_seq_prob <= 1.0);
assert(seed > 0); assert(seed > 0);
...@@ -223,7 +227,10 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -223,7 +227,10 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
auto sizes = sizes_.unchecked<1>(); auto sizes = sizes_.unchecked<1>();
// For efficiency, convert probability to ratio. Note: rand() generates int. // For efficiency, convert probability to ratio. Note: rand() generates int.
const auto short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob)); int32_t short_seq_ratio = 0;
if (short_seq_prob > 0) {
short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));
}
if (verbose) { if (verbose) {
const auto sent_start_index = docs[0]; const auto sent_start_index = docs[0];
...@@ -322,7 +329,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -322,7 +329,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
} }
// If we have more than two sentences. // If we have more than two sentences.
if ((num_remain_sent > 1) && (!contains_long_sentence)) { if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
// Set values. // Set values.
auto seq_len = int32_t{0}; auto seq_len = int32_t{0};
...@@ -346,7 +353,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -346,7 +353,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
// and if we have reached end of the document. // and if we have reached end of the document.
if (((seq_len >= target_seq_len) && if (((seq_len >= target_seq_len) &&
(num_remain_sent > 1) && (num_remain_sent > 1) &&
(num_sent > 1) ) || (num_remain_sent == 0)) { (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
// Check for overflow. // Check for overflow.
if ((3 * map_index + 2) > if ((3 * map_index + 2) >
...@@ -437,7 +444,8 @@ py::array build_mapping(const py::array_t<int64_t>& docs_, ...@@ -437,7 +444,8 @@ py::array build_mapping(const py::array_t<int64_t>& docs_,
const int max_seq_length, const int max_seq_length,
const double short_seq_prob, const double short_seq_prob,
const int seed, const int seed,
const bool verbose) { const bool verbose,
const int32_t min_num_sent) {
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) { if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
if (verbose) { if (verbose) {
...@@ -445,14 +453,16 @@ py::array build_mapping(const py::array_t<int64_t>& docs_, ...@@ -445,14 +453,16 @@ py::array build_mapping(const py::array_t<int64_t>& docs_,
} }
return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs, return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs,
max_num_samples, max_seq_length, max_num_samples, max_seq_length,
short_seq_prob, seed, verbose); short_seq_prob, seed, verbose,
min_num_sent);
} else { } else {
if (verbose) { if (verbose) {
cout << " using uint32 for data mapping..." << endl << std::flush; cout << " using uint32 for data mapping..." << endl << std::flush;
} }
return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs, return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs,
max_num_samples, max_seq_length, max_num_samples, max_seq_length,
short_seq_prob, seed, verbose); short_seq_prob, seed, verbose,
min_num_sent);
} }
} }
......
...@@ -29,10 +29,6 @@ from megatron.model.utils import init_method_normal ...@@ -29,10 +29,6 @@ from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule from .module import MegatronModule
def bert_attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
def bert_extended_attention_mask(attention_mask): def bert_extended_attention_mask(attention_mask):
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s] # [b, 1, s]
...@@ -78,9 +74,7 @@ class BertLMHead(MegatronModule): ...@@ -78,9 +74,7 @@ class BertLMHead(MegatronModule):
args = get_args() args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
self.bias.tensor_model_parallel = True mpu.set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
self.bias.partition_dim = 0
self.bias.stride = 1
self.parallel_output = parallel_output self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method) self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
...@@ -145,7 +139,6 @@ class BertModelBase(MegatronModule): ...@@ -145,7 +139,6 @@ class BertModelBase(MegatronModule):
args.num_layers) args.num_layers)
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=self.add_binary_head, add_pooler=self.add_binary_head,
encoder_attn_mask_type=AttnMaskType.padding, encoder_attn_mask_type=AttnMaskType.padding,
......
...@@ -20,7 +20,7 @@ import torch ...@@ -20,7 +20,7 @@ import torch
from megatron import get_args, print_rank_last from megatron import get_args, print_rank_last
from megatron import mpu from megatron import mpu
from megatron.model.enums import AttnMaskType from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
...@@ -38,7 +38,6 @@ class ClassificationBase(MegatronModule): ...@@ -38,7 +38,6 @@ class ClassificationBase(MegatronModule):
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=True, add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding, encoder_attn_mask_type=AttnMaskType.padding,
......
...@@ -28,11 +28,6 @@ from .utils import init_method_normal ...@@ -28,11 +28,6 @@ from .utils import init_method_normal
from .utils import scaled_init_method_normal from .utils import scaled_init_method_normal
def gpt_attention_mask_func(attention_scores, ltor_mask):
attention_scores.masked_fill_(ltor_mask, -10000.0)
return attention_scores
def post_language_model_processing(lm_output, labels, logit_weights, def post_language_model_processing(lm_output, labels, logit_weights,
get_key_value, parallel_output, get_key_value, parallel_output,
forward_method_parallel_output, forward_method_parallel_output,
...@@ -73,7 +68,6 @@ class GPTModelBase(MegatronModule): ...@@ -73,7 +68,6 @@ class GPTModelBase(MegatronModule):
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=gpt_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=False, add_pooler=False,
encoder_attn_mask_type=AttnMaskType.causal, encoder_attn_mask_type=AttnMaskType.causal,
......
...@@ -43,7 +43,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -43,7 +43,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
return mpu.gather_from_tensor_model_parallel_region(logits_parallel) return mpu.gather_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(attention_mask_func, num_tokentypes, add_pooler, def get_language_model(num_tokentypes, add_pooler,
encoder_attn_mask_type, init_method=None, encoder_attn_mask_type, init_method=None,
scaled_init_method=None, add_decoder=False, scaled_init_method=None, add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal): decoder_attn_mask_type=AttnMaskType.causal):
...@@ -58,8 +58,7 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler, ...@@ -58,8 +58,7 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
args.num_layers) args.num_layers)
# Language model. # Language model.
args = [attention_mask_func, init_method, args = [init_method, scaled_init_method, encoder_attn_mask_type]
scaled_init_method, encoder_attn_mask_type]
kwargs = {} kwargs = {}
cls = None cls = None
if mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage(): if mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage():
...@@ -269,12 +268,6 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -269,12 +268,6 @@ class TransformerLanguageModelBase(MegatronModule):
Arguments: Arguments:
transformer_hparams: transformer hyperparameters transformer_hparams: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores`
with size [b, np, s, s] and an `attention-mask` and will apply
the masking. The function should return a masked score of the
same size [b, np, s, s].
masked-attention-scores = attention_mask_func(
unmaksed-attention-scores, attention-mask)
vocab_size: vocabulary size vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This max_sequence_length: maximum size of sequence. This
is used for positional embedding is used for positional embedding
...@@ -284,7 +277,6 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -284,7 +277,6 @@ class TransformerLanguageModelBase(MegatronModule):
""" """
def __init__(self, def __init__(self,
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type, encoder_attn_mask_type,
...@@ -315,7 +307,6 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -315,7 +307,6 @@ class TransformerLanguageModelBase(MegatronModule):
# Transformer. # Transformer.
self.encoder = ParallelTransformer( self.encoder = ParallelTransformer(
attention_mask_func,
self.init_method, self.init_method,
output_layer_init_method, output_layer_init_method,
self_attn_mask_type=self.encoder_attn_mask_type) self_attn_mask_type=self.encoder_attn_mask_type)
...@@ -326,7 +317,6 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -326,7 +317,6 @@ class TransformerLanguageModelBase(MegatronModule):
assert args.pipeline_model_parallel_size == 1, \ assert args.pipeline_model_parallel_size == 1, \
'pipeline parallelism is not supported in the presence of decoder' 'pipeline parallelism is not supported in the presence of decoder'
self.decoder = ParallelTransformer( self.decoder = ParallelTransformer(
attention_mask_func,
self.init_method, self.init_method,
output_layer_init_method, output_layer_init_method,
layer_type=LayerType.decoder, layer_type=LayerType.decoder,
...@@ -479,7 +469,6 @@ class TransformerLanguageModel(TransformerLanguageModelBase): ...@@ -479,7 +469,6 @@ class TransformerLanguageModel(TransformerLanguageModelBase):
""" """
def __init__(self, def __init__(self,
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type, encoder_attn_mask_type,
...@@ -488,7 +477,6 @@ class TransformerLanguageModel(TransformerLanguageModelBase): ...@@ -488,7 +477,6 @@ class TransformerLanguageModel(TransformerLanguageModelBase):
add_decoder=False, add_decoder=False,
add_pooler=False): add_pooler=False):
super(TransformerLanguageModel, self).__init__( super(TransformerLanguageModel, self).__init__(
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type, encoder_attn_mask_type,
...@@ -523,13 +511,11 @@ class TransformerLanguageModelFirstStage(TransformerLanguageModelBase): ...@@ -523,13 +511,11 @@ class TransformerLanguageModelFirstStage(TransformerLanguageModelBase):
""" """
def __init__(self, def __init__(self,
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type, encoder_attn_mask_type,
num_tokentypes=0): num_tokentypes=0):
super(TransformerLanguageModelFirstStage, self).__init__( super(TransformerLanguageModelFirstStage, self).__init__(
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type, encoder_attn_mask_type,
...@@ -552,12 +538,10 @@ class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase): ...@@ -552,12 +538,10 @@ class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase):
""" """
def __init__(self, def __init__(self,
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type): encoder_attn_mask_type):
super(TransformerLanguageModelIntermediateStage, self).__init__( super(TransformerLanguageModelIntermediateStage, self).__init__(
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type) encoder_attn_mask_type)
...@@ -578,13 +562,11 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase): ...@@ -578,13 +562,11 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
""" """
def __init__(self, def __init__(self,
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type, encoder_attn_mask_type,
add_pooler=False): add_pooler=False):
super(TransformerLanguageModelLastStage, self).__init__( super(TransformerLanguageModelLastStage, self).__init__(
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type, encoder_attn_mask_type,
......
...@@ -60,6 +60,13 @@ class MegatronModule(torch.nn.Module): ...@@ -60,6 +60,13 @@ class MegatronModule(torch.nn.Module):
if not self.share_word_embeddings: if not self.share_word_embeddings:
raise Exception('initialize_word_embeddings() was called but ' raise Exception('initialize_word_embeddings() was called but '
'share_word_embeddings is false') 'share_word_embeddings is false')
# This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. If we aren't using pipeline
# parallelism there is nothing to do.
if args.pipeline_model_parallel_size == 1:
return
# Parameters are shared between the word embeddings layer, and the # Parameters are shared between the word embeddings layer, and the
# heads at the end of the model. In a pipelined setup with more than # heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different # one stage, the initial embedding layer and the head are on different
...@@ -73,16 +80,16 @@ class MegatronModule(torch.nn.Module): ...@@ -73,16 +80,16 @@ class MegatronModule(torch.nn.Module):
# the two word_embeddings layers to ensure that every applied weight # the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages. # update is the same on both stages.
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
if not mpu.is_pipeline_first_stage(): assert not mpu.is_pipeline_first_stage()
self._word_embeddings_for_head_key = 'word_embeddings_for_head' self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# If first and last stages are different, set word_embeddings # set word_embeddings weights to 0 here, then copy first
# weights to 0 here, then copy first stage's weights using # stage's weights using all_reduce below.
# all_reduce below.
self.word_embeddings = mpu.VocabParallelEmbedding( self.word_embeddings = mpu.VocabParallelEmbedding(
args.padded_vocab_size, args.hidden_size, args.padded_vocab_size, args.hidden_size,
init_method=init_method_normal(args.init_method_std)) init_method=init_method_normal(args.init_method_std))
self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True self.word_embeddings.weight.shared = True
# Ensure that first and last stages have the same initial parameter # Ensure that first and last stages have the same initial parameter
# values. # values.
if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage(): if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
......
...@@ -20,7 +20,7 @@ import torch ...@@ -20,7 +20,7 @@ import torch
from megatron import get_args, print_rank_last from megatron import get_args, print_rank_last
from megatron import mpu from megatron import mpu
from megatron.model.enums import AttnMaskType from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
...@@ -37,7 +37,6 @@ class MultipleChoiceBase(MegatronModule): ...@@ -37,7 +37,6 @@ class MultipleChoiceBase(MegatronModule):
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=True, add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding, encoder_attn_mask_type=AttnMaskType.padding,
......
...@@ -11,7 +11,7 @@ from megatron.model.utils import get_linear_layer ...@@ -11,7 +11,7 @@ from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
def general_ict_model_provider(only_query_model=False, only_block_model=False): def general_ict_model_provider(only_query_model=False, only_block_model=False):
...@@ -157,7 +157,6 @@ class IREncoderBertModel(MegatronModule): ...@@ -157,7 +157,6 @@ class IREncoderBertModel(MegatronModule):
args.num_layers) args.num_layers)
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=True, add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding, encoder_attn_mask_type=AttnMaskType.padding,
......
...@@ -26,7 +26,7 @@ from megatron.model.enums import AttnMaskType, LayerType, AttnType ...@@ -26,7 +26,7 @@ from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model import import_layernorm from megatron.model import import_layernorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import openai_gelu, erf_gelu from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
# flags required to enable jit fusion kernels # flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_mode(False)
...@@ -47,12 +47,6 @@ torch._C._jit_override_can_fuse_on_gpu(True) ...@@ -47,12 +47,6 @@ torch._C._jit_override_can_fuse_on_gpu(True)
Transformer takes input of size [s, b, h] and returns a Transformer takes input of size [s, b, h] and returns a
tensor of the same size. We use the following arguments: tensor of the same size. We use the following arguments:
hyperparameters: transformer hyperparameters hyperparameters: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores`
with size [b, np, s, s] and an `attention-mask` and will apply
the masking. The function should return a masked score of the
same size [b, np, s, s].
masked-attention-scores = attention_mask_func(
unmaksed-attention-scores, attention-mask)
""" """
class ParallelMLP(MegatronModule): class ParallelMLP(MegatronModule):
...@@ -115,7 +109,7 @@ class ParallelAttention(MegatronModule): ...@@ -115,7 +109,7 @@ class ParallelAttention(MegatronModule):
and returns output of the same size. and returns output of the same size.
""" """
def __init__(self, attention_mask_func, init_method, def __init__(self, init_method,
output_layer_init_method, layer_number, output_layer_init_method, layer_number,
attention_type=AttnType.self_attn, attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding): attn_mask_type=AttnMaskType.padding):
...@@ -123,7 +117,6 @@ class ParallelAttention(MegatronModule): ...@@ -123,7 +117,6 @@ class ParallelAttention(MegatronModule):
args = get_args() args = get_args()
self.fp16 = args.fp16 self.fp16 = args.fp16
self.attention_mask_func = attention_mask_func
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling: if self.apply_query_key_layer_scaling:
...@@ -174,7 +167,7 @@ class ParallelAttention(MegatronModule): ...@@ -174,7 +167,7 @@ class ParallelAttention(MegatronModule):
self.fp16, self.fp16,
self.attn_mask_type, self.attn_mask_type,
args.masked_softmax_fusion, args.masked_softmax_fusion,
self.attention_mask_func, attention_mask_func,
self.attention_softmax_in_fp32, self.attention_softmax_in_fp32,
coeff) coeff)
...@@ -440,9 +433,8 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -440,9 +433,8 @@ class ParallelTransformerLayer(MegatronModule):
output of the same size. output of the same size.
""" """
def __init__(self, attention_mask_func, init_method, def __init__(self, init_method, output_layer_init_method,
output_layer_init_method, layer_number, layer_number, layer_type=LayerType.encoder,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding): self_attn_mask_type=AttnMaskType.padding):
args = get_args() args = get_args()
...@@ -461,7 +453,6 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -461,7 +453,6 @@ class ParallelTransformerLayer(MegatronModule):
# Self attention. # Self attention.
self.self_attention = ParallelAttention( self.self_attention = ParallelAttention(
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
layer_number, layer_number,
...@@ -477,7 +468,6 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -477,7 +468,6 @@ class ParallelTransformerLayer(MegatronModule):
if self.layer_type == LayerType.decoder: if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention( self.inter_attention = ParallelAttention(
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
layer_number, layer_number,
...@@ -585,8 +575,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -585,8 +575,7 @@ class ParallelTransformerLayer(MegatronModule):
class ParallelTransformer(MegatronModule): class ParallelTransformer(MegatronModule):
"""Transformer class.""" """Transformer class."""
def __init__(self, attention_mask_func, def __init__(self, init_method, output_layer_init_method,
init_method, output_layer_init_method,
layer_type=LayerType.encoder, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding): self_attn_mask_type=AttnMaskType.padding):
super(ParallelTransformer, self).__init__() super(ParallelTransformer, self).__init__()
...@@ -606,8 +595,9 @@ class ParallelTransformer(MegatronModule): ...@@ -606,8 +595,9 @@ class ParallelTransformer(MegatronModule):
# Transformer layers. # Transformer layers.
def build_layer(layer_number): def build_layer(layer_number):
return ParallelTransformerLayer( return ParallelTransformerLayer(
attention_mask_func, init_method, init_method,
output_layer_init_method, layer_number, output_layer_init_method,
layer_number,
layer_type=layer_type, layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type) self_attn_mask_type=self_attn_mask_type)
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
......
...@@ -39,6 +39,11 @@ def scaled_init_method_normal(sigma, num_layers): ...@@ -39,6 +39,11 @@ def scaled_init_method_normal(sigma, num_layers):
return init_ return init_
def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
def get_linear_layer(rows, columns, init_method): def get_linear_layer(rows, columns, init_method):
"""Simple linear layer with weight initialization.""" """Simple linear layer with weight initialization."""
layer = torch.nn.Linear(rows, columns) layer = torch.nn.Linear(rows, columns)
......
...@@ -44,7 +44,8 @@ from .initialize import model_parallel_is_initialized ...@@ -44,7 +44,8 @@ from .initialize import model_parallel_is_initialized
from .layers import ColumnParallelLinear from .layers import ColumnParallelLinear
from .layers import RowParallelLinear from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding from .layers import VocabParallelEmbedding
from .layers import (set_defaults_if_not_set_tensor_model_parallel_attributes, from .layers import (set_tensor_model_parallel_attributes,
set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes) copy_tensor_model_parallel_attributes)
from .mappings import copy_to_tensor_model_parallel_region from .mappings import copy_to_tensor_model_parallel_region
......
...@@ -109,7 +109,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size, ...@@ -109,7 +109,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
per_partition_per_stride_size = divide(per_partition_size, stride) per_partition_per_stride_size = divide(per_partition_size, stride)
weight_list = torch.split(master_weight, per_partition_per_stride_size, weight_list = torch.split(master_weight, per_partition_per_stride_size,
dim=partition_dim) dim=partition_dim)
rank = get_model_parallel_rank() rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
my_weight_list = weight_list[rank::world_size] my_weight_list = weight_list[rank::world_size]
...@@ -260,9 +260,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -260,9 +260,7 @@ class ColumnParallelLinear(torch.nn.Module):
self.output_size_per_partition, self.output_size_per_partition,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=args.params_dtype)) dtype=args.params_dtype))
self.bias.tensor_model_parallel = True set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
self.bias.partition_dim = 0
self.bias.stride = stride
# Always initialize bias to zero. # Always initialize bias to zero.
with torch.no_grad(): with torch.no_grad():
self.bias.zero_() self.bias.zero_()
......
...@@ -23,7 +23,10 @@ from megatron import print_rank_0 ...@@ -23,7 +23,10 @@ from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import BertModel, BertModelFirstStage, BertModelIntermediateStage, BertModelLastStage from megatron.model import (BertModel,
BertModelFirstStage,
BertModelIntermediateStage,
BertModelLastStage)
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
...@@ -34,23 +37,24 @@ def model_provider(): ...@@ -34,23 +37,24 @@ def model_provider():
print_rank_0('building BERT model ...') print_rank_0('building BERT model ...')
args = get_args() args = get_args()
num_tokentypes = 2 if args.bert_binary_head else 0
if mpu.get_pipeline_model_parallel_world_size() > 1: if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline. # Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
model = BertModelFirstStage( model = BertModelFirstStage(
num_tokentypes=2) num_tokentypes=num_tokentypes)
elif mpu.is_pipeline_last_stage(): elif mpu.is_pipeline_last_stage():
model = BertModelLastStage( model = BertModelLastStage(
num_tokentypes=2, num_tokentypes=num_tokentypes,
add_binary_head=True, add_binary_head=args.bert_binary_head,
parallel_output=True) parallel_output=True)
else: else:
model = BertModelIntermediateStage( model = BertModelIntermediateStage(
num_tokentypes=2) num_tokentypes=num_tokentypes)
else: else:
model = BertModel( model = BertModel(
num_tokentypes=2, num_tokentypes=num_tokentypes,
add_binary_head=True, add_binary_head=args.bert_binary_head,
parallel_output=True) parallel_output=True)
return model return model
...@@ -92,6 +96,9 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -92,6 +96,9 @@ def forward_step(data_iterator, model, input_tensor):
= get_batch(data_iterator) = get_batch(data_iterator)
timers('batch-generator').stop() timers('batch-generator').stop()
if not args.bert_binary_head:
types = None
# Forward pass through the model. # Forward pass through the model.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
assert input_tensor is None assert input_tensor is None
...@@ -110,21 +117,28 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -110,21 +117,28 @@ def forward_step(data_iterator, model, input_tensor):
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
lm_loss_, sop_logits = output_tensor lm_loss_, sop_logits = output_tensor
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1),
ignore_index=-1)
sop_loss = sop_loss.float()
lm_loss_ = lm_loss_.float() lm_loss_ = lm_loss_.float()
loss_mask = loss_mask.float() loss_mask = loss_mask.float()
lm_loss = torch.sum( lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
if sop_logits is not None:
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1),
ignore_index=-1)
sop_loss = sop_loss.float()
loss = lm_loss + sop_loss loss = lm_loss + sop_loss
averaged_losses = average_losses_across_data_parallel_group(
[lm_loss, sop_loss])
return loss, {'lm loss': averaged_losses[0],
'sop loss': averaged_losses[1]}
averaged_losses = average_losses_across_data_parallel_group([lm_loss, sop_loss]) else:
loss = lm_loss
averaged_losses = average_losses_across_data_parallel_group(
[lm_loss])
return loss, {'lm loss': averaged_losses[0]}
return loss, {'lm loss': averaged_losses[0], 'sop loss': averaged_losses[1]}
return output_tensor return output_tensor
...@@ -143,7 +157,8 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -143,7 +157,8 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
masked_lm_prob=args.mask_prob, masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob, short_seq_prob=args.short_seq_prob,
seed=args.seed, seed=args.seed,
skip_warmup=(not args.mmap_warmup)) skip_warmup=(not args.mmap_warmup),
binary_head=args.bert_binary_head)
print_rank_0("> finished creating BERT datasets ...") print_rank_0("> finished creating BERT datasets ...")
return train_ds, valid_ds, test_ds return train_ds, valid_ds, test_ds
......
...@@ -23,11 +23,13 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ...@@ -23,11 +23,13 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
import torch import torch
from megatron import mpu from megatron import mpu
from megatron.checkpointing import load_checkpoint, save_checkpoint
from megatron.checkpointing import ensure_directory_exists from megatron.checkpointing import ensure_directory_exists
from megatron.checkpointing import get_checkpoint_name from megatron.checkpointing import get_checkpoint_name
from megatron.checkpointing import get_checkpoint_version
from megatron.checkpointing import get_checkpoint_tracker_filename from megatron.checkpointing import get_checkpoint_tracker_filename
from megatron.global_vars import set_global_variables, get_args
from megatron.global_vars import rebuild_tokenizer from megatron.global_vars import rebuild_tokenizer
from megatron.global_vars import _parse_args
def split_into_partitions(tensor, num_partitions, partition_dim, stride): def split_into_partitions(tensor, num_partitions, partition_dim, stride):
...@@ -185,8 +187,23 @@ def get_mp_merge_args(parser): ...@@ -185,8 +187,23 @@ def get_mp_merge_args(parser):
def main(): def main():
# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes
os.environ["WORLD_SIZE"] = f'{2**31}'
# Args # Args
args = _parse_args(extra_args_provider=get_mp_merge_args) set_global_variables(extra_args_provider=get_mp_merge_args,
args_defaults = {'use_cpu_initialization': True,
'micro_batch_size': 1,
'no_load_optim': True,
'no_load_rng': True,
'save_interval': 1})
args = get_args()
if args.pipeline_model_parallel_size > 1:
print("Checkpoints with pipeline model parallelism are not currently supported.")
exit()
model_type = args.model_type model_type = args.model_type
orig_tensor_model_parallel_size = args.tensor_model_parallel_size orig_tensor_model_parallel_size = args.tensor_model_parallel_size
args.tensor_model_parallel_size = 1 args.tensor_model_parallel_size = 1
...@@ -209,6 +226,8 @@ def main(): ...@@ -209,6 +226,8 @@ def main():
print('> building the full model ...') print('> building the full model ...')
mpu.initialize.set_tensor_model_parallel_world_size(1) mpu.initialize.set_tensor_model_parallel_world_size(1)
mpu.initialize.set_tensor_model_parallel_rank(0) mpu.initialize.set_tensor_model_parallel_rank(0)
mpu.initialize.set_pipeline_model_parallel_world_size(1)
mpu.initialize.set_pipeline_model_parallel_rank(0)
merged_model = get_model(model_type) merged_model = get_model(model_type)
# Build and load partitions. # Build and load partitions.
...@@ -220,13 +239,16 @@ def main(): ...@@ -220,13 +239,16 @@ def main():
for rank in range(args.tensor_model_parallel_size): for rank in range(args.tensor_model_parallel_size):
mpu.initialize.set_tensor_model_parallel_rank(rank) mpu.initialize.set_tensor_model_parallel_rank(rank)
checkpoint_name, iteration = get_parallel_checkpoint_name(args.load) checkpoint_name, iteration = get_parallel_checkpoint_name(args.load)
print('> loading {} ...'.format(checkpoint_name))
model_ = get_model(model_type) model_ = get_model(model_type)
sd = torch.load(checkpoint_name, map_location='cpu') print(f'> loading {checkpoint_name} ...')
model_.load_state_dict(sd['model']) load_checkpoint(model_, None, None)
print(f'> checkpoint version {get_checkpoint_version()}')
if get_checkpoint_version() < 2.0:
# Need to deal with the qkv matrix order of old versions
print("Checkpoints less than version 2.0 are not currently supported.")
exit()
partitions.append(model_) partitions.append(model_)
# Parameter generators so we can loop through them semiltaneouly. # Parameter generators so we can loop through them semiltaneouly.
merged_params_gen = merged_model.named_parameters() merged_params_gen = merged_model.named_parameters()
partitions_params_gen = [partition.named_parameters() partitions_params_gen = [partition.named_parameters()
...@@ -254,29 +276,26 @@ def main(): ...@@ -254,29 +276,26 @@ def main():
merged_param.data.copy_(partitions_param[0].data) merged_param.data.copy_(partitions_param[0].data)
# For parallel parameters, merge the values # For parallel parameters, merge the values
else: else:
print(' parallel parameter merge with stride {} along ' dim = merged_param.partition_dim
'dimention {}'.format(merged_param.stride, stride = merged_param.partition_stride
merged_param.partition_dim)) print(f' parallel parameter merge with stride {stride} along '
f'dimention {dim}')
merge_partitions(merged_param, merge_partitions(merged_param,
partitions_param, partitions_param,
merged_param.partition_dim, dim,
merged_param.stride) stride)
except StopIteration: except StopIteration:
break break
# Save the model. # Save the model.
args.tensor_model_parallel_size = 1 args.tensor_model_parallel_size = 1
args.pipeline_model_parallel_size = 1
# And now one last time so proper arguments are set in saved checkpoint
tokenizer = rebuild_tokenizer(args)
mpu.initialize.set_tensor_model_parallel_rank(0) mpu.initialize.set_tensor_model_parallel_rank(0)
sd = {} print('> saving merged model')
sd['model'] = merged_model.state_dict_for_save_checkpoint() save_checkpoint(iteration, merged_model, None, None)
sd['iteration'] = iteration
merged_path = os.path.join(args.load, 'merged')
checkpoint_name = get_checkpoint_name(merged_path, iteration)
ensure_directory_exists(checkpoint_name)
print('> saving merged model to {}'.format(checkpoint_name))
torch.save(sd, checkpoint_name)
print('done :-)') print('done :-)')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment