Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
d2fc0672
Unverified
Commit
d2fc0672
authored
Oct 24, 2022
by
YuliangLiu0306
Committed by
GitHub
Oct 24, 2022
Browse files
[autoparallel] fix param hook issue in transform pass (#1755)
parent
262652c8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
85 additions
and
11 deletions
+85
-11
colossalai/fx/passes/experimental/adding_shape_consistency_pass_v2.py
...x/passes/experimental/adding_shape_consistency_pass_v2.py
+7
-3
tests/test_auto_parallel/test_tensor_shard/test_resnet_block_runtime.py
...o_parallel/test_tensor_shard/test_resnet_block_runtime.py
+78
-8
No files found.
colossalai/fx/passes/experimental/adding_shape_consistency_pass_v2.py
View file @
d2fc0672
...
@@ -70,10 +70,14 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], de
...
@@ -70,10 +70,14 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], de
comm_spec_to_use
=
comm_action
.
comm_spec
comm_spec_to_use
=
comm_action
.
comm_spec
if
operation_data
.
type
==
OperationDataType
.
PARAM
and
operation_data
.
name
==
name
and
comm_action
.
comm_type
==
CommType
.
HOOK
:
if
operation_data
.
type
==
OperationDataType
.
PARAM
and
operation_data
.
name
==
name
and
comm_action
.
comm_type
==
CommType
.
HOOK
:
def
hook_fn
(
grad
):
def
wrapper
(
param
,
comm_spec
):
_all_reduce
(
grad
,
comm_spec_to_use
)
param_sharded
.
register_hook
(
hook_fn
)
def
hook_fn
(
grad
):
_all_reduce
(
grad
,
comm_spec
)
param
.
register_hook
(
hook_fn
)
wrapper
(
param_sharded
,
comm_spec_to_use
)
sharded_buffer_dict
=
{}
sharded_buffer_dict
=
{}
for
name
,
buffer
in
target_module
.
named_buffers
():
for
name
,
buffer
in
target_module
.
named_buffers
():
...
...
tests/test_auto_parallel/test_tensor_shard/test_resnet_block_runtime.py
View file @
d2fc0672
...
@@ -171,22 +171,92 @@ def check_apply_bottleneck(rank, world_size, port):
...
@@ -171,22 +171,92 @@ def check_apply_bottleneck(rank, world_size, port):
torch
.
cuda
.
set_rng_state
(
cuda_rng_state
)
torch
.
cuda
.
set_rng_state
(
cuda_rng_state
)
origin_output
.
sum
().
backward
()
origin_output
.
sum
().
backward
()
if
rank
==
0
:
if
rank
==
0
:
print
((
gm
.
bn3
.
weight
.
grad
-
test_model
.
bn3
.
weight
.
grad
.
narrow
(
0
,
0
,
4
)).
abs
().
sum
())
print
(
print
((
gm
.
conv3
.
weight
.
grad
-
test_model
.
conv3
.
weight
.
grad
.
narrow
(
0
,
0
,
8
)).
abs
().
sum
())
f
"bn3 diff sum in rank
{
rank
}
:
{
(
gm
.
bn3
.
weight
.
grad
-
test_model
.
bn3
.
weight
.
grad
.
narrow
(
0
,
0
,
4
)).
abs
().
sum
()
}
"
)
print
(
f
"conv3 diff sum in rank
{
rank
}
:
{
(
gm
.
conv3
.
weight
.
grad
-
test_model
.
conv3
.
weight
.
grad
.
narrow
(
0
,
0
,
8
)).
abs
().
sum
()
}
"
)
print
(
f
"bn2 diff sum in rank
{
rank
}
:
{
(
gm
.
bn2
.
weight
.
grad
-
test_model
.
bn2
.
weight
.
grad
.
narrow
(
0
,
0
,
2
)).
abs
().
sum
()
}
"
)
print
(
f
"conv2 diff sum in rank
{
rank
}
:
{
(
gm
.
conv2
.
weight
.
grad
-
test_model
.
conv2
.
weight
.
grad
.
narrow
(
0
,
0
,
2
)).
abs
().
sum
()
}
"
)
print
(
f
"bn1 diff sum in rank
{
rank
}
:
{
(
gm
.
bn1
.
weight
.
grad
-
test_model
.
bn1
.
weight
.
grad
.
narrow
(
0
,
0
,
1
)).
abs
().
sum
()
}
"
)
print
(
f
"conv1 diff sum in rank
{
rank
}
:
{
(
gm
.
conv1
.
weight
.
grad
-
test_model
.
conv1
.
weight
.
grad
).
sum
()
}
"
)
assert_close_loose
(
gm
.
conv3
.
weight
.
grad
.
sum
(),
test_model
.
conv3
.
weight
.
grad
.
narrow
(
0
,
0
,
8
).
sum
())
assert_close_loose
(
gm
.
conv3
.
weight
.
grad
.
sum
(),
test_model
.
conv3
.
weight
.
grad
.
narrow
(
0
,
0
,
8
).
sum
())
assert_close_loose
(
gm
.
conv2
.
weight
.
grad
.
sum
(),
test_model
.
conv2
.
weight
.
grad
.
narrow
(
0
,
0
,
2
).
sum
())
assert_close_loose
(
gm
.
conv1
.
weight
.
grad
.
sum
(),
test_model
.
conv1
.
weight
.
grad
.
sum
())
if
rank
==
1
:
if
rank
==
1
:
print
((
gm
.
bn3
.
weight
.
grad
-
test_model
.
bn3
.
weight
.
grad
.
narrow
(
0
,
4
,
4
)).
abs
().
sum
())
print
(
print
((
gm
.
conv3
.
weight
.
grad
-
test_model
.
conv3
.
weight
.
grad
.
narrow
(
0
,
0
,
8
)).
abs
().
sum
())
f
"bn3 diff sum in rank
{
rank
}
:
{
(
gm
.
bn3
.
weight
.
grad
-
test_model
.
bn3
.
weight
.
grad
.
narrow
(
0
,
4
,
4
)).
abs
().
sum
()
}
"
)
print
(
f
"conv3 diff sum in rank
{
rank
}
:
{
(
gm
.
conv3
.
weight
.
grad
-
test_model
.
conv3
.
weight
.
grad
.
narrow
(
0
,
0
,
8
)).
abs
().
sum
()
}
"
)
print
(
f
"bn2 diff sum in rank
{
rank
}
:
{
(
gm
.
bn2
.
weight
.
grad
-
test_model
.
bn2
.
weight
.
grad
.
narrow
(
0
,
2
,
2
)).
abs
().
sum
()
}
"
)
print
(
f
"conv2 diff sum in rank
{
rank
}
:
{
(
gm
.
conv2
.
weight
.
grad
-
test_model
.
conv2
.
weight
.
grad
.
narrow
(
0
,
2
,
2
)).
abs
().
sum
()
}
"
)
print
(
f
"bn1 diff sum in rank
{
rank
}
:
{
(
gm
.
bn1
.
weight
.
grad
-
test_model
.
bn1
.
weight
.
grad
.
narrow
(
0
,
1
,
1
)).
abs
().
sum
()
}
"
)
print
(
f
"conv1 diff sum in rank
{
rank
}
:
{
(
gm
.
conv1
.
weight
.
grad
-
test_model
.
conv1
.
weight
.
grad
).
sum
()
}
"
)
assert_close_loose
(
gm
.
conv3
.
weight
.
grad
.
sum
(),
test_model
.
conv3
.
weight
.
grad
.
narrow
(
0
,
0
,
8
).
sum
())
assert_close_loose
(
gm
.
conv3
.
weight
.
grad
.
sum
(),
test_model
.
conv3
.
weight
.
grad
.
narrow
(
0
,
0
,
8
).
sum
())
assert_close_loose
(
gm
.
conv2
.
weight
.
grad
.
sum
(),
test_model
.
conv2
.
weight
.
grad
.
narrow
(
0
,
2
,
2
).
sum
())
assert_close_loose
(
gm
.
conv1
.
weight
.
grad
.
sum
(),
test_model
.
conv1
.
weight
.
grad
.
sum
())
if
rank
==
2
:
if
rank
==
2
:
print
((
gm
.
bn3
.
weight
.
grad
-
test_model
.
bn3
.
weight
.
grad
.
narrow
(
0
,
8
,
4
)).
abs
().
sum
())
print
(
print
((
gm
.
conv3
.
weight
.
grad
-
test_model
.
conv3
.
weight
.
grad
.
narrow
(
0
,
8
,
8
)).
abs
().
sum
())
f
"bn3 diff sum in rank
{
rank
}
:
{
(
gm
.
bn3
.
weight
.
grad
-
test_model
.
bn3
.
weight
.
grad
.
narrow
(
0
,
8
,
4
)).
abs
().
sum
()
}
"
)
print
(
f
"conv3 diff sum in rank
{
rank
}
:
{
(
gm
.
conv3
.
weight
.
grad
-
test_model
.
conv3
.
weight
.
grad
.
narrow
(
0
,
8
,
8
)).
abs
().
sum
()
}
"
)
print
(
f
"bn2 diff sum in rank
{
rank
}
:
{
(
gm
.
bn2
.
weight
.
grad
-
test_model
.
bn2
.
weight
.
grad
.
narrow
(
0
,
0
,
2
)).
abs
().
sum
()
}
"
)
print
(
f
"conv2 diff sum in rank
{
rank
}
:
{
(
gm
.
conv2
.
weight
.
grad
-
test_model
.
conv2
.
weight
.
grad
.
narrow
(
0
,
0
,
2
)).
abs
().
sum
()
}
"
)
print
(
f
"bn1 diff sum in rank
{
rank
}
:
{
(
gm
.
bn1
.
weight
.
grad
-
test_model
.
bn1
.
weight
.
grad
.
narrow
(
0
,
2
,
1
)).
abs
().
sum
()
}
"
)
print
(
f
"conv1 diff sum in rank
{
rank
}
:
{
(
gm
.
conv1
.
weight
.
grad
-
test_model
.
conv1
.
weight
.
grad
).
sum
()
}
"
)
assert_close_loose
(
gm
.
conv3
.
weight
.
grad
.
sum
(),
test_model
.
conv3
.
weight
.
grad
.
narrow
(
0
,
8
,
8
).
sum
())
assert_close_loose
(
gm
.
conv3
.
weight
.
grad
.
sum
(),
test_model
.
conv3
.
weight
.
grad
.
narrow
(
0
,
8
,
8
).
sum
())
assert_close_loose
(
gm
.
conv2
.
weight
.
grad
.
sum
(),
test_model
.
conv2
.
weight
.
grad
.
narrow
(
0
,
0
,
2
).
sum
())
assert_close_loose
(
gm
.
conv1
.
weight
.
grad
.
sum
(),
test_model
.
conv1
.
weight
.
grad
.
sum
())
if
rank
==
3
:
if
rank
==
3
:
print
((
gm
.
bn3
.
weight
.
grad
-
test_model
.
bn3
.
weight
.
grad
.
narrow
(
0
,
12
,
4
)).
abs
().
sum
())
print
(
print
((
gm
.
conv3
.
weight
.
grad
-
test_model
.
conv3
.
weight
.
grad
.
narrow
(
0
,
8
,
8
)).
abs
().
sum
())
f
"bn3 diff sum in rank
{
rank
}
:
{
(
gm
.
bn3
.
weight
.
grad
-
test_model
.
bn3
.
weight
.
grad
.
narrow
(
0
,
12
,
4
)).
abs
().
sum
()
}
"
)
print
(
f
"conv3 diff sum in rank
{
rank
}
:
{
(
gm
.
conv3
.
weight
.
grad
-
test_model
.
conv3
.
weight
.
grad
.
narrow
(
0
,
8
,
8
)).
abs
().
sum
()
}
"
)
print
(
f
"bn2 diff sum in rank
{
rank
}
:
{
(
gm
.
bn2
.
weight
.
grad
-
test_model
.
bn2
.
weight
.
grad
.
narrow
(
0
,
2
,
2
)).
abs
().
sum
()
}
"
)
print
(
f
"conv2 diff sum in rank
{
rank
}
:
{
(
gm
.
conv2
.
weight
.
grad
-
test_model
.
conv2
.
weight
.
grad
.
narrow
(
0
,
2
,
2
)).
abs
().
sum
()
}
"
)
print
(
f
"bn1 diff sum in rank
{
rank
}
:
{
(
gm
.
bn1
.
weight
.
grad
-
test_model
.
bn1
.
weight
.
grad
.
narrow
(
0
,
3
,
1
)).
abs
().
sum
()
}
"
)
print
(
f
"conv1 diff sum in rank
{
rank
}
:
{
(
gm
.
conv1
.
weight
.
grad
-
test_model
.
conv1
.
weight
.
grad
).
sum
()
}
"
)
assert_close_loose
(
gm
.
conv3
.
weight
.
grad
.
sum
(),
test_model
.
conv3
.
weight
.
grad
.
narrow
(
0
,
8
,
8
).
sum
())
assert_close_loose
(
gm
.
conv3
.
weight
.
grad
.
sum
(),
test_model
.
conv3
.
weight
.
grad
.
narrow
(
0
,
8
,
8
).
sum
())
assert_close_loose
(
gm
.
conv2
.
weight
.
grad
.
sum
(),
test_model
.
conv2
.
weight
.
grad
.
narrow
(
0
,
2
,
2
).
sum
())
assert_close_loose
(
gm
.
conv1
.
weight
.
grad
.
sum
(),
test_model
.
conv1
.
weight
.
grad
.
sum
())
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment