Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
28398f1c
Unverified
Commit
28398f1c
authored
Feb 08, 2023
by
YuliangLiu0306
Committed by
GitHub
Feb 08, 2023
Browse files
add overlap option (#2613)
parent
cb3d1bef
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
16 deletions
+32
-16
colossalai/auto_parallel/passes/runtime_preparation_pass.py
colossalai/auto_parallel/passes/runtime_preparation_pass.py
+18
-11
colossalai/auto_parallel/tensor_shard/initialize.py
colossalai/auto_parallel/tensor_shard/initialize.py
+14
-5
No files found.
colossalai/auto_parallel/passes/runtime_preparation_pass.py
View file @
28398f1c
...
@@ -352,7 +352,7 @@ def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
...
@@ -352,7 +352,7 @@ def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
return
gm
return
gm
def
_module_params_sharding
(
gm
:
torch
.
fx
.
GraphModule
,
device_mesh
:
DeviceMesh
):
def
_module_params_sharding
(
gm
:
torch
.
fx
.
GraphModule
,
device_mesh
:
DeviceMesh
,
overlap
=
False
):
"""
"""
Apply the sharding action to the module parameters and buffers following the
Apply the sharding action to the module parameters and buffers following the
instructions of solver solution.
instructions of solver solution.
...
@@ -387,15 +387,18 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
...
@@ -387,15 +387,18 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
# register hook to the parameters
# register hook to the parameters
if
operation_data
.
type
==
OperationDataType
.
PARAM
and
operation_data
.
name
==
name
and
comm_action
.
comm_type
==
CommType
.
HOOK
:
if
operation_data
.
type
==
OperationDataType
.
PARAM
and
operation_data
.
name
==
name
and
comm_action
.
comm_type
==
CommType
.
HOOK
:
def
wrapper
(
param
,
comm_spec
,
stream
):
def
wrapper
(
param
,
comm_spec
,
stream
,
overlap
):
def
hook_fn
(
grad
):
def
hook_fn
(
grad
):
with
torch
.
cuda
.
stream
(
stream
):
if
overlap
:
_all_reduce
(
grad
,
comm_spec
,
async_op
=
True
)
with
torch
.
cuda
.
stream
(
stream
):
_all_reduce
(
grad
,
comm_spec
,
async_op
=
True
)
else
:
_all_reduce
(
grad
,
comm_spec
,
async_op
=
False
)
param
.
register_hook
(
hook_fn
)
param
.
register_hook
(
hook_fn
)
wrapper
(
param
,
comm_spec_to_use
,
reduction_stream
)
wrapper
(
param
,
comm_spec_to_use
,
reduction_stream
,
overlap
=
overlap
)
sharded_buffer_dict
=
{}
sharded_buffer_dict
=
{}
# apply the sharding spec of buffers
# apply the sharding spec of buffers
...
@@ -441,15 +444,18 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
...
@@ -441,15 +444,18 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
# register hook to the parameters
# register hook to the parameters
if
isinstance
(
node
.
_meta_data
,
torch
.
nn
.
parameter
.
Parameter
)
and
comm_action
.
comm_type
==
CommType
.
HOOK
:
if
isinstance
(
node
.
_meta_data
,
torch
.
nn
.
parameter
.
Parameter
)
and
comm_action
.
comm_type
==
CommType
.
HOOK
:
def
wrapper
(
param
,
comm_spec
,
stream
):
def
wrapper
(
param
,
comm_spec
,
stream
,
overlap
):
def
hook_fn
(
grad
):
def
hook_fn
(
grad
):
with
torch
.
cuda
.
stream
(
stream
):
if
overlap
:
_all_reduce
(
grad
,
comm_spec
,
async_op
=
True
)
with
torch
.
cuda
.
stream
(
stream
):
_all_reduce
(
grad
,
comm_spec
,
async_op
=
True
)
else
:
_all_reduce
(
grad
,
comm_spec
,
async_op
=
False
)
param
.
register_hook
(
hook_fn
)
param
.
register_hook
(
hook_fn
)
wrapper
(
target
,
comm_spec_to_use
,
reduction_stream
)
wrapper
(
target
,
comm_spec_to_use
,
reduction_stream
,
overlap
=
overlap
)
return
gm
return
gm
...
@@ -463,13 +469,14 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
...
@@ -463,13 +469,14 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
def
runtime_preparation_pass
(
gm
:
torch
.
fx
.
GraphModule
,
def
runtime_preparation_pass
(
gm
:
torch
.
fx
.
GraphModule
,
solution
:
List
[
int
],
solution
:
List
[
int
],
device_mesh
:
DeviceMesh
,
device_mesh
:
DeviceMesh
,
strategies_constructor
:
StrategiesConstructor
=
None
):
strategies_constructor
:
StrategiesConstructor
=
None
,
overlap
=
False
):
gm
,
sharding_spec_convert_dict
,
origin_node_sharding_spec_dict
,
comm_actions_dict
=
_solution_annotatation
(
gm
,
sharding_spec_convert_dict
,
origin_node_sharding_spec_dict
,
comm_actions_dict
=
_solution_annotatation
(
gm
,
solution
,
strategies_constructor
)
gm
,
solution
,
strategies_constructor
)
gm
=
_size_value_converting
(
gm
,
device_mesh
)
gm
=
_size_value_converting
(
gm
,
device_mesh
)
gm
=
_node_args_converting
(
gm
,
device_mesh
)
gm
=
_node_args_converting
(
gm
,
device_mesh
)
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
# gm = implicit_comm_action_apply(gm)
# gm = implicit_comm_action_apply(gm)
gm
=
_module_params_sharding
(
gm
,
device_mesh
)
gm
=
_module_params_sharding
(
gm
,
device_mesh
,
overlap
=
overlap
)
return
gm
,
sharding_spec_convert_dict
,
origin_node_sharding_spec_dict
,
comm_actions_dict
return
gm
,
sharding_spec_convert_dict
,
origin_node_sharding_spec_dict
,
comm_actions_dict
colossalai/auto_parallel/tensor_shard/initialize.py
View file @
28398f1c
...
@@ -98,16 +98,22 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc
...
@@ -98,16 +98,22 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc
return
solution
return
solution
def
transform_to_sharded_model
(
gm
:
ColoGraphModule
,
solution
:
List
[
int
],
device_mesh
:
DeviceMesh
,
def
transform_to_sharded_model
(
gm
:
ColoGraphModule
,
strategies_constructor
:
StrategiesConstructor
):
solution
:
List
[
int
],
device_mesh
:
DeviceMesh
,
strategies_constructor
:
StrategiesConstructor
,
overlap
:
bool
=
False
):
'''
'''
This method is used to transform the original graph to the sharded graph.
This method is used to transform the original graph to the sharded graph.
The model parameters will be sharded according to the solution and the grad hooks
The model parameters will be sharded according to the solution and the grad hooks
will be added to the sharded graph using the runtime_preparation_pass.
will be added to the sharded graph using the runtime_preparation_pass.
The communication node will be added into the graph using the runtime_apply_pass.
The communication node will be added into the graph using the runtime_apply_pass.
'''
'''
gm
,
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
=
runtime_preparation_pass
(
gm
,
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
=
runtime_preparation_pass
(
gm
,
gm
,
solution
,
device_mesh
,
strategies_constructor
)
solution
,
device_mesh
,
strategies_constructor
,
overlap
=
overlap
)
gm
=
runtime_apply_pass
(
gm
)
gm
=
runtime_apply_pass
(
gm
)
gm
.
recompile
()
gm
.
recompile
()
sharding_spec_dicts
=
(
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
)
sharding_spec_dicts
=
(
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
)
...
@@ -176,6 +182,7 @@ def initialize_model(model: nn.Module,
...
@@ -176,6 +182,7 @@ def initialize_model(model: nn.Module,
meta_args
:
Dict
[
str
,
torch
.
Tensor
],
meta_args
:
Dict
[
str
,
torch
.
Tensor
],
device_mesh
:
DeviceMesh
,
device_mesh
:
DeviceMesh
,
memory_budget
:
float
=
-
1.0
,
memory_budget
:
float
=
-
1.0
,
overlap
:
bool
=
False
,
save_solver_solution
:
bool
=
False
,
save_solver_solution
:
bool
=
False
,
load_solver_solution
:
bool
=
False
,
load_solver_solution
:
bool
=
False
,
solution_path
:
str
=
None
,
solution_path
:
str
=
None
,
...
@@ -189,6 +196,8 @@ def initialize_model(model: nn.Module,
...
@@ -189,6 +196,8 @@ def initialize_model(model: nn.Module,
device_mesh: the device mesh to execute the model.
device_mesh: the device mesh to execute the model.
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
the memory budget will be infinity.
the memory budget will be infinity.
overlap(optional): the overlap is used to specify whether to overlap gradient communication and
backward computing.
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
to the solution_path.
to the solution_path.
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
...
@@ -211,7 +220,7 @@ def initialize_model(model: nn.Module,
...
@@ -211,7 +220,7 @@ def initialize_model(model: nn.Module,
if
save_solver_solution
:
if
save_solver_solution
:
torch
.
save
(
solution
,
solution_path
)
torch
.
save
(
solution
,
solution_path
)
gm
,
sharding_spec_dicts
=
transform_to_sharded_model
(
gm
,
solution
,
device_mesh
,
strategies_constructor
)
gm
,
sharding_spec_dicts
=
transform_to_sharded_model
(
gm
,
solution
,
device_mesh
,
strategies_constructor
,
overlap
)
model_to_return
=
ModuleWrapper
(
gm
,
*
sharding_spec_dicts
)
model_to_return
=
ModuleWrapper
(
gm
,
*
sharding_spec_dicts
)
if
return_solution
:
if
return_solution
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment