Commit 5da3bb92 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'core-merge-main' into 'core'

Core merge main

See merge request ADLR/megatron-lm!464
parents 8806ba73 2a86fa20
...@@ -217,6 +217,15 @@ def validate_args(args, defaults={}): ...@@ -217,6 +217,15 @@ def validate_args(args, defaults={}):
'can only specify one of lr-warmup-fraction ' \ 'can only specify one of lr-warmup-fraction ' \
'and lr-warmup-samples' 'and lr-warmup-samples'
if args.num_layers is not None:
assert args.encoder_num_layers is None, \
'cannot have both num-layers and encoder-num-layers specified'
args.encoder_num_layers = args.num_layers
else:
assert args.encoder_num_layers is not None, \
'either num-layers or encoder-num-layers should be specified'
args.num_layers = args.encoder_num_layers
# Check required arguments. # Check required arguments.
required_args = ['num_layers', 'hidden_size', 'num_attention_heads', required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
'max_position_embeddings'] 'max_position_embeddings']
...@@ -344,6 +353,10 @@ def _add_network_size_args(parser): ...@@ -344,6 +353,10 @@ def _add_network_size_args(parser):
group.add_argument('--num-layers', type=int, default=None, group.add_argument('--num-layers', type=int, default=None,
help='Number of transformer layers.') help='Number of transformer layers.')
group.add_argument('--encoder-num-layers', type=int, default=None,
help='Number of encoder transformer layers.')
group.add_argument('--decoder-num-layers', type=int, default=None,
help='Number of decoder transformer layers.')
group.add_argument('--hidden-size', type=int, default=None, group.add_argument('--hidden-size', type=int, default=None,
help='Tansformer hidden size.') help='Tansformer hidden size.')
group.add_argument('--ffn-hidden-size', type=int, default=None, group.add_argument('--ffn-hidden-size', type=int, default=None,
...@@ -817,12 +830,31 @@ def _add_data_args(parser): ...@@ -817,12 +830,31 @@ def _add_data_args(parser):
help='Path to the training dataset. Accepted format:' help='Path to the training dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the' '1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight ' 'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...') 'dataset2-path ... It is used with --split when a '
'single dataset used for all three: train, valid '
'and test. It is exclusive to the other '
'--*-data-path args')
group.add_argument('--split', type=str, default='969, 30, 1', group.add_argument('--split', type=str, default='969, 30, 1',
help='Comma-separated list of proportions for training,' help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split ' ' validation, and test split. For example the split '
'`90,5,5` will use 90%% of data for training, 5%% for ' '`90,5,5` will use 90%% of data for training, 5%% for '
'validation and 5%% for test.') 'validation and 5%% for test.')
group.add_argument('--train-data-path', nargs='*', default=None,
help='Path to the training dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...')
group.add_argument('--valid-data-path', nargs='*', default=None,
help='Path to the validation dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...')
group.add_argument('--test-data-path', nargs='*', default=None,
help='Path to the test dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...')
group.add_argument('--vocab-file', type=str, default=None, group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file.') help='Path to the vocab file.')
group.add_argument('--merge-file', type=str, default=None, group.add_argument('--merge-file', type=str, default=None,
...@@ -855,8 +887,11 @@ def _add_data_args(parser): ...@@ -855,8 +887,11 @@ def _add_data_args(parser):
default=None, default=None,
choices=['BertWordPieceLowerCase', choices=['BertWordPieceLowerCase',
'BertWordPieceCase', 'BertWordPieceCase',
'GPT2BPETokenizer'], 'GPT2BPETokenizer',
'SentencePieceTokenizer'],
help='What type of tokenizer to use.') help='What type of tokenizer to use.')
group.add_argument('--tokenizer-model', type=str, default=None,
help='Sentencepiece tokenizer model.')
group.add_argument('--data-impl', type=str, default='infer', group.add_argument('--data-impl', type=str, default='infer',
choices=['lazy', 'cached', 'mmap', 'infer'], choices=['lazy', 'cached', 'mmap', 'infer'],
help='Implementation of indexed datasets.') help='Implementation of indexed datasets.')
......
...@@ -63,12 +63,18 @@ def get_datasets_weights_and_num_samples(data_prefix, ...@@ -63,12 +63,18 @@ def get_datasets_weights_and_num_samples(data_prefix,
# Add 0.5% (the 1.005 factor) so in case the bleding dataset does # Add 0.5% (the 1.005 factor) so in case the bleding dataset does
# not uniformly distribute the number of samples, we still have # not uniformly distribute the number of samples, we still have
# samples left to feed to the network. # samples left to feed to the network.
if isinstance(train_valid_test_num_samples, list):
datasets_train_valid_test_num_samples = [] datasets_train_valid_test_num_samples = []
for weight in weights: for weight in weights:
datasets_train_valid_test_num_samples.append( datasets_train_valid_test_num_samples.append(
[int(math.ceil(val * weight * 1.005)) [int(math.ceil(val * weight * 1.005))
for val in train_valid_test_num_samples]) for val in train_valid_test_num_samples])
else:
# Used when separate dataset files are provided for train,
# valid and test
datasets_train_valid_test_num_samples = [
int(math.ceil(train_valid_test_num_samples * weight * 1.005))
for weight in weights]
return prefixes, weights, datasets_train_valid_test_num_samples return prefixes, weights, datasets_train_valid_test_num_samples
......
...@@ -16,11 +16,15 @@ from megatron.data.dataset_utils import get_train_valid_test_split_ ...@@ -16,11 +16,15 @@ from megatron.data.dataset_utils import get_train_valid_test_split_
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, def build_train_valid_test_datasets(data_prefix, data_impl,
train_valid_test_num_samples, splits_string, train_valid_test_num_samples,
seq_length, seed, skip_warmup): seq_length, seed, skip_warmup,
train_data_prefix=None, valid_data_prefix=None,
test_data_prefix=None,):
"""Build train, valid, and test datasets.""" """Build train, valid, and test datasets."""
if data_prefix:
print_rank_0("Single data path provided for train, valid & test")
# Single dataset. # Single dataset.
if len(data_prefix) == 1: if len(data_prefix) == 1:
return _build_train_valid_test_datasets(data_prefix[0], return _build_train_valid_test_datasets(data_prefix[0],
...@@ -63,6 +67,83 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -63,6 +67,83 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
return (blending_train_dataset, blending_valid_dataset, return (blending_train_dataset, blending_valid_dataset,
blending_test_dataset) blending_test_dataset)
else:
print_rank_0("Separate data paths provided for train, valid & test. Split string will be ignored.")
train_dataset, valid_dataset, test_dataset = None, None, None
# Single dataset.
if train_data_prefix is not None:
train_dataset = build_dataset("train", train_data_prefix, data_impl,
train_valid_test_num_samples[0], seq_length, seed,
skip_warmup)
if valid_data_prefix is not None:
valid_dataset = build_dataset("valid", valid_data_prefix, data_impl,
train_valid_test_num_samples[1], seq_length, seed,
False)
if test_data_prefix is not None:
test_dataset = build_dataset("test", test_data_prefix, data_impl,
train_valid_test_num_samples[2], seq_length, seed,
False)
return (train_dataset, valid_dataset, test_dataset)
def build_dataset(dataset_name, data_prefix, data_impl, num_samples, seq_length, seed, skip_warmup):
dataset = None
if len(data_prefix) == 1:
dataset = _build_dataset(dataset_name,
data_prefix[0], data_impl,
num_samples, seq_length,
seed, skip_warmup)
else:
# Blending dataset.
# Parse the values.
output = get_datasets_weights_and_num_samples(data_prefix, num_samples)
prefixes, weights, dataset_num_samples = output
# Build individual datasets.
datasets = []
for i in range(len(prefixes)):
ds = _build_dataset(dataset_name, prefixes[i],
data_impl, dataset_num_samples[i],
seq_length, seed, skip_warmup)
if ds:
datasets.append(ds)
if datasets:
dataset = BlendableDataset(datasets, weights)
return dataset
def _build_dataset(dataset_name, data_prefix, data_impl,
num_samples, seq_length, seed, skip_warmup):
"""
Build dataset. This method is called when individual
train, valid, test datasets are provided
"""
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
skip_warmup)
total_num_of_documents = indexed_dataset.sizes.shape[0]
print_rank_0(' {}:'.format(dataset_name))
print_rank_0(' document indices in [0, {}) total of {} '
'documents'.format(total_num_of_documents, total_num_of_documents))
documents = np.arange(start=0, stop=total_num_of_documents,
step=1, dtype=np.int32)
dataset = GPTDataset(dataset_name, data_prefix,
documents, indexed_dataset,
num_samples, seq_length, seed)
return dataset
def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
......
...@@ -736,9 +736,9 @@ class NoopTransformerLayer(MegatronModule): ...@@ -736,9 +736,9 @@ class NoopTransformerLayer(MegatronModule):
return hidden_states.clone() return hidden_states.clone()
def _get_num_layers(args, is_encoder_and_decoder_model): def _get_num_layers(args, is_encoder_and_decoder_model, is_decoder=False):
"""Compute the number of transformer layers resident on the current rank.""" """Compute the number of transformer layers resident on the current rank."""
if mpu.get_pipeline_model_parallel_world_size() > 1: if get_pipeline_model_parallel_world_size() > 1:
if is_encoder_and_decoder_model: if is_encoder_and_decoder_model:
assert args.pipeline_model_parallel_split_rank is not None assert args.pipeline_model_parallel_split_rank is not None
...@@ -752,20 +752,21 @@ def _get_num_layers(args, is_encoder_and_decoder_model): ...@@ -752,20 +752,21 @@ def _get_num_layers(args, is_encoder_and_decoder_model):
args.pipeline_model_parallel_split_rank args.pipeline_model_parallel_split_rank
) )
num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder
assert args.num_layers % num_ranks_in_encoder == 0, \ assert args.encoder_num_layers % num_ranks_in_encoder == 0, \
'num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.num_layers, num_ranks_in_encoder) 'encoder_num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.encoder_num_layers, num_ranks_in_encoder)
assert args.num_layers % num_ranks_in_decoder == 0, \ assert args.decoder_num_layers % num_ranks_in_decoder == 0, \
'num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.num_layers, num_ranks_in_decoder) 'decoder_num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.decoder_num_layers, num_ranks_in_decoder)
if mpu.is_pipeline_stage_before_split(): if is_pipeline_stage_before_split():
num_layers = ( num_layers = (
0 0
if args.standalone_embedding_stage if args.standalone_embedding_stage
and mpu.get_pipeline_model_parallel_rank() == 0 else and get_pipeline_model_parallel_rank() == 0 else
args.num_layers // num_ranks_in_encoder args.encoder_num_layers // num_ranks_in_encoder
) )
else: else:
num_layers = args.num_layers // num_ranks_in_decoder num_layers = args.decoder_num_layers // num_ranks_in_decoder
else: else:
assert args.num_layers == args.encoder_num_layers
assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \ assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
'num_layers must be divisible by transformer_pipeline_model_parallel_size' 'num_layers must be divisible by transformer_pipeline_model_parallel_size'
...@@ -776,11 +777,14 @@ def _get_num_layers(args, is_encoder_and_decoder_model): ...@@ -776,11 +777,14 @@ def _get_num_layers(args, is_encoder_and_decoder_model):
num_layers = ( num_layers = (
0 0
if args.standalone_embedding_stage if args.standalone_embedding_stage
and mpu.get_pipeline_model_parallel_rank() == 0 else and get_pipeline_model_parallel_rank() == 0 else
args.num_layers // args.transformer_pipeline_model_parallel_size args.num_layers // args.transformer_pipeline_model_parallel_size
) )
else: else:
num_layers = args.num_layers if not is_decoder:
num_layers = args.encoder_num_layers
else:
num_layers = args.decoder_num_layers
return num_layers return num_layers
...@@ -817,7 +821,9 @@ class ParallelTransformer(MegatronModule): ...@@ -817,7 +821,9 @@ class ParallelTransformer(MegatronModule):
# Number of layers. # Number of layers.
self.num_layers = _get_num_layers( self.num_layers = _get_num_layers(
args, args.model_type == ModelType.encoder_and_decoder) args,
args.model_type == ModelType.encoder_and_decoder,
layer_type == LayerType.decoder)
self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)] self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)]
......
...@@ -132,6 +132,7 @@ def get_megatron_optimizer(model, ...@@ -132,6 +132,7 @@ def get_megatron_optimizer(model,
args.use_contiguous_buffers_in_local_ddp, args.use_contiguous_buffers_in_local_ddp,
args.fp16, args.fp16,
args.bf16, args.bf16,
args.params_dtype,
grad_scaler, grad_scaler,
model) model)
......
...@@ -337,7 +337,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -337,7 +337,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp, params_have_main_grad, use_contiguous_buffers_in_local_ddp,
fp16, bf16, grad_scaler, models): fp16, bf16, params_dtype, grad_scaler, models):
""" """
See top of class definition for argument descriptions. See top of class definition for argument descriptions.
...@@ -351,7 +351,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -351,7 +351,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
super().__init__( super().__init__(
optimizer, clip_grad, log_num_zeros_in_grad, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp, params_have_main_grad, use_contiguous_buffers_in_local_ddp,
fp16, bf16, grad_scaler, models) fp16, bf16, params_dtype, grad_scaler, models)
# Verify that contiguous buffers are being used. # Verify that contiguous buffers are being used.
# - Note: this should already be checked in arguments.py. # - Note: this should already be checked in arguments.py.
...@@ -380,6 +380,21 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -380,6 +380,21 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self.model_param_gbuf_map, self.model_param_gbuf_map,
self.opt_group_ranges) self.opt_group_ranges)
# Initialize param buffers.
# - These are views on the DDP model's grad buffers, that share
# storage & have their own dtype. This is safe because the param
# dtype size is always <= grad dtype size.
self.param_buffers = []
for model_index, model in enumerate(self.models):
current_param_buffers = {}
for dtype, grad_buffer in model._grad_buffers.items():
param_buffer = torch.tensor(grad_buffer.data.storage()._untyped(),
dtype = params_dtype,
device = grad_buffer.data.device)
param_buffer = param_buffer[:grad_buffer.numel_padded]
current_param_buffers[dtype] = param_buffer
self.param_buffers.append(current_param_buffers)
# Update optimizer groups. # Update optimizer groups.
# - Also, leverage state_dict() and load_state_dict() to # - Also, leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors. # recast preexisting per-param state tensors.
...@@ -474,36 +489,48 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -474,36 +489,48 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
_zero_grad_group_helper(group, set_to_none) _zero_grad_group_helper(group, set_to_none)
def get_model_grad_buffer_dp_views(self): @staticmethod
def get_model_buffer_dp_views(model_buffers):
""" """
Get shard views of each of the DDP's grad buffers. Get shard views of each of the DDP's param/grad buffers.
In this nested list, the top level is grouped by the virtual model In this nested list, the top level is grouped by the virtual model
index and the grad buffer's data type. The sub-level is a list of index and the buffer's data type. The sub-level is a list of
shards of that grad buffer, where each shard in the list represents shards of that buffer, where each shard in the list represents
a contiguous view of the grad buffer, that is owned by a data-parallel a contiguous view of the buffer, that is owned by a data-parallel
rank. The shard boundary does not respect parameter boundaries, and rank. The shard boundary does not respect parameter boundaries, and
so the elements of some parameters are split across data parallel so the elements of some parameters are split across data parallel
ranks. ranks.
Additionally, return references to the entire grad buffers, for use Additionally, return references to the entire buffers, for use
in _reduce_scatter_base and _all_gather_base. in _reduce_scatter_base and _all_gather_base.
""" """
data_parallel_world_size = mpu.get_data_parallel_world_size() data_parallel_world_size = mpu.get_data_parallel_world_size()
# Grad buffer views. # Buffer views.
gbuf_view_items = [] view_items = []
for model_index, model in enumerate(self.models): for model_index, buffers in enumerate(model_buffers):
for dtype, gbuf in model._grad_buffers.items(): for dtype, buf in buffers.items():
assert gbuf.numel_padded % data_parallel_world_size == 0 assert buf.numel() % data_parallel_world_size == 0
shard_size = int(gbuf.numel_padded / data_parallel_world_size) shard_size = int(buf.numel() / data_parallel_world_size)
gbuf_views = [gbuf.data[(r*shard_size):((r+1)*shard_size)] buf_views = [buf[(r*shard_size):((r+1)*shard_size)]
for r in range(data_parallel_world_size)] for r in range(data_parallel_world_size)]
gbuf_view_items.append((model_index, dtype, gbuf.data, gbuf_views)) view_items.append((model_index, dtype, buf, buf_views))
return view_items
return gbuf_view_items
def get_model_grad_buffer_dp_views(self):
return self.get_model_buffer_dp_views([
{dtype : mem_buffer.data}
for model in self.models
for dtype, mem_buffer in model._grad_buffers.items()])
def get_model_param_buffer_dp_views(self):
return self.get_model_buffer_dp_views(self.param_buffers)
def reduce_model_grads(self, args, timers): def reduce_model_grads(self, args, timers):
...@@ -560,9 +587,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -560,9 +587,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
""" """
All-gather updated model params. All-gather updated model params.
The DDP's grad buffer is used for the all-gather, and thus no The DDP's param buffer is used for the all-gather, and thus no
tensors are dynamically allocated. After the all-gather, the params tensors are dynamically allocated. After the all-gather, the params
can be copied from param.main_grad to param. can be copied from the param buffer to the param.
""" """
timers('params-all-gather', log_level=1).start( timers('params-all-gather', log_level=1).start(
...@@ -572,26 +599,28 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -572,26 +599,28 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
data_parallel_group = mpu.get_data_parallel_group() data_parallel_group = mpu.get_data_parallel_group()
# All-gather updated main params. # All-gather updated main params.
# - All grad buffer views are guaranteed to have the same num elements # - All param buffer views are guaranteed to have the same num elements
# across all data parallel ranks, with grad buffer padding that is done # across all data parallel ranks, due to grad buffer padding that is
# in distributed.py. Thus, all sub-views will have consistent start/end # done in distributed.py, and extended to the param buffers. Thus,
# indexes across data parallel ranks. # all sub-views will have consistent start/end indexes across data
gbuf_view_items = self.get_model_grad_buffer_dp_views() # parallel ranks.
for index, (model_index, dtype, gbuf, gbuf_views) \ pbuf_view_items = self.get_model_param_buffer_dp_views()
in enumerate(gbuf_view_items): for index, (model_index, dtype, pbuf, pbuf_views) \
in enumerate(pbuf_view_items):
torch.distributed._all_gather_base( torch.distributed._all_gather_base(
gbuf, pbuf,
gbuf_views[data_parallel_rank], pbuf_views[data_parallel_rank],
group = data_parallel_group, group = data_parallel_group,
) )
# Each model param now contains its updated values in its # Copy from param buffer to each param.
# '.main_grad' field. for model_id, model in enumerate(self.models):
for model in self.models:
for dtype, param_map in model._grad_buffer_param_index_map.items(): for dtype, param_map in model._grad_buffer_param_index_map.items():
for param in param_map: for param, buf_range in param_map.items():
param.detach().copy_(param.main_grad) param_buf = self.param_buffers[model_id][dtype]
param_buf_shard = param_buf[buf_range[0]:buf_range[1]]
param.view(-1).detach().copy_(param_buf_shard)
timers('params-all-gather').stop() timers('params-all-gather').stop()
...@@ -671,14 +700,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -671,14 +700,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
model_group): model_group):
param_range_map = self.get_model_param_range_map(model_param) param_range_map = self.get_model_param_range_map(model_param)
param_range = param_range_map["param"] world_range = param_range_map["gbuf_world"]
assert param_range.size == shard_main_param.nelement()
model_grad = model_param.main_grad assert world_range.size == shard_main_param.nelement()
shard_model_grad = model_grad.view(-1) \
[param_range.start:param_range.end] model_id, dtype = self.model_param_gbuf_map[model_param]
model_param_buffer = self.param_buffers[model_id][dtype]
shard_model_param = model_param_buffer.view(-1) \
[world_range.start:world_range.end]
shard_model_grad.data.copy_(shard_main_param) shard_model_param.data.copy_(shard_main_param)
# Copy shard groups to model groups. # Copy shard groups to model groups.
copy_group_params(self.shard_fp32_from_float16_groups, copy_group_params(self.shard_fp32_from_float16_groups,
......
...@@ -321,6 +321,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -321,6 +321,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
is using a contiguous buffer to hold the model grads. is using a contiguous buffer to hold the model grads.
fp16: if true, the model is running in fp16. fp16: if true, the model is running in fp16.
bf16: if true, the model is running in bfloat16. bf16: if true, the model is running in bfloat16.
params_dtype: used by distributed optimizer.
grad_scaler: used for scaling gradients. Note that this can be grad_scaler: used for scaling gradients. Note that this can be
None. This case happens when `bf16 = True` and we don't None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have use any loss scale. Note that for `bf16 = True`, we can have
...@@ -332,7 +333,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -332,7 +333,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp, params_have_main_grad, use_contiguous_buffers_in_local_ddp,
fp16, bf16, grad_scaler, fp16, bf16, params_dtype, grad_scaler,
models): models):
super().__init__( super().__init__(
...@@ -342,6 +343,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -342,6 +343,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
self.fp16 = fp16 self.fp16 = fp16
self.bf16 = bf16 self.bf16 = bf16
self.params_dtype = params_dtype
self.grad_scaler = grad_scaler self.grad_scaler = grad_scaler
# None grad scaler is only supported for bf16. # None grad scaler is only supported for bf16.
...@@ -491,12 +493,12 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer): ...@@ -491,12 +493,12 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp, params_have_main_grad, use_contiguous_buffers_in_local_ddp,
fp16, bf16, grad_scaler, models): fp16, bf16, params_dtype, grad_scaler, models):
super().__init__( super().__init__(
optimizer, clip_grad, log_num_zeros_in_grad, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp, params_have_main_grad, use_contiguous_buffers_in_local_ddp,
fp16, bf16, grad_scaler, models) fp16, bf16, params_dtype, grad_scaler, models)
# ====================== # ======================
# main parameter stuff # main parameter stuff
......
...@@ -237,7 +237,7 @@ def forward_backward_no_pipelining(forward_step_func, ...@@ -237,7 +237,7 @@ def forward_backward_no_pipelining(forward_step_func,
timers, collect_non_loss_data) timers, collect_non_loss_data)
if not forward_only: if not forward_only:
backward_step(optimizer, input_tensor, output_tensor, backward_step(optimizer, input_tensor, output_tensor,
timers, output_tensor_grad) output_tensor_grad, timers)
# Run computation for last microbatch out of context handler (want to # Run computation for last microbatch out of context handler (want to
# synchronize gradients). # synchronize gradients).
......
...@@ -28,6 +28,9 @@ def build_tokenizer(args): ...@@ -28,6 +28,9 @@ def build_tokenizer(args):
elif args.tokenizer_type == 'GPT2BPETokenizer': elif args.tokenizer_type == 'GPT2BPETokenizer':
assert args.merge_file is not None assert args.merge_file is not None
tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file) tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)
elif args.tokenizer_type == 'SentencePieceTokenizer':
assert args.tokenizer_model is not None
tokenizer = _SentencePieceTokenizer(args.tokenizer_model, vocab_extra_ids=args.vocab_extra_ids)
else: else:
raise NotImplementedError('{} tokenizer is not ' raise NotImplementedError('{} tokenizer is not '
'implemented.'.format(args.tokenizer_type)) 'implemented.'.format(args.tokenizer_type))
...@@ -276,3 +279,169 @@ class _GPT2BPETokenizer(AbstractTokenizer): ...@@ -276,3 +279,169 @@ class _GPT2BPETokenizer(AbstractTokenizer):
@property @property
def eod(self): def eod(self):
return self.eod_id return self.eod_id
class _SentencePieceTokenizer(AbstractTokenizer):
"""SentencePieceTokenizer-Megatron wrapper"""
def __init__(self, model_file, vocab_extra_ids=0):
name = 'SentencePieceTokenizer'
super().__init__(name)
import sentencepiece
self._tokenizer = sentencepiece.SentencePieceProcessor(model_file=model_file)
self._initalize(vocab_extra_ids)
def _initalize(self, vocab_extra_ids):
self._vocab = {}
self._inv_vocab = {}
self._special_tokens = {}
self._inv_special_tokens = {}
self._t5_tokens = []
for i in range(len(self._tokenizer)):
t = self._tokenizer.id_to_piece(i)
self._inv_vocab[i] = t
self._vocab[t] = i
def _add_special_token(t):
if t not in self._vocab:
next_id = len(self._vocab)
self._vocab[t] = next_id
self._inv_vocab[next_id] = t
self._special_tokens[t] = self._vocab[t]
self._inv_special_tokens[self._vocab[t]] = t
_add_special_token('<CLS>')
self._cls_id = self._vocab['<CLS>']
_add_special_token('<SEP>')
self._sep_id = self._vocab['<SEP>']
_add_special_token('<EOD>')
self._eod_id = self._vocab['<EOD>']
_add_special_token('<MASK>')
self._mask_id = self._vocab['<MASK>']
pad_id = self._tokenizer.pad_id()
try:
pad_token = self._tokenizer.id_to_piece(pad_id)
except IndexError:
pad_token = '<PAD>'
_add_special_token(pad_token)
self._pad_id = self._vocab[pad_token]
bos_id = self._tokenizer.bos_id()
try:
bos_token = self._tokenizer.id_to_piece(bos_id)
except IndexError:
bos_token = '<BOS>'
_add_special_token(bos_token)
self._bos_id = self._vocab[bos_token]
eos_id = self._tokenizer.eos_id()
try:
eos_token = self._tokenizer.id_to_piece(eos_id)
except IndexError:
eos_token = '<EOS>'
_add_special_token(eos_token)
self._eos_id = self._vocab[eos_token]
for i in range(vocab_extra_ids):
t = "<extra_id_{}>".format(i)
_add_special_token(t)
self._t5_tokens += [t]
@property
def vocab_size(self):
return len(self._vocab)
@property
def vocab(self):
return self._vocab
@property
def inv_vocab(self):
return self._inv_vocab
# From:
# https://github.com/NVIDIA/NeMo/blob/c8fa217e811d60d11d014827c7f3845ff6c99ae7/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py#L89
def tokenize(self, text):
ids = []
idx = 0
while 1:
indices = {}
for token in self._special_tokens:
try:
indices[token] = text[idx:].index(token)
except ValueError:
continue
if len(indices) == 0:
break
next_token = min(indices, key=indices.get)
next_idx = idx + indices[next_token]
ids.extend(self._tokenizer.encode_as_ids(text[idx:next_idx]))
ids.append(self._special_tokens[next_token])
idx = next_idx + len(next_token)
ids.extend(self._tokenizer.encode_as_ids(text[idx:]))
return ids
# From:
# https://github.com/NVIDIA/NeMo/blob/c8fa217e811d60d11d014827c7f3845ff6c99ae7/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py#L125
def detokenize(self, ids):
text = ""
last_i = 0
for i, id in enumerate(ids):
if id in self._inv_special_tokens:
text += self._tokenizer.decode_ids(ids[last_i:i]) + " "
text += self._inv_special_tokens[id] + " "
last_i = i + 1
text += self._tokenizer.decode_ids(ids[last_i:])
return text.strip()
@property
def cls(self):
return self._cls_id
@property
def sep(self):
return self._sep_id
@property
def pad(self):
return self._pad_id
@property
def bos_token_id(self):
return self._bos_id
@property
def bos(self):
return self._bos_id
@property
def eod(self):
return self._eod_id
@property
def eos_token_id(self):
return self._eos_id
@property
def eos(self):
return self._eos_id
@property
def mask(self):
return self._mask_id
@property
def additional_special_tokens_ids(self):
return [self.vocab[k] for k in self._t5_tokens]
...@@ -100,7 +100,10 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -100,7 +100,10 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
train_valid_test_num_samples=train_val_test_num_samples, train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length, seq_length=args.seq_length,
seed=args.seed, seed=args.seed,
skip_warmup=(not args.mmap_warmup)) skip_warmup=(not args.mmap_warmup),
train_data_prefix=args.train_data_path,
valid_data_prefix=args.valid_data_path,
test_data_prefix=args.test_data_path,)
print_rank_0("> finished creating GPT datasets ...") print_rank_0("> finished creating GPT datasets ...")
return train_ds, valid_ds, test_ds return train_ds, valid_ds, test_ds
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment