Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
ca2d3f28
Unverified
Commit
ca2d3f28
authored
Jul 15, 2022
by
XYE
Committed by
GitHub
Jul 15, 2022
Browse files
[fx] Add unit test and fix bugs for transform_mlp_pass (#1299)
* add test and fix bugs * add functions back * add comments
parent
1b416864
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
114 additions
and
24 deletions
+114
-24
colossalai/fx/passes/shard_1d_pass.py
colossalai/fx/passes/shard_1d_pass.py
+55
-24
tests/test_fx/test_transform_mlp_pass.py
tests/test_fx/test_transform_mlp_pass.py
+59
-0
No files found.
colossalai/fx/passes/shard_1d_pass.py
View file @
ca2d3f28
import
torch
import
torch
from
colossalai.tensor
import
ColoTensorSpec
,
distspec
,
ProcessGroup
,
ComputeSpec
,
ComputePattern
,
ShardSpec
import
operator
import
colossalai
ELEMENTWISE_MODULE_OP
=
[
torch
.
nn
.
Dropout
,
torch
.
nn
.
ReLU
,
torch
.
nn
.
Conv1d
,
torch
.
nn
.
Conv2d
,
torch
.
nn
.
Conv3d
,
torch
.
nn
.
MaxPool1d
,
torch
.
nn
.
MaxPool2d
,
torch
.
nn
.
AvgPool1d
,
torch
.
nn
.
AvgPool2d
]
ELEMENTWISE_FUNC_OP
=
[
torch
.
add
,
operator
.
add
,
torch
.
abs
,
torch
.
cos
,
torch
.
exp
,
torch
.
mul
,
operator
.
mul
,
operator
.
floordiv
,
operator
.
truediv
,
operator
.
neg
,
torch
.
multiply
,
torch
.
nn
.
functional
.
relu
,
torch
.
nn
.
functional
.
dropout
,
torch
.
nn
.
functional
.
conv1d
,
torch
.
nn
.
functional
.
conv2d
,
torch
.
nn
.
functional
.
conv3d
,
torch
.
nn
.
functional
.
avg_pool1d
,
torch
.
nn
.
functional
.
avg_pool2d
,
torch
.
nn
.
functional
.
avg_pool3d
,
torch
.
nn
.
functional
.
max_pool1d
,
torch
.
nn
.
functional
.
max_pool2d
,
torch
.
nn
.
functional
.
max_pool3d
]
def
weight_split
(
weight
:
torch
.
Tensor
,
dim
:
int
)
->
torch
.
nn
.
parameter
.
Parameter
:
def
weight_split
(
weight
:
torch
.
nn
.
parameter
.
Parameter
,
dim
:
int
,
col_normal
:
bool
)
->
torch
.
nn
.
parameter
.
Parameter
:
"""weight_split
"""weight_split
split a nn.Parameter
split a nn.Parameter
Args:
Args:
weight (torch.nn.parameter.Parameter): a torch Parameter instance
weight (torch.nn.parameter.Parameter): a torch Parameter instance
dim (int): the dimension to be sharded along with
dim (int): the dimension to be sharded along with
col_normal(bool): col shard with gather or not
Returns:
Returns:
_type_: _description_
_type_: _description_
"""
"""
# Append a Tensor spec to target_module.weight.shard
if
col_normal
:
# Convert to ColoTensor: colo_tensor = ColoTensor.from_torch_tensor(tensor, spec)
setattr
(
weight
,
"fx_attr"
,
(
dim
,
"SHARD"
,
"TP"
,
"col_normal"
))
assert
isinstance
(
weight
,
torch
.
Tensor
),
\
else
:
f
'The type of the input tensor should be torch.nn.parameter'
\
setattr
(
weight
,
"fx_attr"
,
(
dim
,
"SHARD"
,
"TP"
,
"col_needs_many_outputs"
))
f
'Your Input tensor is
{
type
(
weight
)
}
'
# FIXME() I initialized a PG for this tensor. Only has TP comm group.
# we only consider the TP-only caes.
world_size
=
torch
.
distributed
.
get_world_size
()
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
spec
=
ColoTensorSpec
(
pg
,
ShardSpec
([
dim
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
# As you has constructed a Spec, why not directly convert the tensor to ColoTensor.
setattr
(
weight
,
"fx_attr"
,
spec
)
return
weight
return
weight
def
column_shard_linear_pass
(
gm
:
torch
.
fx
.
GraphModule
):
def
column_shard_linear_pass
(
gm
:
torch
.
fx
.
GraphModule
):
# Split all the linear module with column shard. Currently for testing only.
mod_graph
=
gm
.
graph
mod_graph
=
gm
.
graph
for
node
in
mod_graph
.
nodes
:
for
node
in
mod_graph
.
nodes
:
if
node
.
op
==
"call_module"
:
if
node
.
op
==
"call_module"
:
target_module
=
node
.
graph
.
owning_module
.
get_submodule
(
node
.
target
)
target_module
=
node
.
graph
.
owning_module
.
get_submodule
(
node
.
target
)
if
isinstance
(
target_module
,
torch
.
nn
.
Linear
):
if
isinstance
(
target_module
,
torch
.
nn
.
Linear
):
target_module
.
weight
=
weight_split
(
target_module
.
weight
,
dim
=
0
)
target_module
.
weight
=
weight_split
(
target_module
.
weight
,
dim
=
0
,
col_normal
=
False
)
if
target_module
.
bias
is
not
None
:
if
target_module
.
bias
is
not
None
:
target_module
.
bias
.
data
=
weight_split
(
target_module
.
bias
.
data
,
dim
=
0
)
target_module
.
bias
.
data
=
weight_split
(
target_module
.
bias
.
data
,
dim
=
0
,
col_normal
=
False
)
gm
.
recompile
()
gm
.
recompile
()
return
gm
return
gm
def
row_shard_linear_pass
(
gm
:
torch
.
fx
.
GraphModule
):
def
row_shard_linear_pass
(
gm
:
torch
.
fx
.
GraphModule
):
# Split all the linear module with row shard. Currently for testing only.
mod_graph
=
gm
.
graph
mod_graph
=
gm
.
graph
for
node
in
mod_graph
.
nodes
:
for
node
in
mod_graph
.
nodes
:
if
node
.
op
==
"call_module"
:
if
node
.
op
==
"call_module"
:
target_module
=
node
.
graph
.
owning_module
.
get_submodule
(
node
.
target
)
target_module
=
node
.
graph
.
owning_module
.
get_submodule
(
node
.
target
)
if
isinstance
(
target_module
,
torch
.
nn
.
Linear
):
if
isinstance
(
target_module
,
torch
.
nn
.
Linear
):
target_module
.
weight
=
weight_split
(
target_module
.
weight
,
dim
=-
1
)
target_module
.
weight
=
weight_split
(
target_module
.
weight
,
dim
=-
1
,
col_normal
=
False
)
gm
.
recompile
()
gm
.
recompile
()
return
gm
return
gm
def
transform_mlp_pass
(
gm
:
torch
.
fx
.
GraphModule
):
#TODO: add elementwise op process pass, then we can try to use column and row mixed strategy.
#TODO: Needs to handle special cases, like x = linear(x) + linear(x)
mod_graph
=
gm
.
graph
col_shard
=
True
element_op
=
[]
all_linear_name
=
[]
linear_name
=
[]
# Get the name of element wise module(torch.nn.ReLU)
# Get the name of all the linear modules and repeated linear modules
for
name
,
func
in
gm
.
named_children
():
if
not
isinstance
(
func
,
torch
.
nn
.
Linear
):
for
i
in
ELEMENTWISE_MODULE_OP
:
if
isinstance
(
func
,
i
):
element_op
.
append
(
name
)
break
else
:
if
name
in
all_linear_name
:
if
name
in
linear_name
:
linear_name
.
remove
(
name
)
else
:
all_linear_name
.
append
(
name
)
linear_name
.
append
(
name
)
# If the linear modules is called multiple times, set the dist spec as col shard
# If the module is element wise or the function/method is element wise, remains col_shard
for
node
in
mod_graph
.
nodes
:
if
node
.
target
in
linear_name
:
target_module
=
node
.
graph
.
owning_module
.
get_submodule
(
node
.
target
)
dim
=
0
if
col_shard
else
-
1
target_module
.
weight
=
weight_split
(
target_module
.
weight
,
dim
=
dim
,
col_normal
=
False
)
col_shard
=
not
col_shard
elif
node
.
target
in
all_linear_name
:
target_module
=
node
.
graph
.
owning_module
.
get_submodule
(
node
.
target
)
dim
=
0
if
col_shard
else
-
1
target_module
.
weight
=
weight_split
(
target_module
.
weight
,
dim
=
dim
,
col_normal
=
True
)
col_shard
=
not
col_shard
else
:
if
node
.
target
not
in
element_op
and
all
(
node
.
target
!=
i
for
i
in
ELEMENTWISE_FUNC_OP
):
col_shard
=
True
gm
.
recompile
()
return
gm
\ No newline at end of file
tests/test_fx/test_transform_mlp_pass.py
0 → 100644
View file @
ca2d3f28
import
torch
import
torch.nn
as
nn
import
pytest
import
colossalai
from
colossalai.fx
import
ColoTracer
from
colossalai.fx.passes.shard_1d_pass
import
transform_mlp_pass
CONFIG
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
size
=
2
,
mode
=
'1d'
)))
class
MLP
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
):
super
().
__init__
()
self
.
linear1
=
torch
.
nn
.
Linear
(
dim
,
dim
)
self
.
linear2
=
torch
.
nn
.
Linear
(
dim
,
dim
)
self
.
linear3
=
torch
.
nn
.
Linear
(
dim
,
dim
)
self
.
linear4
=
torch
.
nn
.
Linear
(
dim
,
dim
)
self
.
dropout
=
torch
.
nn
.
Dropout
()
self
.
relu
=
torch
.
nn
.
ReLU
()
def
forward
(
self
,
x
):
x
=
self
.
relu
(
self
.
linear1
(
x
))
x
=
self
.
dropout
(
self
.
relu
(
self
.
linear2
(
x
)))
x
=
self
.
linear3
(
x
)
x
=
torch
.
nn
.
functional
.
relu
(
self
.
linear4
(
x
))
return
x
def
test_out_acc
():
model
=
MLP
(
16
).
cuda
()
model
.
eval
()
input_tensor
=
torch
.
rand
(
2
,
16
).
cuda
()
output
=
model
(
input_tensor
)
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
'x'
:
torch
.
randn
((
2
,
16
),
device
=
"meta"
)})
gm
=
torch
.
fx
.
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
splitted_gm
=
transform_mlp_pass
(
gm
)
new_output
=
splitted_gm
(
input_tensor
)
assert
output
.
equal
(
new_output
)
def
test_linear_acc
():
input_tensor
=
torch
.
rand
(
2
,
16
).
cuda
()
model
=
MLP
(
16
).
cuda
()
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
'x'
:
torch
.
randn
((
2
,
16
),
device
=
"meta"
)})
gm
=
torch
.
fx
.
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
splitted_gm
=
transform_mlp_pass
(
gm
)
col_shard
=
True
for
node
in
splitted_gm
.
graph
.
nodes
:
if
node
.
op
==
"call_module"
and
isinstance
(
node
.
graph
.
owning_module
.
get_submodule
(
node
.
target
),
torch
.
nn
.
Linear
):
target_module
=
node
.
graph
.
owning_module
.
get_submodule
(
node
.
target
)
dim
=
0
if
col_shard
else
-
1
assert
target_module
.
weight
.
fx_attr
==
(
dim
,
"SHARD"
,
"TP"
,
"col_needs_many_outputs"
)
col_shard
=
not
col_shard
if
__name__
==
"__main__"
:
torch
.
manual_seed
(
1
)
torch
.
cuda
.
manual_seed
(
1
)
# colossalai.launch_from_torch(config=CONFIG)
test_out_acc
()
test_linear_acc
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment