Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
aa0f6686
"...metrics/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "b92834c8c575dbcc24763d302117b47bf725ae46"
Unverified
Commit
aa0f6686
authored
Jan 29, 2023
by
YuliangLiu0306
Committed by
GitHub
Jan 29, 2023
Browse files
[autoparallel] accelerate gpt2 training (#2495)
parent
a360b9bc
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
21 additions
and
17 deletions
+21
-17
colossalai/auto_parallel/passes/runtime_preparation_pass.py
colossalai/auto_parallel/passes/runtime_preparation_pass.py
+8
-6
colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py
...auto_parallel/tensor_shard/node_handler/matmul_handler.py
+2
-0
colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py
..._shard/node_handler/strategy/matmul_strategy_generator.py
+7
-7
colossalai/device/device_mesh.py
colossalai/device/device_mesh.py
+1
-1
colossalai/tensor/comm_spec.py
colossalai/tensor/comm_spec.py
+3
-3
No files found.
colossalai/auto_parallel/passes/runtime_preparation_pass.py
View file @
aa0f6686
...
@@ -387,14 +387,15 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
...
@@ -387,14 +387,15 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
# register hook to the parameters
# register hook to the parameters
if
operation_data
.
type
==
OperationDataType
.
PARAM
and
operation_data
.
name
==
name
and
comm_action
.
comm_type
==
CommType
.
HOOK
:
if
operation_data
.
type
==
OperationDataType
.
PARAM
and
operation_data
.
name
==
name
and
comm_action
.
comm_type
==
CommType
.
HOOK
:
def
wrapper
(
param
,
comm_spec
):
def
wrapper
(
param
,
comm_spec
,
stream
):
def
hook_fn
(
grad
):
def
hook_fn
(
grad
):
_all_reduce
(
grad
,
comm_spec
,
async_op
=
False
)
with
torch
.
cuda
.
stream
(
stream
):
_all_reduce
(
grad
,
comm_spec
,
async_op
=
True
)
param
.
register_hook
(
hook_fn
)
param
.
register_hook
(
hook_fn
)
wrapper
(
param
,
comm_spec_to_use
)
wrapper
(
param
,
comm_spec_to_use
,
reduction_stream
)
sharded_buffer_dict
=
{}
sharded_buffer_dict
=
{}
# apply the sharding spec of buffers
# apply the sharding spec of buffers
...
@@ -440,14 +441,15 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
...
@@ -440,14 +441,15 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
# register hook to the parameters
# register hook to the parameters
if
isinstance
(
node
.
_meta_data
,
torch
.
nn
.
parameter
.
Parameter
)
and
comm_action
.
comm_type
==
CommType
.
HOOK
:
if
isinstance
(
node
.
_meta_data
,
torch
.
nn
.
parameter
.
Parameter
)
and
comm_action
.
comm_type
==
CommType
.
HOOK
:
def
wrapper
(
param
,
comm_spec
):
def
wrapper
(
param
,
comm_spec
,
stream
):
def
hook_fn
(
grad
):
def
hook_fn
(
grad
):
_all_reduce
(
grad
,
comm_spec
,
async_op
=
False
)
with
torch
.
cuda
.
stream
(
stream
):
_all_reduce
(
grad
,
comm_spec
,
async_op
=
True
)
param
.
register_hook
(
hook_fn
)
param
.
register_hook
(
hook_fn
)
wrapper
(
target
,
comm_spec_to_use
)
wrapper
(
target
,
comm_spec_to_use
,
reduction_stream
)
return
gm
return
gm
...
...
colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py
View file @
aa0f6686
...
@@ -483,4 +483,6 @@ class MatMulHandler(NodeHandler):
...
@@ -483,4 +483,6 @@ class MatMulHandler(NodeHandler):
raise
TypeError
(
raise
TypeError
(
f
"Found unexpected output type
{
type
(
output
)
}
from the recover method of BmmTransform"
)
f
"Found unexpected output type
{
type
(
output
)
}
from the recover method of BmmTransform"
)
strategies
=
recovered_stragies
strategies
=
recovered_stragies
for
index
,
strategies
in
enumerate
(
strategies
):
strategies
.
name
=
f
"
{
strategies
.
name
}
_
{
index
}
"
return
strategies
return
strategies
colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py
View file @
aa0f6686
...
@@ -247,12 +247,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
...
@@ -247,12 +247,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
strategies
.
append
(
self
.
split_rhs_space_both_contract
(
1
,
0
))
strategies
.
append
(
self
.
split_rhs_space_both_contract
(
1
,
0
))
# RR= RS x SR
# RR= RS x SR
strategies
.
append
(
self
.
recompute_split_both_contract
(
0
))
#
strategies.append(self.recompute_split_both_contract(0))
strategies
.
append
(
self
.
recompute_split_both_contract
(
1
))
#
strategies.append(self.recompute_split_both_contract(1))
# RS = RR x RS
#
#
RS = RR x RS
strategies
.
append
(
self
.
split_rhs_space_only
(
0
))
#
strategies.append(self.split_rhs_space_only(0))
strategies
.
append
(
self
.
split_rhs_space_only
(
1
))
#
strategies.append(self.split_rhs_space_only(1))
# S01R = S01R x RR
# S01R = S01R x RR
strategies
.
append
(
self
.
split_lhs_1st_dim_1d
(
0
,
1
))
strategies
.
append
(
self
.
split_lhs_1st_dim_1d
(
0
,
1
))
...
@@ -263,8 +263,8 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
...
@@ -263,8 +263,8 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# RS01 = RR x RS01
# RS01 = RR x RS01
strategies
.
append
(
self
.
split_rhs_2nd_dim_1d
(
0
,
1
))
strategies
.
append
(
self
.
split_rhs_2nd_dim_1d
(
0
,
1
))
# RR = RR x RR
#
#
RR = RR x RR
strategies
.
append
(
self
.
non_split
())
#
strategies.append(self.non_split())
return
strategies
return
strategies
...
...
colossalai/device/device_mesh.py
View file @
aa0f6686
...
@@ -98,7 +98,7 @@ class DeviceMesh:
...
@@ -98,7 +98,7 @@ class DeviceMesh:
return
DeviceMesh
(
self
.
physical_mesh_id
,
return
DeviceMesh
(
self
.
physical_mesh_id
,
tuple
(
flatten_mesh_shape
),
tuple
(
flatten_mesh_shape
),
mesh_alpha
=
[
max
(
self
.
mesh_alpha
)]
*
(
flatten_mesh_shape_size
-
1
),
mesh_alpha
=
[
max
(
self
.
mesh_alpha
)]
*
(
flatten_mesh_shape_size
-
1
),
mesh_beta
=
[
m
in
(
self
.
mesh_beta
)]
*
(
flatten_mesh_shape_size
-
1
),
mesh_beta
=
[
m
ax
(
self
.
mesh_beta
)]
*
(
flatten_mesh_shape_size
-
1
),
init_process_group
=
self
.
init_process_group
,
init_process_group
=
self
.
init_process_group
,
need_flatten
=
False
)
need_flatten
=
False
)
...
...
colossalai/tensor/comm_spec.py
View file @
aa0f6686
...
@@ -463,7 +463,7 @@ class CommSpec:
...
@@ -463,7 +463,7 @@ class CommSpec:
if
self
.
comm_pattern
==
CollectiveCommPattern
.
GATHER_FWD_SPLIT_BWD
:
if
self
.
comm_pattern
==
CollectiveCommPattern
.
GATHER_FWD_SPLIT_BWD
:
forward_communication_cost
=
self
.
device_mesh
.
all_gather_cost
(
comm_size
,
self
.
logical_process_axis
)
forward_communication_cost
=
self
.
device_mesh
.
all_gather_cost
(
comm_size
,
self
.
logical_process_axis
)
# give a tiny cost to shard
# give a tiny cost to shard
backward_communication_cost
=
10
backward_communication_cost
=
10
0
if
self
.
comm_pattern
==
CollectiveCommPattern
.
ALL2ALL_FWD_ALL2ALL_BWD
:
if
self
.
comm_pattern
==
CollectiveCommPattern
.
ALL2ALL_FWD_ALL2ALL_BWD
:
forward_communication_cost
=
self
.
device_mesh
.
all_to_all_cost
(
comm_size
,
self
.
logical_process_axis
)
forward_communication_cost
=
self
.
device_mesh
.
all_to_all_cost
(
comm_size
,
self
.
logical_process_axis
)
...
@@ -481,13 +481,13 @@ class CommSpec:
...
@@ -481,13 +481,13 @@ class CommSpec:
if
self
.
comm_pattern
==
CollectiveCommPattern
.
SPLIT_FWD_GATHER_BWD
:
if
self
.
comm_pattern
==
CollectiveCommPattern
.
SPLIT_FWD_GATHER_BWD
:
# give a tiny cost to shard
# give a tiny cost to shard
forward_communication_cost
=
10
forward_communication_cost
=
10
0
backward_communication_cost
=
self
.
device_mesh
.
all_gather_cost
(
comm_size
,
self
.
logical_process_axis
)
backward_communication_cost
=
self
.
device_mesh
.
all_gather_cost
(
comm_size
,
self
.
logical_process_axis
)
if
self
.
comm_pattern
==
CollectiveCommPattern
.
MIXGATHER_FWD_SPLIT_BWD
:
if
self
.
comm_pattern
==
CollectiveCommPattern
.
MIXGATHER_FWD_SPLIT_BWD
:
# no need for axis because all devices are used in mix_gather
# no need for axis because all devices are used in mix_gather
forward_communication_cost
=
self
.
device_mesh
.
mix_gather_cost
(
comm_size
)
forward_communication_cost
=
self
.
device_mesh
.
mix_gather_cost
(
comm_size
)
backward_communication_cost
=
10
backward_communication_cost
=
10
0
if
self
.
forward_only
:
if
self
.
forward_only
:
cost_dict
[
"forward"
]
=
forward_communication_cost
cost_dict
[
"forward"
]
=
forward_communication_cost
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment