Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
1ff7d5bf
Commit
1ff7d5bf
authored
Mar 27, 2023
by
LuGY
Committed by
binmakeswell
Mar 29, 2023
Browse files
[NFC] polish colossalai/engine/gradient_handler/_moe_gradient_handler.py (#3260)
parent
204ca2f0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
46 additions
and
45 deletions
+46
-45
colossalai/engine/gradient_handler/_moe_gradient_handler.py
colossalai/engine/gradient_handler/_moe_gradient_handler.py
+46
-45
No files found.
colossalai/engine/gradient_handler/_moe_gradient_handler.py
View file @
1ff7d5bf
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context.moe_context
import
MOE_CONTEXT
from
colossalai.registry
import
GRADIENT_HANDLER
from
colossalai.core
import
global_context
as
gpc
from
colossalai.utils.moe
import
get_moe_epsize_param_dict
from
colossalai.registry
import
GRADIENT_HANDLER
from
._base_gradient_handler
import
BaseGradientHandler
from
colossalai.utils.moe
import
get_moe_epsize_param_dict
from
...context.parallel_mode
import
ParallelMode
from
.utils
import
bucket_allreduce
from
...context.parallel_mode
import
ParallelMode
from
colossalai.context.moe_context
import
MOE_CONTEXT
from
._base_gradient_handler
import
BaseGradientHandler
from
.utils
import
bucket_allreduce
@
GRADIENT_HANDLER
.
register_module
class
MoeGradientHandler
(
BaseGradientHandler
):
@
GRADIENT_HANDLER
.
register_module
"""A helper class to handle all-reduce operations in a data parallel group and
class
MoeGradientHandler
(
BaseGradientHandler
):
moe model parallel. A all-reduce collective communication will be operated in
"""A helper class to handle all-reduce operations in a data parallel group and
:func:`handle_gradient` among a data parallel group.
moe model parallel. A all-reduce collective communication will be operated in
For better performance, it bucketizes the gradients of all parameters that are
:func:`handle_gradient` among a data parallel group.
the same type to improve the efficiency of communication.
For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.
Args:
model (Module): Model where the gradients accumulate.
Args:
optimizer (Optimizer): Optimizer for updating the parameters.
model (Module): Model where the gradients accumulate.
"""
optimizer (Optimizer): Optimizer for updating the parameters.
"""
def
__init__
(
self
,
model
,
optimizer
=
None
):
super
().
__init__
(
model
,
optimizer
)
def
__init__
(
self
,
model
,
optimizer
=
None
):
super
().
__init__
(
model
,
optimizer
)
def
handle_gradient
(
self
):
"""A method running an all-reduce operation in a data parallel group.
def
handle_gradient
(
self
):
Then running an all-reduce operation for all parameters in experts
"""A method running an all-reduce operation in a data parallel group.
across moe model parallel group
Then running an all-reduce operation for all parameters in experts
"""
across moe model parallel group
global_data
=
gpc
.
data_parallel_size
"""
global_data
=
gpc
.
data_parallel_size
if
global_data
>
1
:
epsize_param_dict
=
get_moe_epsize_param_dict
(
self
.
_model
)
if
global_data
>
1
:
epsize_param_dict
=
get_moe_epsize_param_dict
(
self
.
_model
)
# epsize is 1, indicating the params are replicated among processes in data parallelism
# use the ParallelMode.DATA to get data parallel group
# epsize is 1, indicating the params are replicated among processes in data parallelism
# reduce gradients for all parameters in data parallelism
# use the ParallelMode.DATA to get data parallel group
if
1
in
epsize_param_dict
:
# reduce gradients for all parameters in data parallelism
bucket_allreduce
(
param_list
=
epsize_param_dict
[
1
],
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
if
1
in
epsize_param_dict
:
bucket_allreduce
(
param_list
=
epsize_param_dict
[
1
],
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
for
ep_size
in
epsize_param_dict
:
if
ep_size
!=
1
and
ep_size
!=
MOE_CONTEXT
.
world_size
:
for
ep_size
in
epsize_param_dict
:
bucket_allreduce
(
param_list
=
epsize_param_dict
[
ep_size
],
if
ep_size
!=
1
and
ep_size
!=
MOE_CONTEXT
.
world_size
:
group
=
MOE_CONTEXT
.
parallel_info_dict
[
ep_size
].
dp_group
)
bucket_allreduce
(
param_list
=
epsize_param_dict
[
ep_size
],
group
=
MOE_CONTEXT
.
parallel_info_dict
[
ep_size
].
dp_group
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment