Commit b31b7e3f authored by zms1999's avatar zms1999
Browse files

patch for megatron v2.5

parent b8d0e81f
......@@ -235,6 +235,7 @@ void fmoe_cuda_fused_forward_impl(
NCCL_SAFE_CALL(ncclGroupEnd());
}
}
smgr->sync(1);
delete [] local_ptr;
delete [] global_ptr;
......@@ -381,6 +382,7 @@ void fmoe_cuda_fused_backward_impl(
}
}
smgr->sync(1);
checkCudaErrors(cudaGetLastError());
delete [] local_ptr;
......
diff --git a/megatron/arguments.py b/megatron/arguments.py
index b35af1d..2a55699 100644
--- a/megatron/arguments.py
+++ b/megatron/arguments.py
@@ -20,6 +20,9 @@ import os
import torch
+# FastMoE
+from fmoe.megatron import add_fmoe_args as _add_fmoe_args
+
def parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False):
"""Parse all arguments."""
@@ -42,6 +45,9 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_vit_args(parser)
parser = _add_logging_args(parser)
+ # FastMoE arguments.
+ parser = _add_fmoe_args(parser)
+
# Custom arguments.
if extra_args_provider is not None:
parser = extra_args_provider(parser)
diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py
index 1251066..32afb2f 100644
--- a/megatron/data/indexed_dataset.py
+++ b/megatron/data/indexed_dataset.py
@@ -95,7 +95,7 @@ dtypes = {
3: np.int16,
4: np.int32,
5: np.int64,
- 6: np.float,
+ 6: np.float32,
7: np.double,
8: np.uint16
}
@@ -268,7 +268,7 @@ class IndexedDatasetBuilder(object):
np.int16: 2,
np.int32: 4,
np.int64: 8,
- np.float: 4,
+ np.float32: 4,
np.double: 8
}
diff --git a/megatron/optimizer/__init__.py b/megatron/optimizer/__init__.py
index 823a51f..32f4b2e 100644
--- a/megatron/optimizer/__init__.py
+++ b/megatron/optimizer/__init__.py
@@ -69,8 +69,10 @@ def get_megatron_optimizer(model):
# Determine whether the params have main-grad field.
params_have_main_grad = False
- if args.DDP_impl == 'local':
- params_have_main_grad = True
+
+ # FastMoE does not have main_grad field
+ # if args.DDP_impl == 'local':
+ # params_have_main_grad = True
if args.fp16 or args.bf16:
diff --git a/megatron/optimizer/clip_grads.py b/megatron/optimizer/clip_grads.py
index 036a1d4..81d5bd9 100644
--- a/megatron/optimizer/clip_grads.py
+++ b/megatron/optimizer/clip_grads.py
@@ -54,17 +54,23 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# - should not be a replica due to tensor model parallelism
grads = []
grads_for_norm = []
+ # FastMoE
+ grads_in_moe = []
for param in parameters:
grad_not_none = param.grad is not None
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
- grad = param.grad.detach()
if grad_not_none:
+ grad = param.grad.detach()
# Make sure the grads are in fp32
assert param.grad.type() == 'torch.cuda.FloatTensor'
grads.append(grad)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
- grads_for_norm.append(grad)
+ # FastMoE
+ if hasattr(param, 'dp_comm') and param.dp_comm in ('none'):
+ grads_in_moe.append(grad)
+ else:
+ grads_for_norm.append(grad)
# Norm parameters.
max_norm = float(max_norm)
@@ -73,6 +79,8 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# Calculate norm.
if norm_type == inf:
+ # FastMoE TODO
+ assert False, f"norm_type {norm_type} is not supported by FastMoE "
total_norm = max(grad.abs().max() for grad in grads_for_norm)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all model-parallel GPUs.
@@ -97,7 +105,20 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# we need the pow(norm-type).
total_norm = grad_norm ** norm_type
+ # FastMoE
+ if len(grads_in_moe) > 0 : # 'cold' experts may not have any grads in one iteration
+ grad_norm, _ = multi_tensor_applier(
+ amp_C.multi_tensor_l2norm,
+ dummy_overflow_buf,
+ [grads_in_moe],
+ False # no per-parameter norm
+ )
+ grad_norm = grad_norm ** norm_type
+ torch.distributed.all_reduce(grad_norm, op=torch.distributed.ReduceOp.SUM, group=mpu.get_model_parallel_group())
+ total_norm += grad_norm
else:
+ # FastMoE TODO
+ assert False, f"norm_type {norm_type} is not supported by FastMoE "
for grad in grads_for_norm:
grad_norm = torch.norm(grad, norm_type)
total_norm += grad_norm ** norm_type
diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py
index 368f587..080b06f 100644
--- a/megatron/optimizer/optimizer.py
+++ b/megatron/optimizer/optimizer.py
@@ -250,6 +250,9 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
param)
if hasattr(param, 'shared'):
main_param.shared = param.shared
+ # FastMoE
+ if hasattr(param, 'dp_comm'):
+ main_param.dp_comm = param.dp_comm
# Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = main_param
fp32_from_float16_params_this_group.append(main_param)
@@ -396,17 +399,26 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# so we can update the loss scale.
self.grad_scaler.update(found_inf_flag)
- # If we found inf/nan, skip the update.
- if found_inf_flag:
- return False, None, None
+ # move to L417.
+ # if found_inf_flag:
+ # return False, None, None
# Clip the main gradients.
timers('optimizer-clip-main-grad').start()
grad_norm = None
- if self.clip_grad > 0.0:
- grad_norm = self.clip_grad_norm(self.clip_grad)
+
+ # remove if branch to avoid dead-lock in FastMoE
+ # if self.clip_grad > 0.0:
+ # grad_norm = self.clip_grad_norm(self.clip_grad)
+ grad_norm = self.clip_grad_norm(self.clip_grad)
+
timers('optimizer-clip-main-grad').stop()
+ # move early return to here to avoid dead-lock in FastMoE
+ # If we found inf/nan, skip the update.
+ if found_inf_flag:
+ return False, None, None
+
# count the zeros in the grads
num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None
diff --git a/megatron/schedules.py b/megatron/schedules.py
index d346c30..8eef46c 100644
--- a/megatron/schedules.py
+++ b/megatron/schedules.py
@@ -23,7 +23,11 @@ from megatron import get_timers
from megatron import mpu
from megatron import p2p_communication
from megatron.utils import unwrap_model
-from megatron.model import DistributedDataParallel as LocalDDP
+
+# FastMoE
+# from megatron.model import DistributedDataParallel as LocalDDP
+from fmoe.megatron import DistributedDataParallel as LocalDDP
+
from megatron.model import Float16Module
def get_forward_backward_func():
@@ -54,7 +58,8 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
unwrapped_model = unwrap_model(
model, (torchDDP, LocalDDP, Float16Module))
unwrapped_model.set_input_tensor(input_tensor)
- output_tensor, loss_func = forward_step_func(data_iterator, model)
+ output_tensor, loss_func, bal_loss = forward_step_func(data_iterator, model)
+ bal_loss = bal_loss / get_num_microbatches()
if mpu.is_pipeline_last_stage():
output_tensor = loss_func(output_tensor)
loss, loss_reduced = output_tensor
@@ -62,10 +67,10 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
losses_reduced.append(loss_reduced)
timers('forward-compute').stop()
- return output_tensor
+ return output_tensor, bal_loss
-def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
+def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, bal_loss):
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
@@ -85,7 +90,9 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
# Backward pass.
if output_tensor_grad is None:
output_tensor = optimizer.scale_loss(output_tensor)
- torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
+ torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
+ else:
+ torch.autograd.backward([output_tensor,bal_loss], grad_tensors=[output_tensor_grad, None])
# Collect the grad of the input_tensor.
input_tensor_grad = None
@@ -122,18 +129,18 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
input_tensor, output_tensor_grad = None, None
with context_handler():
for i in range(get_num_microbatches() - 1):
- output_tensor = forward_step(forward_step_func, data_iterator, model,
+ output_tensor, bal_loss = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
if not forward_only:
backward_step(optimizer, input_tensor, output_tensor,
- output_tensor_grad)
+ output_tensor_grad, bal_loss)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
- output_tensor = forward_step(forward_step_func, data_iterator, model,
+ output_tensor, bal_loss = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
if not forward_only:
- backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)
+ backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, bal_loss)
return losses_reduced
@@ -144,6 +151,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
+ # FastMoE TODO
+ assert False, "FastMoE not supports pipeline with interleaving"
+
input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))]
losses_reduced = []
@@ -385,17 +395,19 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
input_tensors = []
output_tensors = []
+ bal_losses = []
losses_reduced = []
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_tensor = p2p_communication.recv_forward(timers=timers)
- output_tensor = forward_step(forward_step_func, data_iterator, model,
+ output_tensor, bal_loss = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
p2p_communication.send_forward(output_tensor, timers=timers)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
+ bal_losses.append(bal_loss)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
@@ -407,7 +419,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1))
- output_tensor = forward_step(forward_step_func, data_iterator, model,
+ output_tensor, bal_loss = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
if forward_only:
p2p_communication.send_forward(output_tensor, timers=timers)
@@ -420,16 +432,17 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
# start of the list for backward pass.
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
+ bal_losses.append(bal_loss)
if forward_only:
if not last_iteration:
input_tensor = p2p_communication.recv_forward(timers=timers)
else:
- input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
+ input_tensor, output_tensor, bal_loss = input_tensors.pop(0), output_tensors.pop(0), bal_losses.pop(0)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
- output_tensor_grad)
+ output_tensor_grad, bal_loss)
if last_iteration:
input_tensor = None
@@ -444,12 +457,13 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
for i in range(num_warmup_microbatches):
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
+ bal_loss = bal_losses.pop(0)
output_tensor_grad = p2p_communication.recv_backward(timers=timers)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
- output_tensor_grad)
+ output_tensor_grad, bal_loss)
p2p_communication.send_backward(input_tensor_grad, timers=timers)
diff --git a/megatron/training.py b/megatron/training.py
index 1ab57e9..fbe2fe8 100644
--- a/megatron/training.py
+++ b/megatron/training.py
@@ -35,14 +35,23 @@ from megatron import update_num_microbatches
from megatron import mpu
from megatron import print_rank_0
from megatron import print_rank_last
-from megatron.checkpointing import load_checkpoint
-from megatron.checkpointing import save_checkpoint
+
+# FastMoE
+# from megatron.checkpointing import load_checkpoint
+from fmoe.megatron.checkpoint import load_checkpoint
+# from megatron.checkpointing import save_checkpoint
+from fmoe.megatron.checkpoint import save_checkpoint
+
from megatron.model import Float16Module
from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
from megatron.learning_rates import AnnealingLR
-from megatron.model import DistributedDataParallel as LocalDDP
+
+# FastMoE
+# from megatron.model import DistributedDataParallel as LocalDDP
+from fmoe.megatron import DistributedDataParallel as LocalDDP
+
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import unwrap_model
from megatron.data.data_samplers import build_pretraining_data_loader
@@ -107,6 +116,13 @@ def pretrain(train_valid_test_dataset_provider,
args = get_args()
timers = get_timers()
+ # Initialize FastMoE
+ if args.fmoefy:
+ from fmoe.megatron import patch_forward_step, patch_model_provider
+
+ forward_step_func = patch_forward_step(forward_step_func, Megatron_Version="v2.5")
+ model_provider = patch_model_provider(model_provider, Megatron_Version='v2.5')
+
# Model, optimizer, and learning rate.
timers('model-and-optimizer-setup').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
@@ -386,10 +402,12 @@ def train_step(forward_step_func, data_iterator,
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
- if args.DDP_impl == 'local':
- grad = word_embeddings_weight.main_grad
- else:
- grad = word_embeddings_weight.grad
+ grad = word_embeddings_weight.grad
+ # FastMoE does not have main_grad field
+ # if args.DDP_impl == 'local':
+ # grad = word_embeddings_weight.main_grad
+ # else:
+ # grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
timers('backward-embedding-all-reduce').stop()
@@ -458,26 +476,13 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
# Logging.
timers_to_log = []
- def add_to_logging(name):
- if name in timers.timers:
+ # FastMoE add several timers.
+ # For simplicity, add all timers to log.
+ def add_all():
+ for name in timers.timers:
timers_to_log.append(name)
- add_to_logging('forward-compute')
- add_to_logging('forward-recv')
- add_to_logging('forward-send')
- add_to_logging('forward-backward-send-forward-backward-recv')
- add_to_logging('backward-compute')
- add_to_logging('backward-recv')
- add_to_logging('backward-send')
- add_to_logging('backward-send-forward-recv')
- add_to_logging('backward-send-backward-recv')
- add_to_logging('backward-params-all-reduce')
- add_to_logging('backward-embedding-all-reduce')
- add_to_logging('optimizer-copy-to-main-grad')
- add_to_logging('optimizer-unscale-and-check-inf')
- add_to_logging('optimizer-clip-main-grad')
- add_to_logging('optimizer-copy-main-to-model-params')
- add_to_logging('optimizer')
- add_to_logging('batch-generator')
+
+ add_all()
# Calculate batch size.
batch_size = args.micro_batch_size * args.data_parallel_size * \
......@@ -42,7 +42,7 @@ class DistributedGroupedDataParallel(nn.Module):
if k not in self.comms:
self.comms[k] = get_torch_default_comm()
def allreduce_params(no_scale=False,
def allreduce_gradients(no_scale=False,
reduce_after=False, fp32_allreduce=False):
groups = dict()
for p in self.module.parameters():
......@@ -74,6 +74,10 @@ class DistributedGroupedDataParallel(nn.Module):
for g, s in zip(grads, synced):
g.copy_(s)
def allreduce_params(*args, **kwargs):
return allreduce_gradients(*args, **kwargs)
self.allreduce_gradients = allreduce_gradients
self.allreduce_params = allreduce_params
if need_sync:
self._sync_params()
......
......@@ -37,6 +37,9 @@ class NaiveGate(BaseGate):
# (BxL) x 1 x top_k
gate_score = F.softmax(gate_top_k_val, dim=-1)
# dummy loss
self.set_loss(torch.zeros(1, requires_grad=True).cuda())
if return_all_scores:
return gate_top_k_idx, gate_score, gate
return gate_top_k_idx, gate_score
......@@ -14,6 +14,10 @@ def _set_groups(**kwargs):
_groups = kwargs
def get_moe_group():
return _groups["moe_group"]
def _init():
from megatron import get_args
from megatron import mpu
......@@ -39,12 +43,21 @@ class DistributedDataParallel(DistributedGroupedDataParallel):
is adapted to enable the sophiscated parallel and reduction strategies in
Fast MoE.
"""
def __init__(self, module):
def __init__(self, module, accumulate_allreduce_grads_in_fp32=False, use_contiguous_buffers_in_ddp=False):
assert not accumulate_allreduce_grads_in_fp32, "FastMoE not supports accumulate_allrecude_grads_in_fp32"
assert not use_contiguous_buffers_in_ddp, "FastMoE not supports use_contiguous_buffers_in_ddp"
if _groups is None:
_init()
super().__init__(module, **_groups)
def set_input_tensor(self, *args, **kwargs):
r"""
Keep consitency with Megatron
"""
return self.module.set_input_tensor(*args, **kwargs)
def state_dict(self, *args, **kwargs):
r"""
Keep consitency with Megatron
......
......@@ -151,6 +151,7 @@ def fmoefy(
hidden_hidden_size=None,
top_k=None,
gate=None,
megatron_version=None
):
r"""
Replace MLP layers in a transformer-based model in Megatron by MoE.
......@@ -184,11 +185,28 @@ def fmoefy(
args.hidden_hidden_size = hidden_hidden_size
for idx, l in enumerate(model.language_model.transformer.layers):
l.mlp = MegatronMLP(args, idx, gate=gate)
if megatron_version == "v2.2":
for idx, l in enumerate(model.language_model.transformer.layers):
l.mlp = MegatronMLP(args, idx, gate=gate)
# initialize gate hook
num_layers = len(model.language_model.transformer.layers)
elif megatron_version == "v2.5":
for idx, l in enumerate(model.language_model.encoder.layers):
l.mlp = MegatronMLP(args, idx, gate=gate)
if hasattr(model.language_model, "decoder"):
for idx, l in enumerate(model.language_model.decoder.layers):
l.mlp = MegatronMLP(args, idx, gate=gate)
# initialize gate hook
num_layers = len(model.language_model.encoder.layers)
if hasattr(model.language_model, "decoder"):
num_layers += len(model.language_model.decoder.layers)
else:
assert False, f"megatron_version {megatron_version} not known."
# initialize gate hook
num_layers = len(model.language_model.transformer.layers)
reset_gate_hook(num_layers)
return model
......@@ -3,8 +3,65 @@ Patching some of Megatron-LM's functions to create an MoE model
"""
import torch
def patch_loss_func_v2_5(loss_func):
r"""
Patch model's loss_func to support balance loss
"""
from megatron.mpu import is_pipeline_last_stage
from megatron.mpu import get_tensor_model_parallel_group
from megatron import get_args
from megatron import get_num_microbatches
if not get_args().balance_strategy:
return loss_func
def loss_func_with_balance_loss(model, output_tensor):
args = get_args()
assert args.balance_strategy, "Only use patched loss_func when having balance_strategy."
assert is_pipeline_last_stage(), "Only call loss_func at pipeline last stage."
output = loss_func(output_tensor)
while hasattr(model, 'module'):
model = model.module
loss_list = [l.mlp.gate.get_loss(clear=False).view(1)
for l in model.language_model.encoder.layers
if l.mlp.gate.has_loss]
def patch_forward_step(forward_step_func):
if hasattr(model.language_model, "decoder"):
loss_list_decoder = [l.mlp.gate.get_loss(clear=False).view(1)
for l in model.language_model.decoder.layers
if l.mlp.gate.has_loss]
loss_list.append(loss_list_decoder)
if len(loss_list) == 0:
return output
loss_name = args.balance_strategy + "_loss"
(loss, state_dict), bal_loss = (
output,
torch.cat(loss_list).mean() * args.balance_loss_weight / args.pipeline_model_parallel_size
)
bal_loss = bal_loss / get_num_microbatches()
# avarage across moe group
moe_group = get_tensor_model_parallel_group()
world_size = torch.distributed.get_world_size(group=moe_group)
averaged_bal_loss = bal_loss.clone().detach()
torch.distributed.all_reduce(averaged_bal_loss, group=moe_group)
averaged_bal_loss /= world_size
loss += bal_loss
state_dict[loss_name] = averaged_bal_loss
return loss, state_dict
return loss_func_with_balance_loss
def patch_forward_step(forward_step_func, Megatron_Version="v2.2"):
r"""
Patch model's forward_step_func to support balance loss
"""
......@@ -16,7 +73,7 @@ def patch_forward_step(forward_step_func):
if not get_args().balance_strategy:
return forward_step_func
def forward_step_with_balance_loss(data_iterator, model, input_tensor):
def forward_step_with_balance_loss_v2_2(data_iterator, model, input_tensor):
args = get_args()
output = forward_step_func(data_iterator, model, input_tensor)
......@@ -50,13 +107,33 @@ def patch_forward_step(forward_step_func):
return loss, state_dict
return forward_step_with_balance_loss
def forward_step_with_balance_loss_v2_5(data_iterator, model):
from functools import partial
output, loss_func = forward_step_func(data_iterator, model)
while hasattr(model, 'module'):
model = model.module
loss_list = [l.mlp.gate.get_loss(clear=False).view(1)
for l in model.language_model.encoder.layers
if l.mlp.gate.has_loss]
bal_loss = torch.cat(loss_list).mean() * get_args().balance_loss_weight / get_args().pipeline_model_parallel_size
return output, partial(patch_loss_func_v2_5(loss_func), model), bal_loss
if Megatron_Version == "v2.2":
return forward_step_with_balance_loss_v2_2
elif Megatron_Version == "v2.5":
return forward_step_with_balance_loss_v2_5
else:
assert False, f"megatron version {Megatron_Version} not known."
def patch_model_provider(model_provider, gate=None):
def patch_model_provider(model_provider, gate=None, Megatron_Version='v2.2'):
from megatron import get_args
def fmoefied_model_provider():
def fmoefied_model_provider_v2_2():
from .layers import fmoefy
args = get_args()
hhs = args.hidden_size * 4
......@@ -69,7 +146,30 @@ def patch_model_provider(model_provider, gate=None):
num_experts=args.num_experts,
hidden_hidden_size=hhs,
top_k=args.top_k,
gate=gate
gate=gate,
megatron_version="v2.2"
)
def fmoefied_model_provider_v2_5(pre_process, post_process):
from .layers import fmoefy
args = get_args()
hhs = args.hidden_size * 4
assert hhs % args.top_k == 0
hhs = hhs // args.top_k
assert hhs % args.tensor_model_parallel_size == 0
hhs = hhs // args.tensor_model_parallel_size
return fmoefy(
model_provider(pre_process=pre_process, post_process=post_process),
num_experts=args.num_experts,
hidden_hidden_size=hhs,
top_k=args.top_k,
gate=gate,
megatron_version="v2.5"
)
return fmoefied_model_provider
if Megatron_Version == 'v2.2':
return fmoefied_model_provider_v2_2
elif Megatron_Version == 'v2.5':
return fmoefied_model_provider_v2_5
else:
assert False, f"Megatron Version {Megatron_Version} unknown."
......@@ -29,6 +29,9 @@ if os.environ.get('USE_NCCL', '1') == '1':
else:
ext_libs.append('nccl')
if os.environ.get('MOE_DEBUG', '0') == '1':
cxx_flags.append('-DMOE_DEBUG')
if is_rocm_pytorch:
define_macros=[('FMOE_USE_HIP', None)]
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment