Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
a6d202a6
Unverified
Commit
a6d202a6
authored
Jul 04, 2023
by
Jiezhong Qiu
Committed by
GitHub
Jul 04, 2023
Browse files
Megatron v3.0.2 Patch (#159)
* support megatron v3.0.2 * keep num_experts for lower version of megatron
parent
d56522bc
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
570 additions
and
27 deletions
+570
-27
README.md
README.md
+1
-1
doc/readme-cn.md
doc/readme-cn.md
+1
-1
examples/megatron/README.md
examples/megatron/README.md
+1
-1
examples/megatron/v3.0.2.patch
examples/megatron/v3.0.2.patch
+498
-0
fmoe/megatron/checkpoint.py
fmoe/megatron/checkpoint.py
+27
-10
fmoe/megatron/layers.py
fmoe/megatron/layers.py
+11
-10
fmoe/megatron/patch.py
fmoe/megatron/patch.py
+24
-3
fmoe/megatron/utils.py
fmoe/megatron/utils.py
+7
-1
No files found.
README.md
View file @
a6d202a6
...
...
@@ -59,7 +59,7 @@ multiple experts.
model
=
...
from
fmoe.megatron
import
fmoefy
model
=
fmoefy
(
model
,
num_experts
=<
number
of
experts
per
worker
>
)
model
=
fmoefy
(
model
,
fmoe_
num_experts
=<
number
of
experts
per
worker
>
)
train
(
model
,
...)
```
...
...
doc/readme-cn.md
View file @
a6d202a6
...
...
@@ -50,7 +50,7 @@ Transformer 模型变为一个 MoE 的模型. 其使用方法如下.
model
=
...
from
fmoe.megatron
import
fmoefy
model
=
fmoefy
(
model
,
num_experts
=<
number
of
experts
per
worker
>
)
model
=
fmoefy
(
model
,
fmoe_
num_experts
=<
number
of
experts
per
worker
>
)
train
(
model
,
...)
```
...
...
examples/megatron/README.md
View file @
a6d202a6
...
...
@@ -25,7 +25,7 @@ transformer language models.
```
python
from
fmoe.megatron
import
fmoefy
model
=
fmoefy
(
model
,
num_experts
=
4
)
model
=
fmoefy
(
model
,
fmoe_
num_experts
=
4
)
```
Note that the
`fmoefy`
function currently only takes a standard Megatron-LM's
...
...
examples/megatron/v3.0.2.patch
0 → 100644
View file @
a6d202a6
diff --git a/megatron/arguments.py b/megatron/arguments.py
index 102e890..c3504bd 100644
--- a/megatron/arguments.py
+++ b/megatron/arguments.py
@@ -20,6 +20,9 @@
import os
import torch
+# FastMoE
+from fmoe.megatron import add_fmoe_args as _add_fmoe_args
+
def parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False):
"""Parse all arguments."""
@@ -43,6 +46,9 @@
def parse_args(extra_args_provider=None, defaults={},
parser = _add_logging_args(parser)
parser = _add_inference_args(parser)
+ # FastMoE arguments.
+ parser = _add_fmoe_args(parser)
+
# Custom arguments.
if extra_args_provider is not None:
parser = extra_args_provider(parser)
@@ -316,6 +322,12 @@
def parse_args(extra_args_provider=None, defaults={},
if args.sequence_parallel:
args.async_tensor_model_parallel_allreduce = False
+ # if fmoe_num_experts is not specified,
+ # we are using lower version of megatron,
+ # copy num_experts to fmoe_num_experts
+ if not hasattr(args, 'fmoe_num_experts'):
+ args.fmoe_num_experts = args.num_experts
+
_print_args(args)
return args
diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py
index ceba352..01754d0 100644
--- a/megatron/checkpointing.py
+++ b/megatron/checkpointing.py
@@ -124,6 +124,10 @@
def read_metadata(tracker_filename):
sys.exit()
assert iteration > 0 or release, 'error parsing metadata file {}'.format(
tracker_filename)
+
+ args = get_args()
+ if args.fmoefy:
+ return iteration, release
# Get the max iteration retrieved across the ranks.
iters_cuda = torch.cuda.LongTensor([iteration])
@@ -134,6 +138,7 @@
def read_metadata(tracker_filename):
# If not, print a warning and chose the maximum
# iteration across all ranks.
if iteration != max_iter:
+ rank = torch.distributed.get_rank()
print('WARNING: on rank {} found iteration {} in the '
'metadata while max iteration across the ranks '
'is {}, replacing it with max iteration.'.format(
@@ -399,7 +404,8 @@
def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
opt_param_scheduler.load_state_dict(state_dict['lr_scheduler'])
else:
opt_param_scheduler.load_state_dict(state_dict['opt_param_scheduler'])
- except KeyError:
+ except KeyError as e:
+ print(e)
print_rank_0('Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent '
'attempting to load the optimizer state, '
diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py
index 2f6e1b8..e2483db 100644
--- a/megatron/data/indexed_dataset.py
+++ b/megatron/data/indexed_dataset.py
@@ -95,7 +95,7 @@
dtypes = {
3: np.int16,
4: np.int32,
5: np.int64,
- 6: np.float,
+ 6: np.float32,
7: np.double,
8: np.uint16
}
@@ -268,7 +268,7 @@
class IndexedDatasetBuilder(object):
np.int16: 2,
np.int32: 4,
np.int64: 8,
- np.float: 4,
+ np.float32: 4,
np.double: 8
}
diff --git a/megatron/optimizer/__init__.py b/megatron/optimizer/__init__.py
index d8bee27..6f4ecfb 100644
--- a/megatron/optimizer/__init__.py
+++ b/megatron/optimizer/__init__.py
@@ -101,8 +101,9 @@
def get_megatron_optimizer(model,
# Determine whether the params have main-grad field.
params_have_main_grad = False
- if args.DDP_impl == 'local':
- params_have_main_grad = True
+ # FastMoE does not have main_grad field
+ # if args.DDP_impl == 'local':
+ # params_have_main_grad = True
if args.fp16 or args.bf16:
diff --git a/megatron/optimizer/clip_grads.py b/megatron/optimizer/clip_grads.py
index 36cd915..7b4eaaa 100644
--- a/megatron/optimizer/clip_grads.py
+++ b/megatron/optimizer/clip_grads.py
@@ -54,6 +54,8 @@
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# - should not be a replica due to tensor model parallelism
grads = []
grads_for_norm = []
+ # FastMoE
+ grads_in_moe = []
for param in parameters:
grad_not_none = param.grad is not None
is_not_shared = param_is_not_shared(param)
@@ -65,7 +67,11 @@
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
assert param.grad.type() == 'torch.cuda.FloatTensor'
grads.append(grad)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
- grads_for_norm.append(grad)
+ # FastMoE
+ if hasattr(param, 'dp_comm') and param.dp_comm in ('none'):
+ grads_in_moe.append(grad)
+ else:
+ grads_for_norm.append(grad)
# Norm parameters.
max_norm = float(max_norm)
@@ -74,6 +80,8 @@
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# Calculate norm.
if norm_type == inf:
+ # FastMoE TODO
+ assert False, f"norm_type {norm_type} is not supported by FastMoE "
total_norm = max(grad.abs().max() for grad in grads_for_norm)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all model-parallel GPUs.
@@ -98,7 +106,20 @@
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# we need the pow(norm-type).
total_norm = grad_norm ** norm_type
+ # FastMoE
+ if len(grads_in_moe) > 0 : # 'cold' experts may not have any grads in one iteration
+ grad_norm, _ = multi_tensor_applier(
+ amp_C.multi_tensor_l2norm,
+ dummy_overflow_buf,
+ [grads_in_moe],
+ False # no per-parameter norm
+ )
+ grad_norm = grad_norm ** norm_type
+ torch.distributed.all_reduce(grad_norm, op=torch.distributed.ReduceOp.SUM, group=mpu.get_model_parallel_group())
+ total_norm += grad_norm
else:
+ # FastMoE TODO
+ assert False, f"norm_type {norm_type} is not supported by FastMoE "
for grad in grads_for_norm:
grad_norm = torch.norm(grad, norm_type)
total_norm += grad_norm ** norm_type
diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py
index d6ac42e..7eecff4 100644
--- a/megatron/optimizer/optimizer.py
+++ b/megatron/optimizer/optimizer.py
@@ -257,6 +257,9 @@
class Float16OptimizerWithFloat16Params(MegatronOptimizer):
param)
if hasattr(param, 'shared'):
main_param.shared = param.shared
+ # FastMoE
+ if hasattr(param, 'dp_comm'):
+ main_param.dp_comm = param.dp_comm
# Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = main_param
fp32_from_float16_params_this_group.append(main_param)
@@ -411,18 +414,27 @@
class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# We are done with scaling gradients
# so we can update the loss scale.
self.grad_scaler.update(found_inf_flag)
-
+
+ # move to L433-L436
# If we found inf/nan, skip the update.
- if found_inf_flag:
- return False, None, None
+ # if found_inf_flag:
+ # return False, None, None
# Clip the main gradients.
timers('optimizer-clip-main-grad').start()
grad_norm = None
- if self.clip_grad > 0.0:
- grad_norm = self.clip_grad_norm(self.clip_grad)
+
+ # remove if branch to avoid dead-lock in FastMoE
+ # if self.clip_grad > 0.0:
+ # grad_norm = self.clip_grad_norm(self.clip_grad)
+ grad_norm = self.clip_grad_norm(self.clip_grad)
timers('optimizer-clip-main-grad').stop()
+ # move early return to here to avoid dead-lock in FastMoE
+ # If we found inf/nan, skip the update.
+ if found_inf_flag:
+ return False, None, None
+
# count the zeros in the grads
num_zeros_in_grad = self.count_zeros() if \
self.log_num_zeros_in_grad else None
diff --git a/megatron/schedules.py b/megatron/schedules.py
index ac5ba6f..26b717a 100644
--- a/megatron/schedules.py
+++ b/megatron/schedules.py
@@ -24,7 +24,10 @@
from megatron import get_timers
from megatron import mpu
from megatron import p2p_communication
from megatron.utils import unwrap_model
-from megatron.model import DistributedDataParallel as LocalDDP
+
+# FastMoE
+# from megatron.model import DistributedDataParallel as LocalDDP
+from fmoe.megatron import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.model import ModelType
@@ -66,7 +69,7 @@
def deallocate_output_tensor(out):
dtype = out.dtype,
)
-def custom_backward(output, grad_output):
+def custom_backward(output, grad_output, bal_loss):
'''Directly call C++ autograd engine.
To make the 'deallocate_output_tensor' (above) optimization work, the C++
@@ -89,11 +92,16 @@
def custom_backward(output, grad_output):
output,
memory_format = torch.preserve_format,
)
+ tensors = (output,)
+ grad_tensors = (grad_output,)
+ else:
+ tensors = (output, bal_loss)
+ grad_tensors = (grad_output, None)
# Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ]
Variable._execution_engine.run_backward(
- tensors = (output,),
- grad_tensors = (grad_output,),
+ tensors = tensors,
+ grad_tensors = grad_tensors,
keep_graph = False,
create_graph = False,
inputs = tuple(),
@@ -127,7 +135,8 @@
def forward_step(forward_step_func,
unwrap_output_tensor = True
unwrapped_model.set_input_tensor(input_tensor)
- output_tensor, loss_func = forward_step_func(data_iterator, model)
+ output_tensor, loss_func, bal_loss = forward_step_func(data_iterator, model)
+ bal_loss = bal_loss / get_num_microbatches()
if mpu.is_pipeline_last_stage():
if not collect_non_loss_data:
output_tensor = loss_func(output_tensor)
@@ -145,13 +154,14 @@
def forward_step(forward_step_func,
# downstream as well.
if mpu.is_pipeline_stage_after_split() and \
args.model_type == ModelType.encoder_and_decoder:
+ assert False, f"encoder-decoder model is not supported by FastMoE "
return [output_tensor, input_tensor[-1]]
if unwrap_output_tensor:
- return output_tensor
- return [output_tensor]
+ return output_tensor, bal_loss
+ return [output_tensor, bal_loss]
-def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
+def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, bal_loss):
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
@@ -185,7 +195,7 @@
def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
# Backward pass.
if output_tensor_grad[0] is None:
output_tensor = optimizer.scale_loss(output_tensor[0])
- custom_backward(output_tensor[0], output_tensor_grad[0])
+ custom_backward(output_tensor[0], output_tensor_grad[0], bal_loss)
# Collect the grad of the input_tensor.
input_tensor_grad = [None]
@@ -241,20 +251,20 @@
def forward_backward_no_pipelining(forward_step_func,
input_tensor, output_tensor_grad = None, None
with context_handler():
for i in range(get_num_microbatches() - 1):
- output_tensor = forward_step(forward_step_func, data_iterator,
+ output_tensor, bal_loss = forward_step(forward_step_func, data_iterator,
model, input_tensor, forward_data_store,
collect_non_loss_data)
if not forward_only:
backward_step(optimizer, input_tensor, output_tensor,
- output_tensor_grad)
+ output_tensor_grad, bal_loss)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
- output_tensor = forward_step(forward_step_func, data_iterator,
+ output_tensor, bal_loss = forward_step(forward_step_func, data_iterator,
model, input_tensor, forward_data_store,
collect_non_loss_data)
if not forward_only:
- backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)
+ backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad, bal_loss)
return forward_data_store
@@ -269,6 +279,8 @@
def forward_backward_pipelining_with_interleaving(forward_step_func,
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
+ # FastMoE TODO
+ assert False, "FastMoE not supports pipeline with interleaving"
input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))]
forward_data_store = []
@@ -646,15 +658,17 @@
def forward_backward_pipelining_without_interleaving(forward_step_func,
# Input, output tensors only need to be saved when doing backward passes
input_tensors = None
output_tensors = None
+ bal_losses = None
if not forward_only:
input_tensors = []
output_tensors = []
+ bal_losses = []
forward_data_store = []
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_tensor = recv_forward(recv_tensor_shapes, timers=timers)
- output_tensor = forward_step(forward_step_func, data_iterator, model,
+ output_tensor, bal_loss = forward_step(forward_step_func, data_iterator, model,
input_tensor, forward_data_store,
collect_non_loss_data)
send_forward(output_tensor, send_tensor_shapes, timers=timers)
@@ -662,6 +676,7 @@
def forward_backward_pipelining_without_interleaving(forward_step_func,
if not forward_only:
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
+ bal_losses.append(bal_loss)
deallocate_output_tensor(output_tensor[0])
# Before running 1F1B, need to receive first forward tensor.
@@ -674,7 +689,7 @@
def forward_backward_pipelining_without_interleaving(forward_step_func,
for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1))
- output_tensor = forward_step(forward_step_func, data_iterator, model,
+ output_tensor, bal_loss = forward_step(forward_step_func, data_iterator, model,
input_tensor, forward_data_store,
collect_non_loss_data)
if forward_only:
@@ -692,16 +707,18 @@
def forward_backward_pipelining_without_interleaving(forward_step_func,
# Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
+ bal_losses.append(bal_loss)
deallocate_output_tensor(output_tensor[0])
# Pop input_tensor and output_tensor from the start of the list for
# the backward pass.
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
+ bal_loss = bal_loss.pop(0)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
- output_tensor_grad)
+ output_tensor_grad, bal_loss)
if last_iteration:
input_tensor = None
@@ -716,12 +733,13 @@
def forward_backward_pipelining_without_interleaving(forward_step_func,
for i in range(num_warmup_microbatches):
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
+ bal_loss = bal_losses.pop(0)
output_tensor_grad = recv_backward(send_tensor_shapes, timers=timers)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
- output_tensor_grad)
+ output_tensor_grad, bal_loss)
send_backward(input_tensor_grad, recv_tensor_shapes, timers=timers)
diff --git a/megatron/training.py b/megatron/training.py
index 023bdf1..caefb88 100644
--- a/megatron/training.py
+++ b/megatron/training.py
@@ -36,8 +36,13 @@
from megatron import update_num_microbatches
from megatron import mpu
from megatron import print_rank_0
from megatron import print_rank_last
-from megatron.checkpointing import load_checkpoint
-from megatron.checkpointing import save_checkpoint
+
+# FastMoE
+# from megatron.checkpointing import load_checkpoint
+from fmoe.megatron.checkpoint import load_checkpoint
+# from megatron.checkpointing import save_checkpoint
+from fmoe.megatron.checkpoint import save_checkpoint
+
from megatron.model import Float16Module
from megatron.model import ModelType
from megatron.optimizer import get_megatron_optimizer
@@ -45,7 +50,11 @@
from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
from megatron.initialize import set_jit_fusion_options
from megatron.optimizer_param_scheduler import OptimizerParamScheduler
-from megatron.model import DistributedDataParallel as LocalDDP
+
+# FastMoE
+# from megatron.model import DistributedDataParallel as LocalDDP
+from fmoe.megatron import DistributedDataParallel as LocalDDP
+
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import unwrap_model
from megatron.data.data_samplers import build_pretraining_data_loader
@@ -119,6 +128,13 @@
def pretrain(train_valid_test_dataset_provider,
args = get_args()
timers = get_timers()
+ # Initialize FastMoE
+ if args.fmoefy:
+ from fmoe.megatron import patch_forward_step, patch_model_provider
+
+ forward_step_func = patch_forward_step(forward_step_func, Megatron_Version="v3.0.2")
+ model_provider = patch_model_provider(model_provider, Megatron_Version='v3.0.2')
+
# Model, optimizer, and learning rate.
timers('model-and-optimizer-setup').start()
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider,
@@ -466,10 +482,12 @@
def train_step(forward_step_func, data_iterator,
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
- if args.DDP_impl == 'local':
- grad = word_embeddings_weight.main_grad
- else:
- grad = word_embeddings_weight.grad
+ grad = word_embeddings_weight.grad
+ # FastMoE does not have main_grad field
+ # if args.DDP_impl == 'local':
+ # grad = word_embeddings_weight.main_grad
+ # else:
+ # grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
@@ -568,26 +586,13 @@
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
# Logging.
timers_to_log = []
- def add_to_logging(name):
- if name in timers.timers:
- timers_to_log.append(name)
- add_to_logging('forward-compute')
- add_to_logging('forward-recv')
- add_to_logging('forward-send')
- add_to_logging('forward-backward-send-forward-backward-recv')
- add_to_logging('backward-compute')
- add_to_logging('backward-recv')
- add_to_logging('backward-send')
- add_to_logging('backward-send-forward-recv')
- add_to_logging('backward-send-backward-recv')
- add_to_logging('backward-params-all-reduce')
- add_to_logging('backward-embedding-all-reduce')
- add_to_logging('optimizer-copy-to-main-grad')
- add_to_logging('optimizer-unscale-and-check-inf')
- add_to_logging('optimizer-clip-main-grad')
- add_to_logging('optimizer-copy-main-to-model-params')
- add_to_logging('optimizer')
- add_to_logging('batch-generator')
+ # FastMoE add several timers.
+ # For simplicity, add all timers to log.
+ def add_all():
+ for name in timers.timers:
+ timers_to_log.append(name)
+
+ add_all()
# Calculate batch size.
batch_size = args.micro_batch_size * args.data_parallel_size * \
--
2.25.1
fmoe/megatron/checkpoint.py
View file @
a6d202a6
...
...
@@ -54,6 +54,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
print_rank_last
from
megatron
import
utils
expert_dp_comm
=
"none"
...
...
@@ -67,8 +68,13 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
args
=
get_args
()
# Only rank zero of the data parallel writes to the disk.
if
hasattr
(
model
,
'module'
):
model
=
model
.
module
try
:
model
=
utils
.
unwrap_model
(
model
)
except
AttributeError
:
# fallback to the old way of unwrapping a model
if
hasattr
(
model
,
'module'
):
model
=
model
.
module
model
=
[
model
,]
print_rank_last
(
"saving checkpoint at iteration {:7d} to {}"
.
format
(
iteration
,
args
.
save
)
...
...
@@ -76,7 +82,8 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
# Arguments, iteration, and model.
state_dict
=
{}
state_dict
[
"model"
]
=
model
.
state_dict_for_save_checkpoint
(
assert
len
(
model
)
==
1
,
"FMoE does not support interleaved pipelining, i.e., only supports len(model) == 1 for now."
state_dict
[
"model"
]
=
model
[
0
].
state_dict_for_save_checkpoint
(
keep_vars
=
(
mpu
.
get_data_parallel_rank
()
>
0
)
)
...
...
@@ -215,6 +222,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"):
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
print_rank_last
from
megatron
import
utils
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
from
megatron.checkpointing
import
set_checkpoint_version
from
megatron.checkpointing
import
check_checkpoint_args
...
...
@@ -229,8 +237,15 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"):
args
=
get_args
()
load_dir
=
getattr
(
args
,
load_arg
)
if
hasattr
(
model
,
'module'
):
model
=
model
.
module
# Only rank zero of the data parallel writes to the disk.
try
:
model
=
utils
.
unwrap_model
(
model
)
except
AttributeError
:
# fallback to the old way of unwrapping a model
if
hasattr
(
model
,
'module'
):
model
=
model
.
module
model
=
[
model
,]
# Read the tracker file and set the iteration.
tracker_filename
=
get_checkpoint_tracker_filename
(
load_dir
)
...
...
@@ -341,7 +356,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"):
print_rank_last
(
"could not find arguments in the checkpoint ..."
)
# Model.
model
.
load_state_dict
(
state_dict
[
"model"
])
assert
len
(
model
)
==
1
,
"FMoE does not support interleaved pipelining, i.e., only supports len(model) == 1 for now."
model
[
0
].
load_state_dict
(
state_dict
[
"model"
])
# Optimizer.
if
not
release
and
not
args
.
finetune
and
not
args
.
no_load_optim
:
...
...
@@ -350,9 +366,9 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"):
optimizer
.
load_state_dict
(
state_dict
[
"optimizer"
])
if
lr_scheduler
is
not
None
:
lr_scheduler
.
load_state_dict
(
state_dict
[
"lr_scheduler"
])
except
KeyError
:
except
KeyError
as
e
:
print_rank_last
(
"
U
nable to load optimizer from checkpoint {}. "
"
FMoE is u
nable to load optimizer from checkpoint {}. "
"Specify --no-load-optim or --finetune to prevent "
"attempting to load the optimizer state, "
"exiting ..."
.
format
(
checkpoint_name_local
)
...
...
@@ -367,9 +383,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg="load"):
torch
.
set_rng_state
(
state_dict
[
"torch_rng_state"
])
torch
.
cuda
.
set_rng_state
(
state_dict
[
"cuda_rng_state"
])
mpu
.
get_cuda_rng_tracker
().
set_states
(
state_dict
[
"rng_tracker_states"
])
except
KeyError
:
except
KeyError
as
e
:
print_rank_last
(
e
)
print_rank_last
(
"
U
nable to load
optimizer
from checkpoint {}. "
"
FMoE is u
nable to load
rng state
from checkpoint {}. "
"Specify --no-load-rng or --finetune to prevent "
"attempting to load the optimizer state, "
"exiting ..."
.
format
(
checkpoint_name_local
)
...
...
fmoe/megatron/layers.py
View file @
a6d202a6
...
...
@@ -102,7 +102,7 @@ class MegatronMLP(FMoETransformerMLP):
assert
False
,
"Undefined balance strategy {}"
%
(
args
.
balance_strategy
)
super
().
__init__
(
args
.
num_experts
,
args
.
fmoe_
num_experts
,
top_k
=
args
.
top_k
,
d_model
=
args
.
hidden_size
,
d_hidden
=
args
.
hidden_hidden_size
,
...
...
@@ -110,7 +110,7 @@ class MegatronMLP(FMoETransformerMLP):
moe_group
=
moe_group
,
expert_dp_comm
=
"none"
if
args
.
distributed_experts
else
"dp"
,
gate_hook
=
generate_megatron_gate_hook
(
layer_idx
,
args
.
num_experts
*
world_size
layer_idx
,
args
.
fmoe_
num_experts
*
world_size
),
gate
=
gate
,
)
...
...
@@ -157,7 +157,7 @@ class MegatronMLP(FMoETransformerMLP):
def
fmoefy
(
model
,
num_experts
=
None
,
fmoe_
num_experts
=
None
,
distributed_experts
=
True
,
hidden_hidden_size
=
None
,
top_k
=
None
,
...
...
@@ -183,11 +183,11 @@ def fmoefy(
if
distributed_experts
is
not
None
:
args
.
distributed_experts
=
distributed_experts
if
num_experts
is
not
None
:
args
.
num_experts
=
num_experts
if
fmoe_
num_experts
is
not
None
:
args
.
fmoe_
num_experts
=
fmoe_
num_experts
assert
(
"num_experts"
in
args
),
"num_experts should be specified in arguments or fmoefy function"
"
fmoe_
num_experts"
in
args
),
"
fmoe_
num_experts should be specified in arguments or fmoefy function"
if
top_k
is
not
None
:
args
.
top_k
=
top_k
...
...
@@ -203,19 +203,20 @@ def fmoefy(
# initialize gate hook
num_layers
=
len
(
model
.
language_model
.
transformer
.
layers
)
elif
megatron_version
==
"v2.5"
:
elif
megatron_version
in
[
"v2.5"
,
"v3.0.2"
]
:
for
idx
,
l
in
enumerate
(
model
.
language_model
.
encoder
.
layers
):
l
.
mlp
=
MegatronMLP
(
args
,
idx
,
gate
=
gate
)
if
hasattr
(
model
.
language_model
,
"decoder"
):
if
hasattr
(
model
.
language_model
,
"decoder"
)
and
model
.
language_model
.
decoder
is
not
None
:
for
idx
,
l
in
enumerate
(
model
.
language_model
.
decoder
.
layers
):
l
.
mlp
=
MegatronMLP
(
args
,
idx
,
gate
=
gate
)
# initialize gate hook
num_layers
=
len
(
model
.
language_model
.
encoder
.
layers
)
if
hasattr
(
model
.
language_model
,
"decoder"
):
if
hasattr
(
model
.
language_model
,
"decoder"
)
and
model
.
language_model
.
decoder
is
not
None
:
num_layers
+=
len
(
model
.
language_model
.
decoder
.
layers
)
else
:
print
(
model
.
language_model
)
assert
False
,
f
"megatron_version
{
megatron_version
}
not known."
reset_gate_hook
(
num_layers
)
...
...
fmoe/megatron/patch.py
View file @
a6d202a6
...
...
@@ -30,7 +30,7 @@ def patch_loss_func_v2_5(loss_func):
for
l
in
model
.
language_model
.
encoder
.
layers
if
l
.
mlp
.
gate
.
has_loss
]
if
hasattr
(
model
.
language_model
,
"decoder"
):
if
hasattr
(
model
.
language_model
,
"decoder"
)
and
model
.
language_model
.
decoder
is
not
None
:
loss_list_decoder
=
[
l
.
mlp
.
gate
.
get_loss
(
clear
=
False
).
view
(
1
)
for
l
in
model
.
language_model
.
decoder
.
layers
if
l
.
mlp
.
gate
.
has_loss
]
...
...
@@ -125,6 +125,8 @@ def patch_forward_step(forward_step_func, Megatron_Version="v2.2"):
return
forward_step_with_balance_loss_v2_2
elif
Megatron_Version
==
"v2.5"
:
return
forward_step_with_balance_loss_v2_5
elif
Megatron_Version
==
"v3.0.2"
:
return
forward_step_with_balance_loss_v2_5
else
:
assert
False
,
f
"megatron version
{
Megatron_Version
}
not known."
...
...
@@ -143,7 +145,7 @@ def patch_model_provider(model_provider, gate=None, Megatron_Version='v2.2'):
hhs
=
hhs
//
args
.
tensor_model_parallel_size
return
fmoefy
(
model_provider
(),
num_experts
=
args
.
num_experts
,
fmoe_
num_experts
=
args
.
fmoe_
num_experts
,
hidden_hidden_size
=
hhs
,
top_k
=
args
.
top_k
,
gate
=
gate
,
...
...
@@ -160,16 +162,35 @@ def patch_model_provider(model_provider, gate=None, Megatron_Version='v2.2'):
hhs
=
hhs
//
args
.
tensor_model_parallel_size
return
fmoefy
(
model_provider
(
pre_process
=
pre_process
,
post_process
=
post_process
),
num_experts
=
args
.
num_experts
,
fmoe_
num_experts
=
args
.
fmoe_
num_experts
,
hidden_hidden_size
=
hhs
,
top_k
=
args
.
top_k
,
gate
=
gate
,
megatron_version
=
"v2.5"
)
def
fmoefied_model_provider_v3_0_2
(
pre_process
,
post_process
):
from
.layers
import
fmoefy
args
=
get_args
()
hhs
=
args
.
hidden_size
*
4
assert
hhs
%
args
.
top_k
==
0
hhs
=
hhs
//
args
.
top_k
assert
hhs
%
args
.
tensor_model_parallel_size
==
0
hhs
=
hhs
//
args
.
tensor_model_parallel_size
return
fmoefy
(
model_provider
(
pre_process
=
pre_process
,
post_process
=
post_process
),
fmoe_num_experts
=
args
.
fmoe_num_experts
,
hidden_hidden_size
=
hhs
,
top_k
=
args
.
top_k
,
gate
=
gate
,
megatron_version
=
"v3.0.2"
)
if
Megatron_Version
==
'v2.2'
:
return
fmoefied_model_provider_v2_2
elif
Megatron_Version
==
'v2.5'
:
return
fmoefied_model_provider_v2_5
elif
Megatron_Version
==
'v3.0.2'
:
return
fmoefied_model_provider_v3_0_2
else
:
assert
False
,
f
"Megatron Version
{
Megatron_Version
}
unknown."
fmoe/megatron/utils.py
View file @
a6d202a6
r
"""
Utility in Megatron
"""
import
argparse
def
add_fmoe_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
"fastmoe"
)
group
.
add_argument
(
"--fmoefy"
,
action
=
"store_true"
)
group
.
add_argument
(
"--num-experts"
,
type
=
int
,
default
=
None
)
try
:
group
.
add_argument
(
"--num-experts"
,
type
=
int
,
default
=
None
)
except
argparse
.
ArgumentError
:
group
.
add_argument
(
"--fmoe-num-experts"
,
type
=
int
,
default
=
None
)
group
.
add_argument
(
"--top-k"
,
type
=
int
,
default
=
2
)
group
.
add_argument
(
"--balance-loss-weight"
,
type
=
float
,
default
=
1
)
group
.
add_argument
(
"--balance-strategy"
,
type
=
str
,
default
=
None
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment