Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
291e22aa
"examples/language/vscode:/vscode.git/clone" did not exist on "a255a38f7f7bb7dc185b752f76d7aea997fe5246"
Unverified
Commit
291e22aa
authored
Jul 06, 2022
by
XYE
Committed by
GitHub
Jul 06, 2022
Browse files
[fx] temporarily used (#1215)
parent
ae7d3f49
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
81 deletions
+26
-81
colossalai/fx/passes/shard_1d_pass.py
colossalai/fx/passes/shard_1d_pass.py
+26
-81
No files found.
colossalai/fx/passes/shard_1d_pass.py
View file @
291e22aa
...
...
@@ -4,75 +4,37 @@ from torch.fx.node import Node
from
torch.fx.passes.split_module
import
split_module
import
colossalai
from
colossalai.context
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.tensor
import
ColoTensor
,
TensorSpec
,
distspec
,
ProcessGroup
,
ComputeSpec
,
ComputePattern
def
all_gather_function
(
input_
):
world_size
=
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)
rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
tensor_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
tensor_list
[
rank
]
=
input_
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
)
torch
.
distributed
.
all_gather
(
tensor_list
,
input_
,
group
=
group
)
output
=
torch
.
cat
(
tensor_list
,
dim
=-
1
).
contiguous
()
return
output
def
weight_split
(
weight
:
torch
.
nn
.
parameter
.
Parameter
,
dim
:
int
)
->
torch
.
nn
.
parameter
.
Parameter
:
"""weight_split
split a nn.Parameter
Args:
weight (torch.nn.parameter.Parameter): a torch Parameter instance
dim (int): the dimension to be sharded along with
def
all_reduce_function
(
input_
):
if
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)
==
1
:
return
input_
torch
.
distributed
.
all_reduce
(
input_
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
))
return
input_
def
weight_split
(
weight
,
dim
):
#TODO: this function will be refactored by using ColoTensor dist_spec when a stable reshaper feature is ready to use.
num_partition
=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
shape
=
weight
.
shape
length
=
shape
[
dim
]
//
num_partition
sharded_weight_list
=
[]
for
i
in
range
(
num_partition
):
sharded_weight_list
.
append
(
weight
.
narrow
(
dim
,
i
*
length
,
length
))
return
sharded_weight_list
[
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)]
def
replace_all_uses_except_replaced
(
node
,
replace_node
):
Returns:
_type_: _description_
"""
Replace all uses of ``node`` in the Graph with the Node ``replace_node``,
except the user of ``node`` is ``replace_node``.
Args:
#TODO: This func temporarily works with no materialization
# Append a Tensor spec to target_module.weight.shard
# Convert to ColoTensor: colo_tensor = ColoTensor.from_torch_tensor(tensor, spec)
# assert isinstance(weight, torch.nn.parameter.Parameter), \
# f'The type of the input tensor should be torch.nn.parameter' \
# f'Your Input tensor is {type(weight)}'
replace_node (Node): The node to replace all uses of ``node`` with.
# FIXME() I initialized a PG for this tensor. Only has TP comm group.
# we only consider the TP-only caes.
world_size
=
torch
.
distributed
.
get_world_size
()
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
Returns:
The list of Nodes on which this change was made.
"""
to_process
=
list
(
node
.
users
)
for
use_node
in
to_process
:
if
use_node
==
replace_node
:
continue
def
may_replace_node
(
n
):
if
n
==
node
:
return
replace_node
else
:
return
n
new_args
=
map_arg
(
use_node
.
args
,
may_replace_node
)
new_kwargs
=
map_arg
(
use_node
.
kwargs
,
may_replace_node
)
use_node
.
_args
=
new_args
use_node
.
_kwargs
=
new_kwargs
for
old_use
in
use_node
.
_input_nodes
.
keys
():
old_use
.
users
.
pop
(
use_node
)
use_node
.
_input_nodes
=
{}
map_arg
(
use_node
.
_args
,
lambda
n
:
use_node
.
_input_nodes
.
setdefault
(
n
))
map_arg
(
use_node
.
_kwargs
,
lambda
n
:
use_node
.
_input_nodes
.
setdefault
(
n
))
for
new_use
in
use_node
.
_input_nodes
.
keys
():
new_use
.
users
.
setdefault
(
use_node
)
return
to_process
spec
=
TensorSpec
(
distspec
.
shard
(
pg
,
[
dim
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
# As you has constructed a Spec, why not directly convert the tensor to ColoTensor.
# setattr(weight, "fx_attr", spec)
weight
.
data
=
ColoTensor
(
data
=
weight
.
data
,
spec
=
spec
)
return
weight
def
column_shard_linear_pass
(
gm
:
torch
.
fx
.
GraphModule
):
...
...
@@ -81,14 +43,10 @@ def column_shard_linear_pass(gm: torch.fx.GraphModule):
if
node
.
op
==
"call_module"
:
target_module
=
node
.
graph
.
owning_module
.
get_submodule
(
node
.
target
)
if
isinstance
(
target_module
,
torch
.
nn
.
Linear
):
target_module
.
weight
.
data
=
weight_split
(
target_module
.
weight
.
data
,
dim
=
0
)
target_module
.
weight
=
weight_split
(
target_module
.
weight
,
dim
=
0
)
if
target_module
.
bias
is
not
None
:
target_module
.
bias
.
data
=
weight_split
(
target_module
.
bias
.
data
,
dim
=
0
)
# inserting communication node after the sharded linear node
with
mod_graph
.
inserting_after
(
node
):
new_node
=
mod_graph
.
create_node
(
'call_function'
,
all_gather_function
,
args
=
(
node
,))
replace_all_uses_except_replaced
(
node
,
new_node
)
gm
.
recompile
()
return
gm
...
...
@@ -99,20 +57,7 @@ def row_shard_linear_pass(gm: torch.fx.GraphModule):
if
node
.
op
==
"call_module"
:
target_module
=
node
.
graph
.
owning_module
.
get_submodule
(
node
.
target
)
if
isinstance
(
target_module
,
torch
.
nn
.
Linear
):
target_module
.
weight
.
data
=
weight_split
(
target_module
.
weight
.
data
,
dim
=-
1
)
# insert input sharding node before the sharded linear node
with
mod_graph
.
inserting_before
(
node
):
input_node_list
=
list
(
node
.
_input_nodes
.
keys
())
assert
len
(
input_node_list
)
==
1
,
'linear forward must have and only have one input tensor.'
input_node
=
input_node_list
[
0
]
new_input_node
=
mod_graph
.
create_node
(
'call_function'
,
weight_split
,
args
=
(
input_node
,
-
1
))
replace_all_uses_except_replaced
(
input_node
,
new_input_node
)
# inserting communication node after the sharded linear node
with
mod_graph
.
inserting_after
(
node
):
new_node
=
mod_graph
.
create_node
(
'call_function'
,
all_reduce_function
,
args
=
(
node
,))
replace_all_uses_except_replaced
(
node
,
new_node
)
target_module
.
weight
=
weight_split
(
target_module
.
weight
,
dim
=-
1
)
gm
.
recompile
()
return
gm
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment