Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
698a12ae
Unverified
Commit
698a12ae
authored
Mar 21, 2023
by
Rick Ho
Committed by
GitHub
Mar 21, 2023
Browse files
Merge pull request #145 from laekov/patch_megatron_v2_5
FastMoE with Megatron-LM v2.5
parents
b8d0e81f
b31b7e3f
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
572 additions
and
14 deletions
+572
-14
cuda/fastermoe/smart_schedule.h
cuda/fastermoe/smart_schedule.h
+2
-0
examples/megatron/v2.5.patch
examples/megatron/v2.5.patch
+415
-0
fmoe/distributed.py
fmoe/distributed.py
+5
-1
fmoe/gates/naive_gate.py
fmoe/gates/naive_gate.py
+3
-0
fmoe/megatron/distributed.py
fmoe/megatron/distributed.py
+15
-2
fmoe/megatron/layers.py
fmoe/megatron/layers.py
+22
-4
fmoe/megatron/patch.py
fmoe/megatron/patch.py
+107
-7
setup.py
setup.py
+3
-0
No files found.
cuda/fastermoe/smart_schedule.h
View file @
698a12ae
...
...
@@ -235,6 +235,7 @@ void fmoe_cuda_fused_forward_impl(
NCCL_SAFE_CALL
(
ncclGroupEnd
());
}
}
smgr
->
sync
(
1
);
delete
[]
local_ptr
;
delete
[]
global_ptr
;
...
...
@@ -381,6 +382,7 @@ void fmoe_cuda_fused_backward_impl(
}
}
smgr
->
sync
(
1
);
checkCudaErrors
(
cudaGetLastError
());
delete
[]
local_ptr
;
...
...
examples/megatron/v2.5.patch
0 → 100644
View file @
698a12ae
diff --git a/megatron/arguments.py b/megatron/arguments.py
index b35af1d..2a55699 100644
--- a/megatron/arguments.py
+++ b/megatron/arguments.py
@@ -20,6 +20,9 @@
import os
import torch
+# FastMoE
+from fmoe.megatron import add_fmoe_args as _add_fmoe_args
+
def parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False):
"""Parse all arguments."""
@@ -42,6 +45,9 @@
def parse_args(extra_args_provider=None, defaults={},
parser = _add_vit_args(parser)
parser = _add_logging_args(parser)
+ # FastMoE arguments.
+ parser = _add_fmoe_args(parser)
+
# Custom arguments.
if extra_args_provider is not None:
parser = extra_args_provider(parser)
diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py
index 1251066..32afb2f 100644
--- a/megatron/data/indexed_dataset.py
+++ b/megatron/data/indexed_dataset.py
@@ -95,7 +95,7 @@
dtypes = {
3: np.int16,
4: np.int32,
5: np.int64,
- 6: np.float,
+ 6: np.float32,
7: np.double,
8: np.uint16
}
@@ -268,7 +268,7 @@
class IndexedDatasetBuilder(object):
np.int16: 2,
np.int32: 4,
np.int64: 8,
- np.float: 4,
+ np.float32: 4,
np.double: 8
}
diff --git a/megatron/optimizer/__init__.py b/megatron/optimizer/__init__.py
index 823a51f..32f4b2e 100644
--- a/megatron/optimizer/__init__.py
+++ b/megatron/optimizer/__init__.py
@@ -69,8 +69,10 @@
def get_megatron_optimizer(model):
# Determine whether the params have main-grad field.
params_have_main_grad = False
- if args.DDP_impl == 'local':
- params_have_main_grad = True
+
+ # FastMoE does not have main_grad field
+ # if args.DDP_impl == 'local':
+ # params_have_main_grad = True
if args.fp16 or args.bf16:
diff --git a/megatron/optimizer/clip_grads.py b/megatron/optimizer/clip_grads.py
index 036a1d4..81d5bd9 100644
--- a/megatron/optimizer/clip_grads.py
+++ b/megatron/optimizer/clip_grads.py
@@ -54,17 +54,23 @@
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# - should not be a replica due to tensor model parallelism
grads = []
grads_for_norm = []
+ # FastMoE
+ grads_in_moe = []
for param in parameters:
grad_not_none = param.grad is not None
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
- grad = param.grad.detach()
if grad_not_none:
+ grad = param.grad.detach()
# Make sure the grads are in fp32
assert param.grad.type() == 'torch.cuda.FloatTensor'
grads.append(grad)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
- grads_for_norm.append(grad)
+ # FastMoE
+ if hasattr(param, 'dp_comm') and param.dp_comm in ('none'):
+ grads_in_moe.append(grad)
+ else:
+ grads_for_norm.append(grad)
# Norm parameters.
max_norm = float(max_norm)
@@ -73,6 +79,8 @@
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# Calculate norm.
if norm_type == inf:
+ # FastMoE TODO
+ assert False, f"norm_type {norm_type} is not supported by FastMoE "
total_norm = max(grad.abs().max() for grad in grads_for_norm)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all model-parallel GPUs.
@@ -97,7 +105,20 @@
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# we need the pow(norm-type).
total_norm = grad_norm ** norm_type
+ # FastMoE
+ if len(grads_in_moe) > 0 : # 'cold' experts may not have any grads in one iteration
+ grad_norm, _ = multi_tensor_applier(
+ amp_C.multi_tensor_l2norm,
+ dummy_overflow_buf,
+ [grads_in_moe],
+ False # no per-parameter norm
+ )
+ grad_norm = grad_norm ** norm_type
+ torch.distributed.all_reduce(grad_norm, op=torch.distributed.ReduceOp.SUM, group=mpu.get_model_parallel_group())
+ total_norm += grad_norm
else:
+ # FastMoE TODO
+ assert False, f"norm_type {norm_type} is not supported by FastMoE "
for grad in grads_for_norm:
grad_norm = torch.norm(grad, norm_type)
total_norm += grad_norm ** norm_type
diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py
index 368f587..080b06f 100644
--- a/megatron/optimizer/optimizer.py
+++ b/megatron/optimizer/optimizer.py
@@ -250,6 +250,9 @@
class Float16OptimizerWithFloat16Params(MegatronOptimizer):
param)
if hasattr(param, 'shared'):
main_param.shared = param.shared
+ # FastMoE
+ if hasattr(param, 'dp_comm'):
+ main_param.dp_comm = param.dp_comm
# Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = main_param
fp32_from_float16_params_this_group.append(main_param)
@@ -396,17 +399,26 @@
class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# so we can update the loss scale.
self.grad_scaler.update(found_inf_flag)
- # If we found inf/nan, skip the update.
- if found_inf_flag:
- return False, None, None
+ # move to L417.
+ # if found_inf_flag:
+ # return False, None, None
# Clip the main gradients.
timers('optimizer-clip-main-grad').start()
grad_norm = None
- if self.clip_grad > 0.0:
- grad_norm = self.clip_grad_norm(self.clip_grad)
+
+ # remove if branch to avoid dead-lock in FastMoE
+ # if self.clip_grad > 0.0:
+ # grad_norm = self.clip_grad_norm(self.clip_grad)
+ grad_norm = self.clip_grad_norm(self.clip_grad)
+
timers('optimizer-clip-main-grad').stop()
+ # move early return to here to avoid dead-lock in FastMoE
+ # If we found inf/nan, skip the update.
+ if found_inf_flag:
+ return False, None, None
+
# count the zeros in the grads
num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None
diff --git a/megatron/schedules.py b/megatron/schedules.py
index d346c30..8eef46c 100644
--- a/megatron/schedules.py
+++ b/megatron/schedules.py
@@ -23,7 +23,11 @@
from megatron import get_timers
from megatron import mpu
from megatron import p2p_communication
from megatron.utils import unwrap_model
-from megatron.model import DistributedDataParallel as LocalDDP
+
+# FastMoE
+# from megatron.model import DistributedDataParallel as LocalDDP
+from fmoe.megatron import DistributedDataParallel as LocalDDP
+
from megatron.model import Float16Module
def get_forward_backward_func():
@@ -54,7 +58,8 @@
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
unwrapped_model = unwrap_model(
model, (torchDDP, LocalDDP, Float16Module))
unwrapped_model.set_input_tensor(input_tensor)
- output_tensor, loss_func = forward_step_func(data_iterator, model)
+ output_tensor, loss_func, bal_loss = forward_step_func(data_iterator, model)
+ bal_loss = bal_loss / get_num_microbatches()
if mpu.is_pipeline_last_stage():
output_tensor = loss_func(output_tensor)
loss, loss_reduced = output_tensor
@@ -62,10 +67,10 @@
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
losses_reduced.append(loss_reduced)
timers('forward-compute').stop()
- return output_tensor
+ return output_tensor, bal_loss
-def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
+def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, bal_loss):
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
@@ -85,7 +90,9 @@
def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
# Backward pass.
if output_tensor_grad is None:
output_tensor = optimizer.scale_loss(output_tensor)
- torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
+ torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
+ else:
+ torch.autograd.backward([output_tensor,bal_loss], grad_tensors=[output_tensor_grad, None])
# Collect the grad of the input_tensor.
input_tensor_grad = None
@@ -122,18 +129,18 @@
def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
input_tensor, output_tensor_grad = None, None
with context_handler():
for i in range(get_num_microbatches() - 1):
- output_tensor = forward_step(forward_step_func, data_iterator, model,
+ output_tensor, bal_loss = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
if not forward_only:
backward_step(optimizer, input_tensor, output_tensor,
- output_tensor_grad)
+ output_tensor_grad, bal_loss)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
- output_tensor = forward_step(forward_step_func, data_iterator, model,
+ output_tensor, bal_loss = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
if not forward_only:
- backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)
+ backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, bal_loss)
return losses_reduced
@@ -144,6 +151,9 @@
def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
+ # FastMoE TODO
+ assert False, "FastMoE not supports pipeline with interleaving"
+
input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))]
losses_reduced = []
@@ -385,17 +395,19 @@
def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
input_tensors = []
output_tensors = []
+ bal_losses = []
losses_reduced = []
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_tensor = p2p_communication.recv_forward(timers=timers)
- output_tensor = forward_step(forward_step_func, data_iterator, model,
+ output_tensor, bal_loss = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
p2p_communication.send_forward(output_tensor, timers=timers)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
+ bal_losses.append(bal_loss)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
@@ -407,7 +419,7 @@
def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1))
- output_tensor = forward_step(forward_step_func, data_iterator, model,
+ output_tensor, bal_loss = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
if forward_only:
p2p_communication.send_forward(output_tensor, timers=timers)
@@ -420,16 +432,17 @@
def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
# start of the list for backward pass.
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
+ bal_losses.append(bal_loss)
if forward_only:
if not last_iteration:
input_tensor = p2p_communication.recv_forward(timers=timers)
else:
- input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
+ input_tensor, output_tensor, bal_loss = input_tensors.pop(0), output_tensors.pop(0), bal_losses.pop(0)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
- output_tensor_grad)
+ output_tensor_grad, bal_loss)
if last_iteration:
input_tensor = None
@@ -444,12 +457,13 @@
def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
for i in range(num_warmup_microbatches):
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
+ bal_loss = bal_losses.pop(0)
output_tensor_grad = p2p_communication.recv_backward(timers=timers)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
- output_tensor_grad)
+ output_tensor_grad, bal_loss)
p2p_communication.send_backward(input_tensor_grad, timers=timers)
diff --git a/megatron/training.py b/megatron/training.py
index 1ab57e9..fbe2fe8 100644
--- a/megatron/training.py
+++ b/megatron/training.py
@@ -35,14 +35,23 @@
from megatron import update_num_microbatches
from megatron import mpu
from megatron import print_rank_0
from megatron import print_rank_last
-from megatron.checkpointing import load_checkpoint
-from megatron.checkpointing import save_checkpoint
+
+# FastMoE
+# from megatron.checkpointing import load_checkpoint
+from fmoe.megatron.checkpoint import load_checkpoint
+# from megatron.checkpointing import save_checkpoint
+from fmoe.megatron.checkpoint import save_checkpoint
+
from megatron.model import Float16Module
from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
from megatron.learning_rates import AnnealingLR
-from megatron.model import DistributedDataParallel as LocalDDP
+
+# FastMoE
+# from megatron.model import DistributedDataParallel as LocalDDP
+from fmoe.megatron import DistributedDataParallel as LocalDDP
+
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import unwrap_model
from megatron.data.data_samplers import build_pretraining_data_loader
@@ -107,6 +116,13 @@
def pretrain(train_valid_test_dataset_provider,
args = get_args()
timers = get_timers()
+ # Initialize FastMoE
+ if args.fmoefy:
+ from fmoe.megatron import patch_forward_step, patch_model_provider
+
+ forward_step_func = patch_forward_step(forward_step_func, Megatron_Version="v2.5")
+ model_provider = patch_model_provider(model_provider, Megatron_Version='v2.5')
+
# Model, optimizer, and learning rate.
timers('model-and-optimizer-setup').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
@@ -386,10 +402,12 @@
def train_step(forward_step_func, data_iterator,
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
- if args.DDP_impl == 'local':
- grad = word_embeddings_weight.main_grad
- else:
- grad = word_embeddings_weight.grad
+ grad = word_embeddings_weight.grad
+ # FastMoE does not have main_grad field
+ # if args.DDP_impl == 'local':
+ # grad = word_embeddings_weight.main_grad
+ # else:
+ # grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
timers('backward-embedding-all-reduce').stop()
@@ -458,26 +476,13 @@
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
# Logging.
timers_to_log = []
- def add_to_logging(name):
- if name in timers.timers:
+ # FastMoE add several timers.
+ # For simplicity, add all timers to log.
+ def add_all():
+ for name in timers.timers:
timers_to_log.append(name)
- add_to_logging('forward-compute')
- add_to_logging('forward-recv')
- add_to_logging('forward-send')
- add_to_logging('forward-backward-send-forward-backward-recv')
- add_to_logging('backward-compute')
- add_to_logging('backward-recv')
- add_to_logging('backward-send')
- add_to_logging('backward-send-forward-recv')
- add_to_logging('backward-send-backward-recv')
- add_to_logging('backward-params-all-reduce')
- add_to_logging('backward-embedding-all-reduce')
- add_to_logging('optimizer-copy-to-main-grad')
- add_to_logging('optimizer-unscale-and-check-inf')
- add_to_logging('optimizer-clip-main-grad')
- add_to_logging('optimizer-copy-main-to-model-params')
- add_to_logging('optimizer')
- add_to_logging('batch-generator')
+
+ add_all()
# Calculate batch size.
batch_size = args.micro_batch_size * args.data_parallel_size * \
fmoe/distributed.py
View file @
698a12ae
...
...
@@ -42,7 +42,7 @@ class DistributedGroupedDataParallel(nn.Module):
if
k
not
in
self
.
comms
:
self
.
comms
[
k
]
=
get_torch_default_comm
()
def
allreduce_
param
s
(
no_scale
=
False
,
def
allreduce_
gradient
s
(
no_scale
=
False
,
reduce_after
=
False
,
fp32_allreduce
=
False
):
groups
=
dict
()
for
p
in
self
.
module
.
parameters
():
...
...
@@ -74,6 +74,10 @@ class DistributedGroupedDataParallel(nn.Module):
for
g
,
s
in
zip
(
grads
,
synced
):
g
.
copy_
(
s
)
def
allreduce_params
(
*
args
,
**
kwargs
):
return
allreduce_gradients
(
*
args
,
**
kwargs
)
self
.
allreduce_gradients
=
allreduce_gradients
self
.
allreduce_params
=
allreduce_params
if
need_sync
:
self
.
_sync_params
()
...
...
fmoe/gates/naive_gate.py
View file @
698a12ae
...
...
@@ -37,6 +37,9 @@ class NaiveGate(BaseGate):
# (BxL) x 1 x top_k
gate_score
=
F
.
softmax
(
gate_top_k_val
,
dim
=-
1
)
# dummy loss
self
.
set_loss
(
torch
.
zeros
(
1
,
requires_grad
=
True
).
cuda
())
if
return_all_scores
:
return
gate_top_k_idx
,
gate_score
,
gate
return
gate_top_k_idx
,
gate_score
fmoe/megatron/distributed.py
View file @
698a12ae
...
...
@@ -14,6 +14,10 @@ def _set_groups(**kwargs):
_groups
=
kwargs
def
get_moe_group
():
return
_groups
[
"moe_group"
]
def
_init
():
from
megatron
import
get_args
from
megatron
import
mpu
...
...
@@ -39,12 +43,21 @@ class DistributedDataParallel(DistributedGroupedDataParallel):
is adapted to enable the sophiscated parallel and reduction strategies in
Fast MoE.
"""
def
__init__
(
self
,
module
):
def
__init__
(
self
,
module
,
accumulate_allreduce_grads_in_fp32
=
False
,
use_contiguous_buffers_in_ddp
=
False
):
assert
not
accumulate_allreduce_grads_in_fp32
,
"FastMoE not supports accumulate_allrecude_grads_in_fp32"
assert
not
use_contiguous_buffers_in_ddp
,
"FastMoE not supports use_contiguous_buffers_in_ddp"
if
_groups
is
None
:
_init
()
super
().
__init__
(
module
,
**
_groups
)
def
set_input_tensor
(
self
,
*
args
,
**
kwargs
):
r
"""
Keep consitency with Megatron
"""
return
self
.
module
.
set_input_tensor
(
*
args
,
**
kwargs
)
def
state_dict
(
self
,
*
args
,
**
kwargs
):
r
"""
Keep consitency with Megatron
...
...
fmoe/megatron/layers.py
View file @
698a12ae
...
...
@@ -151,6 +151,7 @@ def fmoefy(
hidden_hidden_size
=
None
,
top_k
=
None
,
gate
=
None
,
megatron_version
=
None
):
r
"""
Replace MLP layers in a transformer-based model in Megatron by MoE.
...
...
@@ -184,11 +185,28 @@ def fmoefy(
args
.
hidden_hidden_size
=
hidden_hidden_size
for
idx
,
l
in
enumerate
(
model
.
language_model
.
transformer
.
layers
):
l
.
mlp
=
MegatronMLP
(
args
,
idx
,
gate
=
gate
)
if
megatron_version
==
"v2.2"
:
for
idx
,
l
in
enumerate
(
model
.
language_model
.
transformer
.
layers
):
l
.
mlp
=
MegatronMLP
(
args
,
idx
,
gate
=
gate
)
# initialize gate hook
num_layers
=
len
(
model
.
language_model
.
transformer
.
layers
)
elif
megatron_version
==
"v2.5"
:
for
idx
,
l
in
enumerate
(
model
.
language_model
.
encoder
.
layers
):
l
.
mlp
=
MegatronMLP
(
args
,
idx
,
gate
=
gate
)
if
hasattr
(
model
.
language_model
,
"decoder"
):
for
idx
,
l
in
enumerate
(
model
.
language_model
.
decoder
.
layers
):
l
.
mlp
=
MegatronMLP
(
args
,
idx
,
gate
=
gate
)
# initialize gate hook
num_layers
=
len
(
model
.
language_model
.
encoder
.
layers
)
if
hasattr
(
model
.
language_model
,
"decoder"
):
num_layers
+=
len
(
model
.
language_model
.
decoder
.
layers
)
else
:
assert
False
,
f
"megatron_version
{
megatron_version
}
not known."
# initialize gate hook
num_layers
=
len
(
model
.
language_model
.
transformer
.
layers
)
reset_gate_hook
(
num_layers
)
return
model
fmoe/megatron/patch.py
View file @
698a12ae
...
...
@@ -3,8 +3,65 @@ Patching some of Megatron-LM's functions to create an MoE model
"""
import
torch
def
patch_loss_func_v2_5
(
loss_func
):
r
"""
Patch model's loss_func to support balance loss
"""
from
megatron.mpu
import
is_pipeline_last_stage
from
megatron.mpu
import
get_tensor_model_parallel_group
from
megatron
import
get_args
from
megatron
import
get_num_microbatches
if
not
get_args
().
balance_strategy
:
return
loss_func
def
loss_func_with_balance_loss
(
model
,
output_tensor
):
args
=
get_args
()
assert
args
.
balance_strategy
,
"Only use patched loss_func when having balance_strategy."
assert
is_pipeline_last_stage
(),
"Only call loss_func at pipeline last stage."
output
=
loss_func
(
output_tensor
)
while
hasattr
(
model
,
'module'
):
model
=
model
.
module
loss_list
=
[
l
.
mlp
.
gate
.
get_loss
(
clear
=
False
).
view
(
1
)
for
l
in
model
.
language_model
.
encoder
.
layers
if
l
.
mlp
.
gate
.
has_loss
]
def
patch_forward_step
(
forward_step_func
):
if
hasattr
(
model
.
language_model
,
"decoder"
):
loss_list_decoder
=
[
l
.
mlp
.
gate
.
get_loss
(
clear
=
False
).
view
(
1
)
for
l
in
model
.
language_model
.
decoder
.
layers
if
l
.
mlp
.
gate
.
has_loss
]
loss_list
.
append
(
loss_list_decoder
)
if
len
(
loss_list
)
==
0
:
return
output
loss_name
=
args
.
balance_strategy
+
"_loss"
(
loss
,
state_dict
),
bal_loss
=
(
output
,
torch
.
cat
(
loss_list
).
mean
()
*
args
.
balance_loss_weight
/
args
.
pipeline_model_parallel_size
)
bal_loss
=
bal_loss
/
get_num_microbatches
()
# avarage across moe group
moe_group
=
get_tensor_model_parallel_group
()
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
moe_group
)
averaged_bal_loss
=
bal_loss
.
clone
().
detach
()
torch
.
distributed
.
all_reduce
(
averaged_bal_loss
,
group
=
moe_group
)
averaged_bal_loss
/=
world_size
loss
+=
bal_loss
state_dict
[
loss_name
]
=
averaged_bal_loss
return
loss
,
state_dict
return
loss_func_with_balance_loss
def
patch_forward_step
(
forward_step_func
,
Megatron_Version
=
"v2.2"
):
r
"""
Patch model's forward_step_func to support balance loss
"""
...
...
@@ -16,7 +73,7 @@ def patch_forward_step(forward_step_func):
if
not
get_args
().
balance_strategy
:
return
forward_step_func
def
forward_step_with_balance_loss
(
data_iterator
,
model
,
input_tensor
):
def
forward_step_with_balance_loss
_v2_2
(
data_iterator
,
model
,
input_tensor
):
args
=
get_args
()
output
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
...
...
@@ -50,13 +107,33 @@ def patch_forward_step(forward_step_func):
return
loss
,
state_dict
return
forward_step_with_balance_loss
def
forward_step_with_balance_loss_v2_5
(
data_iterator
,
model
):
from
functools
import
partial
output
,
loss_func
=
forward_step_func
(
data_iterator
,
model
)
while
hasattr
(
model
,
'module'
):
model
=
model
.
module
loss_list
=
[
l
.
mlp
.
gate
.
get_loss
(
clear
=
False
).
view
(
1
)
for
l
in
model
.
language_model
.
encoder
.
layers
if
l
.
mlp
.
gate
.
has_loss
]
bal_loss
=
torch
.
cat
(
loss_list
).
mean
()
*
get_args
().
balance_loss_weight
/
get_args
().
pipeline_model_parallel_size
return
output
,
partial
(
patch_loss_func_v2_5
(
loss_func
),
model
),
bal_loss
if
Megatron_Version
==
"v2.2"
:
return
forward_step_with_balance_loss_v2_2
elif
Megatron_Version
==
"v2.5"
:
return
forward_step_with_balance_loss_v2_5
else
:
assert
False
,
f
"megatron version
{
Megatron_Version
}
not known."
def
patch_model_provider
(
model_provider
,
gate
=
None
):
def
patch_model_provider
(
model_provider
,
gate
=
None
,
Megatron_Version
=
'v2.2'
):
from
megatron
import
get_args
def
fmoefied_model_provider
():
def
fmoefied_model_provider
_v2_2
():
from
.layers
import
fmoefy
args
=
get_args
()
hhs
=
args
.
hidden_size
*
4
...
...
@@ -69,7 +146,30 @@ def patch_model_provider(model_provider, gate=None):
num_experts
=
args
.
num_experts
,
hidden_hidden_size
=
hhs
,
top_k
=
args
.
top_k
,
gate
=
gate
gate
=
gate
,
megatron_version
=
"v2.2"
)
def
fmoefied_model_provider_v2_5
(
pre_process
,
post_process
):
from
.layers
import
fmoefy
args
=
get_args
()
hhs
=
args
.
hidden_size
*
4
assert
hhs
%
args
.
top_k
==
0
hhs
=
hhs
//
args
.
top_k
assert
hhs
%
args
.
tensor_model_parallel_size
==
0
hhs
=
hhs
//
args
.
tensor_model_parallel_size
return
fmoefy
(
model_provider
(
pre_process
=
pre_process
,
post_process
=
post_process
),
num_experts
=
args
.
num_experts
,
hidden_hidden_size
=
hhs
,
top_k
=
args
.
top_k
,
gate
=
gate
,
megatron_version
=
"v2.5"
)
return
fmoefied_model_provider
if
Megatron_Version
==
'v2.2'
:
return
fmoefied_model_provider_v2_2
elif
Megatron_Version
==
'v2.5'
:
return
fmoefied_model_provider_v2_5
else
:
assert
False
,
f
"Megatron Version
{
Megatron_Version
}
unknown."
setup.py
View file @
698a12ae
...
...
@@ -29,6 +29,9 @@ if os.environ.get('USE_NCCL', '1') == '1':
else
:
ext_libs
.
append
(
'nccl'
)
if
os
.
environ
.
get
(
'MOE_DEBUG'
,
'0'
)
==
'1'
:
cxx_flags
.
append
(
'-DMOE_DEBUG'
)
if
is_rocm_pytorch
:
define_macros
=
[(
'FMOE_USE_HIP'
,
None
)]
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment