Commit 4eec9807 authored by Sengxian's avatar Sengxian
Browse files

Adapt balance loss for new gate interface & update patch

parent fa5f45f0
diff --git a/megatron/training.py b/megatron/training.py
index 96aec98..fe55dbd 100644
--- a/megatron/training.py
+++ b/megatron/training.py
@@ -33,7 +33,7 @@ from megatron.fp16 import FP16_Module
from megatron.fp16 import FP16_Optimizer
from megatron.initialize import initialize_megatron
from megatron.learning_rates import AnnealingLR
-from megatron.model import DistributedDataParallel as LocalDDP
+from fmoe.megatron import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization
from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination
diff --git a/pretrain_bert.py b/pretrain_bert.py
index b937b36..5841256 100644
--- a/pretrain_bert.py
+++ b/pretrain_bert.py
@@ -37,6 +37,8 @@ def model_provider():
num_tokentypes=2,
add_binary_head=True,
parallel_output=True)
+ from fmoe.megatron import fmoefy
+ model = fmoefy(model, num_experts=4)
return model
diff --git a/megatron/training.py b/megatron/training.py
index 56d1c7c..9c624d2 100644
--- a/megatron/training.py
+++ b/megatron/training.py
@@ -43,7 +43,8 @@ from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
from megatron.learning_rates import AnnealingLR
-from megatron.model import DistributedDataParallel as LocalDDP
+# from megatron.model import DistributedDataParallel as LocalDDP
+from fmoe.megatron import DistributedDataParallel as LocalDDP
from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination
from megatron.data.data_loaders import build_pretraining_data_loader
diff --git a/pretrain_bert.py b/pretrain_bert.py
index 48bc6ad..48628ce 100644
--- a/pretrain_bert.py
+++ b/pretrain_bert.py
@@ -52,6 +52,8 @@ def model_provider():
num_tokentypes=2,
add_binary_head=True,
parallel_output=True)
+ from fmoe.megatron import fmoefy
+ model = fmoefy(model, num_experts=4)
return model
diff --git a/megatron/arguments.py b/megatron/arguments.py
index 26a7cec..0acfb22 100644
--- a/megatron/arguments.py
+++ b/megatron/arguments.py
@@ -21,6 +21,8 @@ import os
import torch
from megatron import fused_kernels
+from fmoe.megatron import add_fmoe_args as _add_fmoe_args
+
def parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False):
"""Parse all arguments."""
@@ -40,6 +42,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_data_args(parser)
parser = _add_autoresume_args(parser)
parser = _add_realm_args(parser)
+ parser = _add_fmoe_args(parser)
# Custom arguments.
if extra_args_provider is not None:
diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py
index 9d42260..2583db2 100644
--- a/megatron/optimizer/optimizer.py
+++ b/megatron/optimizer/optimizer.py
@@ -177,6 +177,8 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
param)
if hasattr(param, 'shared'):
main_param.shared = param.shared
+ if hasattr(param, 'dp_comm'):
+ main_param.dp_comm = param.dp_comm
# Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = main_param
fp32_from_fp16_params_this_group.append(main_param)
diff --git a/megatron/training.py b/megatron/training.py
index 56d1c7c..f825bf3 100644
--- a/megatron/training.py
+++ b/megatron/training.py
@@ -35,20 +35,24 @@ from megatron import update_num_microbatches
from megatron import mpu
from megatron import print_rank_0
from megatron import print_rank_last
-from megatron.checkpointing import load_checkpoint
-from megatron.checkpointing import save_checkpoint
+# from megatron.checkpointing import load_checkpoint
+from fmoe.megatron.checkpoint import load_checkpoint
+# from megatron.checkpointing import save_checkpoint
+from fmoe.megatron.checkpoint import save_checkpoint
from megatron.model import FP16Module
from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
from megatron.learning_rates import AnnealingLR
-from megatron.model import DistributedDataParallel as LocalDDP
+# from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination
from megatron.data.data_loaders import build_pretraining_data_loader
from megatron.utils import report_memory
+from fmoe.megatron import DistributedDataParallel as LocalDDP
+from fmoe.megatron import add_balance_log
def print_datetime(string):
"""Note that this call will sync across all ranks."""
@@ -102,6 +106,13 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
args = get_args()
timers = get_timers()
+ # Initialize FastMoE
+ if args.fmoefy:
+ from fmoe.megatron import patch_forward_step, patch_model_provider
+
+ forward_step_func = patch_forward_step(forward_step_func)
+ model_provider = patch_model_provider(model_provider)
+
# Model, optimizer, and learning rate.
timers('model and optimizer').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
@@ -643,7 +654,7 @@ def train_step(forward_step_func, data_iterator,
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
- loss_scale, report_memory_flag, skipped_iter):
+ loss_scale, report_memory_flag, skipped_iter, model):
"""Log training information such as losses, timing, ...."""
args = get_args()
timers = get_timers()
@@ -725,6 +736,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
args.consumed_train_samples)
timers.write(timers_to_log, writer, iteration,
normalizer=total_iterations)
+ if args.fmoefy and args.balance_strategy and args.balance_strategy != 'naive':
+ add_balance_log(model, writer, iteration)
if iteration % args.log_interval == 0:
elapsed_time = timers('interval time').elapsed()
@@ -816,7 +829,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
report_memory_flag = training_log(loss_dict, total_loss_dict,
optimizer.param_groups[0]['lr'],
iteration, loss_scale,
- report_memory_flag, skipped_iter)
+ report_memory_flag, skipped_iter, model)
# Autoresume
if args.adlr_autoresume and \
......@@ -45,11 +45,14 @@ def generate_megatron_gate_hook(layer_idx, num_expert_global):
return megatron_gate_hook
def add_balance_log(writer, iteration):
def add_balance_log(model, writer, iteration):
from megatron import is_last_rank
if hasattr(model, 'module'):
model = model.module
balance_dict_tensor = torch.vstack(
[torch.tensor(item, device=item[0].device) for item in balance_dict.values()]
[l.mlp.gate.get_loss(clear=True) for l in model.language_model.transformer.layers]
).detach()
world_group = get_torch_default_comm()
world_size = torch.distributed.get_world_size(group=world_group)
......@@ -68,8 +71,6 @@ def add_balance_log(writer, iteration):
iteration,
)
reset_gate_hook()
def patch_forward_step(forward_step_func):
r"""
......@@ -86,16 +87,19 @@ def patch_forward_step(forward_step_func):
args = get_args()
output = forward_step_func(data_iterator, model, input_tensor)
if not is_pipeline_last_stage():
if not is_pipeline_last_stage() or not args.balance_strategy or args.balance_strategy == 'naive':
return output
loss_name = args.balance_strategy + "_loss"
if hasattr(model, 'module'):
model = model.module
loss_list = [l.mlp.gate.get_loss(clear=False) for l in model.language_model.transformer.layers]
(loss, state_dict), bal_loss = (
output,
(
torch.tensor(
balance_dict[loss_name],
device=balance_dict[loss_name][0].device,
loss_list, device=loss_list[0].device
).mean()
* args.balance_loss_weight
).float(),
......
......@@ -84,7 +84,7 @@ class MegatronMLP(FMoETransformerMLP):
else:
world_size = args.world_size
gate = None
if not args.balance_strategy or args.balance_strategy == "gshard":
if not args.balance_strategy or args.balance_strategy == "naive":
from fmoe.gates import NaiveGate
gate = NaiveGate
......@@ -92,6 +92,14 @@ class MegatronMLP(FMoETransformerMLP):
from fmoe.gates import NoisyGate
gate = NoisyGate
elif args.balance_strategy == "gshard":
from fmoe.gates import GShardGate
gate = GShardGate
elif args.balance_strategy == "switch":
from fmoe.gates import SwitchGate
gate = SwitchGate
else:
assert False, "Undefined balance strategy {}" % (args.balance_strategy)
super().__init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment