Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
4d59a9db
Unverified
Commit
4d59a9db
authored
May 31, 2021
by
Rick Ho
Committed by
GitHub
May 31, 2021
Browse files
Merge pull request #40 from laekov/new-gate-patch
Adapt balance loss for new gate interface & update patch
parents
c77f676d
4eec9807
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
127 additions
and
61 deletions
+127
-61
examples/megatron/fmoefy-v2.0.patch
examples/megatron/fmoefy-v2.0.patch
+0
-26
examples/megatron/fmoefy-v2.1.patch
examples/megatron/fmoefy-v2.1.patch
+0
-27
examples/megatron/fmoefy-v2.2.patch
examples/megatron/fmoefy-v2.2.patch
+107
-0
fmoe/megatron/balance.py
fmoe/megatron/balance.py
+11
-7
fmoe/megatron/layers.py
fmoe/megatron/layers.py
+9
-1
No files found.
examples/megatron/fmoefy-v2.0.patch
deleted
100644 → 0
View file @
c77f676d
diff --git a/megatron/training.py b/megatron/training.py
index 96aec98..fe55dbd 100644
--- a/megatron/training.py
+++ b/megatron/training.py
@@ -33,7 +33,7 @@
from megatron.fp16 import FP16_Module
from megatron.fp16 import FP16_Optimizer
from megatron.initialize import initialize_megatron
from megatron.learning_rates import AnnealingLR
-from megatron.model import DistributedDataParallel as LocalDDP
+from fmoe.megatron import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization
from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination
diff --git a/pretrain_bert.py b/pretrain_bert.py
index b937b36..5841256 100644
--- a/pretrain_bert.py
+++ b/pretrain_bert.py
@@ -37,6 +37,8 @@
def model_provider():
num_tokentypes=2,
add_binary_head=True,
parallel_output=True)
+ from fmoe.megatron import fmoefy
+ model = fmoefy(model, num_experts=4)
return model
examples/megatron/fmoefy-v2.1.patch
deleted
100644 → 0
View file @
c77f676d
diff --git a/megatron/training.py b/megatron/training.py
index 56d1c7c..9c624d2 100644
--- a/megatron/training.py
+++ b/megatron/training.py
@@ -43,7 +43,8 @@
from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
from megatron.learning_rates import AnnealingLR
-from megatron.model import DistributedDataParallel as LocalDDP
+# from megatron.model import DistributedDataParallel as LocalDDP
+from fmoe.megatron import DistributedDataParallel as LocalDDP
from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination
from megatron.data.data_loaders import build_pretraining_data_loader
diff --git a/pretrain_bert.py b/pretrain_bert.py
index 48bc6ad..48628ce 100644
--- a/pretrain_bert.py
+++ b/pretrain_bert.py
@@ -52,6 +52,8 @@
def model_provider():
num_tokentypes=2,
add_binary_head=True,
parallel_output=True)
+ from fmoe.megatron import fmoefy
+ model = fmoefy(model, num_experts=4)
return model
examples/megatron/fmoefy-v2.2.patch
0 → 100644
View file @
4d59a9db
diff --git a/megatron/arguments.py b/megatron/arguments.py
index 26a7cec..0acfb22 100644
--- a/megatron/arguments.py
+++ b/megatron/arguments.py
@@ -21,6 +21,8 @@
import os
import torch
from megatron import fused_kernels
+from fmoe.megatron import add_fmoe_args as _add_fmoe_args
+
def parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False):
"""Parse all arguments."""
@@ -40,6 +42,7 @@
def parse_args(extra_args_provider=None, defaults={},
parser = _add_data_args(parser)
parser = _add_autoresume_args(parser)
parser = _add_realm_args(parser)
+ parser = _add_fmoe_args(parser)
# Custom arguments.
if extra_args_provider is not None:
diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py
index 9d42260..2583db2 100644
--- a/megatron/optimizer/optimizer.py
+++ b/megatron/optimizer/optimizer.py
@@ -177,6 +177,8 @@
class FP16OptimizerWithFP16Params(MegatronOptimizer):
param)
if hasattr(param, 'shared'):
main_param.shared = param.shared
+ if hasattr(param, 'dp_comm'):
+ main_param.dp_comm = param.dp_comm
# Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = main_param
fp32_from_fp16_params_this_group.append(main_param)
diff --git a/megatron/training.py b/megatron/training.py
index 56d1c7c..f825bf3 100644
--- a/megatron/training.py
+++ b/megatron/training.py
@@ -35,20 +35,24 @@
from megatron import update_num_microbatches
from megatron import mpu
from megatron import print_rank_0
from megatron import print_rank_last
-from megatron.checkpointing import load_checkpoint
-from megatron.checkpointing import save_checkpoint
+# from megatron.checkpointing import load_checkpoint
+from fmoe.megatron.checkpoint import load_checkpoint
+# from megatron.checkpointing import save_checkpoint
+from fmoe.megatron.checkpoint import save_checkpoint
from megatron.model import FP16Module
from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
from megatron.learning_rates import AnnealingLR
-from megatron.model import DistributedDataParallel as LocalDDP
+# from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination
from megatron.data.data_loaders import build_pretraining_data_loader
from megatron.utils import report_memory
+from fmoe.megatron import DistributedDataParallel as LocalDDP
+from fmoe.megatron import add_balance_log
def print_datetime(string):
"""Note that this call will sync across all ranks."""
@@ -102,6 +106,13 @@
def pretrain(train_valid_test_dataset_provider, model_provider,
args = get_args()
timers = get_timers()
+ # Initialize FastMoE
+ if args.fmoefy:
+ from fmoe.megatron import patch_forward_step, patch_model_provider
+
+ forward_step_func = patch_forward_step(forward_step_func)
+ model_provider = patch_model_provider(model_provider)
+
# Model, optimizer, and learning rate.
timers('model and optimizer').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
@@ -643,7 +654,7 @@
def train_step(forward_step_func, data_iterator,
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
- loss_scale, report_memory_flag, skipped_iter):
+ loss_scale, report_memory_flag, skipped_iter, model):
"""Log training information such as losses, timing, ...."""
args = get_args()
timers = get_timers()
@@ -725,6 +736,8 @@
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
args.consumed_train_samples)
timers.write(timers_to_log, writer, iteration,
normalizer=total_iterations)
+ if args.fmoefy and args.balance_strategy and args.balance_strategy != 'naive':
+ add_balance_log(model, writer, iteration)
if iteration % args.log_interval == 0:
elapsed_time = timers('interval time').elapsed()
@@ -816,7 +829,7 @@
def train(forward_step_func, model, optimizer, lr_scheduler,
report_memory_flag = training_log(loss_dict, total_loss_dict,
optimizer.param_groups[0]['lr'],
iteration, loss_scale,
- report_memory_flag, skipped_iter)
+ report_memory_flag, skipped_iter, model)
# Autoresume
if args.adlr_autoresume and \
fmoe/megatron/balance.py
View file @
4d59a9db
...
...
@@ -45,11 +45,14 @@ def generate_megatron_gate_hook(layer_idx, num_expert_global):
return
megatron_gate_hook
def
add_balance_log
(
writer
,
iteration
):
def
add_balance_log
(
model
,
writer
,
iteration
):
from
megatron
import
is_last_rank
if
hasattr
(
model
,
'module'
):
model
=
model
.
module
balance_dict_tensor
=
torch
.
vstack
(
[
torch
.
tensor
(
item
,
device
=
item
[
0
].
device
)
for
item
in
balance_dict
.
values
()
]
[
l
.
mlp
.
gate
.
get_loss
(
clear
=
True
)
for
l
in
model
.
language_model
.
transformer
.
layers
]
).
detach
()
world_group
=
get_torch_default_comm
()
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
world_group
)
...
...
@@ -68,8 +71,6 @@ def add_balance_log(writer, iteration):
iteration
,
)
reset_gate_hook
()
def
patch_forward_step
(
forward_step_func
):
r
"""
...
...
@@ -86,16 +87,19 @@ def patch_forward_step(forward_step_func):
args
=
get_args
()
output
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
if
not
is_pipeline_last_stage
():
if
not
is_pipeline_last_stage
()
or
not
args
.
balance_strategy
or
args
.
balance_strategy
==
'naive'
:
return
output
loss_name
=
args
.
balance_strategy
+
"_loss"
if
hasattr
(
model
,
'module'
):
model
=
model
.
module
loss_list
=
[
l
.
mlp
.
gate
.
get_loss
(
clear
=
False
)
for
l
in
model
.
language_model
.
transformer
.
layers
]
(
loss
,
state_dict
),
bal_loss
=
(
output
,
(
torch
.
tensor
(
balance_dict
[
loss_name
],
device
=
balance_dict
[
loss_name
][
0
].
device
,
loss_list
,
device
=
loss_list
[
0
].
device
).
mean
()
*
args
.
balance_loss_weight
).
float
(),
...
...
fmoe/megatron/layers.py
View file @
4d59a9db
...
...
@@ -84,7 +84,7 @@ class MegatronMLP(FMoETransformerMLP):
else
:
world_size
=
args
.
world_size
gate
=
None
if
not
args
.
balance_strategy
or
args
.
balance_strategy
==
"
gshard
"
:
if
not
args
.
balance_strategy
or
args
.
balance_strategy
==
"
naive
"
:
from
fmoe.gates
import
NaiveGate
gate
=
NaiveGate
...
...
@@ -92,6 +92,14 @@ class MegatronMLP(FMoETransformerMLP):
from
fmoe.gates
import
NoisyGate
gate
=
NoisyGate
elif
args
.
balance_strategy
==
"gshard"
:
from
fmoe.gates
import
GShardGate
gate
=
GShardGate
elif
args
.
balance_strategy
==
"switch"
:
from
fmoe.gates
import
SwitchGate
gate
=
SwitchGate
else
:
assert
False
,
"Undefined balance strategy {}"
%
(
args
.
balance_strategy
)
super
().
__init__
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment