"tests/vscode:/vscode.git/clone" did not exist on "09b7bfce91d70533e0b7becba523a36fdd6549dc"
Unverified Commit a6d202a6 authored by Jiezhong Qiu's avatar Jiezhong Qiu Committed by GitHub
Browse files

Megatron v3.0.2 Patch (#159)

* support megatron v3.0.2

* keep num_experts for lower version of megatron
parent d56522bc
...@@ -59,7 +59,7 @@ multiple experts. ...@@ -59,7 +59,7 @@ multiple experts.
model = ... model = ...
from fmoe.megatron import fmoefy from fmoe.megatron import fmoefy
model = fmoefy(model, num_experts=<number of experts per worker>) model = fmoefy(model, fmoe_num_experts=<number of experts per worker>)
train(model, ...) train(model, ...)
``` ```
......
...@@ -50,7 +50,7 @@ Transformer 模型变为一个 MoE 的模型. 其使用方法如下. ...@@ -50,7 +50,7 @@ Transformer 模型变为一个 MoE 的模型. 其使用方法如下.
model = ... model = ...
from fmoe.megatron import fmoefy from fmoe.megatron import fmoefy
model = fmoefy(model, num_experts=<number of experts per worker>) model = fmoefy(model, fmoe_num_experts=<number of experts per worker>)
train(model, ...) train(model, ...)
``` ```
......
...@@ -25,7 +25,7 @@ transformer language models. ...@@ -25,7 +25,7 @@ transformer language models.
```python ```python
from fmoe.megatron import fmoefy from fmoe.megatron import fmoefy
model = fmoefy(model, num_experts=4) model = fmoefy(model, fmoe_num_experts=4)
``` ```
Note that the `fmoefy` function currently only takes a standard Megatron-LM's Note that the `fmoefy` function currently only takes a standard Megatron-LM's
......
diff --git a/megatron/arguments.py b/megatron/arguments.py
index 102e890..c3504bd 100644
--- a/megatron/arguments.py
+++ b/megatron/arguments.py
@@ -20,6 +20,9 @@ import os
import torch
+# FastMoE
+from fmoe.megatron import add_fmoe_args as _add_fmoe_args
+
def parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False):
"""Parse all arguments."""
@@ -43,6 +46,9 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_logging_args(parser)
parser = _add_inference_args(parser)
+ # FastMoE arguments.
+ parser = _add_fmoe_args(parser)
+
# Custom arguments.
if extra_args_provider is not None:
parser = extra_args_provider(parser)
@@ -316,6 +322,12 @@ def parse_args(extra_args_provider=None, defaults={},
if args.sequence_parallel:
args.async_tensor_model_parallel_allreduce = False
+ # if fmoe_num_experts is not specified,
+ # we are using lower version of megatron,
+ # copy num_experts to fmoe_num_experts
+ if not hasattr(args, 'fmoe_num_experts'):
+ args.fmoe_num_experts = args.num_experts
+
_print_args(args)
return args
diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py
index ceba352..01754d0 100644
--- a/megatron/checkpointing.py
+++ b/megatron/checkpointing.py
@@ -124,6 +124,10 @@ def read_metadata(tracker_filename):
sys.exit()
assert iteration > 0 or release, 'error parsing metadata file {}'.format(
tracker_filename)
+
+ args = get_args()
+ if args.fmoefy:
+ return iteration, release
# Get the max iteration retrieved across the ranks.
iters_cuda = torch.cuda.LongTensor([iteration])
@@ -134,6 +138,7 @@ def read_metadata(tracker_filename):
# If not, print a warning and chose the maximum
# iteration across all ranks.
if iteration != max_iter:
+ rank = torch.distributed.get_rank()
print('WARNING: on rank {} found iteration {} in the '
'metadata while max iteration across the ranks '
'is {}, replacing it with max iteration.'.format(
@@ -399,7 +404,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
opt_param_scheduler.load_state_dict(state_dict['lr_scheduler'])
else:
opt_param_scheduler.load_state_dict(state_dict['opt_param_scheduler'])
- except KeyError:
+ except KeyError as e:
+ print(e)
print_rank_0('Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent '
'attempting to load the optimizer state, '
diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py
index 2f6e1b8..e2483db 100644
--- a/megatron/data/indexed_dataset.py
+++ b/megatron/data/indexed_dataset.py
@@ -95,7 +95,7 @@ dtypes = {
3: np.int16,
4: np.int32,
5: np.int64,
- 6: np.float,
+ 6: np.float32,
7: np.double,
8: np.uint16
}
@@ -268,7 +268,7 @@ class IndexedDatasetBuilder(object):
np.int16: 2,
np.int32: 4,
np.int64: 8,
- np.float: 4,
+ np.float32: 4,
np.double: 8
}
diff --git a/megatron/optimizer/__init__.py b/megatron/optimizer/__init__.py
index d8bee27..6f4ecfb 100644
--- a/megatron/optimizer/__init__.py
+++ b/megatron/optimizer/__init__.py
@@ -101,8 +101,9 @@ def get_megatron_optimizer(model,
# Determine whether the params have main-grad field.
params_have_main_grad = False
- if args.DDP_impl == 'local':
- params_have_main_grad = True
+ # FastMoE does not have main_grad field
+ # if args.DDP_impl == 'local':
+ # params_have_main_grad = True
if args.fp16 or args.bf16:
diff --git a/megatron/optimizer/clip_grads.py b/megatron/optimizer/clip_grads.py
index 36cd915..7b4eaaa 100644
--- a/megatron/optimizer/clip_grads.py
+++ b/megatron/optimizer/clip_grads.py
@@ -54,6 +54,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# - should not be a replica due to tensor model parallelism
grads = []
grads_for_norm = []
+ # FastMoE
+ grads_in_moe = []
for param in parameters:
grad_not_none = param.grad is not None
is_not_shared = param_is_not_shared(param)
@@ -65,7 +67,11 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
assert param.grad.type() == 'torch.cuda.FloatTensor'
grads.append(grad)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
- grads_for_norm.append(grad)
+ # FastMoE
+ if hasattr(param, 'dp_comm') and param.dp_comm in ('none'):
+ grads_in_moe.append(grad)
+ else:
+ grads_for_norm.append(grad)
# Norm parameters.
max_norm = float(max_norm)
@@ -74,6 +80,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# Calculate norm.
if norm_type == inf:
+ # FastMoE TODO
+ assert False, f"norm_type {norm_type} is not supported by FastMoE "
total_norm = max(grad.abs().max() for grad in grads_for_norm)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all model-parallel GPUs.
@@ -98,7 +106,20 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# we need the pow(norm-type).
total_norm = grad_norm ** norm_type
+ # FastMoE
+ if len(grads_in_moe) > 0 : # 'cold' experts may not have any grads in one iteration
+ grad_norm, _ = multi_tensor_applier(
+ amp_C.multi_tensor_l2norm,
+ dummy_overflow_buf,
+ [grads_in_moe],
+ False # no per-parameter norm
+ )
+ grad_norm = grad_norm ** norm_type
+ torch.distributed.all_reduce(grad_norm, op=torch.distributed.ReduceOp.SUM, group=mpu.get_model_parallel_group())
+ total_norm += grad_norm
else:
+ # FastMoE TODO
+ assert False, f"norm_type {norm_type} is not supported by FastMoE "
for grad in grads_for_norm:
grad_norm = torch.norm(grad, norm_type)
total_norm += grad_norm ** norm_type
diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py
index d6ac42e..7eecff4 100644
--- a/megatron/optimizer/optimizer.py
+++ b/megatron/optimizer/optimizer.py
@@ -257,6 +257,9 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
param)
if hasattr(param, 'shared'):
main_param.shared = param.shared
+ # FastMoE
+ if hasattr(param, 'dp_comm'):
+ main_param.dp_comm = param.dp_comm
# Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = main_param
fp32_from_float16_params_this_group.append(main_param)
@@ -411,18 +414,27 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# We are done with scaling gradients
# so we can update the loss scale.
self.grad_scaler.update(found_inf_flag)
-
+
+ # move to L433-L436
# If we found inf/nan, skip the update.
- if found_inf_flag:
- return False, None, None
+ # if found_inf_flag:
+ # return False, None, None
# Clip the main gradients.
timers('optimizer-clip-main-grad').start()
grad_norm = None
- if self.clip_grad > 0.0:
- grad_norm = self.clip_grad_norm(self.clip_grad)
+
+ # remove if branch to avoid dead-lock in FastMoE
+ # if self.clip_grad > 0.0:
+ # grad_norm = self.clip_grad_norm(self.clip_grad)
+ grad_norm = self.clip_grad_norm(self.clip_grad)
timers('optimizer-clip-main-grad').stop()
+ # move early return to here to avoid dead-lock in FastMoE
+ # If we found inf/nan, skip the update.
+ if found_inf_flag:
+ return False, None, None
+
# count the zeros in the grads
num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None
diff --git a/megatron/schedules.py b/megatron/schedules.py
index ac5ba6f..26b717a 100644
--- a/megatron/schedules.py
+++ b/megatron/schedules.py
@@ -24,7 +24,10 @@ from megatron import get_timers
from megatron import mpu
from megatron import p2p_communication
from megatron.utils import unwrap_model
-from megatron.model import DistributedDataParallel as LocalDDP
+
+# FastMoE
+# from megatron.model import DistributedDataParallel as LocalDDP
+from fmoe.megatron import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.model import ModelType
@@ -66,7 +69,7 @@ def deallocate_output_tensor(out):
dtype = out.dtype,
)
-def custom_backward(output, grad_output):
+def custom_backward(output, grad_output, bal_loss):
'''Directly call C++ autograd engine.
To make the 'deallocate_output_tensor' (above) optimization work, the C++
@@ -89,11 +92,16 @@ def custom_backward(output, grad_output):
output,
memory_format = torch.preserve_format,
)
+ tensors = (output,)
+ grad_tensors = (grad_output,)
+ else:
+ tensors = (output, bal_loss)
+ grad_tensors = (grad_output, None)
# Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ]
Variable._execution_engine.run_backward(
- tensors = (output,),
- grad_tensors = (grad_output,),
+ tensors = tensors,
+ grad_tensors = grad_tensors,
keep_graph = False,
create_graph = False,
inputs = tuple(),
@@ -127,7 +135,8 @@ def forward_step(forward_step_func,
unwrap_output_tensor = True
unwrapped_model.set_input_tensor(input_tensor)
- output_tensor, loss_func = forward_step_func(data_iterator, model)
+ output_tensor, loss_func, bal_loss = forward_step_func(data_iterator, model)
+ bal_loss = bal_loss / get_num_microbatches()
if mpu.is_pipeline_last_stage():
if not collect_non_loss_data:
output_tensor = loss_func(output_tensor)
@@ -145,13 +154,14 @@ def forward_step(forward_step_func,
# downstream as well.
if mpu.is_pipeline_stage_after_split() and \
args.model_type == ModelType.encoder_and_decoder:
+ assert False, f"encoder-decoder model is not supported by FastMoE "
return [output_tensor, input_tensor[-1]]
if unwrap_output_tensor:
- return output_tensor
- return [output_tensor]
+ return output_tensor, bal_loss
+ return [output_tensor, bal_loss]
-def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
+def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, bal_loss):
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
@@ -185,7 +195,7 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
# Backward pass.
if output_tensor_grad[0] is None:
output_tensor = optimizer.scale_loss(output_tensor[0])
- custom_backward(output_tensor[0], output_tensor_grad[0])
+ custom_backward(output_tensor[0], output_tensor_grad[0], bal_loss)
# Collect the grad of the input_tensor.
input_tensor_grad = [None]
@@ -241,20 +251,20 @@ def forward_backward_no_pipelining(forward_step_func,
input_tensor, output_tensor_grad = None, None
with context_handler():
for i in range(get_num_microbatches() - 1):
- output_tensor = forward_step(forward_step_func, data_iterator,
+ output_tensor, bal_loss = forward_step(forward_step_func, data_iterator,
model, input_tensor, forward_data_store,
collect_non_loss_data)
if not forward_only:
backward_step(optimizer, input_tensor, output_tensor,
- output_tensor_grad)
+ output_tensor_grad, bal_loss)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
- output_tensor = forward_step(forward_step_func, data_iterator,
+ output_tensor, bal_loss = forward_step(forward_step_func, data_iterator,
model, input_tensor, forward_data_store,
collect_non_loss_data)
if not forward_only:
- backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)
+ backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, bal_loss)
return forward_data_store
@@ -269,6 +279,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
+ # FastMoE TODO
+ assert False, "FastMoE not supports pipeline with interleaving"
input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))]
forward_data_store = []
@@ -646,15 +658,17 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
# Input, output tensors only need to be saved when doing backward passes
input_tensors = None
output_tensors = None
+ bal_losses = None
if not forward_only:
input_tensors = []
output_tensors = []
+ bal_losses = []
forward_data_store = []
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
- output_tensor = forward_step(forward_step_func, data_iterator, model,
+ output_tensor, bal_loss = forward_step(forward_step_func, data_iterator, model,
input_tensor, forward_data_store,
collect_non_loss_data)
send_forward(output_tensor, send_tensor_shapes, timers=timers)
@@ -662,6 +676,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
if not forward_only:
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
+ bal_losses.append(bal_loss)
deallocate_output_tensor(output_tensor[0])
# Before running 1F1B, need to receive first forward tensor.
@@ -674,7 +689,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1))
- output_tensor = forward_step(forward_step_func, data_iterator, model,
+ output_tensor, bal_loss = forward_step(forward_step_func, data_iterator, model,
input_tensor, forward_data_store,
collect_non_loss_data)
if forward_only:
@@ -692,16 +707,18 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
# Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
+ bal_losses.append(bal_loss)
deallocate_output_tensor(output_tensor[0])
# Pop input_tensor and output_tensor from the start of the list for
# the backward pass.
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
+ bal_loss = bal_loss.pop(0)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
- output_tensor_grad)
+ output_tensor_grad, bal_loss)
if last_iteration:
input_tensor = None
@@ -716,12 +733,13 @@ def forward_backward_pipelining_without_interleaving(forward_step_func,
for i in range(num_warmup_microbatches):
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
+ bal_loss = bal_losses.pop(0)
output_tensor_grad = recv_backward(send_tensor_shapes, timers=timers)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
- output_tensor_grad)
+ output_tensor_grad, bal_loss)
send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
diff --git a/megatron/training.py b/megatron/training.py
index 023bdf1..caefb88 100644
--- a/megatron/training.py
+++ b/megatron/training.py
@@ -36,8 +36,13 @@ from megatron import update_num_microbatches
from megatron import mpu
from megatron import print_rank_0
from megatron import print_rank_last
-from megatron.checkpointing import load_checkpoint
-from megatron.checkpointing import save_checkpoint
+
+# FastMoE
+# from megatron.checkpointing import load_checkpoint
+from fmoe.megatron.checkpoint import load_checkpoint
+# from megatron.checkpointing import save_checkpoint
+from fmoe.megatron.checkpoint import save_checkpoint
+
from megatron.model import Float16Module
from megatron.model import ModelType
from megatron.optimizer import get_megatron_optimizer
@@ -45,7 +50,11 @@ from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
from megatron.initialize import set_jit_fusion_options
from megatron.optimizer_param_scheduler import OptimizerParamScheduler
-from megatron.model import DistributedDataParallel as LocalDDP
+
+# FastMoE
+# from megatron.model import DistributedDataParallel as LocalDDP
+from fmoe.megatron import DistributedDataParallel as LocalDDP
+
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import unwrap_model
from megatron.data.data_samplers import build_pretraining_data_loader
@@ -119,6 +128,13 @@ def pretrain(train_valid_test_dataset_provider,
args = get_args()
timers = get_timers()
+ # Initialize FastMoE
+ if args.fmoefy:
+ from fmoe.megatron import patch_forward_step, patch_model_provider
+
+ forward_step_func = patch_forward_step(forward_step_func, Megatron_Version="v3.0.2")
+ model_provider = patch_model_provider(model_provider, Megatron_Version='v3.0.2')
+
# Model, optimizer, and learning rate.
timers('model-and-optimizer-setup').start()
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider,
@@ -466,10 +482,12 @@ def train_step(forward_step_func, data_iterator,
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
- if args.DDP_impl == 'local':
- grad = word_embeddings_weight.main_grad
- else:
- grad = word_embeddings_weight.grad
+ grad = word_embeddings_weight.grad
+ # FastMoE does not have main_grad field
+ # if args.DDP_impl == 'local':
+ # grad = word_embeddings_weight.main_grad
+ # else:
+ # grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
@@ -568,26 +586,13 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
# Logging.
timers_to_log = []
- def add_to_logging(name):
- if name in timers.timers:
- timers_to_log.append(name)
- add_to_logging('forward-compute')
- add_to_logging('forward-recv')
- add_to_logging('forward-send')
- add_to_logging('forward-backward-send-forward-backward-recv')
- add_to_logging('backward-compute')
- add_to_logging('backward-recv')
- add_to_logging('backward-send')
- add_to_logging('backward-send-forward-recv')
- add_to_logging('backward-send-backward-recv')
- add_to_logging('backward-params-all-reduce')
- add_to_logging('backward-embedding-all-reduce')
- add_to_logging('optimizer-copy-to-main-grad')
- add_to_logging('optimizer-unscale-and-check-inf')
- add_to_logging('optimizer-clip-main-grad')
- add_to_logging('optimizer-copy-main-to-model-params')
- add_to_logging('optimizer')
- add_to_logging('batch-generator')
+ # FastMoE add several timers.
+ # For simplicity, add all timers to log.
+ def add_all():
+ for name in timers.timers:
+ timers_to_log.append(name)
+
+ add_all()
# Calculate batch size.
batch_size = args.micro_batch_size * args.data_parallel_size * \
--
2.25.1
...@@ -54,6 +54,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -54,6 +54,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron import print_rank_last from megatron import print_rank_last
from megatron import utils
expert_dp_comm = "none" expert_dp_comm = "none"
...@@ -67,8 +68,13 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -67,8 +68,13 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
args = get_args() args = get_args()
# Only rank zero of the data parallel writes to the disk. # Only rank zero of the data parallel writes to the disk.
if hasattr(model, 'module'): try:
model = model.module model = utils.unwrap_model(model)
except AttributeError:
# fallback to the old way of unwrapping a model
if hasattr(model, 'module'):
model = model.module
model = [model,]
print_rank_last( print_rank_last(
"saving checkpoint at iteration {:7d} to {}".format(iteration, args.save) "saving checkpoint at iteration {:7d} to {}".format(iteration, args.save)
...@@ -76,7 +82,8 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -76,7 +82,8 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
# Arguments, iteration, and model. # Arguments, iteration, and model.
state_dict = {} state_dict = {}
state_dict["model"] = model.state_dict_for_save_checkpoint( assert len(model) == 1, "FMoE does not support interleaved pipelining, i.e., only supports len(model) == 1 for now."
state_dict["model"] = model[0].state_dict_for_save_checkpoint(
keep_vars=(mpu.get_data_parallel_rank() > 0) keep_vars=(mpu.get_data_parallel_rank() > 0)
) )
...@@ -215,6 +222,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"): ...@@ -215,6 +222,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"):
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron import print_rank_last from megatron import print_rank_last
from megatron import utils
from megatron.checkpointing import get_checkpoint_tracker_filename from megatron.checkpointing import get_checkpoint_tracker_filename
from megatron.checkpointing import set_checkpoint_version from megatron.checkpointing import set_checkpoint_version
from megatron.checkpointing import check_checkpoint_args from megatron.checkpointing import check_checkpoint_args
...@@ -229,8 +237,15 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"): ...@@ -229,8 +237,15 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"):
args = get_args() args = get_args()
load_dir = getattr(args, load_arg) load_dir = getattr(args, load_arg)
if hasattr(model, 'module'): # Only rank zero of the data parallel writes to the disk.
model = model.module try:
model = utils.unwrap_model(model)
except AttributeError:
# fallback to the old way of unwrapping a model
if hasattr(model, 'module'):
model = model.module
model = [model,]
# Read the tracker file and set the iteration. # Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(load_dir) tracker_filename = get_checkpoint_tracker_filename(load_dir)
...@@ -341,7 +356,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"): ...@@ -341,7 +356,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"):
print_rank_last("could not find arguments in the checkpoint ...") print_rank_last("could not find arguments in the checkpoint ...")
# Model. # Model.
model.load_state_dict(state_dict["model"]) assert len(model) == 1, "FMoE does not support interleaved pipelining, i.e., only supports len(model) == 1 for now."
model[0].load_state_dict(state_dict["model"])
# Optimizer. # Optimizer.
if not release and not args.finetune and not args.no_load_optim: if not release and not args.finetune and not args.no_load_optim:
...@@ -350,9 +366,9 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"): ...@@ -350,9 +366,9 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"):
optimizer.load_state_dict(state_dict["optimizer"]) optimizer.load_state_dict(state_dict["optimizer"])
if lr_scheduler is not None: if lr_scheduler is not None:
lr_scheduler.load_state_dict(state_dict["lr_scheduler"]) lr_scheduler.load_state_dict(state_dict["lr_scheduler"])
except KeyError: except KeyError as e:
print_rank_last( print_rank_last(
"Unable to load optimizer from checkpoint {}. " "FMoE is unable to load optimizer from checkpoint {}. "
"Specify --no-load-optim or --finetune to prevent " "Specify --no-load-optim or --finetune to prevent "
"attempting to load the optimizer state, " "attempting to load the optimizer state, "
"exiting ...".format(checkpoint_name_local) "exiting ...".format(checkpoint_name_local)
...@@ -367,9 +383,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"): ...@@ -367,9 +383,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"):
torch.set_rng_state(state_dict["torch_rng_state"]) torch.set_rng_state(state_dict["torch_rng_state"])
torch.cuda.set_rng_state(state_dict["cuda_rng_state"]) torch.cuda.set_rng_state(state_dict["cuda_rng_state"])
mpu.get_cuda_rng_tracker().set_states(state_dict["rng_tracker_states"]) mpu.get_cuda_rng_tracker().set_states(state_dict["rng_tracker_states"])
except KeyError: except KeyError as e:
print_rank_last(e)
print_rank_last( print_rank_last(
"Unable to load optimizer from checkpoint {}. " "FMoE is unable to load rng state from checkpoint {}. "
"Specify --no-load-rng or --finetune to prevent " "Specify --no-load-rng or --finetune to prevent "
"attempting to load the optimizer state, " "attempting to load the optimizer state, "
"exiting ...".format(checkpoint_name_local) "exiting ...".format(checkpoint_name_local)
......
...@@ -102,7 +102,7 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -102,7 +102,7 @@ class MegatronMLP(FMoETransformerMLP):
assert False, "Undefined balance strategy {}" % (args.balance_strategy) assert False, "Undefined balance strategy {}" % (args.balance_strategy)
super().__init__( super().__init__(
args.num_experts, args.fmoe_num_experts,
top_k=args.top_k, top_k=args.top_k,
d_model=args.hidden_size, d_model=args.hidden_size,
d_hidden=args.hidden_hidden_size, d_hidden=args.hidden_hidden_size,
...@@ -110,7 +110,7 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -110,7 +110,7 @@ class MegatronMLP(FMoETransformerMLP):
moe_group=moe_group, moe_group=moe_group,
expert_dp_comm="none" if args.distributed_experts else "dp", expert_dp_comm="none" if args.distributed_experts else "dp",
gate_hook=generate_megatron_gate_hook( gate_hook=generate_megatron_gate_hook(
layer_idx, args.num_experts * world_size layer_idx, args.fmoe_num_experts * world_size
), ),
gate=gate, gate=gate,
) )
...@@ -157,7 +157,7 @@ class MegatronMLP(FMoETransformerMLP): ...@@ -157,7 +157,7 @@ class MegatronMLP(FMoETransformerMLP):
def fmoefy( def fmoefy(
model, model,
num_experts=None, fmoe_num_experts=None,
distributed_experts=True, distributed_experts=True,
hidden_hidden_size=None, hidden_hidden_size=None,
top_k=None, top_k=None,
...@@ -183,11 +183,11 @@ def fmoefy( ...@@ -183,11 +183,11 @@ def fmoefy(
if distributed_experts is not None: if distributed_experts is not None:
args.distributed_experts = distributed_experts args.distributed_experts = distributed_experts
if num_experts is not None: if fmoe_num_experts is not None:
args.num_experts = num_experts args.fmoe_num_experts = fmoe_num_experts
assert ( assert (
"num_experts" in args "fmoe_num_experts" in args
), "num_experts should be specified in arguments or fmoefy function" ), "fmoe_num_experts should be specified in arguments or fmoefy function"
if top_k is not None: if top_k is not None:
args.top_k = top_k args.top_k = top_k
...@@ -203,19 +203,20 @@ def fmoefy( ...@@ -203,19 +203,20 @@ def fmoefy(
# initialize gate hook # initialize gate hook
num_layers = len(model.language_model.transformer.layers) num_layers = len(model.language_model.transformer.layers)
elif megatron_version == "v2.5": elif megatron_version in ["v2.5", "v3.0.2"]:
for idx, l in enumerate(model.language_model.encoder.layers): for idx, l in enumerate(model.language_model.encoder.layers):
l.mlp = MegatronMLP(args, idx, gate=gate) l.mlp = MegatronMLP(args, idx, gate=gate)
if hasattr(model.language_model, "decoder"): if hasattr(model.language_model, "decoder") and model.language_model.decoder is not None:
for idx, l in enumerate(model.language_model.decoder.layers): for idx, l in enumerate(model.language_model.decoder.layers):
l.mlp = MegatronMLP(args, idx, gate=gate) l.mlp = MegatronMLP(args, idx, gate=gate)
# initialize gate hook # initialize gate hook
num_layers = len(model.language_model.encoder.layers) num_layers = len(model.language_model.encoder.layers)
if hasattr(model.language_model, "decoder"): if hasattr(model.language_model, "decoder") and model.language_model.decoder is not None:
num_layers += len(model.language_model.decoder.layers) num_layers += len(model.language_model.decoder.layers)
else: else:
print(model.language_model)
assert False, f"megatron_version {megatron_version} not known." assert False, f"megatron_version {megatron_version} not known."
reset_gate_hook(num_layers) reset_gate_hook(num_layers)
......
...@@ -30,7 +30,7 @@ def patch_loss_func_v2_5(loss_func): ...@@ -30,7 +30,7 @@ def patch_loss_func_v2_5(loss_func):
for l in model.language_model.encoder.layers for l in model.language_model.encoder.layers
if l.mlp.gate.has_loss] if l.mlp.gate.has_loss]
if hasattr(model.language_model, "decoder"): if hasattr(model.language_model, "decoder") and model.language_model.decoder is not None:
loss_list_decoder = [l.mlp.gate.get_loss(clear=False).view(1) loss_list_decoder = [l.mlp.gate.get_loss(clear=False).view(1)
for l in model.language_model.decoder.layers for l in model.language_model.decoder.layers
if l.mlp.gate.has_loss] if l.mlp.gate.has_loss]
...@@ -125,6 +125,8 @@ def patch_forward_step(forward_step_func, Megatron_Version="v2.2"): ...@@ -125,6 +125,8 @@ def patch_forward_step(forward_step_func, Megatron_Version="v2.2"):
return forward_step_with_balance_loss_v2_2 return forward_step_with_balance_loss_v2_2
elif Megatron_Version == "v2.5": elif Megatron_Version == "v2.5":
return forward_step_with_balance_loss_v2_5 return forward_step_with_balance_loss_v2_5
elif Megatron_Version == "v3.0.2":
return forward_step_with_balance_loss_v2_5
else: else:
assert False, f"megatron version {Megatron_Version} not known." assert False, f"megatron version {Megatron_Version} not known."
...@@ -143,7 +145,7 @@ def patch_model_provider(model_provider, gate=None, Megatron_Version='v2.2'): ...@@ -143,7 +145,7 @@ def patch_model_provider(model_provider, gate=None, Megatron_Version='v2.2'):
hhs = hhs // args.tensor_model_parallel_size hhs = hhs // args.tensor_model_parallel_size
return fmoefy( return fmoefy(
model_provider(), model_provider(),
num_experts=args.num_experts, fmoe_num_experts=args.fmoe_num_experts,
hidden_hidden_size=hhs, hidden_hidden_size=hhs,
top_k=args.top_k, top_k=args.top_k,
gate=gate, gate=gate,
...@@ -160,16 +162,35 @@ def patch_model_provider(model_provider, gate=None, Megatron_Version='v2.2'): ...@@ -160,16 +162,35 @@ def patch_model_provider(model_provider, gate=None, Megatron_Version='v2.2'):
hhs = hhs // args.tensor_model_parallel_size hhs = hhs // args.tensor_model_parallel_size
return fmoefy( return fmoefy(
model_provider(pre_process=pre_process, post_process=post_process), model_provider(pre_process=pre_process, post_process=post_process),
num_experts=args.num_experts, fmoe_num_experts=args.fmoe_num_experts,
hidden_hidden_size=hhs, hidden_hidden_size=hhs,
top_k=args.top_k, top_k=args.top_k,
gate=gate, gate=gate,
megatron_version="v2.5" megatron_version="v2.5"
) )
def fmoefied_model_provider_v3_0_2(pre_process, post_process):
from .layers import fmoefy
args = get_args()
hhs = args.hidden_size * 4
assert hhs % args.top_k == 0
hhs = hhs // args.top_k
assert hhs % args.tensor_model_parallel_size == 0
hhs = hhs // args.tensor_model_parallel_size
return fmoefy(
model_provider(pre_process=pre_process, post_process=post_process),
fmoe_num_experts=args.fmoe_num_experts,
hidden_hidden_size=hhs,
top_k=args.top_k,
gate=gate,
megatron_version="v3.0.2"
)
if Megatron_Version == 'v2.2': if Megatron_Version == 'v2.2':
return fmoefied_model_provider_v2_2 return fmoefied_model_provider_v2_2
elif Megatron_Version == 'v2.5': elif Megatron_Version == 'v2.5':
return fmoefied_model_provider_v2_5 return fmoefied_model_provider_v2_5
elif Megatron_Version == 'v3.0.2':
return fmoefied_model_provider_v3_0_2
else: else:
assert False, f"Megatron Version {Megatron_Version} unknown." assert False, f"Megatron Version {Megatron_Version} unknown."
r""" r"""
Utility in Megatron Utility in Megatron
""" """
import argparse
def add_fmoe_args(parser): def add_fmoe_args(parser):
group = parser.add_argument_group(title="fastmoe") group = parser.add_argument_group(title="fastmoe")
group.add_argument("--fmoefy", action="store_true") group.add_argument("--fmoefy", action="store_true")
group.add_argument("--num-experts", type=int, default=None) try:
group.add_argument("--num-experts", type=int, default=None)
except argparse.ArgumentError:
group.add_argument("--fmoe-num-experts", type=int, default=None)
group.add_argument("--top-k", type=int, default=2) group.add_argument("--top-k", type=int, default=2)
group.add_argument("--balance-loss-weight", type=float, default=1) group.add_argument("--balance-loss-weight", type=float, default=1)
group.add_argument("--balance-strategy", type=str, default=None) group.add_argument("--balance-strategy", type=str, default=None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment