Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
ac279799
Unverified
Commit
ac279799
authored
Sep 15, 2023
by
Xuanlei Zhao
Committed by
GitHub
Sep 15, 2023
Browse files
[shardformer] add custom policy in hybrid parallel plugin (#4718)
* add custom policy * update assert
parent
451c3465
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
4 deletions
+10
-4
colossalai/booster/plugin/hybrid_parallel_plugin.py
colossalai/booster/plugin/hybrid_parallel_plugin.py
+10
-4
No files found.
colossalai/booster/plugin/hybrid_parallel_plugin.py
View file @
ac279799
...
@@ -22,6 +22,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper
...
@@ -22,6 +22,7 @@ from colossalai.interface import ModelWrapper, OptimizerWrapper
from
colossalai.pipeline.schedule
import
OneForwardOneBackwardSchedule
from
colossalai.pipeline.schedule
import
OneForwardOneBackwardSchedule
from
colossalai.pipeline.stage_manager
import
PipelineStageManager
from
colossalai.pipeline.stage_manager
import
PipelineStageManager
from
colossalai.shardformer
import
ShardConfig
,
ShardFormer
from
colossalai.shardformer
import
ShardConfig
,
ShardFormer
from
colossalai.shardformer.policies.base_policy
import
Policy
from
colossalai.zero.low_level
import
LowLevelZeroOptimizer
from
colossalai.zero.low_level
import
LowLevelZeroOptimizer
from
.pp_plugin_base
import
PipelinePluginBase
from
.pp_plugin_base
import
PipelinePluginBase
...
@@ -38,13 +39,15 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
...
@@ -38,13 +39,15 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
class
HybridParallelModule
(
ModelWrapper
):
class
HybridParallelModule
(
ModelWrapper
):
def
__init__
(
self
,
module
:
Module
,
precision
:
str
,
shard_config
:
ShardConfig
,
dp_group
:
ProcessGroup
,
use_ddp
:
bool
,
def
__init__
(
self
,
module
:
Module
,
precision
:
str
,
shard_config
:
ShardConfig
,
dp_group
:
ProcessGroup
,
use_ddp
:
bool
,
ddp_config
:
dict
)
->
None
:
ddp_config
:
dict
,
custom_policy
:
Policy
)
->
None
:
self
.
stage_manager
=
shard_config
.
pipeline_stage_manager
self
.
stage_manager
=
shard_config
.
pipeline_stage_manager
self
.
dp_group
=
dp_group
self
.
dp_group
=
dp_group
shardformer
=
ShardFormer
(
shard_config
)
shardformer
=
ShardFormer
(
shard_config
)
module
,
self
.
shared_params
=
shardformer
.
optimize
(
module
)
if
custom_policy
is
not
None
:
assert
isinstance
(
custom_policy
,
object
)
module
,
self
.
shared_params
=
shardformer
.
optimize
(
module
,
policy
=
custom_policy
)
# setting process groups for shared parameters
# setting process groups for shared parameters
self
.
shared_param_process_groups
=
[]
self
.
shared_param_process_groups
=
[]
...
@@ -270,6 +273,7 @@ class HybridParallelPlugin(PipelinePluginBase):
...
@@ -270,6 +273,7 @@ class HybridParallelPlugin(PipelinePluginBase):
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -302,7 +306,8 @@ class HybridParallelPlugin(PipelinePluginBase):
...
@@ -302,7 +306,8 @@ class HybridParallelPlugin(PipelinePluginBase):
zero_bucket_size_in_m
:
int
=
12
,
zero_bucket_size_in_m
:
int
=
12
,
cpu_offload
:
bool
=
False
,
cpu_offload
:
bool
=
False
,
communication_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
communication_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
overlap_communication
:
bool
=
True
)
->
None
:
overlap_communication
:
bool
=
True
,
custom_policy
:
Policy
=
None
)
->
None
:
super
().
__init__
()
super
().
__init__
()
assert
dist
.
get_world_size
()
%
(
assert
dist
.
get_world_size
()
%
(
...
@@ -326,6 +331,7 @@ class HybridParallelPlugin(PipelinePluginBase):
...
@@ -326,6 +331,7 @@ class HybridParallelPlugin(PipelinePluginBase):
self
.
pg_mesh
=
ProcessGroupMesh
(
self
.
dp_size
,
self
.
pp_size
,
self
.
tp_size
)
self
.
pg_mesh
=
ProcessGroupMesh
(
self
.
dp_size
,
self
.
pp_size
,
self
.
tp_size
)
self
.
stage_manager
=
None
self
.
stage_manager
=
None
self
.
schedule
=
None
self
.
schedule
=
None
self
.
custom_policy
=
custom_policy
assert
zero_stage
in
(
0
,
1
,
2
)
assert
zero_stage
in
(
0
,
1
,
2
)
if
self
.
pp_size
>
1
:
if
self
.
pp_size
>
1
:
assert
num_microbatches
is
not
None
or
microbatch_size
is
not
None
,
'num_microbatches or microbatch_size must be specified when using pipeline parallelism'
assert
num_microbatches
is
not
None
or
microbatch_size
is
not
None
,
'num_microbatches or microbatch_size must be specified when using pipeline parallelism'
...
@@ -405,7 +411,7 @@ class HybridParallelPlugin(PipelinePluginBase):
...
@@ -405,7 +411,7 @@ class HybridParallelPlugin(PipelinePluginBase):
if
not
isinstance
(
model
,
ModelWrapper
):
if
not
isinstance
(
model
,
ModelWrapper
):
use_ddp
=
self
.
dp_size
>
1
and
self
.
pp_size
==
1
and
self
.
zero_stage
==
0
use_ddp
=
self
.
dp_size
>
1
and
self
.
pp_size
==
1
and
self
.
zero_stage
==
0
model
=
HybridParallelModule
(
model
,
self
.
precision
,
self
.
shard_config
,
self
.
dp_group
,
use_ddp
,
model
=
HybridParallelModule
(
model
,
self
.
precision
,
self
.
shard_config
,
self
.
dp_group
,
use_ddp
,
self
.
ddp_config
)
self
.
ddp_config
,
self
.
custom_policy
)
if
optimizer
is
not
None
and
not
isinstance
(
optimizer
,
OptimizerWrapper
):
if
optimizer
is
not
None
and
not
isinstance
(
optimizer
,
OptimizerWrapper
):
if
self
.
zero_stage
==
0
:
if
self
.
zero_stage
==
0
:
if
self
.
precision
in
[
'fp16'
,
'bf16'
]:
if
self
.
precision
in
[
'fp16'
,
'bf16'
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment