diff --git a/megatron/arguments.py b/megatron/arguments.py index b35af1d..2a55699 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -20,6 +20,9 @@ import os import torch +# FastMoE +from fmoe.megatron import add_fmoe_args as _add_fmoe_args + def parse_args(extra_args_provider=None, defaults={}, ignore_unknown_args=False): """Parse all arguments.""" @@ -42,6 +45,9 @@ def parse_args(extra_args_provider=None, defaults={}, parser = _add_vit_args(parser) parser = _add_logging_args(parser) + # FastMoE arguments. + parser = _add_fmoe_args(parser) + # Custom arguments. if extra_args_provider is not None: parser = extra_args_provider(parser) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 1251066..32afb2f 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -95,7 +95,7 @@ dtypes = { 3: np.int16, 4: np.int32, 5: np.int64, - 6: np.float, + 6: np.float32, 7: np.double, 8: np.uint16 } @@ -268,7 +268,7 @@ class IndexedDatasetBuilder(object): np.int16: 2, np.int32: 4, np.int64: 8, - np.float: 4, + np.float32: 4, np.double: 8 } diff --git a/megatron/optimizer/__init__.py b/megatron/optimizer/__init__.py index 823a51f..32f4b2e 100644 --- a/megatron/optimizer/__init__.py +++ b/megatron/optimizer/__init__.py @@ -69,8 +69,10 @@ def get_megatron_optimizer(model): # Determine whether the params have main-grad field. params_have_main_grad = False - if args.DDP_impl == 'local': - params_have_main_grad = True + + # FastMoE does not have main_grad field + # if args.DDP_impl == 'local': + # params_have_main_grad = True if args.fp16 or args.bf16: diff --git a/megatron/optimizer/clip_grads.py b/megatron/optimizer/clip_grads.py index 036a1d4..81d5bd9 100644 --- a/megatron/optimizer/clip_grads.py +++ b/megatron/optimizer/clip_grads.py @@ -54,17 +54,23 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): # - should not be a replica due to tensor model parallelism grads = [] grads_for_norm = [] + # FastMoE + grads_in_moe = [] for param in parameters: grad_not_none = param.grad is not None is_not_shared = param_is_not_shared(param) is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) - grad = param.grad.detach() if grad_not_none: + grad = param.grad.detach() # Make sure the grads are in fp32 assert param.grad.type() == 'torch.cuda.FloatTensor' grads.append(grad) if grad_not_none and is_not_shared and is_not_tp_duplicate: - grads_for_norm.append(grad) + # FastMoE + if hasattr(param, 'dp_comm') and param.dp_comm in ('none'): + grads_in_moe.append(grad) + else: + grads_for_norm.append(grad) # Norm parameters. max_norm = float(max_norm) @@ -73,6 +79,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): # Calculate norm. if norm_type == inf: + # FastMoE TODO + assert False, f"norm_type {norm_type} is not supported by FastMoE " total_norm = max(grad.abs().max() for grad in grads_for_norm) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) # Take max across all model-parallel GPUs. @@ -97,7 +105,20 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): # we need the pow(norm-type). total_norm = grad_norm ** norm_type + # FastMoE + if len(grads_in_moe) > 0 : # 'cold' experts may not have any grads in one iteration + grad_norm, _ = multi_tensor_applier( + amp_C.multi_tensor_l2norm, + dummy_overflow_buf, + [grads_in_moe], + False # no per-parameter norm + ) + grad_norm = grad_norm ** norm_type + torch.distributed.all_reduce(grad_norm, op=torch.distributed.ReduceOp.SUM, group=mpu.get_model_parallel_group()) + total_norm += grad_norm else: + # FastMoE TODO + assert False, f"norm_type {norm_type} is not supported by FastMoE " for grad in grads_for_norm: grad_norm = torch.norm(grad, norm_type) total_norm += grad_norm ** norm_type diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py index 368f587..080b06f 100644 --- a/megatron/optimizer/optimizer.py +++ b/megatron/optimizer/optimizer.py @@ -250,6 +250,9 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): param) if hasattr(param, 'shared'): main_param.shared = param.shared + # FastMoE + if hasattr(param, 'dp_comm'): + main_param.dp_comm = param.dp_comm # Replace the optimizer params with the new fp32 copy. param_group['params'][i] = main_param fp32_from_float16_params_this_group.append(main_param) @@ -396,17 +399,26 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): # so we can update the loss scale. self.grad_scaler.update(found_inf_flag) - # If we found inf/nan, skip the update. - if found_inf_flag: - return False, None, None + # move to L417. + # if found_inf_flag: + # return False, None, None # Clip the main gradients. timers('optimizer-clip-main-grad').start() grad_norm = None - if self.clip_grad > 0.0: - grad_norm = self.clip_grad_norm(self.clip_grad) + + # remove if branch to avoid dead-lock in FastMoE + # if self.clip_grad > 0.0: + # grad_norm = self.clip_grad_norm(self.clip_grad) + grad_norm = self.clip_grad_norm(self.clip_grad) + timers('optimizer-clip-main-grad').stop() + # move early return to here to avoid dead-lock in FastMoE + # If we found inf/nan, skip the update. + if found_inf_flag: + return False, None, None + # count the zeros in the grads num_zeros_in_grad = self.count_zeros() if \ self.log_num_zeros_in_grad else None diff --git a/megatron/schedules.py b/megatron/schedules.py index d346c30..8eef46c 100644 --- a/megatron/schedules.py +++ b/megatron/schedules.py @@ -23,7 +23,11 @@ from megatron import get_timers from megatron import mpu from megatron import p2p_communication from megatron.utils import unwrap_model -from megatron.model import DistributedDataParallel as LocalDDP + +# FastMoE +# from megatron.model import DistributedDataParallel as LocalDDP +from fmoe.megatron import DistributedDataParallel as LocalDDP + from megatron.model import Float16Module def get_forward_backward_func(): @@ -54,7 +58,8 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r unwrapped_model = unwrap_model( model, (torchDDP, LocalDDP, Float16Module)) unwrapped_model.set_input_tensor(input_tensor) - output_tensor, loss_func = forward_step_func(data_iterator, model) + output_tensor, loss_func, bal_loss = forward_step_func(data_iterator, model) + bal_loss = bal_loss / get_num_microbatches() if mpu.is_pipeline_last_stage(): output_tensor = loss_func(output_tensor) loss, loss_reduced = output_tensor @@ -62,10 +67,10 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r losses_reduced.append(loss_reduced) timers('forward-compute').stop() - return output_tensor + return output_tensor, bal_loss -def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad): +def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, bal_loss): """Backward step through passed-in output tensor. If last stage, output_tensor_grad is None, otherwise gradient of loss @@ -85,7 +90,9 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad): # Backward pass. if output_tensor_grad is None: output_tensor = optimizer.scale_loss(output_tensor) - torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) + torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) + else: + torch.autograd.backward([output_tensor,bal_loss], grad_tensors=[output_tensor_grad, None]) # Collect the grad of the input_tensor. input_tensor_grad = None @@ -122,18 +129,18 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model, input_tensor, output_tensor_grad = None, None with context_handler(): for i in range(get_num_microbatches() - 1): - output_tensor = forward_step(forward_step_func, data_iterator, model, + output_tensor, bal_loss = forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced) if not forward_only: backward_step(optimizer, input_tensor, output_tensor, - output_tensor_grad) + output_tensor_grad, bal_loss) # Run computation for last microbatch out of context handler (want to # synchronize gradients). - output_tensor = forward_step(forward_step_func, data_iterator, model, + output_tensor, bal_loss = forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced) if not forward_only: - backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad) + backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, bal_loss) return losses_reduced @@ -144,6 +151,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat communication between pipeline stages as needed. Returns dictionary with losses if the last stage, empty dict otherwise.""" + # FastMoE TODO + assert False, "FastMoE not supports pipeline with interleaving" + input_tensors = [[] for _ in range(len(model))] output_tensors = [[] for _ in range(len(model))] losses_reduced = [] @@ -385,17 +395,19 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite input_tensors = [] output_tensors = [] + bal_losses = [] losses_reduced = [] # Run warmup forward passes. for i in range(num_warmup_microbatches): input_tensor = p2p_communication.recv_forward(timers=timers) - output_tensor = forward_step(forward_step_func, data_iterator, model, + output_tensor, bal_loss = forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced) p2p_communication.send_forward(output_tensor, timers=timers) input_tensors.append(input_tensor) output_tensors.append(output_tensor) + bal_losses.append(bal_loss) # Before running 1F1B, need to receive first forward tensor. # If all microbatches are run in warmup / cooldown phase, then no need to @@ -407,7 +419,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite for i in range(num_microbatches_remaining): last_iteration = (i == (num_microbatches_remaining - 1)) - output_tensor = forward_step(forward_step_func, data_iterator, model, + output_tensor, bal_loss = forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced) if forward_only: p2p_communication.send_forward(output_tensor, timers=timers) @@ -420,16 +432,17 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite # start of the list for backward pass. input_tensors.append(input_tensor) output_tensors.append(output_tensor) + bal_losses.append(bal_loss) if forward_only: if not last_iteration: input_tensor = p2p_communication.recv_forward(timers=timers) else: - input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) + input_tensor, output_tensor, bal_loss = input_tensors.pop(0), output_tensors.pop(0), bal_losses.pop(0) input_tensor_grad = \ backward_step(optimizer, input_tensor, output_tensor, - output_tensor_grad) + output_tensor_grad, bal_loss) if last_iteration: input_tensor = None @@ -444,12 +457,13 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite for i in range(num_warmup_microbatches): input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) + bal_loss = bal_losses.pop(0) output_tensor_grad = p2p_communication.recv_backward(timers=timers) input_tensor_grad = \ backward_step(optimizer, input_tensor, output_tensor, - output_tensor_grad) + output_tensor_grad, bal_loss) p2p_communication.send_backward(input_tensor_grad, timers=timers) diff --git a/megatron/training.py b/megatron/training.py index 1ab57e9..fbe2fe8 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -35,14 +35,23 @@ from megatron import update_num_microbatches from megatron import mpu from megatron import print_rank_0 from megatron import print_rank_last -from megatron.checkpointing import load_checkpoint -from megatron.checkpointing import save_checkpoint + +# FastMoE +# from megatron.checkpointing import load_checkpoint +from fmoe.megatron.checkpoint import load_checkpoint +# from megatron.checkpointing import save_checkpoint +from fmoe.megatron.checkpoint import save_checkpoint + from megatron.model import Float16Module from megatron.optimizer import get_megatron_optimizer from megatron.initialize import initialize_megatron from megatron.initialize import write_args_to_tensorboard from megatron.learning_rates import AnnealingLR -from megatron.model import DistributedDataParallel as LocalDDP + +# FastMoE +# from megatron.model import DistributedDataParallel as LocalDDP +from fmoe.megatron import DistributedDataParallel as LocalDDP + from megatron.utils import check_adlr_autoresume_termination from megatron.utils import unwrap_model from megatron.data.data_samplers import build_pretraining_data_loader @@ -107,6 +116,13 @@ def pretrain(train_valid_test_dataset_provider, args = get_args() timers = get_timers() + # Initialize FastMoE + if args.fmoefy: + from fmoe.megatron import patch_forward_step, patch_model_provider + + forward_step_func = patch_forward_step(forward_step_func, Megatron_Version="v2.5") + model_provider = patch_model_provider(model_provider, Megatron_Version='v2.5') + # Model, optimizer, and learning rate. timers('model-and-optimizer-setup').start() model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider) @@ -386,10 +402,12 @@ def train_step(forward_step_func, data_iterator, if unwrapped_model.share_word_embeddings: word_embeddings_weight = unwrapped_model.word_embeddings_weight() - if args.DDP_impl == 'local': - grad = word_embeddings_weight.main_grad - else: - grad = word_embeddings_weight.grad + grad = word_embeddings_weight.grad + # FastMoE does not have main_grad field + # if args.DDP_impl == 'local': + # grad = word_embeddings_weight.main_grad + # else: + # grad = word_embeddings_weight.grad torch.distributed.all_reduce(grad, group=mpu.get_embedding_group()) timers('backward-embedding-all-reduce').stop() @@ -458,26 +476,13 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, # Logging. timers_to_log = [] - def add_to_logging(name): - if name in timers.timers: + # FastMoE add several timers. + # For simplicity, add all timers to log. + def add_all(): + for name in timers.timers: timers_to_log.append(name) - add_to_logging('forward-compute') - add_to_logging('forward-recv') - add_to_logging('forward-send') - add_to_logging('forward-backward-send-forward-backward-recv') - add_to_logging('backward-compute') - add_to_logging('backward-recv') - add_to_logging('backward-send') - add_to_logging('backward-send-forward-recv') - add_to_logging('backward-send-backward-recv') - add_to_logging('backward-params-all-reduce') - add_to_logging('backward-embedding-all-reduce') - add_to_logging('optimizer-copy-to-main-grad') - add_to_logging('optimizer-unscale-and-check-inf') - add_to_logging('optimizer-clip-main-grad') - add_to_logging('optimizer-copy-main-to-model-params') - add_to_logging('optimizer') - add_to_logging('batch-generator') + + add_all() # Calculate batch size. batch_size = args.micro_batch_size * args.data_parallel_size * \