Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
819e25d8
"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "8789850eeaf534323d4a94da892a24dee9b8420f"
Unverified
Commit
819e25d8
authored
Feb 23, 2023
by
YuliangLiu0306
Committed by
GitHub
Feb 23, 2023
Browse files
[hotfix] fix autoparallel compatibility test issues (#2754)
parent
0f392d74
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
9 deletions
+23
-9
colossalai/auto_parallel/tensor_shard/initialize.py
colossalai/auto_parallel/tensor_shard/initialize.py
+1
-0
tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py
...parallel/test_tensor_shard/test_compatibility_with_ddp.py
+11
-4
tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py
...allel/test_tensor_shard/test_compatibility_with_gemini.py
+11
-5
No files found.
colossalai/auto_parallel/tensor_shard/initialize.py
View file @
819e25d8
...
@@ -330,6 +330,7 @@ def autoparallelize(model: nn.Module,
...
@@ -330,6 +330,7 @@ def autoparallelize(model: nn.Module,
device_mesh
,
device_mesh
,
solver_preference
=
solver_preference
,
solver_preference
=
solver_preference
,
dataloader_option
=
dataloader_option
,
dataloader_option
=
dataloader_option
,
shard_option
=
shard_option
,
save_solver_solution
=
save_solver_solution
,
save_solver_solution
=
save_solver_solution
,
load_solver_solution
=
load_solver_solution
,
load_solver_solution
=
load_solver_solution
,
solution_path
=
solver_solution_path
,
solution_path
=
solver_solution_path
,
...
...
tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py
View file @
819e25d8
...
@@ -33,11 +33,15 @@ def check_compatibility_with_ddp(rank, world_size, port):
...
@@ -33,11 +33,15 @@ def check_compatibility_with_ddp(rank, world_size, port):
disable_existing_loggers
()
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model
=
MLP
(
4
).
cuda
()
model
=
MLP
(
4
).
cuda
()
input
=
torch
.
rand
(
4
,
4
).
cuda
()
if
rank
in
[
0
,
1
]:
output_compare
=
model
(
input
)
input
=
torch
.
arange
(
0
,
16
,
dtype
=
torch
.
float
).
reshape
(
4
,
4
).
cuda
()
elif
rank
in
[
2
,
3
]:
input
=
torch
.
arange
(
16
,
32
,
dtype
=
torch
.
float
).
reshape
(
4
,
4
).
cuda
()
input_compare
=
torch
.
arange
(
0
,
32
,
dtype
=
torch
.
float
).
reshape
(
8
,
4
).
cuda
()
output_compare
=
model
(
input_compare
)
loss_compare
=
output_compare
.
sum
()
loss_compare
=
output_compare
.
sum
()
loss_compare
.
backward
()
loss_compare
.
backward
()
grad_compare
=
copy
.
deepcopy
(
model
.
linear_1
.
weight
.
grad
)
grad_compare
=
copy
.
deepcopy
(
model
.
linear_1
.
weight
.
grad
/
2
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
mesh_shape
=
(
2
,
2
)
...
@@ -70,7 +74,10 @@ def check_compatibility_with_ddp(rank, world_size, port):
...
@@ -70,7 +74,10 @@ def check_compatibility_with_ddp(rank, world_size, port):
gm
=
DDP
(
gm
,
process_group
=
dp_process_group
)
gm
=
DDP
(
gm
,
process_group
=
dp_process_group
)
output
=
gm
(
input
)
output
=
gm
(
input
)
assert_close
(
output
,
output_compare
)
if
rank
in
(
0
,
1
):
assert_close
(
output
,
output_compare
.
narrow
(
0
,
0
,
4
))
else
:
assert_close
(
output
,
output_compare
.
narrow
(
0
,
4
,
4
))
print
(
f
'output on rank
{
rank
}
is correct'
)
print
(
f
'output on rank
{
rank
}
is correct'
)
loss
=
output
.
sum
()
loss
=
output
.
sum
()
...
...
tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py
View file @
819e25d8
...
@@ -37,12 +37,15 @@ def check_auto_parallel_with_gemini(rank, world_size, port):
...
@@ -37,12 +37,15 @@ def check_auto_parallel_with_gemini(rank, world_size, port):
disable_existing_loggers
()
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model
=
MLP
(
4
).
half
().
cuda
()
model
=
MLP
(
4
).
half
().
cuda
()
if
rank
in
[
0
,
1
]:
input
=
torch
.
rand
(
4
,
4
).
half
().
cuda
()
input
=
torch
.
arange
(
0
,
16
).
reshape
(
4
,
4
).
half
().
cuda
()
output_compare
=
model
(
input
)
elif
rank
in
[
2
,
3
]:
input
=
torch
.
arange
(
16
,
32
).
reshape
(
4
,
4
).
half
().
cuda
()
input_compare
=
torch
.
arange
(
0
,
32
).
reshape
(
8
,
4
).
half
().
cuda
()
output_compare
=
model
(
input_compare
)
loss_compare
=
output_compare
.
sum
()
loss_compare
=
output_compare
.
sum
()
loss_compare
.
backward
()
loss_compare
.
backward
()
grad_compare
=
copy
.
deepcopy
(
model
.
linear_1
.
weight
.
grad
)
grad_compare
=
copy
.
deepcopy
(
model
.
linear_1
.
weight
.
grad
/
2
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
mesh_shape
=
(
2
,
2
)
...
@@ -79,7 +82,10 @@ def check_auto_parallel_with_gemini(rank, world_size, port):
...
@@ -79,7 +82,10 @@ def check_auto_parallel_with_gemini(rank, world_size, port):
optimizer
=
HybridAdam
(
gm
.
parameters
(),
betas
=
(
0
,
0
))
optimizer
=
HybridAdam
(
gm
.
parameters
(),
betas
=
(
0
,
0
))
optimizer
=
zero_optim_wrapper
(
gm
,
optimizer
,
initial_scale
=
1
)
optimizer
=
zero_optim_wrapper
(
gm
,
optimizer
,
initial_scale
=
1
)
output
=
gm
(
input
)
output
=
gm
(
input
)
assert_close
(
output
,
output_compare
)
if
rank
in
(
0
,
1
):
assert_close
(
output
,
output_compare
.
narrow
(
0
,
0
,
4
))
else
:
assert_close
(
output
,
output_compare
.
narrow
(
0
,
4
,
4
))
print
(
f
'output on rank
{
rank
}
is correct'
)
print
(
f
'output on rank
{
rank
}
is correct'
)
loss
=
output
.
sum
()
loss
=
output
.
sum
()
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment