Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
7fa6be49
Unverified
Commit
7fa6be49
authored
Feb 15, 2023
by
YuliangLiu0306
Committed by
GitHub
Feb 15, 2023
Browse files
[autoparallel] test compatibility for gemini and auto parallel (#2700)
parent
d701ef81
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
212 additions
and
4 deletions
+212
-4
colossalai/auto_parallel/passes/runtime_preparation_pass.py
colossalai/auto_parallel/passes/runtime_preparation_pass.py
+6
-4
tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py
...parallel/test_tensor_shard/test_compatibility_with_ddp.py
+98
-0
tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py
...allel/test_tensor_shard/test_compatibility_with_gemini.py
+108
-0
No files found.
colossalai/auto_parallel/passes/runtime_preparation_pass.py
View file @
7fa6be49
...
@@ -377,8 +377,9 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o
...
@@ -377,8 +377,9 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o
# TODO: build a ColoParamter class to manager the distributed parameters
# TODO: build a ColoParamter class to manager the distributed parameters
# we could use .data here, because all the operations just happen before the real training
# we could use .data here, because all the operations just happen before the real training
# loop, so we don't need to track these operations in the autograd graph.
# loop, so we don't need to track these operations in the autograd graph.
param
.
data
=
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
param
=
torch
.
nn
.
Parameter
(
param
.
data
,
param
.
sharding_spec
,
target_sharding_spec
).
detach
().
clone
()
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
param
.
data
,
param
.
sharding_spec
,
target_sharding_spec
).
detach
().
clone
())
setattr
(
target_module
,
name
,
param
)
setattr
(
target_module
,
name
,
param
)
comm_actions
=
node
.
best_strategy
.
communication_actions
comm_actions
=
node
.
best_strategy
.
communication_actions
...
@@ -432,8 +433,9 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o
...
@@ -432,8 +433,9 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o
# TODO: build a ColoParamter class to manager the distributed parameters
# TODO: build a ColoParamter class to manager the distributed parameters
# we could use .data here, because all the operations just happen before the real training
# we could use .data here, because all the operations just happen before the real training
# loop, so we don't need to track these operations in the autograd graph.
# loop, so we don't need to track these operations in the autograd graph.
target
.
data
=
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
target
=
torch
.
nn
.
Parameter
(
target
.
data
,
target
.
sharding_spec
,
target_sharding_spec
).
detach
().
clone
()
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
target
.
data
,
target
.
sharding_spec
,
target_sharding_spec
).
detach
().
clone
())
assert
hasattr
(
target_module
,
atoms
[
-
1
])
assert
hasattr
(
target_module
,
atoms
[
-
1
])
setattr
(
target_module
,
atoms
[
-
1
],
target
)
setattr
(
target_module
,
atoms
[
-
1
],
target
)
...
...
tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py
0 → 100644
View file @
7fa6be49
import
copy
from
functools
import
partial
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
colossalai.auto_parallel.tensor_shard.initialize
import
initialize_model
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.testing
import
assert_close
,
rerun_if_address_is_in_use
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.utils
import
free_port
class
MLP
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_features
):
super
().
__init__
()
self
.
linear_1
=
torch
.
nn
.
Linear
(
in_features
,
4
*
in_features
,
bias
=
False
)
self
.
linear_2
=
torch
.
nn
.
Linear
(
4
*
in_features
,
in_features
,
bias
=
False
)
def
forward
(
self
,
x
):
x
=
self
.
linear_1
(
x
)
x
=
self
.
linear_2
(
x
)
return
x
def
check_compatibility_with_ddp
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model
=
MLP
(
4
).
cuda
()
input
=
torch
.
rand
(
4
,
4
).
cuda
()
output_compare
=
model
(
input
)
loss_compare
=
output_compare
.
sum
()
loss_compare
.
backward
()
grad_compare
=
copy
.
deepcopy
(
model
.
linear_1
.
weight
.
grad
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
# [[0, 1]
# [2, 3]]
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
meta_args
=
{
'x'
:
torch
.
rand
(
4
,
4
).
to
(
'meta'
)}
gm
,
solution
=
initialize_model
(
model
,
meta_args
=
meta_args
,
device_mesh
=
device_mesh
,
return_solution
=
True
,
solver_preference
=
'tp'
,
shard_option
=
'shard_last_axis'
)
msg
=
'| TP strategy combination chosen by auto-parallel solver |'
msg_length
=
len
(
msg
)
if
rank
==
0
:
print
(
'='
*
msg_length
)
print
(
msg
)
print
(
'='
*
msg_length
)
for
strategy
in
solution
:
print
(
strategy
)
print
(
'='
*
msg_length
)
dp_process_group
=
None
for
(
ranks
,
process_group_handle
)
in
device_mesh
.
process_groups_dict
[
0
]:
if
rank
in
ranks
:
dp_process_group
=
process_group_handle
assert
dp_process_group
is
not
None
gm
=
DDP
(
gm
,
process_group
=
dp_process_group
)
output
=
gm
(
input
)
assert_close
(
output
,
output_compare
)
print
(
f
'output on rank
{
rank
}
is correct'
)
loss
=
output
.
sum
()
loss
.
backward
()
if
rank
in
(
0
,
2
):
assert_close
(
gm
.
module
.
module
.
linear_1
.
weight
.
grad
,
grad_compare
.
narrow
(
0
,
0
,
8
))
if
rank
in
(
1
,
3
):
assert_close
(
gm
.
module
.
module
.
linear_1
.
weight
.
grad
,
grad_compare
.
narrow
(
0
,
8
,
8
))
print
(
f
'gradient on rank
{
rank
}
is correct'
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_compatibility_with_ddp
():
world_size
=
4
run_func
=
partial
(
check_compatibility_with_ddp
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_compatibility_with_ddp
()
tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py
0 → 100644
View file @
7fa6be49
import
copy
from
functools
import
partial
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
colossalai.auto_parallel.tensor_shard.initialize
import
initialize_model
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.parallel
import
zero_model_wrapper
,
zero_optim_wrapper
from
colossalai.tensor.process_group
import
ProcessGroup
from
colossalai.testing
import
assert_close
,
rerun_if_address_is_in_use
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.utils
import
free_port
,
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
,
post_process_colo_init_ctx
class
MLP
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_features
):
super
().
__init__
()
self
.
linear_1
=
torch
.
nn
.
Linear
(
in_features
,
4
*
in_features
,
bias
=
False
)
self
.
linear_2
=
torch
.
nn
.
Linear
(
4
*
in_features
,
in_features
,
bias
=
False
)
def
forward
(
self
,
x
):
x
=
self
.
linear_1
(
x
)
x
=
self
.
linear_2
(
x
)
return
x
def
check_auto_parallel_with_gemini
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model
=
MLP
(
4
).
half
().
cuda
()
input
=
torch
.
rand
(
4
,
4
).
half
().
cuda
()
output_compare
=
model
(
input
)
loss_compare
=
output_compare
.
sum
()
loss_compare
.
backward
()
grad_compare
=
copy
.
deepcopy
(
model
.
linear_1
.
weight
.
grad
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
# [[0, 1]
# [2, 3]]
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
meta_args
=
{
'x'
:
torch
.
rand
(
4
,
4
).
half
().
to
(
'meta'
)}
gm
,
solution
=
initialize_model
(
model
,
meta_args
=
meta_args
,
device_mesh
=
device_mesh
,
return_solution
=
True
,
solver_preference
=
'tp'
,
shard_option
=
'shard_last_axis'
)
if
rank
==
0
:
msg
=
'| TP strategy combination chosen by auto-parallel solver |'
msg_length
=
len
(
msg
)
print
(
'='
*
msg_length
)
print
(
msg
)
print
(
'='
*
msg_length
)
for
strategy
in
solution
:
print
(
strategy
)
print
(
'='
*
msg_length
)
dp_process_group
=
ProcessGroup
(
rank
=
rank
,
ranks
=
[
0
,
1
,
2
,
3
],
tp_degree
=
2
,
dp_degree
=
2
)
gemini_config
=
dict
(
strict_ddp_mode
=
False
,
device
=
get_current_device
(),
placement_policy
=
'cpu'
,
pin_memory
=
True
,
search_range_mb
=
128
)
post_process_colo_init_ctx
(
gm
,
device
=
get_current_device
(),
default_pg
=
dp_process_group
)
gm
=
zero_model_wrapper
(
gm
,
zero_stage
=
3
,
gemini_config
=
gemini_config
)
optimizer
=
HybridAdam
(
gm
.
parameters
(),
betas
=
(
0
,
0
))
optimizer
=
zero_optim_wrapper
(
gm
,
optimizer
,
initial_scale
=
1
)
output
=
gm
(
input
)
assert_close
(
output
,
output_compare
)
print
(
f
'output on rank
{
rank
}
is correct'
)
loss
=
output
.
sum
()
optimizer
.
zero_grad
()
optimizer
.
backward
(
loss
)
optimizer
.
step
()
if
rank
in
(
0
,
2
):
assert_close
(
list
(
optimizer
.
optim
.
state
.
values
())[
0
][
'exp_avg'
].
half
(),
grad_compare
.
narrow
(
0
,
0
,
8
).
flatten
())
if
rank
in
(
1
,
3
):
assert_close
(
list
(
optimizer
.
optim
.
state
.
values
())[
0
][
'exp_avg'
].
half
(),
grad_compare
.
narrow
(
0
,
8
,
8
).
flatten
())
print
(
f
'gradient on rank
{
rank
}
is correct'
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_auto_parallel_with_gemini
():
world_size
=
4
run_func
=
partial
(
check_auto_parallel_with_gemini
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_auto_parallel_with_gemini
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment