Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
efef43b5
Unverified
Commit
efef43b5
authored
Feb 08, 2024
by
Frank Lee
Committed by
GitHub
Feb 08, 2024
Browse files
Merge pull request #5372 from hpcaitech/exp/mixtral
parents
4c03347f
06db94fb
Changes
33
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2091 additions
and
6 deletions
+2091
-6
applications/ColossalMoE/README.md
applications/ColossalMoE/README.md
+0
-0
applications/ColossalMoE/colossal_moe/__init__.py
applications/ColossalMoE/colossal_moe/__init__.py
+0
-0
applications/ColossalMoE/colossal_moe/models/__init__.py
applications/ColossalMoE/colossal_moe/models/__init__.py
+0
-0
applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py
...ons/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py
+629
-0
applications/ColossalMoE/colossal_moe/models/mixtral_layer.py
...ications/ColossalMoE/colossal_moe/models/mixtral_layer.py
+92
-0
applications/ColossalMoE/colossal_moe/models/mixtral_policy.py
...cations/ColossalMoE/colossal_moe/models/mixtral_policy.py
+557
-0
applications/ColossalMoE/colossal_moe/utils.py
applications/ColossalMoE/colossal_moe/utils.py
+84
-0
applications/ColossalMoE/infer.py
applications/ColossalMoE/infer.py
+111
-0
applications/ColossalMoE/infer.sh
applications/ColossalMoE/infer.sh
+7
-0
applications/ColossalMoE/requirements.txt
applications/ColossalMoE/requirements.txt
+5
-0
applications/ColossalMoE/setup.py
applications/ColossalMoE/setup.py
+43
-0
applications/ColossalMoE/tests/__init__.py
applications/ColossalMoE/tests/__init__.py
+0
-0
applications/ColossalMoE/tests/test_mixtral_layer.py
applications/ColossalMoE/tests/test_mixtral_layer.py
+63
-0
applications/ColossalMoE/tests/test_moe_checkpoint.py
applications/ColossalMoE/tests/test_moe_checkpoint.py
+146
-0
applications/ColossalMoE/train.py
applications/ColossalMoE/train.py
+295
-0
applications/ColossalMoE/train.sh
applications/ColossalMoE/train.sh
+19
-0
applications/ColossalMoE/version.txt
applications/ColossalMoE/version.txt
+1
-0
colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+25
-3
colossalai/checkpoint_io/checkpoint_io_base.py
colossalai/checkpoint_io/checkpoint_io_base.py
+10
-2
colossalai/moe/__init__.py
colossalai/moe/__init__.py
+4
-1
No files found.
applications/ColossalMoE/README.md
0 → 100644
View file @
efef43b5
File added
applications/ColossalMoE/colossal_moe/__init__.py
0 → 100644
View file @
efef43b5
applications/ColossalMoE/colossal_moe/models/__init__.py
0 → 100644
View file @
efef43b5
applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py
0 → 100644
View file @
efef43b5
This diff is collapsed.
Click to expand it.
applications/ColossalMoE/colossal_moe/models/mixtral_layer.py
0 → 100644
View file @
efef43b5
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
transformers.models.mixtral.modeling_mixtral
import
MixtralSparseMoeBlock
from
colossalai.lazy
import
LazyInitContext
from
colossalai.moe
import
MOE_MANAGER
from
colossalai.moe._operation
import
MoeInGradScaler
,
MoeOutGradScaler
,
all_to_all_uneven
from
colossalai.shardformer.shard.utils
import
set_tensors_to_none
from
colossalai.tensor.moe_tensor.api
import
set_moe_tensor_info
class
EPMixtralSparseMoeBlock
(
MixtralSparseMoeBlock
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
setup_ep
()
def
setup_ep
(
self
):
_
,
moe_info
=
MOE_MANAGER
.
get_info
(
self
.
num_experts
)
ep_group
=
moe_info
.
ep_group
self
.
ep_size
=
dist
.
get_world_size
(
ep_group
)
if
ep_group
is
not
None
else
1
self
.
ep_rank
=
dist
.
get_rank
(
ep_group
)
if
ep_group
is
not
None
else
0
assert
self
.
num_experts
%
self
.
ep_size
==
0
self
.
ep_group
=
ep_group
self
.
num_experts_per_ep
=
self
.
num_experts
//
self
.
ep_size
self
.
expert_start_idx
=
self
.
ep_rank
*
self
.
num_experts_per_ep
held_experts
=
self
.
experts
[
self
.
expert_start_idx
:
self
.
expert_start_idx
+
self
.
num_experts_per_ep
]
set_tensors_to_none
(
self
.
experts
,
exclude
=
set
(
held_experts
))
for
p
in
self
.
experts
.
parameters
():
set_moe_tensor_info
(
p
,
moe_info
)
@
staticmethod
def
from_native_module
(
module
:
MixtralSparseMoeBlock
,
*
args
,
**
kwargs
)
->
"EPMixtralSparseMoeBlock"
:
LazyInitContext
.
materialize
(
module
)
module
.
__class__
=
EPMixtralSparseMoeBlock
module
.
setup_ep
()
return
module
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
,
sequence_length
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
# router_logits: (batch * sequence_length, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
routing_weights
=
F
.
softmax
(
router_logits
,
dim
=
1
,
dtype
=
torch
.
float
)
routing_weights
,
selected_experts
=
torch
.
topk
(
routing_weights
,
self
.
top_k
,
dim
=-
1
)
routing_weights
/=
routing_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
# we cast back to the input dtype
routing_weights
=
routing_weights
.
to
(
hidden_states
.
dtype
)
selected_experts
=
selected_experts
.
t
().
reshape
(
-
1
)
selected_experts_idx
=
selected_experts
.
argsort
()
dispatch_states
=
hidden_states
.
repeat
(
self
.
top_k
,
1
)[
selected_experts_idx
]
input_split_sizes
=
selected_experts
.
bincount
(
minlength
=
self
.
num_experts
)
output_split_sizes
=
torch
.
zeros_like
(
input_split_sizes
)
dist
.
all_to_all_single
(
output_split_sizes
,
input_split_sizes
,
group
=
self
.
ep_group
)
input_split_list
=
input_split_sizes
.
view
(
self
.
ep_size
,
self
.
num_experts_per_ep
).
sum
(
dim
=-
1
).
tolist
()
output_split_list
=
output_split_sizes
.
view
(
self
.
ep_size
,
self
.
num_experts_per_ep
).
sum
(
dim
=-
1
).
tolist
()
output_states
,
_
=
all_to_all_uneven
(
dispatch_states
,
input_split_list
,
output_split_list
,
self
.
ep_group
)
# compute expert output
output_states
=
MoeInGradScaler
.
apply
(
output_states
,
self
.
ep_size
)
if
output_states
.
size
(
0
)
>
0
:
if
self
.
num_experts_per_ep
==
1
:
# no need to split
expert
=
self
.
experts
[
self
.
expert_start_idx
]
output_states
=
expert
.
act_fn
(
expert
.
w1
(
output_states
))
*
expert
.
w3
(
output_states
)
output_states
=
expert
.
w2
(
output_states
)
else
:
output_states_splits
=
output_states
.
split
(
output_split_sizes
.
tolist
())
output_states_list
=
[]
for
i
,
split_states
in
enumerate
(
output_states_splits
):
if
split_states
.
size
(
0
)
==
0
:
continue
expert
=
self
.
experts
[
self
.
expert_start_idx
+
i
%
self
.
num_experts_per_ep
]
split_states
=
expert
.
act_fn
(
expert
.
w1
(
split_states
))
*
expert
.
w3
(
split_states
)
split_states
=
expert
.
w2
(
split_states
)
output_states_list
.
append
(
split_states
)
output_states
=
torch
.
cat
(
output_states_list
)
output_states
=
MoeOutGradScaler
.
apply
(
output_states
,
self
.
ep_size
)
dispatch_states
,
_
=
all_to_all_uneven
(
output_states
,
output_split_list
,
input_split_list
,
self
.
ep_group
)
recover_experts_idx
=
torch
.
empty_like
(
selected_experts_idx
)
recover_experts_idx
[
selected_experts_idx
]
=
torch
.
arange
(
selected_experts_idx
.
size
(
0
),
device
=
selected_experts_idx
.
device
)
dispatch_states
=
dispatch_states
[
recover_experts_idx
]
k_hidden_states
=
dispatch_states
.
chunk
(
self
.
top_k
)
output_states
=
k_hidden_states
[
0
]
*
routing_weights
[:,
0
,
None
]
for
i
in
range
(
1
,
self
.
top_k
):
output_states
+=
k_hidden_states
[
i
]
*
routing_weights
[:,
i
,
None
]
output_states
=
output_states
.
reshape
(
batch_size
,
sequence_length
,
hidden_dim
)
return
output_states
,
router_logits
applications/ColossalMoE/colossal_moe/models/mixtral_policy.py
0 → 100644
View file @
efef43b5
This diff is collapsed.
Click to expand it.
applications/ColossalMoE/colossal_moe/utils.py
0 → 100644
View file @
efef43b5
import
json
import
os
from
typing
import
Any
,
Dict
,
Tuple
,
Union
import
torch
from
torch.optim.lr_scheduler
import
_LRScheduler
from
torch.optim.optimizer
import
Optimizer
from
colossalai.booster
import
Booster
from
colossalai.cluster
import
DistCoordinator
def
move_to_cuda
(
batch
,
device
):
return
{
k
:
v
.
to
(
device
)
for
k
,
v
in
batch
.
items
()}
def
load_json
(
file_path
:
Union
[
str
,
os
.
PathLike
])
->
Dict
[
str
,
Any
]:
"""
Load file in JSON format
"""
with
open
(
file
=
file_path
,
mode
=
"r"
,
encoding
=
"utf-8"
)
as
fp
:
return
json
.
load
(
fp
)
def
save_json
(
data
:
Dict
[
str
,
Any
],
file_path
:
Union
[
str
,
os
.
PathLike
])
->
None
:
"""
Save as JSON format
"""
with
open
(
file
=
file_path
,
mode
=
"w"
,
encoding
=
"utf-8"
)
as
fp
:
json
.
dump
(
data
,
fp
=
fp
,
ensure_ascii
=
False
,
indent
=
4
)
def
save_checkpoint
(
save_dir
:
Union
[
str
,
os
.
PathLike
],
booster
:
Booster
,
model
:
torch
.
nn
.
Module
,
optimizer
:
Optimizer
,
lr_scheduler
:
_LRScheduler
,
epoch
:
int
,
step
:
int
,
batch_size
:
int
,
coordinator
:
DistCoordinator
,
)
->
None
:
"""
Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
"""
save_dir
=
os
.
path
.
join
(
save_dir
,
f
"epoch-
{
epoch
}
_step-
{
step
}
"
)
os
.
makedirs
(
os
.
path
.
join
(
save_dir
,
"modeling"
),
exist_ok
=
True
)
booster
.
save_model
(
model
,
os
.
path
.
join
(
save_dir
,
"modeling"
),
shard
=
True
)
booster
.
save_optimizer
(
optimizer
,
os
.
path
.
join
(
save_dir
,
"optimizer"
),
shard
=
True
)
booster
.
save_lr_scheduler
(
lr_scheduler
,
os
.
path
.
join
(
save_dir
,
"lr_scheduler"
))
running_states
=
{
"epoch"
:
epoch
,
"step"
:
step
,
"sample_start_index"
:
step
*
batch_size
,
}
if
coordinator
.
is_master
():
save_json
(
running_states
,
os
.
path
.
join
(
save_dir
,
"running_states.json"
))
def
load_checkpoint
(
load_dir
:
Union
[
str
,
os
.
PathLike
],
booster
:
Booster
,
model
:
torch
.
nn
.
Module
,
optimizer
:
Optimizer
,
lr_scheduler
:
_LRScheduler
,
)
->
Tuple
[
int
,
int
,
int
]:
"""
Load model checkpoint, optimizer, LR scheduler and intermedidate running states.
"""
# Update booster params states.
booster
.
load_model
(
model
,
os
.
path
.
join
(
load_dir
,
"modeling"
))
booster
.
load_optimizer
(
optimizer
=
optimizer
,
checkpoint
=
os
.
path
.
join
(
load_dir
,
"optimizer"
))
booster
.
load_lr_scheduler
(
lr_scheduler
=
lr_scheduler
,
checkpoint
=
os
.
path
.
join
(
load_dir
,
"lr_scheduler"
))
running_states
=
load_json
(
file_path
=
os
.
path
.
join
(
load_dir
,
"running_states.json"
))
return
(
running_states
[
"epoch"
],
running_states
[
"step"
],
running_states
[
"sample_start_index"
],
)
applications/ColossalMoE/infer.py
0 → 100644
View file @
efef43b5
import
argparse
import
torch
import
torch.distributed
as
dist
from
colossal_moe.models.mixtral_checkpoint
import
MixtralMoEHybridParallelCheckpointIO
from
colossal_moe.models.mixtral_policy
import
MixtralForCausalLMPolicy
from
transformers
import
AutoTokenizer
from
transformers.models.mixtral
import
MixtralConfig
,
MixtralForCausalLM
import
colossalai
from
colossalai.booster
import
Booster
from
colossalai.booster.plugin.moe_hybrid_parallel_plugin
import
MoeHybridParallelPlugin
from
colossalai.cluster
import
DistCoordinator
def
parse_args
():
# basic settings
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model_name"
,
type
=
str
,
default
=
"mistralai/Mixtral-8x7B-v0.1"
,
help
=
"Path to pretrained model or model identifier from huggingface.co/models."
,
)
parser
.
add_argument
(
"--plugin"
,
type
=
str
,
default
=
"ep"
,
choices
=
[
"ep"
],
help
=
"Parallel methos."
,
)
parser
.
add_argument
(
"--precision"
,
type
=
str
,
default
=
"bf16"
,
choices
=
[
"fp32"
,
"bf16"
,
"fp16"
],
help
=
"The mixed precision training."
,
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
,
help
=
"A seed for reproducible training."
)
# kernel
parser
.
add_argument
(
"--use_kernel"
,
action
=
"store_true"
,
help
=
"Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed."
,
)
parser
.
add_argument
(
"--use_layernorm_kernel"
,
action
=
"store_true"
,
help
=
"Use layernorm kernel. Need to install apex. Raise error if not installed."
,
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
# Launch ColossalAI
colossalai
.
launch_from_torch
(
config
=
{},
seed
=
args
.
seed
)
coordinator
=
DistCoordinator
()
config
=
MixtralConfig
.
from_pretrained
(
args
.
model_name
)
ep_size
=
min
(
dist
.
get_world_size
(),
config
.
num_local_experts
)
# Set plugin
if
args
.
plugin
==
"ep"
:
plugin
=
MoeHybridParallelPlugin
(
tp_size
=
1
,
pp_size
=
1
,
ep_size
=
ep_size
,
zero_stage
=
1
,
precision
=
args
.
precision
,
custom_policy
=
MixtralForCausalLMPolicy
(),
checkpoint_io
=
MixtralMoEHybridParallelCheckpointIO
,
enable_fused_normalization
=
args
.
use_layernorm_kernel
,
enable_jit_fused
=
args
.
use_kernel
,
)
else
:
raise
ValueError
(
f
"Invalid plugin
{
args
.
plugin
}
"
)
coordinator
.
print_on_master
(
f
"Set plugin as
{
plugin
.
__class__
.
__name__
}
"
)
# Build mixtral model
model
=
MixtralForCausalLM
.
from_pretrained
(
args
.
model_name
)
coordinator
.
print_on_master
(
f
"Finish load model"
)
# Prepare tokenizer and dataloader
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_name
)
# Set booster
booster
=
Booster
(
plugin
=
plugin
)
model
,
_
,
_
,
_
,
_
=
booster
.
boost
(
model
=
model
)
coordinator
.
print_on_master
(
f
"Finish init booster"
)
model
.
eval
()
if
coordinator
.
rank
==
0
:
text
=
[
"Hello my name is"
]
else
:
text
=
[
"What's the largest country in the world?"
,
"How many people live in China?"
,
"帮我续写这首诗:离离原上草"
]
tokenizer
.
pad_token
=
tokenizer
.
unk_token
inputs
=
tokenizer
(
text
,
return_tensors
=
"pt"
,
padding
=
True
).
to
(
torch
.
cuda
.
current_device
())
with
torch
.
no_grad
():
outputs
=
model
.
module
.
generate
(
**
inputs
,
max_new_tokens
=
20
)
outputs
=
tokenizer
.
batch_decode
(
outputs
,
skip_special_tokens
=
True
)
print
(
f
"[
{
coordinator
.
rank
}
]
{
outputs
}
"
)
if
__name__
==
"__main__"
:
main
()
applications/ColossalMoE/infer.sh
0 → 100644
View file @
efef43b5
NUM_GPU
=
2
MODEL
=
"mistralai/Mixtral-8x7B-v0.1"
# ep
torchrun
--standalone
--nproc_per_node
$NUM_GPU
infer.py
\
--model_name
$MODEL
\
--plugin
"ep"
\
applications/ColossalMoE/requirements.txt
0 → 100644
View file @
efef43b5
colossalai >= 0.3.3
torch >= 1.8.1
transformers == 4.36.0
sentencepiece
datasets
applications/ColossalMoE/setup.py
0 → 100644
View file @
efef43b5
from
setuptools
import
find_packages
,
setup
def
fetch_requirements
(
path
):
with
open
(
path
,
"r"
)
as
fd
:
return
[
r
.
strip
()
for
r
in
fd
.
readlines
()]
def
fetch_readme
():
with
open
(
"README.md"
,
encoding
=
"utf-8"
)
as
f
:
return
f
.
read
()
def
fetch_version
():
with
open
(
"version.txt"
,
"r"
)
as
f
:
return
f
.
read
().
strip
()
setup
(
name
=
"colossal_moe"
,
version
=
fetch_version
(),
packages
=
find_packages
(
exclude
=
(
"tests"
,
"benchmarks"
,
"*.egg-info"
,
)
),
description
=
"Colossal-AI MoE"
,
long_description
=
fetch_readme
(),
long_description_content_type
=
"text/markdown"
,
license
=
"Apache Software License 2.0"
,
url
=
"https://github.com/hpcaitech"
,
install_requires
=
fetch_requirements
(
"requirements.txt"
),
python_requires
=
">=3.6"
,
classifiers
=
[
"Programming Language :: Python :: 3"
,
"License :: OSI Approved :: Apache Software License"
,
"Environment :: GPU :: NVIDIA CUDA"
,
"Topic :: Scientific/Engineering :: Artificial Intelligence"
,
"Topic :: System :: Distributed Computing"
,
],
)
applications/ColossalMoE/tests/__init__.py
0 → 100644
View file @
efef43b5
applications/ColossalMoE/tests/test_mixtral_layer.py
0 → 100644
View file @
efef43b5
from
copy
import
deepcopy
import
pytest
import
torch
import
torch.distributed
as
dist
from
colossal_moe.models.mixtral_layer
import
EPMixtralSparseMoeBlock
from
torch.testing
import
assert_close
from
transformers.models.mixtral.configuration_mixtral
import
MixtralConfig
from
transformers.models.mixtral.modeling_mixtral
import
MixtralSparseMoeBlock
import
colossalai
from
colossalai.moe
import
MOE_MANAGER
from
colossalai.testing.utils
import
spawn
tokens
,
n_experts
=
7
,
4
hidden_size
=
8
top_k
=
2
def
check_mixtral_moe_layer
():
torch
.
cuda
.
set_device
(
dist
.
get_rank
())
MOE_MANAGER
.
setup
(
parallel
=
"EP"
,
mode
=
"fixed"
,
fixed_dp_size
=
1
,
fixed_ep_size
=
dist
.
get_world_size
(),
fixed_pp_size
=
1
)
config
=
MixtralConfig
(
hidden_size
=
hidden_size
,
intermediate_size
=
hidden_size
*
2
,
num_local_experts
=
n_experts
,
num_experts_per_tok
=
top_k
,
)
torch
.
manual_seed
(
0
)
orig_model
=
MixtralSparseMoeBlock
(
config
).
cuda
()
x
=
torch
.
rand
(
1
,
tokens
,
hidden_size
,
requires_grad
=
True
).
cuda
()
orig_output
,
orig_logits
=
orig_model
(
x
)
model
=
deepcopy
(
orig_model
)
model
=
EPMixtralSparseMoeBlock
.
from_native_module
(
model
)
ep_output
,
ep_logits
=
model
(
x
)
assert_close
(
orig_logits
,
ep_logits
)
assert_close
(
orig_output
,
ep_output
)
orig_loss
=
orig_output
.
mean
()
orig_loss
.
backward
()
ep_loss
=
ep_output
.
mean
()
ep_loss
.
backward
()
assert_close
(
orig_loss
,
ep_loss
)
name_to_p
=
{
n
:
p
for
n
,
p
in
orig_model
.
named_parameters
()}
for
n
,
ep_p
in
model
.
named_parameters
():
p
=
name_to_p
[
n
]
if
ep_p
.
grad
is
not
None
:
assert_close
(
p
.
grad
,
ep_p
.
grad
)
def
run_dist
(
rank
:
int
,
world_size
:
int
,
port
:
int
):
colossalai
.
launch
({},
rank
,
world_size
,
"localhost"
,
port
)
check_mixtral_moe_layer
()
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
,
4
])
def
test_mixtral_moe_layer
(
world_size
:
int
):
spawn
(
run_dist
,
world_size
)
if
__name__
==
"__main__"
:
test_mixtral_moe_layer
(
2
)
applications/ColossalMoE/tests/test_moe_checkpoint.py
0 → 100644
View file @
efef43b5
from
copy
import
deepcopy
import
pytest
import
torch
import
torch.distributed
as
dist
from
colossal_moe.models.mixtral_checkpoint
import
MixtralMoEHybridParallelCheckpointIO
from
colossal_moe.models.mixtral_policy
import
MixtralForCausalLMPolicy
from
torch.optim
import
Adam
from
transformers.models.mixtral.configuration_mixtral
import
MixtralConfig
from
transformers.models.mixtral.modeling_mixtral
import
MixtralForCausalLM
import
colossalai
from
colossalai.booster
import
Booster
from
colossalai.booster.plugin.moe_hybrid_parallel_plugin
import
MoeHybridParallelPlugin
from
colossalai.testing.utils
import
spawn
tokens
,
n_experts
=
7
,
4
hidden_size
=
8
top_k
=
2
def
check_model_equal
(
model1
,
model2
):
assert
set
(
model1
.
state_dict
().
keys
())
==
set
(
model2
.
state_dict
().
keys
())
for
p1
,
p2
in
zip
(
model1
.
parameters
(),
model2
.
parameters
()):
assert
torch
.
equal
(
p1
.
half
(),
p2
.
half
())
def
get_optimizer_snapshot
(
optim
):
state
=
{
id
(
k
):
deepcopy
(
v
)
for
k
,
v
in
optim
.
state
.
items
()}
param_groups
=
[]
for
group
in
optim
.
param_groups
:
params
=
[
id
(
p
)
for
p
in
group
[
"params"
]]
new_group
=
{
"params"
:
params
}
for
k
,
v
in
group
.
items
():
if
k
!=
"params"
:
new_group
[
k
]
=
v
param_groups
.
append
(
new_group
)
return
{
"state"
:
state
,
"param_groups"
:
param_groups
,
}
def
check_optimizer_snapshot_equal
(
snapshot1
,
snapshot2
):
# check param_groups
assert
len
(
snapshot1
[
"param_groups"
])
==
len
(
snapshot2
[
"param_groups"
])
for
group1
,
group2
in
zip
(
snapshot1
[
"param_groups"
],
snapshot2
[
"param_groups"
]):
assert
set
(
group1
.
keys
())
==
set
(
group2
.
keys
())
for
k
in
group1
.
keys
():
assert
group1
[
k
]
==
group2
[
k
]
# check state
assert
set
(
snapshot1
[
"state"
].
keys
())
==
set
(
snapshot2
[
"state"
].
keys
()
),
f
"
{
snapshot1
[
'state'
].
keys
()
}
,
{
snapshot2
[
'state'
].
keys
()
}
"
for
pid
in
snapshot1
[
"state"
].
keys
():
state1
,
state2
=
snapshot1
[
"state"
][
pid
],
snapshot2
[
"state"
][
pid
]
assert
set
(
state1
.
keys
())
==
set
(
state2
.
keys
())
for
k
in
state1
.
keys
():
if
isinstance
(
state1
[
k
],
torch
.
Tensor
):
assert
torch
.
equal
(
state1
[
k
],
state2
[
k
]),
f
"
{
k
}
,
{
state1
[
k
]
}
,
{
state2
[
k
]
}
"
else
:
assert
state1
[
k
]
==
state2
[
k
]
def
check_mixtral_moe_layer
():
torch
.
cuda
.
set_device
(
dist
.
get_rank
())
config
=
MixtralConfig
(
hidden_size
=
hidden_size
,
intermediate_size
=
hidden_size
*
2
,
num_local_experts
=
n_experts
,
num_experts_per_tok
=
top_k
,
num_attention_heads
=
2
,
num_key_value_heads
=
2
,
)
torch
.
manual_seed
(
0
)
input_ids
=
torch
.
randint
(
0
,
100
,
(
2
,
tokens
)).
cuda
()
orig_model
=
MixtralForCausalLM
(
config
).
cuda
()
model
=
deepcopy
(
orig_model
)
optimizer
=
Adam
(
model
.
parameters
(),
lr
=
1e-3
)
plugin
=
MoeHybridParallelPlugin
(
tp_size
=
1
,
pp_size
=
2
,
ep_size
=
2
,
custom_policy
=
MixtralForCausalLMPolicy
(),
checkpoint_io
=
MixtralMoEHybridParallelCheckpointIO
,
microbatch_size
=
1
,
zero_stage
=
1
,
)
booster
=
Booster
(
plugin
=
plugin
)
model
,
optimizer
,
*
_
=
booster
.
boost
(
model
=
model
,
optimizer
=
optimizer
)
# initialize grads
data_iter
=
iter
(
[{
"input_ids"
:
input_ids
,
"attention_mask"
:
torch
.
ones_like
(
input_ids
),
"labels"
:
input_ids
.
clone
()}]
)
booster
.
execute_pipeline
(
data_iter
,
model
,
lambda
outputs
,
inputs
:
outputs
.
loss
,
optimizer
,
)
# check save model
booster
.
save_model
(
model
,
"mixtral_model"
,
shard
=
True
)
dist
.
barrier
()
if
dist
.
get_rank
()
==
0
:
saved_model
=
MixtralForCausalLM
.
from_pretrained
(
"mixtral_model"
).
cuda
()
check_model_equal
(
orig_model
,
saved_model
)
saved_model
.
save_pretrained
(
"mixtral_hf_model"
)
dist
.
barrier
()
# check load model
new_model
=
MixtralForCausalLM
(
config
).
cuda
()
new_optimizer
=
Adam
(
new_model
.
parameters
(),
lr
=
1e-3
)
new_model
,
new_optimizer
,
*
_
=
booster
.
boost
(
model
=
new_model
,
optimizer
=
new_optimizer
)
booster
.
load_model
(
new_model
,
"mixtral_hf_model"
)
check_model_equal
(
model
,
new_model
)
# check save optimizer
optimizer
.
step
()
for
group
in
optimizer
.
param_groups
:
group
[
"lr"
]
=
0.1
snapshot
=
get_optimizer_snapshot
(
optimizer
.
unwrap
())
booster
.
save_optimizer
(
optimizer
,
"mixtral_optim"
,
shard
=
True
)
dist
.
barrier
()
# reset optimizer state
for
state
in
optimizer
.
unwrap
().
state
.
values
():
for
v
in
state
.
values
():
if
isinstance
(
v
,
torch
.
Tensor
):
v
.
zero_
()
booster
.
load_optimizer
(
optimizer
,
"mixtral_optim"
)
loaded_snapshot
=
get_optimizer_snapshot
(
optimizer
.
unwrap
())
check_optimizer_snapshot_equal
(
snapshot
,
loaded_snapshot
)
def
run_dist
(
rank
:
int
,
world_size
:
int
,
port
:
int
):
colossalai
.
launch
({},
rank
,
world_size
,
"localhost"
,
port
)
check_mixtral_moe_layer
()
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
4
])
def
test_mixtral_moe_layer
(
world_size
:
int
):
spawn
(
run_dist
,
world_size
)
if
__name__
==
"__main__"
:
test_mixtral_moe_layer
(
4
)
applications/ColossalMoE/train.py
0 → 100644
View file @
efef43b5
import
argparse
import
torch
import
torch.distributed
as
dist
from
colossal_moe.models.mixtral_checkpoint
import
MixtralMoEHybridParallelCheckpointIO
from
colossal_moe.models.mixtral_policy
import
MixtralForCausalLMPolicy
from
colossal_moe.utils
import
load_checkpoint
,
move_to_cuda
,
save_checkpoint
from
torch.utils.data
import
Dataset
from
tqdm
import
tqdm
from
transformers
import
AutoTokenizer
from
transformers.models.mixtral
import
MixtralForCausalLM
import
colossalai
from
colossalai.booster
import
Booster
from
colossalai.booster.plugin.moe_hybrid_parallel_plugin
import
MoeHybridParallelPlugin
from
colossalai.cluster
import
DistCoordinator
from
colossalai.nn.lr_scheduler
import
CosineAnnealingWarmupLR
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.utils
import
get_current_device
@
torch
.
no_grad
()
def
get_global_loss
(
loss
,
booster
):
global_loss
=
loss
.
clone
().
detach
()
dist
.
all_reduce
(
tensor
=
global_loss
,
op
=
dist
.
ReduceOp
.
SUM
,
group
=
booster
.
plugin
.
dp_group
)
global_loss
.
div_
(
booster
.
plugin
.
dp_size
)
return
global_loss
class
RandomDataset
(
Dataset
):
def
__init__
(
self
,
num_samples
:
int
=
1000
,
max_length
:
int
=
2048
,
vocab_size
:
int
=
100
,
tokenizer
=
None
):
self
.
num_samples
=
num_samples
self
.
max_length
=
max_length
self
.
input_ids
=
torch
.
randint
(
0
,
vocab_size
,
(
num_samples
,
max_length
),
device
=
get_current_device
())
self
.
attention_mask
=
torch
.
ones_like
(
self
.
input_ids
)
def
__len__
(
self
):
return
self
.
num_samples
def
__getitem__
(
self
,
idx
):
return
{
"input_ids"
:
self
.
input_ids
[
idx
],
"attention_mask"
:
self
.
attention_mask
[
idx
],
"labels"
:
self
.
input_ids
[
idx
],
}
def
parse_args
():
# basic settings
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model_name"
,
type
=
str
,
default
=
"mistralai/Mixtral-8x7B-v0.1"
,
help
=
"Path to pretrained model or model identifier from huggingface.co/models."
,
)
parser
.
add_argument
(
"--load_checkpoint"
,
type
=
str
,
default
=
None
,
help
=
"Load checkpoint"
)
parser
.
add_argument
(
"--plugin"
,
type
=
str
,
default
=
"hybrid"
,
choices
=
[
"hybrid"
],
help
=
"Parallel methods."
,
)
parser
.
add_argument
(
"--output_path"
,
type
=
str
,
default
=
"./outputs"
,
help
=
"The path of your saved model after finetuning."
,
)
parser
.
add_argument
(
"--num_epoch"
,
type
=
int
,
default
=
1
,
help
=
"Number of epochs."
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
,
help
=
"Batch size (per dp group) for the training dataloader."
,
)
parser
.
add_argument
(
"--save_interval"
,
type
=
int
,
default
=
1000
,
help
=
" The interval (steps) of saving checkpoints."
,
)
parser
.
add_argument
(
"--precision"
,
type
=
str
,
default
=
"bf16"
,
choices
=
[
"fp32"
,
"bf16"
,
"fp16"
],
help
=
"The mixed precision training."
,
)
parser
.
add_argument
(
"--max_length"
,
type
=
int
,
default
=
2048
,
help
=
"Max sequence length."
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
,
help
=
"A seed for reproducible training."
)
# optim
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
1e-5
,
help
=
"Learning rate."
)
parser
.
add_argument
(
"--weight_decay"
,
type
=
float
,
default
=
0.0
,
help
=
"Weight decay to use."
)
# lr scheduler
parser
.
add_argument
(
"--num_epochs"
,
type
=
int
,
default
=
1
,
help
=
"Number of training epochs"
)
parser
.
add_argument
(
"--warmup_steps"
,
type
=
int
,
default
=
None
,
help
=
"Warmup steps"
)
# zero stage for all plugins
parser
.
add_argument
(
"--zero_stage"
,
type
=
int
,
default
=
2
,
help
=
"zero stage."
)
# hybrid plugin
parser
.
add_argument
(
"--pp_size"
,
type
=
int
,
default
=
2
,
help
=
"pp size for hybrid plugin"
)
parser
.
add_argument
(
"--dp_size"
,
type
=
int
,
default
=
1
,
help
=
"dp size for hybrid plugin"
)
parser
.
add_argument
(
"--ep_size"
,
type
=
int
,
default
=
2
,
help
=
"ep size for hybrid plugin"
)
parser
.
add_argument
(
"--microbatch_size"
,
type
=
int
,
default
=
1
,
help
=
"Microbatch size in pipeline for hybrid plugin"
)
# kernel
parser
.
add_argument
(
"--use_kernel"
,
action
=
"store_true"
,
help
=
"Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed."
,
)
parser
.
add_argument
(
"--use_layernorm_kernel"
,
action
=
"store_true"
,
help
=
"Use layernorm kernel. Need to install apex. Raise error if not installed."
,
)
# load balance
parser
.
add_argument
(
"--load_balance"
,
action
=
"store_true"
,
help
=
"Expert load balance. Defaults to False. Recommend to enable."
)
parser
.
add_argument
(
"--load_balance_interval"
,
type
=
int
,
default
=
1000
,
help
=
"Expert load balance interval."
)
# communicate overlap
parser
.
add_argument
(
"--comm_overlap"
,
action
=
"store_true"
,
help
=
"Use communication overlap for MoE. Recommended to enable for muiti-node training."
,
)
# hierarchical all-to-all
parser
.
add_argument
(
"--hierarchical_alltoall"
,
action
=
"store_true"
,
help
=
"Use hierarchical all-to-all for MoE. Recommended to enable for muiti-node training."
,
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
parse_args
()
# Launch ColossalAI
colossalai
.
launch_from_torch
(
config
=
{},
seed
=
args
.
seed
)
coordinator
=
DistCoordinator
()
# Set plugin
if
args
.
plugin
==
"hybrid"
:
plugin
=
MoeHybridParallelPlugin
(
tp_size
=
1
,
pp_size
=
args
.
pp_size
,
ep_size
=
args
.
ep_size
,
microbatch_size
=
args
.
microbatch_size
,
custom_policy
=
MixtralForCausalLMPolicy
(),
enable_fused_normalization
=
args
.
use_layernorm_kernel
,
enable_jit_fused
=
args
.
use_kernel
,
precision
=
args
.
precision
,
zero_stage
=
args
.
zero_stage
,
checkpoint_io
=
MixtralMoEHybridParallelCheckpointIO
,
)
else
:
raise
ValueError
(
f
"Invalid plugin
{
args
.
plugin
}
"
)
coordinator
.
print_on_master
(
f
"Set plugin as
{
plugin
.
__class__
.
__name__
}
"
)
# Build Mixtral model
model
=
MixtralForCausalLM
.
from_pretrained
(
args
.
model_name
)
coordinator
.
print_on_master
(
f
"Finish init model"
)
# Enable gradient checkpointing
model
.
gradient_checkpointing_enable
()
# Prepare tokenizer and dataloader
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_name
)
dataset
=
RandomDataset
(
num_samples
=
100
,
tokenizer
=
tokenizer
)
collate_fn
=
None
dataloader
=
plugin
.
prepare_dataloader
(
dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
,
drop_last
=
True
,
collate_fn
=
collate_fn
)
# Set optimizer
optimizer
=
HybridAdam
(
model_params
=
model
.
parameters
(),
lr
=
args
.
lr
,
betas
=
(
0.9
,
0.95
),
weight_decay
=
args
.
weight_decay
,
adamw_mode
=
True
,
)
# Set lr scheduler
lr_scheduler
=
CosineAnnealingWarmupLR
(
optimizer
=
optimizer
,
total_steps
=
args
.
num_epochs
*
len
(
dataloader
),
warmup_steps
=
args
.
warmup_steps
if
args
.
warmup_steps
is
not
None
else
int
(
args
.
num_epochs
*
len
(
dataloader
)
*
0.025
),
eta_min
=
0.1
*
args
.
lr
,
)
# Set booster
booster
=
Booster
(
plugin
=
plugin
)
model
,
optimizer
,
_
,
dataloader
,
lr_scheduler
=
booster
.
boost
(
model
=
model
,
optimizer
=
optimizer
,
lr_scheduler
=
lr_scheduler
,
dataloader
=
dataloader
,
)
use_pipeline
=
isinstance
(
booster
.
plugin
,
MoeHybridParallelPlugin
)
and
booster
.
plugin
.
pp_size
>
1
is_pp_last_stage
=
use_pipeline
and
booster
.
plugin
.
stage_manager
.
is_last_stage
()
coordinator
.
print_on_master
(
f
"Finish init booster"
)
# Load ckpt
if
args
.
load_checkpoint
is
not
None
:
load_checkpoint
(
args
.
load_checkpoint
,
booster
,
model
,
optimizer
,
lr_scheduler
)
coordinator
.
print_on_master
(
f
"Finish load optimizer"
)
# Start finetuning
coordinator
.
print_on_master
(
f
"Start finetuning"
)
for
epoch
in
range
(
args
.
num_epoch
):
model
.
train
()
train_dataloader_iter
=
iter
(
dataloader
)
total_len
=
len
(
train_dataloader_iter
)
with
tqdm
(
range
(
total_len
),
desc
=
f
"Epoch [
{
epoch
+
1
}
/
{
args
.
num_epoch
}
]"
,
disable
=
not
coordinator
.
is_master
()
if
use_pipeline
==
False
else
not
is_pp_last_stage
,
)
as
pbar
:
for
step
in
pbar
:
if
use_pipeline
:
# Forward pass
outputs
=
booster
.
execute_pipeline
(
train_dataloader_iter
,
model
,
lambda
x
,
y
:
x
.
loss
,
optimizer
,
return_loss
=
True
,
return_outputs
=
True
,
)
# Backward and optimize
if
is_pp_last_stage
:
loss
=
outputs
[
"loss"
]
global_loss
=
get_global_loss
(
loss
,
booster
)
if
coordinator
.
_local_rank
==
"0"
:
pbar
.
set_postfix
({
"Loss"
:
global_loss
.
item
()})
else
:
# Forward pass
data
=
next
(
train_dataloader_iter
)
data
=
move_to_cuda
(
data
,
torch
.
cuda
.
current_device
())
outputs
=
model
(
**
data
)
loss
=
outputs
[
"loss"
]
# Backward
booster
.
backward
(
loss
,
optimizer
)
pbar
.
set_postfix
({
"loss"
:
loss
.
item
()})
optimizer
.
step
()
lr_scheduler
.
step
()
optimizer
.
zero_grad
()
# Apply load balance
# if (
# args.load_balance
# and args.load_balance_interval > 0
# and (step + 1) % args.load_balance_interval == 0
# ):
# coordinator.print_on_master(f"Apply load balance")
# apply_load_balance(model, optimizer)
# save ckeckpoint
if
(
step
+
1
)
%
args
.
save_interval
==
0
:
coordinator
.
print_on_master
(
f
"Saving model checkpoint to
{
args
.
output_path
}
"
)
save_checkpoint
(
args
.
output_path
,
booster
,
model
,
optimizer
,
lr_scheduler
,
epoch
,
step
,
args
.
batch_size
,
coordinator
,
)
# save checkpoint at the end of each epochs
booster
.
save_model
(
model
,
args
.
output_path
,
shard
=
True
,
size_per_shard
=
5120
)
coordinator
.
print_on_master
(
f
"Saving model checkpoint to
{
args
.
output_path
}
"
)
# Finish training
coordinator
.
print_on_master
(
f
"Finish training"
)
if
__name__
==
"__main__"
:
main
()
applications/ColossalMoE/train.sh
0 → 100644
View file @
efef43b5
NUM_GPU
=
8
MODEL
=
"mistralai/Mixtral-8x7B-v0.1"
SEQ_LENGTH
=
2048
BATCH_SIZE
=
1
LR
=
0.00001
# hybrid
# torchrun --standalone --nproc_per_node $NUM_GPU \
colossalai run
--nproc_per_node
$NUM_GPU
--hostfile
"hostfile"
\
train.py
\
--num_epoch
1
\
--model_name
$MODEL
\
--plugin
"hybrid"
\
--batch_size
$BATCH_SIZE
\
--lr
$LR
\
--zero_stage
1
\
--pp_size
2
\
--dp_size
1
\
--ep_size
8
\
applications/ColossalMoE/version.txt
0 → 100644
View file @
efef43b5
1.0.0
colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
View file @
efef43b5
...
...
@@ -22,7 +22,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
)
from
colossalai.cluster
import
ProcessGroupMesh
from
colossalai.interface
import
ModelWrapper
,
OptimizerWrapper
from
colossalai.moe
import
MoECheckpintIO
from
colossalai.moe
import
MOE_MANAGER
,
MoECheckpintIO
from
colossalai.pipeline.schedule
import
OneForwardOneBackwardSchedule
from
colossalai.pipeline.stage_manager
import
PipelineStageManager
from
colossalai.shardformer
import
ShardConfig
...
...
@@ -150,6 +150,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self
,
tp_size
:
int
,
pp_size
:
int
,
ep_size
:
int
,
extra_dp_size
:
int
=
1
,
precision
:
str
=
"fp16"
,
zero_stage
:
int
=
0
,
...
...
@@ -181,6 +182,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
overlap_communication
:
bool
=
True
,
use_ep_inside
:
bool
=
True
,
custom_policy
:
Policy
=
None
,
checkpoint_io
:
Optional
[
MoECheckpintIO
]
=
None
,
)
->
None
:
assert
(
dist
.
get_world_size
()
%
(
tp_size
*
pp_size
)
==
0
...
...
@@ -188,10 +190,26 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
if
enable_sequence_parallelism
:
assert
tp_size
>
1
,
"Sequence parallelism must be enabled when using tensor parallelism"
assert
(
dist
.
get_world_size
()
%
(
tp_size
*
pp_size
)
==
0
),
f
"world size
{
dist
.
get_world_size
()
}
is not divisible by tp_size
{
tp_size
}
* pp_size
{
pp_size
}
"
assert
(
dist
.
get_world_size
()
%
(
tp_size
*
pp_size
*
ep_size
)
==
0
),
f
"world size
{
dist
.
get_world_size
()
}
is not divisible by tp_size
{
tp_size
}
* pp_size
{
pp_size
}
* ep_size
{
ep_size
}
"
self
.
real_dp_size
=
dist
.
get_world_size
()
//
(
tp_size
*
pp_size
*
ep_size
)
MOE_MANAGER
.
setup
(
parallel
=
"EP"
,
mode
=
"fixed"
,
fixed_dp_size
=
self
.
real_dp_size
,
fixed_ep_size
=
ep_size
,
fixed_pp_size
=
pp_size
,
use_ep_inside
=
use_ep_inside
,
)
self
.
tp_size
=
tp_size
self
.
pp_size
=
pp_size
self
.
dp_size
=
dist
.
get_world_size
()
//
(
tp_size
*
pp_size
)
self
.
ep_size
=
ep_size
self
.
moe_info
=
MOE_MANAGER
.
get_info
(
0
)[
1
]
self
.
precision
=
precision
self
.
zero_stage
=
zero_stage
self
.
cpu_offload
=
cpu_offload
...
...
@@ -200,6 +218,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self
.
enable_flash_attention
=
enable_flash_attention
self
.
enable_jit_fused
=
enable_jit_fused
self
.
enable_sequence_parallelism
=
enable_sequence_parallelism
self
.
checkpoint_io
=
checkpoint_io
# we change pg mesh to (pp, dp, tp) for better moe performance
self
.
pg_mesh
=
ProcessGroupMesh
(
self
.
pp_size
,
self
.
dp_size
,
self
.
tp_size
)
...
...
@@ -323,7 +342,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
)
def
get_checkpoint_io
(
self
)
->
MoECheckpintIO
:
self
.
checkpoint_io
=
MoECheckpintIO
(
self
.
dp_group
,
self
.
pp_group
,
self
.
tp_group
,
self
.
zero_stage
)
if
self
.
checkpoint_io
is
None
:
self
.
checkpoint_io
=
MoECheckpintIO
(
self
.
dp_group
,
self
.
pp_group
,
self
.
tp_group
,
self
.
zero_stage
)
else
:
self
.
checkpoint_io
=
self
.
checkpoint_io
(
self
.
dp_group
,
self
.
pp_group
,
self
.
tp_group
,
self
.
zero_stage
)
return
self
.
checkpoint_io
def
configure
(
...
...
colossalai/checkpoint_io/checkpoint_io_base.py
View file @
efef43b5
...
...
@@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from
colossalai.interface
import
ModelWrapper
from
.utils
import
has_index_file
from
.utils
import
SAFE_WEIGHTS_NAME
,
WEIGHTS_NAME
,
has_index_file
__all__
=
[
"CheckpointIO"
]
...
...
@@ -90,7 +90,15 @@ class CheckpointIO(ABC):
if
index_file_exists
:
self
.
load_sharded_model
(
model
,
index_file_path
,
strict
)
else
:
self
.
load_unsharded_model
(
model
,
checkpoint
,
strict
)
path
=
Path
(
checkpoint
,
SAFE_WEIGHTS_NAME
)
if
path
.
is_file
():
self
.
load_unsharded_model
(
model
,
str
(
path
),
strict
)
else
:
path
=
Path
(
checkpoint
,
WEIGHTS_NAME
)
if
path
.
is_file
():
self
.
load_unsharded_model
(
model
,
str
(
path
),
strict
)
else
:
self
.
load_unsharded_model
(
model
,
checkpoint
,
strict
)
return
origin_model
...
...
colossalai/moe/__init__.py
View file @
efef43b5
from
.checkpoint
import
MoECheckpintIO
from
.experts
import
MLPExperts
from
.layers
import
SparseMLP
from
.layers
import
SparseMLP
,
apply_load_balance
from
.manager
import
MOE_MANAGER
from
.routers
import
MoeRouter
,
Top1Router
,
Top2Router
,
TopKRouter
from
.utils
import
NormalNoiseGenerator
,
UniformNoiseGenerator
...
...
@@ -14,4 +15,6 @@ __all__ = [
"UniformNoiseGenerator"
,
"SparseMLP"
,
"MoECheckpintIO"
,
"MOE_MANAGER"
,
"apply_load_balance"
,
]
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment