You need to sign in or sign up before continuing.
Commit da8dccd6 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

Merge branch 'main' into vision_transformer

parents a7169297 c0400306
...@@ -370,8 +370,7 @@ python tools/create_doc_index.py \ ...@@ -370,8 +370,7 @@ python tools/create_doc_index.py \
We provide several command line arguments, detailed in the scripts listed below, to handle various zero-shot and fine-tuned downstream tasks. However, you can also finetune your model from a pretrained checkpoint on other corpora as desired. To do so, simply add the `--finetune` flag and adjust the input files and training parameters within the original training script. The iteration count will be reset to zero, and the optimizer and internal state will be reinitialized. If the fine-tuning is interrupted for any reason, be sure to remove the `--finetune` flag before continuing, otherwise the training will start again from the beginning. We provide several command line arguments, detailed in the scripts listed below, to handle various zero-shot and fine-tuned downstream tasks. However, you can also finetune your model from a pretrained checkpoint on other corpora as desired. To do so, simply add the `--finetune` flag and adjust the input files and training parameters within the original training script. The iteration count will be reset to zero, and the optimizer and internal state will be reinitialized. If the fine-tuning is interrupted for any reason, be sure to remove the `--finetune` flag before continuing, otherwise the training will start again from the beginning.
<!-- Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on a single GPU in downstream tasks. The following script accomplishes this. Currently only tensor model parallelism is supported (not pipeline model parallelism).
Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on a single GPU in downstream tasks. The following script accomplishes this.
<pre> <pre>
TENSOR_MODEL_PARALLEL_SIZE=2 TENSOR_MODEL_PARALLEL_SIZE=2
...@@ -390,9 +389,10 @@ WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \ ...@@ -390,9 +389,10 @@ WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \
--seq-length 512 \ --seq-length 512 \
--max-position-embeddings 512 \ --max-position-embeddings 512 \
--load $CHECKPOINT_PATH --load $CHECKPOINT_PATH
--save $CHECKPOINT_PATH/merged
</pre> </pre>
-->
Several downstream tasks are described for both GPT and BERT models below. They can be run in distributed and model parallel modes with the same changes used in the training scripts. Several downstream tasks are described for both GPT and BERT models below. They can be run in distributed and model parallel modes with the same changes used in the training scripts.
## GPT Text Generation ## GPT Text Generation
......
...@@ -92,6 +92,20 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -92,6 +92,20 @@ def parse_args(extra_args_provider=None, defaults={},
'longer valid, use --tensor-model-parallel-size instead' 'longer valid, use --tensor-model-parallel-size instead'
del args.model_parallel_size del args.model_parallel_size
# Set input defaults.
for key in defaults:
# For default to be valid, it should not be provided in the
# arguments that are passed to the program. We check this by
# ensuring the arg is set to None.
if getattr(args, key) is not None:
if args.rank == 0:
print('WARNING: overriding default arguments for {key}:{v} \
with {key}:{v2}'.format(key=key, v=defaults[key],
v2=getattr(args, key)),
flush=True)
else:
setattr(args, key, defaults[key])
# Batch size. # Batch size.
assert args.micro_batch_size is not None assert args.micro_batch_size is not None
assert args.micro_batch_size > 0 assert args.micro_batch_size > 0
...@@ -114,20 +128,6 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -114,20 +128,6 @@ def parse_args(extra_args_provider=None, defaults={},
args.consumed_train_samples = 0 args.consumed_train_samples = 0
args.consumed_valid_samples = 0 args.consumed_valid_samples = 0
# Set input defaults.
for key in defaults:
# For default to be valid, it should not be provided in the
# arguments that are passed to the program. We check this by
# ensuring the arg is set to None.
if getattr(args, key) is not None:
if args.rank == 0:
print('WARNING: overriding default arguments for {key}:{v} \
with {key}:{v2}'.format(key=key, v=defaults[key],
v2=getattr(args, key)),
flush=True)
else:
setattr(args, key, defaults[key])
# Iteration-based training. # Iteration-based training.
if args.train_iters: if args.train_iters:
# If we use iteration-based training, make sure the # If we use iteration-based training, make sure the
...@@ -263,6 +263,9 @@ def _add_network_size_args(parser): ...@@ -263,6 +263,9 @@ def _add_network_size_args(parser):
'reasons.') 'reasons.')
group.add_argument('--onnx-safe', type=bool, required=False, group.add_argument('--onnx-safe', type=bool, required=False,
help='Use workarounds for known problems with Torch ONNX exporter') help='Use workarounds for known problems with Torch ONNX exporter')
group.add_argument('--bert-no-binary-head', action='store_false',
help='Disable BERT binary head.',
dest='bert_binary_head')
return parser return parser
...@@ -442,9 +445,9 @@ def _add_checkpointing_args(parser): ...@@ -442,9 +445,9 @@ def _add_checkpointing_args(parser):
help='Do not save current rng state.') help='Do not save current rng state.')
group.add_argument('--load', type=str, default=None, group.add_argument('--load', type=str, default=None,
help='Directory containing a model checkpoint.') help='Directory containing a model checkpoint.')
group.add_argument('--no-load-optim', action='store_true', group.add_argument('--no-load-optim', action='store_true', default=None,
help='Do not load optimizer when loading checkpoint.') help='Do not load optimizer when loading checkpoint.')
group.add_argument('--no-load-rng', action='store_true', group.add_argument('--no-load-rng', action='store_true', default=None,
help='Do not load rng state when loading checkpoint.') help='Do not load rng state when loading checkpoint.')
group.add_argument('--finetune', action='store_true', group.add_argument('--finetune', action='store_true',
help='Load model for finetuning. Do not load optimizer ' help='Load model for finetuning. Do not load optimizer '
...@@ -513,7 +516,7 @@ def _add_distributed_args(parser): ...@@ -513,7 +516,7 @@ def _add_distributed_args(parser):
' and returns function to complete it instead.' ' and returns function to complete it instead.'
'Also turns on --use-cpu-initialization flag.' 'Also turns on --use-cpu-initialization flag.'
'This is for external DDP manager.' ) 'This is for external DDP manager.' )
group.add_argument('--use-cpu-initialization', action='store_true', group.add_argument('--use-cpu-initialization', action='store_true', default=None,
help='If set, affine parallel weights initialization uses CPU' ) help='If set, affine parallel weights initialization uses CPU' )
return parser return parser
......
...@@ -31,8 +31,9 @@ _CHECKPOINT_VERSION = None ...@@ -31,8 +31,9 @@ _CHECKPOINT_VERSION = None
def set_checkpoint_version(value): def set_checkpoint_version(value):
global _CHECKPOINT_VERSION global _CHECKPOINT_VERSION
assert _CHECKPOINT_VERSION is None, \ if _CHECKPOINT_VERSION is not None:
"checkpoint version already set" assert _CHECKPOINT_VERSION == value, \
"checkpoint versions do not match"
_CHECKPOINT_VERSION = value _CHECKPOINT_VERSION = value
def get_checkpoint_version(): def get_checkpoint_version():
...@@ -112,11 +113,10 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -112,11 +113,10 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
if isinstance(model, torchDDP): if isinstance(model, torchDDP):
model = model.module model = model.module
if torch.distributed.get_rank() == 0: print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
print('saving checkpoint at iteration {:7d} to {}'.format( iteration, args.save))
iteration, args.save), flush=True)
if mpu.get_data_parallel_rank() == 0: if not torch.distributed.is_initialized() or mpu.get_data_parallel_rank() == 0:
# Arguments, iteration, and model. # Arguments, iteration, and model.
state_dict = {} state_dict = {}
...@@ -147,17 +147,21 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -147,17 +147,21 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
torch.save(state_dict, checkpoint_name) torch.save(state_dict, checkpoint_name)
# Wait so everyone is done (necessary) # Wait so everyone is done (necessary)
torch.distributed.barrier() if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0: torch.distributed.barrier()
print(' successfully saved checkpoint at iteration {:7d} to {}'.format(
iteration, args.save), flush=True) print_rank_0(' successfully saved checkpoint at iteration {:7d} to {}'.format(
iteration, args.save))
# And update the latest iteration # And update the latest iteration
if torch.distributed.get_rank() == 0: if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(args.save) tracker_filename = get_checkpoint_tracker_filename(args.save)
with open(tracker_filename, 'w') as f: with open(tracker_filename, 'w') as f:
f.write(str(iteration)) f.write(str(iteration))
# Wait so everyone is done (not necessary) # Wait so everyone is done (not necessary)
torch.distributed.barrier() if torch.distributed.is_initialized():
torch.distributed.barrier()
def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True): def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True):
...@@ -198,9 +202,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True ...@@ -198,9 +202,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
# Checkpoint. # Checkpoint.
checkpoint_name = get_checkpoint_name(load_dir, iteration, release) checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
if torch.distributed.get_rank() == 0: print_rank_0(f' loading checkpoint from {args.load} at iteration {iteration}')
print(' loading checkpoint from {} at iteration {}'.format(
args.load, iteration), flush=True)
# Load the checkpoint. # Load the checkpoint.
try: try:
...@@ -285,10 +287,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True ...@@ -285,10 +287,12 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
'exiting ...'.format(checkpoint_name)) 'exiting ...'.format(checkpoint_name))
sys.exit() sys.exit()
torch.distributed.barrier() # Some utilities want to load a checkpoint without distributed being initialized
if torch.distributed.get_rank() == 0: if torch.distributed.is_initialized():
print(' successfully loaded checkpoint from {} at iteration {}'.format( torch.distributed.barrier()
args.load, iteration), flush=True)
print_rank_0(f' successfully loaded checkpoint from {args.load} '
f'at iteration {iteration}')
return iteration return iteration
......
...@@ -36,13 +36,14 @@ class BertDataset(Dataset): ...@@ -36,13 +36,14 @@ class BertDataset(Dataset):
def __init__(self, name, indexed_dataset, data_prefix, def __init__(self, name, indexed_dataset, data_prefix,
num_epochs, max_num_samples, masked_lm_prob, num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed): max_seq_length, short_seq_prob, seed, binary_head):
# Params to store. # Params to store.
self.name = name self.name = name
self.seed = seed self.seed = seed
self.masked_lm_prob = masked_lm_prob self.masked_lm_prob = masked_lm_prob
self.max_seq_length = max_seq_length self.max_seq_length = max_seq_length
self.binary_head = binary_head
# Dataset. # Dataset.
self.indexed_dataset = indexed_dataset self.indexed_dataset = indexed_dataset
...@@ -55,7 +56,8 @@ class BertDataset(Dataset): ...@@ -55,7 +56,8 @@ class BertDataset(Dataset):
self.max_seq_length, self.max_seq_length,
short_seq_prob, short_seq_prob,
self.seed, self.seed,
self.name) self.name,
self.binary_head)
# Vocab stuff. # Vocab stuff.
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
...@@ -81,7 +83,8 @@ class BertDataset(Dataset): ...@@ -81,7 +83,8 @@ class BertDataset(Dataset):
self.vocab_id_to_token_dict, self.vocab_id_to_token_dict,
self.cls_id, self.sep_id, self.cls_id, self.sep_id,
self.mask_id, self.pad_id, self.mask_id, self.pad_id,
self.masked_lm_prob, np_rng) self.masked_lm_prob, np_rng,
self.binary_head)
def get_samples_mapping_(indexed_dataset, def get_samples_mapping_(indexed_dataset,
...@@ -91,7 +94,8 @@ def get_samples_mapping_(indexed_dataset, ...@@ -91,7 +94,8 @@ def get_samples_mapping_(indexed_dataset,
max_seq_length, max_seq_length,
short_seq_prob, short_seq_prob,
seed, seed,
name): name,
binary_head):
if not num_epochs: if not num_epochs:
if not max_num_samples: if not max_num_samples:
raise ValueError("Need to specify either max_num_samples " raise ValueError("Need to specify either max_num_samples "
...@@ -137,7 +141,8 @@ def get_samples_mapping_(indexed_dataset, ...@@ -137,7 +141,8 @@ def get_samples_mapping_(indexed_dataset,
max_seq_length - 3, # account for added tokens max_seq_length - 3, # account for added tokens
short_seq_prob, short_seq_prob,
seed, seed,
verbose) verbose,
2 if binary_head else 1)
print_rank_0(' > done building sapmles index maping') print_rank_0(' > done building sapmles index maping')
np.save(indexmap_filename, samples_mapping, allow_pickle=True) np.save(indexmap_filename, samples_mapping, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format( print_rank_0(' > saved the index mapping in {}'.format(
...@@ -173,7 +178,7 @@ def build_training_sample(sample, ...@@ -173,7 +178,7 @@ def build_training_sample(sample,
target_seq_length, max_seq_length, target_seq_length, max_seq_length,
vocab_id_list, vocab_id_to_token_dict, vocab_id_list, vocab_id_to_token_dict,
cls_id, sep_id, mask_id, pad_id, cls_id, sep_id, mask_id, pad_id,
masked_lm_prob, np_rng): masked_lm_prob, np_rng, binary_head):
"""Biuld training sample. """Biuld training sample.
Arguments: Arguments:
...@@ -193,12 +198,21 @@ def build_training_sample(sample, ...@@ -193,12 +198,21 @@ def build_training_sample(sample,
the opper bound whereas the numpy one is exclusive. the opper bound whereas the numpy one is exclusive.
""" """
# We assume that we have at least two sentences in the sample if binary_head:
assert len(sample) > 1 # We assume that we have at least two sentences in the sample
assert len(sample) > 1
assert target_seq_length <= max_seq_length assert target_seq_length <= max_seq_length
# Divide sample into two segments (A and B). # Divide sample into two segments (A and B).
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample, np_rng) if binary_head:
tokens_a, tokens_b, is_next_random = get_a_and_b_segments(sample,
np_rng)
else:
tokens_a = []
for j in range(len(sample)):
tokens_a.extend(sample[j])
tokens_b = []
is_next_random = False
# Truncate to `target_sequence_length`. # Truncate to `target_sequence_length`.
max_num_tokens = target_seq_length max_num_tokens = target_seq_length
......
...@@ -114,7 +114,6 @@ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng): ...@@ -114,7 +114,6 @@ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
"""Truncates a pair of sequences to a maximum sequence length.""" """Truncates a pair of sequences to a maximum sequence length."""
#print(len_a, len_b, max_num_tokens) #print(len_a, len_b, max_num_tokens)
assert len_a > 0 assert len_a > 0
assert len_b > 0
if len_a + len_b <= max_num_tokens: if len_a + len_b <= max_num_tokens:
return False return False
while len_a + len_b > max_num_tokens: while len_a + len_b > max_num_tokens:
...@@ -150,10 +149,11 @@ def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): ...@@ -150,10 +149,11 @@ def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
for token in tokens_b: for token in tokens_b:
tokens.append(token) tokens.append(token)
tokentypes.append(1) tokentypes.append(1)
# [SEP]. if tokens_b:
tokens.append(sep_id) # [SEP].
tokentypes.append(1) tokens.append(sep_id)
tokentypes.append(1)
return tokens, tokentypes return tokens, tokentypes
...@@ -392,6 +392,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -392,6 +392,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples, train_valid_test_num_samples,
max_seq_length, masked_lm_prob, max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup, short_seq_prob, seed, skip_warmup,
binary_head,
dataset_type='standard_bert'): dataset_type='standard_bert'):
if len(data_prefix) == 1: if len(data_prefix) == 1:
...@@ -401,6 +402,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -401,6 +402,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
max_seq_length, masked_lm_prob, max_seq_length, masked_lm_prob,
short_seq_prob, seed, short_seq_prob, seed,
skip_warmup, skip_warmup,
binary_head,
dataset_type=dataset_type) dataset_type=dataset_type)
# Blending dataset. # Blending dataset.
# Parse the values. # Parse the values.
...@@ -417,7 +419,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -417,7 +419,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
prefixes[i], data_impl, splits_string, prefixes[i], data_impl, splits_string,
datasets_train_valid_test_num_samples[i], datasets_train_valid_test_num_samples[i],
max_seq_length, masked_lm_prob, short_seq_prob, max_seq_length, masked_lm_prob, short_seq_prob,
seed, skip_warmup, dataset_type=dataset_type) seed, skip_warmup, binary_head, dataset_type=dataset_type)
if train_ds: if train_ds:
train_datasets.append(train_ds) train_datasets.append(train_ds)
if valid_ds: if valid_ds:
...@@ -444,6 +446,7 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -444,6 +446,7 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples, train_valid_test_num_samples,
max_seq_length, masked_lm_prob, max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup, short_seq_prob, seed, skip_warmup,
binary_head,
dataset_type='standard_bert'): dataset_type='standard_bert'):
if dataset_type not in DSET_TYPES: if dataset_type not in DSET_TYPES:
...@@ -503,7 +506,8 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -503,7 +506,8 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
num_epochs=None, num_epochs=None,
max_num_samples=train_valid_test_num_samples[index], max_num_samples=train_valid_test_num_samples[index],
max_seq_length=max_seq_length, max_seq_length=max_seq_length,
seed=seed seed=seed,
binary_head=binary_head
) )
if dataset_type == DSET_TYPE_ICT: if dataset_type == DSET_TYPE_ICT:
......
...@@ -189,6 +189,9 @@ inline int32_t get_target_sample_len(const int32_t short_seq_ratio, ...@@ -189,6 +189,9 @@ inline int32_t get_target_sample_len(const int32_t short_seq_ratio,
const int32_t max_length, const int32_t max_length,
std::mt19937& rand32_gen) { std::mt19937& rand32_gen) {
/* Training sample length. */ /* Training sample length. */
if (short_seq_ratio == 0) {
return max_length;
}
const auto random_number = rand32_gen(); const auto random_number = rand32_gen();
if ((random_number % short_seq_ratio) == 0) { if ((random_number % short_seq_ratio) == 0) {
return 2 + random_number % (max_length - 1); return 2 + random_number % (max_length - 1);
...@@ -205,7 +208,8 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -205,7 +208,8 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
const int32_t max_seq_length, const int32_t max_seq_length,
const double short_seq_prob, const double short_seq_prob,
const int32_t seed, const int32_t seed,
const bool verbose) { const bool verbose,
const int32_t min_num_sent) {
/* Build a mapping of (start-index, end-index, sequence-length) where /* Build a mapping of (start-index, end-index, sequence-length) where
start and end index are the indices of the sentences in the sample start and end index are the indices of the sentences in the sample
and sequence-length is the target sequence length. and sequence-length is the target sequence length.
...@@ -214,7 +218,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -214,7 +218,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
// Consistency checks. // Consistency checks.
assert(num_epochs > 0); assert(num_epochs > 0);
assert(max_seq_length > 1); assert(max_seq_length > 1);
assert(short_seq_prob > 0.0); assert(short_seq_prob >= 0.0);
assert(short_seq_prob <= 1.0); assert(short_seq_prob <= 1.0);
assert(seed > 0); assert(seed > 0);
...@@ -223,7 +227,10 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -223,7 +227,10 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
auto sizes = sizes_.unchecked<1>(); auto sizes = sizes_.unchecked<1>();
// For efficiency, convert probability to ratio. Note: rand() generates int. // For efficiency, convert probability to ratio. Note: rand() generates int.
const auto short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob)); int32_t short_seq_ratio = 0;
if (short_seq_prob > 0) {
short_seq_ratio = static_cast<int32_t>(round(1.0 / short_seq_prob));
}
if (verbose) { if (verbose) {
const auto sent_start_index = docs[0]; const auto sent_start_index = docs[0];
...@@ -322,7 +329,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -322,7 +329,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
} }
// If we have more than two sentences. // If we have more than two sentences.
if ((num_remain_sent > 1) && (!contains_long_sentence)) { if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
// Set values. // Set values.
auto seq_len = int32_t{0}; auto seq_len = int32_t{0};
...@@ -346,7 +353,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -346,7 +353,7 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
// and if we have reached end of the document. // and if we have reached end of the document.
if (((seq_len >= target_seq_len) && if (((seq_len >= target_seq_len) &&
(num_remain_sent > 1) && (num_remain_sent > 1) &&
(num_sent > 1) ) || (num_remain_sent == 0)) { (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
// Check for overflow. // Check for overflow.
if ((3 * map_index + 2) > if ((3 * map_index + 2) >
...@@ -437,7 +444,8 @@ py::array build_mapping(const py::array_t<int64_t>& docs_, ...@@ -437,7 +444,8 @@ py::array build_mapping(const py::array_t<int64_t>& docs_,
const int max_seq_length, const int max_seq_length,
const double short_seq_prob, const double short_seq_prob,
const int seed, const int seed,
const bool verbose) { const bool verbose,
const int32_t min_num_sent) {
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) { if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
if (verbose) { if (verbose) {
...@@ -445,14 +453,16 @@ py::array build_mapping(const py::array_t<int64_t>& docs_, ...@@ -445,14 +453,16 @@ py::array build_mapping(const py::array_t<int64_t>& docs_,
} }
return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs, return build_mapping_impl<uint64_t>(docs_, sizes_, num_epochs,
max_num_samples, max_seq_length, max_num_samples, max_seq_length,
short_seq_prob, seed, verbose); short_seq_prob, seed, verbose,
min_num_sent);
} else { } else {
if (verbose) { if (verbose) {
cout << " using uint32 for data mapping..." << endl << std::flush; cout << " using uint32 for data mapping..." << endl << std::flush;
} }
return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs, return build_mapping_impl<uint32_t>(docs_, sizes_, num_epochs,
max_num_samples, max_seq_length, max_num_samples, max_seq_length,
short_seq_prob, seed, verbose); short_seq_prob, seed, verbose,
min_num_sent);
} }
} }
......
...@@ -29,7 +29,6 @@ from megatron.model.utils import init_method_normal ...@@ -29,7 +29,6 @@ from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule from .module import MegatronModule
def bert_extended_attention_mask(attention_mask): def bert_extended_attention_mask(attention_mask):
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s] # [b, 1, s]
...@@ -75,9 +74,7 @@ class BertLMHead(MegatronModule): ...@@ -75,9 +74,7 @@ class BertLMHead(MegatronModule):
args = get_args() args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
self.bias.tensor_model_parallel = True mpu.set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
self.bias.partition_dim = 0
self.bias.stride = 1
self.parallel_output = parallel_output self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method) self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
......
...@@ -60,6 +60,13 @@ class MegatronModule(torch.nn.Module): ...@@ -60,6 +60,13 @@ class MegatronModule(torch.nn.Module):
if not self.share_word_embeddings: if not self.share_word_embeddings:
raise Exception('initialize_word_embeddings() was called but ' raise Exception('initialize_word_embeddings() was called but '
'share_word_embeddings is false') 'share_word_embeddings is false')
# This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. If we aren't using pipeline
# parallelism there is nothing to do.
if args.pipeline_model_parallel_size == 1:
return
# Parameters are shared between the word embeddings layer, and the # Parameters are shared between the word embeddings layer, and the
# heads at the end of the model. In a pipelined setup with more than # heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different # one stage, the initial embedding layer and the head are on different
...@@ -73,16 +80,16 @@ class MegatronModule(torch.nn.Module): ...@@ -73,16 +80,16 @@ class MegatronModule(torch.nn.Module):
# the two word_embeddings layers to ensure that every applied weight # the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages. # update is the same on both stages.
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
if not mpu.is_pipeline_first_stage(): assert not mpu.is_pipeline_first_stage()
self._word_embeddings_for_head_key = 'word_embeddings_for_head' self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# If first and last stages are different, set word_embeddings # set word_embeddings weights to 0 here, then copy first
# weights to 0 here, then copy first stage's weights using # stage's weights using all_reduce below.
# all_reduce below. self.word_embeddings = mpu.VocabParallelEmbedding(
self.word_embeddings = mpu.VocabParallelEmbedding( args.padded_vocab_size, args.hidden_size,
args.padded_vocab_size, args.hidden_size, init_method=init_method_normal(args.init_method_std))
init_method=init_method_normal(args.init_method_std)) self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.shared = True
self.word_embeddings.weight.shared = True
# Ensure that first and last stages have the same initial parameter # Ensure that first and last stages have the same initial parameter
# values. # values.
if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage(): if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
......
...@@ -44,7 +44,8 @@ from .initialize import model_parallel_is_initialized ...@@ -44,7 +44,8 @@ from .initialize import model_parallel_is_initialized
from .layers import ColumnParallelLinear from .layers import ColumnParallelLinear
from .layers import RowParallelLinear from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding from .layers import VocabParallelEmbedding
from .layers import (set_defaults_if_not_set_tensor_model_parallel_attributes, from .layers import (set_tensor_model_parallel_attributes,
set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes) copy_tensor_model_parallel_attributes)
from .mappings import copy_to_tensor_model_parallel_region from .mappings import copy_to_tensor_model_parallel_region
......
...@@ -109,7 +109,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size, ...@@ -109,7 +109,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
per_partition_per_stride_size = divide(per_partition_size, stride) per_partition_per_stride_size = divide(per_partition_size, stride)
weight_list = torch.split(master_weight, per_partition_per_stride_size, weight_list = torch.split(master_weight, per_partition_per_stride_size,
dim=partition_dim) dim=partition_dim)
rank = get_model_parallel_rank() rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
my_weight_list = weight_list[rank::world_size] my_weight_list = weight_list[rank::world_size]
...@@ -260,9 +260,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -260,9 +260,7 @@ class ColumnParallelLinear(torch.nn.Module):
self.output_size_per_partition, self.output_size_per_partition,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=args.params_dtype)) dtype=args.params_dtype))
self.bias.tensor_model_parallel = True set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
self.bias.partition_dim = 0
self.bias.stride = stride
# Always initialize bias to zero. # Always initialize bias to zero.
with torch.no_grad(): with torch.no_grad():
self.bias.zero_() self.bias.zero_()
......
...@@ -23,7 +23,10 @@ from megatron import print_rank_0 ...@@ -23,7 +23,10 @@ from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import BertModel, BertModelFirstStage, BertModelIntermediateStage, BertModelLastStage from megatron.model import (BertModel,
BertModelFirstStage,
BertModelIntermediateStage,
BertModelLastStage)
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
...@@ -34,23 +37,24 @@ def model_provider(): ...@@ -34,23 +37,24 @@ def model_provider():
print_rank_0('building BERT model ...') print_rank_0('building BERT model ...')
args = get_args() args = get_args()
num_tokentypes = 2 if args.bert_binary_head else 0
if mpu.get_pipeline_model_parallel_world_size() > 1: if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline. # Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
model = BertModelFirstStage( model = BertModelFirstStage(
num_tokentypes=2) num_tokentypes=num_tokentypes)
elif mpu.is_pipeline_last_stage(): elif mpu.is_pipeline_last_stage():
model = BertModelLastStage( model = BertModelLastStage(
num_tokentypes=2, num_tokentypes=num_tokentypes,
add_binary_head=True, add_binary_head=args.bert_binary_head,
parallel_output=True) parallel_output=True)
else: else:
model = BertModelIntermediateStage( model = BertModelIntermediateStage(
num_tokentypes=2) num_tokentypes=num_tokentypes)
else: else:
model = BertModel( model = BertModel(
num_tokentypes=2, num_tokentypes=num_tokentypes,
add_binary_head=True, add_binary_head=args.bert_binary_head,
parallel_output=True) parallel_output=True)
return model return model
...@@ -92,6 +96,9 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -92,6 +96,9 @@ def forward_step(data_iterator, model, input_tensor):
= get_batch(data_iterator) = get_batch(data_iterator)
timers('batch-generator').stop() timers('batch-generator').stop()
if not args.bert_binary_head:
types = None
# Forward pass through the model. # Forward pass through the model.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
assert input_tensor is None assert input_tensor is None
...@@ -109,22 +116,29 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -109,22 +116,29 @@ def forward_step(data_iterator, model, input_tensor):
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
lm_loss_, sop_logits = output_tensor lm_loss_, sop_logits = output_tensor
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1),
ignore_index=-1)
sop_loss = sop_loss.float()
lm_loss_ = lm_loss_.float() lm_loss_ = lm_loss_.float()
loss_mask = loss_mask.float() loss_mask = loss_mask.float()
lm_loss = torch.sum( lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
if sop_logits is not None:
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1),
ignore_index=-1)
sop_loss = sop_loss.float()
loss = lm_loss + sop_loss
averaged_losses = average_losses_across_data_parallel_group(
[lm_loss, sop_loss])
return loss, {'lm loss': averaged_losses[0],
'sop loss': averaged_losses[1]}
else:
loss = lm_loss
averaged_losses = average_losses_across_data_parallel_group(
[lm_loss])
return loss, {'lm loss': averaged_losses[0]}
loss = lm_loss + sop_loss
averaged_losses = average_losses_across_data_parallel_group([lm_loss, sop_loss])
return loss, {'lm loss': averaged_losses[0], 'sop loss': averaged_losses[1]}
return output_tensor return output_tensor
...@@ -143,7 +157,8 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -143,7 +157,8 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
masked_lm_prob=args.mask_prob, masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob, short_seq_prob=args.short_seq_prob,
seed=args.seed, seed=args.seed,
skip_warmup=(not args.mmap_warmup)) skip_warmup=(not args.mmap_warmup),
binary_head=args.bert_binary_head)
print_rank_0("> finished creating BERT datasets ...") print_rank_0("> finished creating BERT datasets ...")
return train_ds, valid_ds, test_ds return train_ds, valid_ds, test_ds
......
...@@ -61,7 +61,7 @@ if __name__ == '__main__': ...@@ -61,7 +61,7 @@ if __name__ == '__main__':
elif args.task in ['MNLI', 'QQP']: elif args.task in ['MNLI', 'QQP']:
from glue.finetune import main from glue.finetune import main
elif args.task in ['LAMBADA', 'WIKITEXT103']: elif args.task in ['LAMBADA', 'WIKITEXT103']:
from zeroshot_gpt2.evaluate import main from zeroshot_gpt.evaluate import main
else: else:
raise NotImplementedError('Task {} is not implemented.'.format( raise NotImplementedError('Task {} is not implemented.'.format(
args.task)) args.task))
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""GPT2 zero-shot evaluation.""" """GPT zero-shot evaluation."""
import math import math
...@@ -24,7 +24,7 @@ from megatron import print_rank_0, is_last_rank ...@@ -24,7 +24,7 @@ from megatron import print_rank_0, is_last_rank
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.model import GPT2Model, GPT2ModelFirstStage, GPT2ModelLastStage, GPT2ModelIntermediateStage from megatron.model import GPTModel, GPTModelFirstStage, GPTModelLastStage, GPTModelIntermediateStage
from megatron.training import get_model, communicate from megatron.training import get_model, communicate
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from tasks.finetune_utils import build_data_loader from tasks.finetune_utils import build_data_loader
...@@ -47,18 +47,18 @@ def get_model_provider(eval_metric): ...@@ -47,18 +47,18 @@ def get_model_provider(eval_metric):
raise NotImplementedError('output type for {} evaluation metric ' raise NotImplementedError('output type for {} evaluation metric '
'is not supported.'.format(eval_metric)) 'is not supported.'.format(eval_metric))
print_rank_0('building GPT2 model ...') print_rank_0('building GPT model ...')
if mpu.get_pipeline_model_parallel_world_size() > 1: if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline. # Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
model = GPT2ModelFirstStage(num_tokentypes=0) model = GPTModelFirstStage(num_tokentypes=0)
elif mpu.is_pipeline_last_stage(): elif mpu.is_pipeline_last_stage():
model = GPT2ModelLastStage( model = GPTModelLastStage(
parallel_output=parallel_output, num_tokentypes=0) parallel_output=parallel_output, num_tokentypes=0)
else: else:
model = GPT2ModelIntermediateStage(num_tokentypes=0) model = GPTModelIntermediateStage(num_tokentypes=0)
else: else:
model = GPT2Model(num_tokentypes=0, parallel_output=parallel_output) model = GPTModel(num_tokentypes=0, parallel_output=parallel_output)
return model return model
......
...@@ -23,11 +23,13 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ...@@ -23,11 +23,13 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
import torch import torch
from megatron import mpu from megatron import mpu
from megatron.checkpointing import load_checkpoint, save_checkpoint
from megatron.checkpointing import ensure_directory_exists from megatron.checkpointing import ensure_directory_exists
from megatron.checkpointing import get_checkpoint_name from megatron.checkpointing import get_checkpoint_name
from megatron.checkpointing import get_checkpoint_version
from megatron.checkpointing import get_checkpoint_tracker_filename from megatron.checkpointing import get_checkpoint_tracker_filename
from megatron.global_vars import set_global_variables, get_args
from megatron.global_vars import rebuild_tokenizer from megatron.global_vars import rebuild_tokenizer
from megatron.global_vars import _parse_args
def split_into_partitions(tensor, num_partitions, partition_dim, stride): def split_into_partitions(tensor, num_partitions, partition_dim, stride):
...@@ -185,8 +187,23 @@ def get_mp_merge_args(parser): ...@@ -185,8 +187,23 @@ def get_mp_merge_args(parser):
def main(): def main():
# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes
os.environ["WORLD_SIZE"] = f'{2**31}'
# Args # Args
args = _parse_args(extra_args_provider=get_mp_merge_args) set_global_variables(extra_args_provider=get_mp_merge_args,
args_defaults = {'use_cpu_initialization': True,
'micro_batch_size': 1,
'no_load_optim': True,
'no_load_rng': True,
'save_interval': 1})
args = get_args()
if args.pipeline_model_parallel_size > 1:
print("Checkpoints with pipeline model parallelism are not currently supported.")
exit()
model_type = args.model_type model_type = args.model_type
orig_tensor_model_parallel_size = args.tensor_model_parallel_size orig_tensor_model_parallel_size = args.tensor_model_parallel_size
args.tensor_model_parallel_size = 1 args.tensor_model_parallel_size = 1
...@@ -209,6 +226,8 @@ def main(): ...@@ -209,6 +226,8 @@ def main():
print('> building the full model ...') print('> building the full model ...')
mpu.initialize.set_tensor_model_parallel_world_size(1) mpu.initialize.set_tensor_model_parallel_world_size(1)
mpu.initialize.set_tensor_model_parallel_rank(0) mpu.initialize.set_tensor_model_parallel_rank(0)
mpu.initialize.set_pipeline_model_parallel_world_size(1)
mpu.initialize.set_pipeline_model_parallel_rank(0)
merged_model = get_model(model_type) merged_model = get_model(model_type)
# Build and load partitions. # Build and load partitions.
...@@ -220,13 +239,16 @@ def main(): ...@@ -220,13 +239,16 @@ def main():
for rank in range(args.tensor_model_parallel_size): for rank in range(args.tensor_model_parallel_size):
mpu.initialize.set_tensor_model_parallel_rank(rank) mpu.initialize.set_tensor_model_parallel_rank(rank)
checkpoint_name, iteration = get_parallel_checkpoint_name(args.load) checkpoint_name, iteration = get_parallel_checkpoint_name(args.load)
print('> loading {} ...'.format(checkpoint_name))
model_ = get_model(model_type) model_ = get_model(model_type)
sd = torch.load(checkpoint_name, map_location='cpu') print(f'> loading {checkpoint_name} ...')
model_.load_state_dict(sd['model']) load_checkpoint(model_, None, None)
print(f'> checkpoint version {get_checkpoint_version()}')
if get_checkpoint_version() < 2.0:
# Need to deal with the qkv matrix order of old versions
print("Checkpoints less than version 2.0 are not currently supported.")
exit()
partitions.append(model_) partitions.append(model_)
# Parameter generators so we can loop through them semiltaneouly. # Parameter generators so we can loop through them semiltaneouly.
merged_params_gen = merged_model.named_parameters() merged_params_gen = merged_model.named_parameters()
partitions_params_gen = [partition.named_parameters() partitions_params_gen = [partition.named_parameters()
...@@ -254,29 +276,26 @@ def main(): ...@@ -254,29 +276,26 @@ def main():
merged_param.data.copy_(partitions_param[0].data) merged_param.data.copy_(partitions_param[0].data)
# For parallel parameters, merge the values # For parallel parameters, merge the values
else: else:
print(' parallel parameter merge with stride {} along ' dim = merged_param.partition_dim
'dimention {}'.format(merged_param.stride, stride = merged_param.partition_stride
merged_param.partition_dim)) print(f' parallel parameter merge with stride {stride} along '
f'dimention {dim}')
merge_partitions(merged_param, merge_partitions(merged_param,
partitions_param, partitions_param,
merged_param.partition_dim, dim,
merged_param.stride) stride)
except StopIteration: except StopIteration:
break break
# Save the model. # Save the model.
args.tensor_model_parallel_size = 1 args.tensor_model_parallel_size = 1
args.pipeline_model_parallel_size = 1
# And now one last time so proper arguments are set in saved checkpoint
tokenizer = rebuild_tokenizer(args)
mpu.initialize.set_tensor_model_parallel_rank(0) mpu.initialize.set_tensor_model_parallel_rank(0)
sd = {} print('> saving merged model')
sd['model'] = merged_model.state_dict_for_save_checkpoint() save_checkpoint(iteration, merged_model, None, None)
sd['iteration'] = iteration
merged_path = os.path.join(args.load, 'merged')
checkpoint_name = get_checkpoint_name(merged_path, iteration)
ensure_directory_exists(checkpoint_name)
print('> saving merged model to {}'.format(checkpoint_name))
torch.save(sd, checkpoint_name)
print('done :-)') print('done :-)')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment