Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
dceae851
Unverified
Commit
dceae851
authored
Jan 07, 2022
by
HELSON
Committed by
GitHub
Jan 07, 2022
Browse files
Added MoE parallel (#127)
parent
42741dd4
Changes
26
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
627 additions
and
16 deletions
+627
-16
colossalai/constants.py
colossalai/constants.py
+2
-1
colossalai/context/parallel_context.py
colossalai/context/parallel_context.py
+8
-0
colossalai/context/parallel_mode.py
colossalai/context/parallel_mode.py
+4
-0
colossalai/context/process_group_initializer/__init__.py
colossalai/context/process_group_initializer/__init__.py
+3
-1
colossalai/context/process_group_initializer/initializer_moe.py
...alai/context/process_group_initializer/initializer_moe.py
+97
-0
colossalai/context/random/__init__.py
colossalai/context/random/__init__.py
+3
-2
colossalai/context/random/_helper.py
colossalai/context/random/_helper.py
+15
-2
colossalai/context/random/seed_manager.py
colossalai/context/random/seed_manager.py
+6
-2
colossalai/engine/gradient_handler/__init__.py
colossalai/engine/gradient_handler/__init__.py
+3
-1
colossalai/engine/gradient_handler/_moe_gradient_handler.py
colossalai/engine/gradient_handler/_moe_gradient_handler.py
+61
-0
colossalai/engine/schedule/_base_schedule.py
colossalai/engine/schedule/_base_schedule.py
+3
-2
colossalai/global_variables.py
colossalai/global_variables.py
+36
-0
colossalai/initialize.py
colossalai/initialize.py
+15
-3
colossalai/nn/layer/moe/__init__.py
colossalai/nn/layer/moe/__init__.py
+8
-0
colossalai/nn/layer/moe/_operation.py
colossalai/nn/layer/moe/_operation.py
+29
-0
colossalai/nn/layer/moe/layers.py
colossalai/nn/layer/moe/layers.py
+242
-0
colossalai/nn/layer/vanilla/__init__.py
colossalai/nn/layer/vanilla/__init__.py
+4
-2
colossalai/nn/layer/vanilla/layers.py
colossalai/nn/layer/vanilla/layers.py
+53
-0
colossalai/nn/loss/__init__.py
colossalai/nn/loss/__init__.py
+1
-0
colossalai/nn/loss/loss_moe.py
colossalai/nn/loss/loss_moe.py
+34
-0
No files found.
colossalai/constants.py
View file @
dceae851
...
@@ -15,7 +15,8 @@ INITIALIZER_MAPPING = {
...
@@ -15,7 +15,8 @@ INITIALIZER_MAPPING = {
'2.5d'
:
'Initializer_2p5D'
,
'2.5d'
:
'Initializer_2p5D'
,
'3d'
:
'Initializer_3D'
,
'3d'
:
'Initializer_3D'
,
'sequence'
:
'Initializer_Sequence'
,
'sequence'
:
'Initializer_Sequence'
,
'model'
:
'Initializer_Model'
'model'
:
'Initializer_Model'
,
'moe'
:
'Initializer_Moe'
}
}
# 1D parallel
# 1D parallel
...
...
colossalai/context/parallel_context.py
View file @
dceae851
...
@@ -15,6 +15,7 @@ from colossalai.registry import DIST_GROUP_INITIALIZER
...
@@ -15,6 +15,7 @@ from colossalai.registry import DIST_GROUP_INITIALIZER
from
.parallel_mode
import
ParallelMode
from
.parallel_mode
import
ParallelMode
from
.random
import
add_seed
,
get_seeds
,
set_mode
from
.random
import
add_seed
,
get_seeds
,
set_mode
from
colossalai.global_variables
import
moe_env
class
ParallelContext
:
class
ParallelContext
:
...
@@ -412,6 +413,13 @@ class ParallelContext:
...
@@ -412,6 +413,13 @@ class ParallelContext:
# add this config to initialize later
# add this config to initialize later
pg_init
.
append
(
dict
(
type
=
INITIALIZER_MAPPING
[
tensor_parallel_mode
.
lower
()],
**
tensor_parallel_cfg
))
pg_init
.
append
(
dict
(
type
=
INITIALIZER_MAPPING
[
tensor_parallel_mode
.
lower
()],
**
tensor_parallel_cfg
))
# initialization for moe environment
if
parallel_config
is
not
None
and
'moe'
in
parallel_config
:
param
=
parallel_config
[
'moe'
]
assert
'size'
in
param
,
"Moe model parallel size should be given"
moe_env
.
setup
(
param
[
'size'
])
pg_init
.
append
(
dict
(
type
=
INITIALIZER_MAPPING
[
'moe'
]))
# run initialization of different process groups
# run initialization of different process groups
for
initializer_cfg
in
pg_init
:
for
initializer_cfg
in
pg_init
:
cfg
=
initializer_cfg
.
copy
()
cfg
=
initializer_cfg
.
copy
()
...
...
colossalai/context/parallel_mode.py
View file @
dceae851
...
@@ -44,3 +44,7 @@ class ParallelMode(Enum):
...
@@ -44,3 +44,7 @@ class ParallelMode(Enum):
PARALLEL_2P5D_COL
=
'2p5d_col'
PARALLEL_2P5D_COL
=
'2p5d_col'
PARALLEL_2P5D_DEP
=
'2p5d_dep'
PARALLEL_2P5D_DEP
=
'2p5d_dep'
PARALLEL_2P5D_XZ
=
'2p5d_xz'
PARALLEL_2P5D_XZ
=
'2p5d_xz'
# MOE parallel
MOE_DATA
=
'moe_data'
MOE_MODEL
=
'moe_model'
colossalai/context/process_group_initializer/__init__.py
View file @
dceae851
...
@@ -7,10 +7,12 @@ from .initializer_pipeline import Initializer_Pipeline
...
@@ -7,10 +7,12 @@ from .initializer_pipeline import Initializer_Pipeline
from
.initializer_sequence
import
Initializer_Sequence
from
.initializer_sequence
import
Initializer_Sequence
from
.initializer_tensor
import
Initializer_Tensor
from
.initializer_tensor
import
Initializer_Tensor
from
.initializer_model
import
Initializer_Model
from
.initializer_model
import
Initializer_Model
from
.initializer_moe
import
Initializer_Moe
from
.process_group_initializer
import
ProcessGroupInitializer
from
.process_group_initializer
import
ProcessGroupInitializer
__all__
=
[
__all__
=
[
'Initializer_Tensor'
,
'Initializer_Sequence'
,
'Initializer_Pipeline'
,
'Initializer_Tensor'
,
'Initializer_Sequence'
,
'Initializer_Pipeline'
,
'Initializer_Data'
,
'Initializer_2p5D'
,
'Initializer_2D'
,
'Initializer_3D'
,
'Initializer_Data'
,
'Initializer_2p5D'
,
'Initializer_2D'
,
'Initializer_3D'
,
'Initializer_1D'
,
'ProcessGroupInitializer'
,
'Initializer_Model'
'Initializer_1D'
,
'ProcessGroupInitializer'
,
'Initializer_Model'
,
'Initializer_Moe'
]
]
colossalai/context/process_group_initializer/initializer_moe.py
0 → 100644
View file @
dceae851
import
torch.distributed
as
dist
from
colossalai.registry
import
DIST_GROUP_INITIALIZER
from
colossalai.global_variables
import
moe_env
from
.process_group_initializer
import
ProcessGroupInitializer
from
..parallel_mode
import
ParallelMode
@
DIST_GROUP_INITIALIZER
.
register_module
class
Initializer_Moemodel
(
ProcessGroupInitializer
):
"""Model parallel initialization for MoE system.
"""
def
__init__
(
self
,
moe_model
,
moe_data
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
moe_model
=
moe_model
self
.
moe_data
=
moe_data
def
init_dist_group
(
self
):
"""Initialize model parallel groups in moe parallel environment,
and assign local_ranks and groups to each gpu.
"""
local_rank
=
None
ranks_in_group
=
None
process_group
=
None
group_world_size
=
None
mode
=
ParallelMode
.
MOE_MODEL
for
i
in
range
(
self
.
moe_data
):
ranks
=
[
i
*
self
.
moe_model
+
j
for
j
in
range
(
self
.
moe_model
)]
group
=
dist
.
new_group
(
ranks
)
if
self
.
rank
in
ranks
:
local_rank
=
ranks
.
index
(
self
.
rank
)
group_world_size
=
len
(
ranks
)
process_group
=
group
ranks_in_group
=
ranks
return
local_rank
,
group_world_size
,
process_group
,
ranks_in_group
,
mode
@
DIST_GROUP_INITIALIZER
.
register_module
class
Initializer_Moedata
(
ProcessGroupInitializer
):
"""Data parallel initialization for MoE system.
"""
def
__init__
(
self
,
moe_model
,
moe_data
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
moe_model
=
moe_model
self
.
moe_data
=
moe_data
def
init_dist_group
(
self
):
"""Initialize data parallel groups in moe parallel environment,
and assign local_ranks and groups to each gpu.
"""
local_rank
=
None
ranks_in_group
=
None
process_group
=
None
group_world_size
=
None
mode
=
ParallelMode
.
MOE_DATA
for
i
in
range
(
self
.
moe_model
):
ranks
=
[
i
+
j
*
self
.
moe_model
for
j
in
range
(
self
.
moe_data
)]
group
=
dist
.
new_group
(
ranks
)
if
self
.
rank
in
ranks
:
local_rank
=
ranks
.
index
(
self
.
rank
)
group_world_size
=
len
(
ranks
)
process_group
=
group
ranks_in_group
=
ranks
return
local_rank
,
group_world_size
,
process_group
,
ranks_in_group
,
mode
@
DIST_GROUP_INITIALIZER
.
register_module
class
Initializer_Moe
(
ProcessGroupInitializer
):
"""Serves as the single entry point to MoE parallel initialization.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
moe_model
=
moe_env
.
model_parallel_size
self
.
moe_data
=
moe_env
.
data_parallel_size
self
.
model_initializer
=
Initializer_Moemodel
(
self
.
moe_model
,
self
.
moe_data
,
*
args
,
**
kwargs
)
self
.
data_initializer
=
Initializer_Moedata
(
self
.
moe_model
,
self
.
moe_data
,
*
args
,
**
kwargs
)
def
init_dist_group
(
self
):
"""Initializes MoE parallel communication groups.
"""
parallel_setting
=
[
self
.
model_initializer
.
init_dist_group
(),
self
.
data_initializer
.
init_dist_group
()]
return
parallel_setting
colossalai/context/random/__init__.py
View file @
dceae851
from
._helper
import
(
seed
,
set_mode
,
with_seed
,
add_seed
,
from
._helper
import
(
seed
,
set_mode
,
with_seed
,
add_seed
,
get_seeds
,
get_states
,
get_current_mode
,
get_seeds
,
get_states
,
get_current_mode
,
set_seed_states
,
sync_states
)
set_seed_states
,
sync_states
,
moe_set_seed
)
__all__
=
[
__all__
=
[
'seed'
,
'set_mode'
,
'with_seed'
,
'add_seed'
,
'get_seeds'
,
'seed'
,
'set_mode'
,
'with_seed'
,
'add_seed'
,
'get_seeds'
,
'get_states'
,
'get_current_mode'
,
'set_seed_states'
,
'sync_states'
'get_states'
,
'get_current_mode'
,
'set_seed_states'
,
'sync_states'
,
'moe_set_seed'
]
]
colossalai/context/random/_helper.py
View file @
dceae851
...
@@ -49,7 +49,7 @@ def get_current_mode():
...
@@ -49,7 +49,7 @@ def get_current_mode():
return
_SEED_MANAGER
.
current_mode
return
_SEED_MANAGER
.
current_mode
def
add_seed
(
parallel_mode
:
ParallelMode
,
seed
:
int
):
def
add_seed
(
parallel_mode
:
ParallelMode
,
seed
:
int
,
overwrite
:
bool
=
False
):
"""Adds a seed to the seed manager for `parallel_mode`.
"""Adds a seed to the seed manager for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:param parallel_mode: The chosen parallel mode
...
@@ -59,7 +59,7 @@ def add_seed(parallel_mode: ParallelMode, seed: int):
...
@@ -59,7 +59,7 @@ def add_seed(parallel_mode: ParallelMode, seed: int):
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of
:class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added
:class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added
"""
"""
_SEED_MANAGER
.
add_seed
(
parallel_mode
,
seed
)
_SEED_MANAGER
.
add_seed
(
parallel_mode
,
seed
,
overwrite
)
def
set_mode
(
parallel_mode
:
ParallelMode
):
def
set_mode
(
parallel_mode
:
ParallelMode
):
...
@@ -142,3 +142,16 @@ def with_seed(func, parallel_mode: ParallelMode):
...
@@ -142,3 +142,16 @@ def with_seed(func, parallel_mode: ParallelMode):
return
out
return
out
return
wrapper
return
wrapper
def
moe_set_seed
(
seed
):
if
torch
.
cuda
.
is_available
():
from
colossalai.core
import
global_context
as
gpc
moe_mp_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
MOE_MODEL
)
moe_mp_seed
=
seed
+
moe_mp_rank
add_seed
(
ParallelMode
.
MOE_MODEL
,
moe_mp_seed
)
global_rank
=
gpc
.
get_global_rank
()
add_seed
(
ParallelMode
.
TENSOR
,
global_rank
,
True
)
print
(
f
"moe seed condition:
{
global_rank
}
with moe seed
{
moe_mp_seed
}
, "
,
f
"tensor seed
{
global_rank
}
"
,
flush
=
True
)
colossalai/context/random/seed_manager.py
View file @
dceae851
...
@@ -54,7 +54,7 @@ class SeedManager:
...
@@ -54,7 +54,7 @@ class SeedManager:
self
.
_current_mode
=
parallel_mode
self
.
_current_mode
=
parallel_mode
torch
.
cuda
.
set_rng_state
(
self
.
_seed_states
[
parallel_mode
])
torch
.
cuda
.
set_rng_state
(
self
.
_seed_states
[
parallel_mode
])
def
add_seed
(
self
,
parallel_mode
:
ParallelMode
,
seed
:
int
):
def
add_seed
(
self
,
parallel_mode
:
ParallelMode
,
seed
:
int
,
overwrtie
:
bool
=
False
):
"""Adds a seed to the seed manager for `parallel_mode`.
"""Adds a seed to the seed manager for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:param parallel_mode: The chosen parallel mode
...
@@ -66,7 +66,11 @@ class SeedManager:
...
@@ -66,7 +66,11 @@ class SeedManager:
"""
"""
assert
isinstance
(
assert
isinstance
(
parallel_mode
,
ParallelMode
),
'A valid ParallelMode must be provided'
parallel_mode
,
ParallelMode
),
'A valid ParallelMode must be provided'
if
overwrtie
is
False
:
assert
parallel_mode
not
in
self
.
_seed_states
,
f
'The seed for
{
parallel_mode
}
has been added'
assert
parallel_mode
not
in
self
.
_seed_states
,
f
'The seed for
{
parallel_mode
}
has been added'
elif
parallel_mode
in
self
.
_seed_states
:
print
(
f
"Warnning:
{
parallel_mode
}
seed has been overwritten."
,
flush
=
True
)
current_state
=
torch
.
cuda
.
get_rng_state
()
current_state
=
torch
.
cuda
.
get_rng_state
()
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
self
.
_seed_states
[
parallel_mode
]
=
torch
.
cuda
.
get_rng_state
()
self
.
_seed_states
[
parallel_mode
]
=
torch
.
cuda
.
get_rng_state
()
...
...
colossalai/engine/gradient_handler/__init__.py
View file @
dceae851
...
@@ -2,6 +2,8 @@ from ._base_gradient_handler import BaseGradientHandler
...
@@ -2,6 +2,8 @@ from ._base_gradient_handler import BaseGradientHandler
from
._data_parallel_gradient_handler
import
DataParallelGradientHandler
from
._data_parallel_gradient_handler
import
DataParallelGradientHandler
from
._zero_gradient_handler
import
ZeROGradientHandler
from
._zero_gradient_handler
import
ZeROGradientHandler
from
._pipeline_parallel_gradient_handler
import
PipelineSharedModuleGradientHandler
from
._pipeline_parallel_gradient_handler
import
PipelineSharedModuleGradientHandler
from
._moe_gradient_handler
import
MoeGradientHandler
__all__
=
[
'BaseGradientHandler'
,
'DataParallelGradientHandler'
,
__all__
=
[
'BaseGradientHandler'
,
'DataParallelGradientHandler'
,
'ZeROGradientHandler'
,
'PipelineSharedModuleGradientHandler'
]
'ZeROGradientHandler'
,
'PipelineSharedModuleGradientHandler'
,
'MoeGradientHandler'
]
colossalai/engine/gradient_handler/_moe_gradient_handler.py
0 → 100644
View file @
dceae851
import
torch.distributed
as
dist
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
colossalai.core
import
global_context
as
gpc
from
colossalai.registry
import
GRADIENT_HANDLER
from
colossalai.global_variables
import
moe_env
from
._base_gradient_handler
import
BaseGradientHandler
from
...context.parallel_mode
import
ParallelMode
@
GRADIENT_HANDLER
.
register_module
class
MoeGradientHandler
(
BaseGradientHandler
):
"""A helper class to handle all-reduce operations in a data parallel group and
moe tensor parallel. A all-reduce collective communication will be operated in
:func:`handle_gradient` among a data parallel group.
For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.
"""
def
handle_gradient
(
self
):
"""A method running an all-reduce operation in a data parallel group.
Then running an all-reduce operation for all parameters in experts
across moe tensor parallel group
"""
moe_data
=
moe_env
.
data_parallel_size
global_data
=
gpc
.
data_parallel_size
if
global_data
>
1
:
# bucketize and all-reduce
buckets
=
{}
# Pack the buckets.
for
param
in
self
.
_model
.
parameters
():
if
param
.
requires_grad
and
\
param
.
grad
is
not
None
and
\
not
hasattr
(
param
,
'moe_param'
):
tp
=
param
.
data
.
type
()
if
tp
not
in
buckets
:
buckets
[
tp
]
=
[]
buckets
[
tp
].
append
(
param
)
# param.main_grad = param.grad
# For each bucket, all-reduce and copy all-reduced grads.
for
tp
in
buckets
:
bucket
=
buckets
[
tp
]
grads
=
[
param
.
grad
.
data
for
param
in
bucket
]
coalesced
=
_flatten_dense_tensors
(
grads
)
coalesced
/=
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
dist
.
all_reduce
(
coalesced
,
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
if
global_data
>
1
:
for
param
in
self
.
_model
.
parameters
():
if
not
param
.
requires_grad
or
param
.
grad
is
None
:
continue
if
moe_data
>
1
and
hasattr
(
param
,
'moe_param'
):
param
.
grad
.
data
/=
moe_data
dist
.
all_reduce
(
param
.
grad
.
data
,
group
=
gpc
.
get_group
(
ParallelMode
.
MOE_DATA
))
colossalai/engine/schedule/_base_schedule.py
View file @
dceae851
...
@@ -38,8 +38,9 @@ class BaseSchedule(ABC):
...
@@ -38,8 +38,9 @@ class BaseSchedule(ABC):
return
data
return
data
@
staticmethod
@
staticmethod
def
_check_sanity
(
data
,
tag
):
def
_check_sanity
(
data
,
tag
:
str
):
assert
isinstance
(
data
,
(
torch
.
Tensor
,
dict
)),
f
'
{
tag
}
must be torch.Tensor or dict'
assert
isinstance
(
data
,
(
torch
.
Tensor
,
dict
)),
\
f
'
{
tag
}
must be torch.Tensor or dict'
def
load_batch
(
self
,
data_iter
,
to_gpu
=
True
):
def
load_batch
(
self
,
data_iter
,
to_gpu
=
True
):
"""Loads a batch from data iterator. It returns the data and labels which are
"""Loads a batch from data iterator. It returns the data and labels which are
...
...
colossalai/global_variables.py
0 → 100644
View file @
dceae851
class
MoeEnv
:
"""Moe enviroment variable.
"""
def
__init__
(
self
):
self
.
data_parallel_size
=
None
self
.
model_parallel_size
=
None
self
.
aux_loss
=
None
def
setup
(
self
,
moe_model_size
):
from
.core
import
global_context
as
gpc
if
gpc
.
tensor_parallel_size
>
1
or
gpc
.
pipeline_parallel_size
>
1
:
raise
NotImplementedError
(
"Moe is not compatible with tensor or pipeline parallel"
)
assert
gpc
.
data_parallel_size
%
moe_model_size
==
0
,
\
"The size of data parallel needs to be divided by moe model parallel size"
self
.
data_parallel_size
=
gpc
.
data_parallel_size
//
moe_model_size
self
.
model_parallel_size
=
moe_model_size
def
is_initialized
(
self
):
return
self
.
model_parallel_size
is
not
None
def
reset_loss
(
self
):
self
.
aux_loss
=
0
def
add_loss
(
self
,
loss
):
self
.
aux_loss
+=
loss
def
get_loss
(
self
):
return
self
.
aux_loss
moe_env
=
MoeEnv
()
colossalai/initialize.py
View file @
dceae851
...
@@ -5,7 +5,6 @@ import argparse
...
@@ -5,7 +5,6 @@ import argparse
import
pprint
import
pprint
import
os
import
os
from
colossalai.nn.optimizer.colossalai_optimizer
import
ColossalaiOptimizer
from
colossalai.nn.optimizer.colossalai_optimizer
import
ColossalaiOptimizer
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -26,6 +25,7 @@ from torch.optim.lr_scheduler import _LRScheduler
...
@@ -26,6 +25,7 @@ from torch.optim.lr_scheduler import _LRScheduler
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
torch.nn.modules.loss
import
_Loss
from
torch.nn.modules.loss
import
_Loss
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
colossalai.global_variables
import
moe_env
def
get_default_parser
():
def
get_default_parser
():
...
@@ -224,7 +224,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
...
@@ -224,7 +224,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
test_dataloader
:
Optional
[
Union
[
Iterable
,
List
[
Iterable
]]]
=
None
,
test_dataloader
:
Optional
[
Union
[
Iterable
,
List
[
Iterable
]]]
=
None
,
lr_scheduler
:
_LRScheduler
=
None
,
lr_scheduler
:
_LRScheduler
=
None
,
verbose
:
bool
=
True
verbose
:
bool
=
True
)
->
Tuple
[
Engine
,
DataLoader
,
DataLoader
]:
)
->
Tuple
[
Engine
,
DataLoader
,
DataLoader
,
_LRScheduler
]:
''' Core function to wrap the essential training components with our functionality based on the config which is loaded into gpc.config.
''' Core function to wrap the essential training components with our functionality based on the config which is loaded into gpc.config.
:param model: your model instance
:param model: your model instance
...
@@ -269,8 +269,13 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
...
@@ -269,8 +269,13 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
# first sync model across dp ranks
# first sync model across dp ranks
model
.
to
(
get_current_device
())
model
.
to
(
get_current_device
())
use_zero3
=
hasattr
(
gpc
.
config
,
'zero'
)
and
gpc
.
config
.
zero
.
level
==
3
use_zero3
=
hasattr
(
gpc
.
config
,
'zero'
)
and
gpc
.
config
.
zero
.
level
==
3
if
not
use_zero3
:
if
not
moe_env
.
is_initialized
()
and
not
use_zero3
:
sync_model_param_in_dp
(
model
)
sync_model_param_in_dp
(
model
)
else
:
print
(
"Warning: The parameters of models is not automatically synchronized.
\n
"
"Please make sure that all parameters are the same in data parallel group."
,
flush
=
True
)
# check amp and zero
# check amp and zero
fp16_cfg
=
gpc
.
config
.
get
(
'fp16'
,
None
)
fp16_cfg
=
gpc
.
config
.
get
(
'fp16'
,
None
)
...
@@ -327,6 +332,13 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
...
@@ -327,6 +332,13 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
"Training with zero is detected, ZeROGradientHandler is automatically "
"Training with zero is detected, ZeROGradientHandler is automatically "
"added even though not specified in the configuration"
,
"added even though not specified in the configuration"
,
ranks
=
[
0
])
ranks
=
[
0
])
elif
is_using_ddp
()
and
moe_env
.
is_initialized
():
gradient_handler_cfg
=
[
dict
(
type
=
'MoeGradientHandler'
)]
if
verbose
:
logger
.
info
(
"Data parallel training is detected with moe parallel, MoeGradientHandler is automatically "
"added even though not specified in the configuration"
,
ranks
=
[
0
])
elif
is_using_ddp
()
and
not
is_using_pp
()
and
amp_mode
!=
AMP_TYPE
.
NAIVE
:
elif
is_using_ddp
()
and
not
is_using_pp
()
and
amp_mode
!=
AMP_TYPE
.
NAIVE
:
model
=
DDP
(
model
,
process_group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
model
=
DDP
(
model
,
process_group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
if
verbose
:
if
verbose
:
...
...
colossalai/nn/layer/moe/__init__.py
0 → 100644
View file @
dceae851
from
._operation
import
AllToAll
from
.layers
import
Experts
,
MoeLayer
,
\
NormalNoiseGenerator
,
Top1Router
,
Top2Router
__all__
=
[
'AllToAll'
,
'Experts'
,
'Top1Router'
,
'Top2Router'
,
'MoeLayer'
,
'NormalNoiseGenerator'
]
\ No newline at end of file
colossalai/nn/layer/moe/_operation.py
0 → 100644
View file @
dceae851
import
torch
import
torch.distributed
as
dist
from
torch
import
Tensor
from
colossalai.context
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
typing
import
Any
,
Tuple
class
AllToAll
(
torch
.
autograd
.
Function
):
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single
operation in torch.distributed.
"""
@
staticmethod
def
forward
(
ctx
:
Any
,
inputs
:
Tensor
,
parallel_mode
:
ParallelMode
)
->
Tensor
:
ctx
.
parallel_mode
=
parallel_mode
if
not
inputs
.
is_contiguous
():
inputs
=
inputs
.
contiguous
()
output
=
torch
.
empty_like
(
inputs
)
dist
.
all_to_all_single
(
output
,
inputs
,
group
=
gpc
.
get_group
(
parallel_mode
))
return
output
@
staticmethod
def
backward
(
ctx
:
Any
,
*
grad_outputs
:
Tensor
)
->
Tuple
[
Tensor
,
None
]:
return
AllToAll
.
apply
(
*
grad_outputs
,
ctx
.
parallel_mode
),
None
colossalai/nn/layer/moe/layers.py
0 → 100644
View file @
dceae851
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.cuda.amp
import
autocast
from
colossalai.global_variables
import
moe_env
from
colossalai.context
import
ParallelMode
,
seed
from
colossalai.utils
import
get_current_device
from
._operation
import
AllToAll
class
NormalNoiseGenerator
:
"""Generates a random noisy mask for logtis tensor.
All noise is generated from a normal distribution (0, 1 / E^2), where
E = the number of experts.
"""
def
__init__
(
self
,
num_experts
:
int
):
self
.
normal
=
torch
.
distributions
.
normal
.
Normal
(
loc
=
torch
.
tensor
(
0.0
,
device
=
get_current_device
()),
scale
=
torch
.
tensor
(
1.0
/
num_experts
**
2
,
device
=
get_current_device
())
).
rsample
def
__call__
(
self
,
inputs
:
torch
.
Tensor
):
noisy
=
self
.
normal
(
inputs
.
shape
)
return
inputs
+
noisy
class
Experts
(
nn
.
Module
):
"""A wrapper class to create experts. It will create E experts across the
moe model parallel group, where E is the number of experts. Every expert
is a instence of the class, 'expert' in initialization parameters.
"""
def
__init__
(
self
,
expert
,
num_experts
,
**
expert_args
):
super
().
__init__
()
assert
num_experts
%
moe_env
.
model_parallel_size
==
0
,
\
"The number of experts should be divied by moe model size"
num_local_experts
=
num_experts
//
moe_env
.
model_parallel_size
with
seed
(
ParallelMode
.
MOE_MODEL
):
self
.
experts
=
nn
.
ModuleList
([
expert
(
**
expert_args
)
for
_
in
range
(
num_local_experts
)])
self
.
num_local_experts
=
num_local_experts
for
exp
in
self
.
experts
:
for
param
in
exp
.
parameters
():
param
.
__setattr__
(
'moe_param'
,
1
)
def
forward
(
self
,
inputs
):
expert_input
=
torch
.
chunk
(
inputs
,
self
.
num_local_experts
,
dim
=
0
)
expert_output
=
[]
for
i
in
range
(
self
.
num_local_experts
):
expert_output
.
append
(
self
.
experts
[
i
](
expert_input
[
i
]))
output
=
torch
.
cat
(
expert_output
,
dim
=
0
)
return
output
class
Top1Router
(
nn
.
Module
):
"""Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
for routing usage. More deailted function can be found in the paper about Switch Transformer
of Google.
"""
def
__init__
(
self
,
capacity_factor
:
float
,
min_capacity
:
int
,
noisy_func
=
None
):
super
().
__init__
()
self
.
capacity_factor
=
capacity_factor
self
.
min_capacity
=
min_capacity
self
.
noisy_func
=
noisy_func
self
.
uniform
=
torch
.
distributions
.
uniform
.
Uniform
(
low
=
torch
.
tensor
(
0.0
,
device
=
get_current_device
()),
high
=
torch
.
tensor
(
1.0
,
device
=
get_current_device
())).
rsample
def
get_capacity
(
self
,
logits_shape
):
capacity
=
math
.
ceil
(
self
.
capacity_factor
*
logits_shape
[
0
]
/
logits_shape
[
1
])
if
capacity
<
self
.
min_capacity
:
capacity
=
self
.
min_capacity
return
capacity
def
forward
(
self
,
inputs
):
if
self
.
noisy_func
is
not
None
:
inputs_noisy
=
self
.
noisy_func
(
inputs
)
else
:
inputs_noisy
=
inputs
logits
=
F
.
softmax
(
inputs
,
dim
=
1
)
num_experts
=
logits
.
shape
[
1
]
capacity
=
self
.
get_capacity
(
logits
.
shape
)
expert_idx
=
torch
.
argmax
(
inputs_noisy
,
dim
=
1
)
expert_mask
=
F
.
one_hot
(
expert_idx
,
num_classes
=
num_experts
)
expert_mask_f
=
expert_mask
.
float
()
exp_counts
=
torch
.
sum
(
expert_mask
,
dim
=
0
).
detach
().
to
(
'cpu'
)
me
=
torch
.
mean
(
logits
,
dim
=
0
)
ce
=
torch
.
mean
(
expert_mask_f
,
dim
=
0
)
l_aux
=
torch
.
sum
(
me
*
ce
)
*
num_experts
moe_env
.
add_loss
(
l_aux
)
rand_mask
=
expert_mask
*
self
.
uniform
(
logits
.
shape
)
_
,
dispatch_idx
=
torch
.
topk
(
rand_mask
,
k
=
capacity
,
dim
=
0
)
dispatch_mask
=
\
expert_mask
*
torch
.
zeros_like
(
expert_mask
).
scatter_
(
0
,
dispatch_idx
,
1
)
locations
=
torch
.
cumsum
(
dispatch_mask
,
dim
=
0
)
-
1
locations
=
torch
.
sum
(
dispatch_mask
*
locations
,
dim
=
1
)
locations
=
F
.
one_hot
(
locations
,
num_classes
=
capacity
)
logits
=
logits
*
dispatch_mask
combine_weights
=
logits
.
unsqueeze
(
2
)
*
locations
.
unsqueeze
(
1
)
sec_mask
=
combine_weights
.
bool
()
return
combine_weights
,
sec_mask
,
exp_counts
class
Top2Router
(
nn
.
Module
):
"""Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
for routing usage. More deailted function can be found in the paper about ViT-MoE.
"""
def
__init__
(
self
,
capacity_factor
:
float
,
noisy_func
=
None
):
super
().
__init__
()
self
.
capacity_factor
=
capacity_factor
self
.
noisy_func
=
noisy_func
def
get_capacity
(
self
,
logits_shape
):
capacity
=
math
.
ceil
(
2
*
self
.
capacity_factor
*
logits_shape
[
0
]
/
logits_shape
[
1
])
return
capacity
def
forward
(
self
,
inputs
):
if
self
.
noisy_func
is
not
None
:
inputs
=
self
.
noisy_func
(
inputs
)
logits
=
F
.
softmax
(
inputs
,
dim
=-
1
)
num_experts
=
logits
.
size
(
-
1
)
capacity
=
self
.
get_capacity
(
logits
.
shape
)
_
,
expert_idx
=
torch
.
topk
(
logits
,
k
=
2
,
dim
=-
1
,
largest
=
True
,
sorted
=
True
)
top1_idx
=
expert_idx
[:,
0
]
top2_idx
=
expert_idx
[:,
1
]
mask1
=
F
.
one_hot
(
top1_idx
,
num_classes
=
num_experts
)
mask2
=
F
.
one_hot
(
top2_idx
,
num_classes
=
num_experts
)
loss_mask
=
(
mask1
+
mask2
)
exp_counts
=
torch
.
sum
(
loss_mask
,
dim
=
0
).
detach
().
to
(
'cpu'
)
me
=
torch
.
mean
(
logits
,
dim
=
0
)
ce
=
torch
.
mean
(
loss_mask
.
float
(),
dim
=
0
)
l_aux
=
num_experts
*
torch
.
sum
(
me
*
ce
)
/
2.0
moe_env
.
add_loss
(
l_aux
)
locations1
=
torch
.
cumsum
(
mask1
,
dim
=
0
)
-
1
locations2
=
torch
.
cumsum
(
mask2
,
dim
=
0
)
-
1
locations2
+=
torch
.
sum
(
mask1
,
dim
=
0
,
keepdim
=
True
)
mask1
*=
torch
.
lt
(
locations1
,
capacity
)
mask2
*=
torch
.
lt
(
locations2
,
capacity
)
weight1
=
mask1
*
logits
weight2
=
mask2
*
logits
locations1
=
torch
.
sum
(
mask1
*
locations1
,
dim
=
1
)
locations2
=
torch
.
sum
(
mask2
*
locations2
,
dim
=
1
)
locations1_sc
=
F
.
one_hot
(
locations1
,
num_classes
=
capacity
)
locations2_sc
=
F
.
one_hot
(
locations2
,
num_classes
=
capacity
)
combine_weights1
=
weight1
.
unsqueeze
(
2
)
*
locations1_sc
.
unsqueeze
(
1
)
combine_weights2
=
weight2
.
unsqueeze
(
2
)
*
locations2_sc
.
unsqueeze
(
1
)
combine_weights
=
combine_weights1
+
combine_weights2
sec_mask
=
combine_weights
.
bool
()
return
combine_weights
,
sec_mask
,
exp_counts
class
MoeLayer
(
nn
.
Module
):
"""A MoE layer, that puts its input tensor to its gate and uses the output logits
to router all tokens, is mainly used to exchange all tokens for every expert across
the moe tensor group by all to all comunication. Then it will get the output of all
experts and exchange the output. At last returns the output of the moe system.
"""
def
__init__
(
self
,
dim_model
:
int
,
num_experts
:
int
,
router
:
nn
.
Module
,
experts
:
nn
.
Module
):
super
().
__init__
()
self
.
d_model
=
dim_model
self
.
num_experts
=
num_experts
self
.
gate
=
nn
.
Linear
(
dim_model
,
num_experts
,
device
=
get_current_device
())
self
.
router
=
router
self
.
experts
=
experts
def
_router_part
(
self
,
tokens
:
torch
.
Tensor
):
gate_output
=
self
.
gate
(
tokens
)
return
self
.
router
(
gate_output
)
def
router_part
(
self
,
tokens
:
torch
.
Tensor
):
autocast_context
=
torch
.
is_autocast_enabled
()
if
not
autocast_context
:
return
self
.
_router_part
(
tokens
)
else
:
with
autocast
(
enabled
=
False
):
if
tokens
.
dtype
==
torch
.
float16
:
input_tokens
=
tokens
.
float
()
else
:
input_tokens
=
tokens
return
self
.
_router_part
(
input_tokens
)
def
forward
(
self
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
tokens
=
inputs
.
reshape
(
-
1
,
self
.
d_model
)
combine_weights
,
sec_mask
,
exp_counts
=
self
.
router_part
(
tokens
)
sec_mask_f
=
sec_mask
.
type_as
(
inputs
)
dispatch_data
=
torch
.
matmul
(
sec_mask_f
.
permute
(
1
,
2
,
0
),
tokens
)
dispatch_data
=
AllToAll
.
apply
(
dispatch_data
,
ParallelMode
.
MOE_MODEL
)
expert_output
=
self
.
experts
(
dispatch_data
)
expert_output
=
AllToAll
.
apply
(
expert_output
,
ParallelMode
.
MOE_MODEL
)
combine_weights
=
combine_weights
.
view
(
combine_weights
.
shape
[
0
],
-
1
)
expert_output
=
expert_output
.
view
(
-
1
,
expert_output
.
shape
[
-
1
])
ret
=
torch
.
matmul
(
combine_weights
,
expert_output
)
ret
=
ret
.
reshape
(
inputs
.
shape
)
return
ret
colossalai/nn/layer/vanilla/__init__.py
View file @
dceae851
from
.layers
import
DropPath
,
VanillaClassifier
,
VanillaPatchEmbedding
from
.layers
import
DropPath
,
VanillaClassifier
,
VanillaPatchEmbedding
,
\
WrappedDropout
,
WrappedDropPath
__all__
=
[
'VanillaPatchEmbedding'
,
'VanillaClassifier'
,
'DropPath'
]
__all__
=
[
'VanillaPatchEmbedding'
,
'VanillaClassifier'
,
'DropPath'
,
'WrappedDropout'
,
'WrappedDropPath'
]
colossalai/nn/layer/vanilla/layers.py
View file @
dceae851
...
@@ -10,6 +10,7 @@ from torch import Tensor, dtype
...
@@ -10,6 +10,7 @@ from torch import Tensor, dtype
from
torch
import
nn
as
nn
from
torch
import
nn
as
nn
from
..utils
import
to_2tuple
from
..utils
import
to_2tuple
from
colossalai.context
import
seed
def
drop_path
(
x
,
drop_prob
:
float
=
0.
,
training
:
bool
=
False
):
def
drop_path
(
x
,
drop_prob
:
float
=
0.
,
training
:
bool
=
False
):
...
@@ -42,6 +43,58 @@ class DropPath(nn.Module):
...
@@ -42,6 +43,58 @@ class DropPath(nn.Module):
return
drop_path
(
x
,
self
.
drop_prob
,
self
.
training
)
return
drop_path
(
x
,
self
.
drop_prob
,
self
.
training
)
class
WrappedDropout
(
nn
.
Module
):
"""Same as torch.nn.Dropout. But it is wrapped with the context of seed manager.
"""
def
__init__
(
self
,
p
:
float
=
0.5
,
inplace
:
bool
=
False
,
mode
=
None
):
super
().
__init__
()
if
p
<
0
or
p
>
1
:
raise
ValueError
(
"dropout probability has to be between 0 and 1, "
"but got {}"
.
format
(
p
))
self
.
p
=
p
self
.
inplace
=
inplace
if
mode
is
None
:
self
.
func
=
self
.
nonefunc
else
:
self
.
func
=
self
.
normalfunc
self
.
mode
=
mode
def
nonefunc
(
self
,
inputs
):
return
F
.
dropout
(
inputs
,
self
.
p
,
self
.
training
,
self
.
inplace
)
def
normalfunc
(
self
,
inputs
):
with
seed
(
self
.
mode
):
return
F
.
dropout
(
inputs
,
self
.
p
,
self
.
training
,
self
.
inplace
)
def
forward
(
self
,
inputs
):
return
self
.
func
(
inputs
)
class
WrappedDropPath
(
nn
.
Module
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Here, it is wrapped with the context of seed manager.
"""
def
__init__
(
self
,
p
:
float
=
0.
,
mode
=
None
):
super
().
__init__
()
self
.
p
=
p
self
.
mode
=
mode
if
self
.
mode
is
None
:
self
.
func
=
self
.
nonefunc
else
:
self
.
func
=
self
.
normalfunc
self
.
mode
=
mode
def
nonefunc
(
self
,
inputs
):
return
drop_path
(
inputs
,
self
.
p
,
self
.
training
)
def
normalfunc
(
self
,
inputs
):
with
seed
(
self
.
mode
):
return
drop_path
(
inputs
,
self
.
p
,
self
.
training
)
def
forward
(
self
,
inputs
):
return
self
.
func
(
inputs
)
@
LAYERS
.
register_module
@
LAYERS
.
register_module
class
VanillaPatchEmbedding
(
nn
.
Module
):
class
VanillaPatchEmbedding
(
nn
.
Module
):
""" 2D Image to Patch Embedding
""" 2D Image to Patch Embedding
...
...
colossalai/nn/loss/__init__.py
View file @
dceae851
...
@@ -6,6 +6,7 @@ from colossalai.nn.layer.utils import get_tensor_parallel_mode
...
@@ -6,6 +6,7 @@ from colossalai.nn.layer.utils import get_tensor_parallel_mode
from
.loss_2d
import
CrossEntropyLoss2D
from
.loss_2d
import
CrossEntropyLoss2D
from
.loss_2p5d
import
CrossEntropyLoss2p5D
from
.loss_2p5d
import
CrossEntropyLoss2p5D
from
.loss_3d
import
CrossEntropyLoss3D
from
.loss_3d
import
CrossEntropyLoss3D
from
.loss_moe
import
MoeCrossEntropyLoss
,
MoeLoss
_parallel_cross_entropy
=
{
_parallel_cross_entropy
=
{
'2d'
:
CrossEntropyLoss2D
,
'2d'
:
CrossEntropyLoss2D
,
...
...
colossalai/nn/loss/loss_moe.py
0 → 100644
View file @
dceae851
import
torch.nn
as
nn
from
colossalai.registry
import
LOSSES
from
torch.nn.modules.loss
import
_Loss
from
colossalai.global_variables
import
moe_env
@
LOSSES
.
register_module
class
MoeCrossEntropyLoss
(
_Loss
):
"""torch.nn.CrossEntropyLoss added with auxiliary loss.
"""
def
__init__
(
self
,
aux_weight
:
float
=
0.01
,
*
args
,
**
kwargs
):
super
().
__init__
()
self
.
loss
=
nn
.
CrossEntropyLoss
(
*
args
,
**
kwargs
)
self
.
aux_weight
=
aux_weight
def
forward
(
self
,
*
args
):
main_loss
=
self
.
loss
(
*
args
)
aux_loss
=
moe_env
.
get_loss
()
return
main_loss
+
self
.
aux_weight
*
aux_loss
@
LOSSES
.
register_module
class
MoeLoss
(
_Loss
):
"""A wrapper class for any loss module to add with auxiliary loss.
"""
def
__init__
(
self
,
aux_weight
:
float
,
loss_fn
,
*
args
,
**
kwargs
):
super
().
__init__
()
self
.
loss_fn
=
loss_fn
(
*
args
,
**
kwargs
)
self
.
aux_weight
=
aux_weight
def
forward
(
self
,
*
args
,
**
kwargs
):
main_loss
=
self
.
loss_fn
(
*
args
,
**
kwargs
)
aux_loss
=
moe_env
.
get_loss
()
return
main_loss
+
self
.
aux_weight
*
aux_loss
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment