Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
4eec9807
Commit
4eec9807
authored
May 31, 2021
by
Sengxian
Browse files
Adapt balance loss for new gate interface & update patch
parent
fa5f45f0
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
127 additions
and
61 deletions
+127
-61
examples/megatron/fmoefy-v2.0.patch
examples/megatron/fmoefy-v2.0.patch
+0
-26
examples/megatron/fmoefy-v2.1.patch
examples/megatron/fmoefy-v2.1.patch
+0
-27
examples/megatron/fmoefy-v2.2.patch
examples/megatron/fmoefy-v2.2.patch
+107
-0
fmoe/megatron/balance.py
fmoe/megatron/balance.py
+11
-7
fmoe/megatron/layers.py
fmoe/megatron/layers.py
+9
-1
No files found.
examples/megatron/fmoefy-v2.0.patch
deleted
100644 → 0
View file @
fa5f45f0
diff --git a/megatron/training.py b/megatron/training.py
index 96aec98..fe55dbd 100644
--- a/megatron/training.py
+++ b/megatron/training.py
@@ -33,7 +33,7 @@
from megatron.fp16 import FP16_Module
from megatron.fp16 import FP16_Optimizer
from megatron.initialize import initialize_megatron
from megatron.learning_rates import AnnealingLR
-from megatron.model import DistributedDataParallel as LocalDDP
+from fmoe.megatron import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization
from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination
diff --git a/pretrain_bert.py b/pretrain_bert.py
index b937b36..5841256 100644
--- a/pretrain_bert.py
+++ b/pretrain_bert.py
@@ -37,6 +37,8 @@
def model_provider():
num_tokentypes=2,
add_binary_head=True,
parallel_output=True)
+ from fmoe.megatron import fmoefy
+ model = fmoefy(model, num_experts=4)
return model
examples/megatron/fmoefy-v2.1.patch
deleted
100644 → 0
View file @
fa5f45f0
diff --git a/megatron/training.py b/megatron/training.py
index 56d1c7c..9c624d2 100644
--- a/megatron/training.py
+++ b/megatron/training.py
@@ -43,7 +43,8 @@
from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
from megatron.learning_rates import AnnealingLR
-from megatron.model import DistributedDataParallel as LocalDDP
+# from megatron.model import DistributedDataParallel as LocalDDP
+from fmoe.megatron import DistributedDataParallel as LocalDDP
from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination
from megatron.data.data_loaders import build_pretraining_data_loader
diff --git a/pretrain_bert.py b/pretrain_bert.py
index 48bc6ad..48628ce 100644
--- a/pretrain_bert.py
+++ b/pretrain_bert.py
@@ -52,6 +52,8 @@
def model_provider():
num_tokentypes=2,
add_binary_head=True,
parallel_output=True)
+ from fmoe.megatron import fmoefy
+ model = fmoefy(model, num_experts=4)
return model
examples/megatron/fmoefy-v2.2.patch
0 → 100644
View file @
4eec9807
diff --git a/megatron/arguments.py b/megatron/arguments.py
index 26a7cec..0acfb22 100644
--- a/megatron/arguments.py
+++ b/megatron/arguments.py
@@ -21,6 +21,8 @@
import os
import torch
from megatron import fused_kernels
+from fmoe.megatron import add_fmoe_args as _add_fmoe_args
+
def parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False):
"""Parse all arguments."""
@@ -40,6 +42,7 @@
def parse_args(extra_args_provider=None, defaults={},
parser = _add_data_args(parser)
parser = _add_autoresume_args(parser)
parser = _add_realm_args(parser)
+ parser = _add_fmoe_args(parser)
# Custom arguments.
if extra_args_provider is not None:
diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py
index 9d42260..2583db2 100644
--- a/megatron/optimizer/optimizer.py
+++ b/megatron/optimizer/optimizer.py
@@ -177,6 +177,8 @@
class FP16OptimizerWithFP16Params(MegatronOptimizer):
param)
if hasattr(param, 'shared'):
main_param.shared = param.shared
+ if hasattr(param, 'dp_comm'):
+ main_param.dp_comm = param.dp_comm
# Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = main_param
fp32_from_fp16_params_this_group.append(main_param)
diff --git a/megatron/training.py b/megatron/training.py
index 56d1c7c..f825bf3 100644
--- a/megatron/training.py
+++ b/megatron/training.py
@@ -35,20 +35,24 @@
from megatron import update_num_microbatches
from megatron import mpu
from megatron import print_rank_0
from megatron import print_rank_last
-from megatron.checkpointing import load_checkpoint
-from megatron.checkpointing import save_checkpoint
+# from megatron.checkpointing import load_checkpoint
+from fmoe.megatron.checkpoint import load_checkpoint
+# from megatron.checkpointing import save_checkpoint
+from fmoe.megatron.checkpoint import save_checkpoint
from megatron.model import FP16Module
from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
from megatron.learning_rates import AnnealingLR
-from megatron.model import DistributedDataParallel as LocalDDP
+# from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination
from megatron.data.data_loaders import build_pretraining_data_loader
from megatron.utils import report_memory
+from fmoe.megatron import DistributedDataParallel as LocalDDP
+from fmoe.megatron import add_balance_log
def print_datetime(string):
"""Note that this call will sync across all ranks."""
@@ -102,6 +106,13 @@
def pretrain(train_valid_test_dataset_provider, model_provider,
args = get_args()
timers = get_timers()
+ # Initialize FastMoE
+ if args.fmoefy:
+ from fmoe.megatron import patch_forward_step, patch_model_provider
+
+ forward_step_func = patch_forward_step(forward_step_func)
+ model_provider = patch_model_provider(model_provider)
+
# Model, optimizer, and learning rate.
timers('model and optimizer').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
@@ -643,7 +654,7 @@
def train_step(forward_step_func, data_iterator,
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
- loss_scale, report_memory_flag, skipped_iter):
+ loss_scale, report_memory_flag, skipped_iter, model):
"""Log training information such as losses, timing, ...."""
args = get_args()
timers = get_timers()
@@ -725,6 +736,8 @@
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
args.consumed_train_samples)
timers.write(timers_to_log, writer, iteration,
normalizer=total_iterations)
+ if args.fmoefy and args.balance_strategy and args.balance_strategy != 'naive':
+ add_balance_log(model, writer, iteration)
if iteration % args.log_interval == 0:
elapsed_time = timers('interval time').elapsed()
@@ -816,7 +829,7 @@
def train(forward_step_func, model, optimizer, lr_scheduler,
report_memory_flag = training_log(loss_dict, total_loss_dict,
optimizer.param_groups[0]['lr'],
iteration, loss_scale,
- report_memory_flag, skipped_iter)
+ report_memory_flag, skipped_iter, model)
# Autoresume
if args.adlr_autoresume and \
fmoe/megatron/balance.py
View file @
4eec9807
...
...
@@ -45,11 +45,14 @@ def generate_megatron_gate_hook(layer_idx, num_expert_global):
return
megatron_gate_hook
def
add_balance_log
(
writer
,
iteration
):
def
add_balance_log
(
model
,
writer
,
iteration
):
from
megatron
import
is_last_rank
if
hasattr
(
model
,
'module'
):
model
=
model
.
module
balance_dict_tensor
=
torch
.
vstack
(
[
torch
.
tensor
(
item
,
device
=
item
[
0
].
device
)
for
item
in
balance_dict
.
values
()
]
[
l
.
mlp
.
gate
.
get_loss
(
clear
=
True
)
for
l
in
model
.
language_model
.
transformer
.
layers
]
).
detach
()
world_group
=
get_torch_default_comm
()
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
world_group
)
...
...
@@ -68,8 +71,6 @@ def add_balance_log(writer, iteration):
iteration
,
)
reset_gate_hook
()
def
patch_forward_step
(
forward_step_func
):
r
"""
...
...
@@ -86,16 +87,19 @@ def patch_forward_step(forward_step_func):
args
=
get_args
()
output
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
if
not
is_pipeline_last_stage
():
if
not
is_pipeline_last_stage
()
or
not
args
.
balance_strategy
or
args
.
balance_strategy
==
'naive'
:
return
output
loss_name
=
args
.
balance_strategy
+
"_loss"
if
hasattr
(
model
,
'module'
):
model
=
model
.
module
loss_list
=
[
l
.
mlp
.
gate
.
get_loss
(
clear
=
False
)
for
l
in
model
.
language_model
.
transformer
.
layers
]
(
loss
,
state_dict
),
bal_loss
=
(
output
,
(
torch
.
tensor
(
balance_dict
[
loss_name
],
device
=
balance_dict
[
loss_name
][
0
].
device
,
loss_list
,
device
=
loss_list
[
0
].
device
).
mean
()
*
args
.
balance_loss_weight
).
float
(),
...
...
fmoe/megatron/layers.py
View file @
4eec9807
...
...
@@ -84,7 +84,7 @@ class MegatronMLP(FMoETransformerMLP):
else
:
world_size
=
args
.
world_size
gate
=
None
if
not
args
.
balance_strategy
or
args
.
balance_strategy
==
"
gshard
"
:
if
not
args
.
balance_strategy
or
args
.
balance_strategy
==
"
naive
"
:
from
fmoe.gates
import
NaiveGate
gate
=
NaiveGate
...
...
@@ -92,6 +92,14 @@ class MegatronMLP(FMoETransformerMLP):
from
fmoe.gates
import
NoisyGate
gate
=
NoisyGate
elif
args
.
balance_strategy
==
"gshard"
:
from
fmoe.gates
import
GShardGate
gate
=
GShardGate
elif
args
.
balance_strategy
==
"switch"
:
from
fmoe.gates
import
SwitchGate
gate
=
SwitchGate
else
:
assert
False
,
"Undefined balance strategy {}"
%
(
args
.
balance_strategy
)
super
().
__init__
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment