Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
ac279799
Unverified
Commit
ac279799
authored
Sep 15, 2023
by
Xuanlei Zhao
Committed by
GitHub
Sep 15, 2023
Browse files
[shardformer] add custom policy in hybrid parallel plugin (#4718)
* add custom policy * update assert
parent
451c3465
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
4 deletions
+10
-4
colossalai/booster/plugin/hybrid_parallel_plugin.py
colossalai/booster/plugin/hybrid_parallel_plugin.py
+10
-4
No files found.
colossalai/booster/plugin/hybrid_parallel_plugin.py
View file @
ac279799
...
@@ -22,6 +22,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper
...
@@ -22,6 +22,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper
from
colossalai.pipeline.schedule
import
OneForwardOneBackwardSchedule
from
colossalai.pipeline.schedule
import
OneForwardOneBackwardSchedule
from
colossalai.pipeline.stage_manager
import
PipelineStageManager
from
colossalai.pipeline.stage_manager
import
PipelineStageManager
from
colossalai.shardformer
import
ShardConfig
,
ShardFormer
from
colossalai.shardformer
import
ShardConfig
,
ShardFormer
from
colossalai.shardformer.policies.base_policy
import
Policy
from
colossalai.zero.low_level
import
LowLevelZeroOptimizer
from
colossalai.zero.low_level
import
LowLevelZeroOptimizer
from
.pp_plugin_base
import
PipelinePluginBase
from
.pp_plugin_base
import
PipelinePluginBase
...
@@ -38,13 +39,15 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
...
@@ -38,13 +39,15 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
class
HybridParallelModule
(
ModelWrapper
):
class
HybridParallelModule
(
ModelWrapper
):
def
__init__
(
self
,
module
:
Module
,
precision
:
str
,
shard_config
:
ShardConfig
,
dp_group
:
ProcessGroup
,
use_ddp
:
bool
,
def
__init__
(
self
,
module
:
Module
,
precision
:
str
,
shard_config
:
ShardConfig
,
dp_group
:
ProcessGroup
,
use_ddp
:
bool
,
ddp_config
:
dict
)
->
None
:
ddp_config
:
dict
,
custom_policy
:
Policy
)
->
None
:
self
.
stage_manager
=
shard_config
.
pipeline_stage_manager
self
.
stage_manager
=
shard_config
.
pipeline_stage_manager
self
.
dp_group
=
dp_group
self
.
dp_group
=
dp_group
shardformer
=
ShardFormer
(
shard_config
)
shardformer
=
ShardFormer
(
shard_config
)
module
,
self
.
shared_params
=
shardformer
.
optimize
(
module
)
if
custom_policy
is
not
None
:
assert
isinstance
(
custom_policy
,
object
)
module
,
self
.
shared_params
=
shardformer
.
optimize
(
module
,
policy
=
custom_policy
)
# setting process groups for shared parameters
# setting process groups for shared parameters
self
.
shared_param_process_groups
=
[]
self
.
shared_param_process_groups
=
[]
...
@@ -270,6 +273,7 @@ class HybridParallelPlugin(PipelinePluginBase):
...
@@ -270,6 +273,7 @@ class HybridParallelPlugin(PipelinePluginBase):
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -302,7 +306,8 @@ class HybridParallelPlugin(PipelinePluginBase):
...
@@ -302,7 +306,8 @@ class HybridParallelPlugin(PipelinePluginBase):
zero_bucket_size_in_m
:
int
=
12
,
zero_bucket_size_in_m
:
int
=
12
,
cpu_offload
:
bool
=
False
,
cpu_offload
:
bool
=
False
,
communication_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
communication_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
overlap_communication
:
bool
=
True
)
->
None
:
overlap_communication
:
bool
=
True
,
custom_policy
:
Policy
=
None
)
->
None
:
super
().
__init__
()
super
().
__init__
()
assert
dist
.
get_world_size
()
%
(
assert
dist
.
get_world_size
()
%
(
...
@@ -326,6 +331,7 @@ class HybridParallelPlugin(PipelinePluginBase):
...
@@ -326,6 +331,7 @@ class HybridParallelPlugin(PipelinePluginBase):
self
.
pg_mesh
=
ProcessGroupMesh
(
self
.
dp_size
,
self
.
pp_size
,
self
.
tp_size
)
self
.
pg_mesh
=
ProcessGroupMesh
(
self
.
dp_size
,
self
.
pp_size
,
self
.
tp_size
)
self
.
stage_manager
=
None
self
.
stage_manager
=
None
self
.
schedule
=
None
self
.
schedule
=
None
self
.
custom_policy
=
custom_policy
assert
zero_stage
in
(
0
,
1
,
2
)
assert
zero_stage
in
(
0
,
1
,
2
)
if
self
.
pp_size
>
1
:
if
self
.
pp_size
>
1
:
assert
num_microbatches
is
not
None
or
microbatch_size
is
not
None
,
'num_microbatches or microbatch_size must be specified when using pipeline parallelism'
assert
num_microbatches
is
not
None
or
microbatch_size
is
not
None
,
'num_microbatches or microbatch_size must be specified when using pipeline parallelism'
...
@@ -405,7 +411,7 @@ class HybridParallelPlugin(PipelinePluginBase):
...
@@ -405,7 +411,7 @@ class HybridParallelPlugin(PipelinePluginBase):
if
not
isinstance
(
model
,
ModelWrapper
):
if
not
isinstance
(
model
,
ModelWrapper
):
use_ddp
=
self
.
dp_size
>
1
and
self
.
pp_size
==
1
and
self
.
zero_stage
==
0
use_ddp
=
self
.
dp_size
>
1
and
self
.
pp_size
==
1
and
self
.
zero_stage
==
0
model
=
HybridParallelModule
(
model
,
self
.
precision
,
self
.
shard_config
,
self
.
dp_group
,
use_ddp
,
model
=
HybridParallelModule
(
model
,
self
.
precision
,
self
.
shard_config
,
self
.
dp_group
,
use_ddp
,
self
.
ddp_config
)
self
.
ddp_config
,
self
.
custom_policy
)
if
optimizer
is
not
None
and
not
isinstance
(
optimizer
,
OptimizerWrapper
):
if
optimizer
is
not
None
and
not
isinstance
(
optimizer
,
OptimizerWrapper
):
if
self
.
zero_stage
==
0
:
if
self
.
zero_stage
==
0
:
if
self
.
precision
in
[
'fp16'
,
'bf16'
]:
if
self
.
precision
in
[
'fp16'
,
'bf16'
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment