Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
cb3d1bef
Unverified
Commit
cb3d1bef
authored
Feb 08, 2023
by
YuliangLiu0306
Committed by
GitHub
Feb 08, 2023
Browse files
[autoparallel] adapt autoparallel tests with latest api (#2626)
parent
c3755636
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
62 additions
and
586 deletions
+62
-586
colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py
..._shard/node_handler/strategy/matmul_strategy_generator.py
+7
-7
colossalai/auto_parallel/tensor_shard/solver/cost_graph.py
colossalai/auto_parallel/tensor_shard/solver/cost_graph.py
+0
-3
tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py
..._parallel/test_tensor_shard/test_bias_addition_forward.py
+8
-89
tests/test_auto_parallel/test_tensor_shard/test_gpt/test_gpt2_performance.py
...allel/test_tensor_shard/test_gpt/test_gpt2_performance.py
+0
-131
tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py
...st_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py
+19
-37
tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py
...st_auto_parallel/test_tensor_shard/test_metainfo/utils.py
+1
-1
tests/test_auto_parallel/test_tensor_shard/test_resnet_block_runtime.py
...o_parallel/test_tensor_shard/test_resnet_block_runtime.py
+0
-270
tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py
...parallel/test_tensor_shard/test_shape_consistency_pass.py
+27
-48
No files found.
colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py
View file @
cb3d1bef
...
@@ -247,12 +247,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
...
@@ -247,12 +247,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
strategies
.
append
(
self
.
split_rhs_space_both_contract
(
1
,
0
))
strategies
.
append
(
self
.
split_rhs_space_both_contract
(
1
,
0
))
# RR= RS x SR
# RR= RS x SR
#
strategies.append(self.recompute_split_both_contract(0))
strategies
.
append
(
self
.
recompute_split_both_contract
(
0
))
#
strategies.append(self.recompute_split_both_contract(1))
strategies
.
append
(
self
.
recompute_split_both_contract
(
1
))
#
#
RS = RR x RS
# RS = RR x RS
#
strategies.append(self.split_rhs_space_only(0))
strategies
.
append
(
self
.
split_rhs_space_only
(
0
))
#
strategies.append(self.split_rhs_space_only(1))
strategies
.
append
(
self
.
split_rhs_space_only
(
1
))
# S01R = S01R x RR
# S01R = S01R x RR
strategies
.
append
(
self
.
split_lhs_1st_dim_1d
(
0
,
1
))
strategies
.
append
(
self
.
split_lhs_1st_dim_1d
(
0
,
1
))
...
@@ -263,8 +263,8 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
...
@@ -263,8 +263,8 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# RS01 = RR x RS01
# RS01 = RR x RS01
strategies
.
append
(
self
.
split_rhs_2nd_dim_1d
(
0
,
1
))
strategies
.
append
(
self
.
split_rhs_2nd_dim_1d
(
0
,
1
))
#
#
RR = RR x RR
# RR = RR x RR
#
strategies.append(self.non_split())
strategies
.
append
(
self
.
non_split
())
return
strategies
return
strategies
...
...
colossalai/auto_parallel/tensor_shard/solver/cost_graph.py
View file @
cb3d1bef
...
@@ -62,9 +62,6 @@ class CostGraph:
...
@@ -62,9 +62,6 @@ class CostGraph:
else
:
else
:
edge_cost
[(
j
,
i
)]
=
resharding_cost_item
.
total
edge_cost
[(
j
,
i
)]
=
resharding_cost_item
.
total
self
.
edge_costs
[
node_pair
]
=
edge_cost
self
.
edge_costs
[
node_pair
]
=
edge_cost
# add parents and children attribute to node
# parent_nodes = [node for node in strategies_vector.predecessor_nodes]
# children_nodes = [node for node in strategies_vector.successor_nodes]
parent_nodes
=
[]
parent_nodes
=
[]
children_nodes
=
[]
children_nodes
=
[]
...
...
tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py
View file @
cb3d1bef
...
@@ -4,21 +4,11 @@ import pytest
...
@@ -4,21 +4,11 @@ import pytest
import
torch
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
colossalai.auto_parallel.passes.runtime_apply_pass
import
runtime_apply_pass
from
colossalai.auto_parallel.tensor_shard.initialize
import
initialize_model
from
colossalai.auto_parallel.passes.runtime_preparation_pass
import
runtime_preparation_pass
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
OperationDataType
from
colossalai.auto_parallel.tensor_shard.solver
import
(
CostGraph
,
GraphAnalyser
,
Solver
,
SolverOptions
,
StrategiesConstructor
,
)
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.initialize
import
launch
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.testing
import
assert_close
,
assert_close_loose
,
rerun_if_address_is_in_use
from
colossalai.testing
import
assert_close
,
rerun_if_address_is_in_use
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
...
@@ -63,42 +53,9 @@ def check_linear_module(rank, world_size, port):
...
@@ -63,42 +53,9 @@ def check_linear_module(rank, world_size, port):
# [[0, 1]
# [[0, 1]
# [2, 3]]
# [2, 3]]
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
tracer
=
ColoTracer
()
meta_args
=
{
'x'
:
torch
.
rand
(
4
,
4
).
to
(
'meta'
)}
# graph():
gm
=
initialize_model
(
model
,
meta_args
=
meta_args
,
device_mesh
=
device_mesh
)
# %x : torch.Tensor [#users=1] = placeholder[target=x]
output
=
gm
(
input
)
# %linear_weight : [#users=1] = get_attr[target=linear.weight]
# %linear_bias : [#users=1] = get_attr[target=linear.bias]
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %linear_weight), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {})
# %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {})
# return mul
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
{
'x'
:
torch
.
rand
(
4
,
4
).
to
(
'meta'
)})
# def forward(self, x : torch.Tensor):
# linear_weight = self.linear.weight
# linear_bias = self.linear.bias
# linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None
# add = linear + linear_bias; linear = linear_bias = None
# mul = add * 2; add = None
# return mul
gm
=
ColoGraphModule
(
model
,
graph
)
gm
.
recompile
()
node_list
=
list
(
graph
.
nodes
)
solver_options
=
SolverOptions
()
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
solver_options
)
strategies_constructor
.
build_strategies_and_cost
()
linear_node
=
node_list
[
3
]
cost_graph
=
CostGraph
(
strategies_constructor
.
leaf_strategies
)
cost_graph
.
simplify_graph
()
graph_analyser
=
GraphAnalyser
(
gm
)
solver
=
Solver
(
gm
.
graph
,
strategies_constructor
,
cost_graph
,
graph_analyser
)
ret
=
solver
.
call_solver_serialized_args
()
solution
=
list
(
ret
[
0
])
gm
,
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
=
runtime_preparation_pass
(
gm
,
solution
,
device_mesh
)
gm
=
runtime_apply_pass
(
gm
)
gm
.
recompile
()
output
=
gm
(
input
,
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
)
assert_close
(
output
,
output_compare
)
assert_close
(
output
,
output_compare
)
...
@@ -113,47 +70,9 @@ def check_conv_module(rank, world_size, port):
...
@@ -113,47 +70,9 @@ def check_conv_module(rank, world_size, port):
# [[0, 1]
# [[0, 1]
# [2, 3]]
# [2, 3]]
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
tracer
=
ColoTracer
()
meta_args
=
{
'x'
:
torch
.
rand
(
4
,
3
,
64
,
64
).
to
(
'meta'
)}
# graph():
gm
=
initialize_model
(
model
,
meta_args
=
meta_args
,
device_mesh
=
device_mesh
)
# %x : torch.Tensor [#users=1] = placeholder[target=x]
output
=
gm
(
input
)
# %conv_weight : [#users=1] = get_attr[target=conv.weight]
# %conv_bias : [#users=1] = get_attr[target=conv.bias]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {})
# %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
# %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {})
# return mul
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
{
'x'
:
torch
.
rand
(
4
,
3
,
64
,
64
).
to
(
'meta'
)})
# def forward(self, x : torch.Tensor):
# conv_weight = self.conv.weight
# conv_bias = self.conv.bias
# conv2d = torch.conv2d(x, conv_weight); x = conv_weight = None
# view = conv_bias.view([1, -1, 1, 1]); conv_bias = None
# add = conv2d + view; conv2d = view = None
# mul = add * 2; add = None
# return mul
gm
=
ColoGraphModule
(
model
,
graph
)
gm
.
recompile
()
node_list
=
list
(
graph
.
nodes
)
conv_node
=
node_list
[
3
]
solver_options
=
SolverOptions
()
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
solver_options
)
strategies_constructor
.
build_strategies_and_cost
()
cost_graph
=
CostGraph
(
strategies_constructor
.
leaf_strategies
)
cost_graph
.
simplify_graph
()
graph_analyser
=
GraphAnalyser
(
gm
)
solver
=
Solver
(
gm
.
graph
,
strategies_constructor
,
cost_graph
,
graph_analyser
)
ret
=
solver
.
call_solver_serialized_args
()
solution
=
list
(
ret
[
0
])
gm
,
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
=
runtime_preparation_pass
(
gm
,
solution
,
device_mesh
)
gm
=
runtime_apply_pass
(
gm
)
gm
.
recompile
()
output
=
gm
(
input
,
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
)
assert_close
(
output
,
output_compare
)
assert_close
(
output
,
output_compare
)
...
...
tests/test_auto_parallel/test_tensor_shard/test_gpt/test_gpt2_performance.py
deleted
100644 → 0
View file @
c3755636
import
copy
import
random
from
functools
import
partial
from
time
import
time
from
typing
import
Dict
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
psutil
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
import
transformers
from
torch.fx
import
GraphModule
from
torch.profiler
import
ProfilerActivity
,
profile
,
record_function
,
schedule
,
tensorboard_trace_handler
from
colossalai.auto_parallel.passes.runtime_apply_pass
import
runtime_apply_pass
from
colossalai.auto_parallel.passes.runtime_preparation_pass
import
runtime_preparation_pass
from
colossalai.auto_parallel.tensor_shard.constants
import
BATCHNORM_MODULE_OP
from
colossalai.auto_parallel.tensor_shard.initialize
import
autoparallelize
,
initialize_model
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
ShardingSpec
from
colossalai.auto_parallel.tensor_shard.solver
import
(
CostGraph
,
GraphAnalyser
,
Solver
,
SolverOptions
,
StrategiesConstructor
,
)
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.initialize
import
launch
,
launch_from_torch
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
,
to_global
from
colossalai.testing
import
assert_close
,
assert_close_loose
,
parameterize
,
rerun_if_address_is_in_use
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.utils
import
free_port
from
tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules
import
GPT2LMHeadModel
,
GPTLMLoss
BATCH_SIZE
=
32
SEQ_LENGTH
=
256
HIDDEN_DIM
=
16384
NUM_HEADS
=
128
NUM_LAYERS
=
4
VOCAB_SIZE
=
50257
NUM_STEPS
=
10
FP16
=
True
def
get_cpu_mem
():
return
psutil
.
Process
().
memory_info
().
rss
/
1024
**
2
def
get_gpu_mem
():
return
torch
.
cuda
.
memory_allocated
()
/
1024
**
2
def
get_mem_info
(
prefix
=
''
):
return
f
'
{
prefix
}
GPU memory usage:
{
get_gpu_mem
():.
2
f
}
MB, CPU memory usage:
{
get_cpu_mem
():.
2
f
}
MB'
def
get_tflops
(
model_numel
,
batch_size
,
seq_len
,
step_time
):
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
return
model_numel
*
batch_size
*
seq_len
*
8
/
1e12
/
(
step_time
+
1e-12
)
/
4
# Randomly Generated Data
def
get_data
(
batch_size
,
seq_len
,
vocab_size
):
input_ids
=
torch
.
randint
(
0
,
vocab_size
,
(
batch_size
,
seq_len
),
device
=
torch
.
cuda
.
current_device
())
attention_mask
=
torch
.
ones_like
(
input_ids
)
return
input_ids
,
attention_mask
def
main
():
disable_existing_loggers
()
launch_from_torch
(
config
=
{})
logger
=
get_dist_logger
()
config
=
transformers
.
GPT2Config
(
n_position
=
SEQ_LENGTH
,
n_layer
=
NUM_LAYERS
,
n_head
=
NUM_HEADS
,
n_embd
=
HIDDEN_DIM
)
if
FP16
:
model
=
GPT2LMHeadModel
(
config
=
config
).
half
().
to
(
'cuda'
)
else
:
model
=
GPT2LMHeadModel
(
config
=
config
).
to
(
'cuda'
)
global_numel
=
sum
([
p
.
numel
()
for
p
in
model
.
parameters
()])
meta_input_sample
=
{
'input_ids'
:
torch
.
zeros
((
BATCH_SIZE
,
SEQ_LENGTH
),
dtype
=
torch
.
int64
).
to
(
'meta'
),
'attention_mask'
:
torch
.
zeros
((
BATCH_SIZE
,
SEQ_LENGTH
),
dtype
=
torch
.
int64
).
to
(
'meta'
),
}
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
# [[0, 1]
# [2, 3]]
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
gm
=
initialize_model
(
model
,
meta_input_sample
,
device_mesh
)
# build criterion
criterion
=
GPTLMLoss
()
optimizer
=
torch
.
optim
.
Adam
(
gm
.
parameters
(),
lr
=
0.01
)
logger
.
info
(
get_mem_info
(
prefix
=
'After init model, '
),
ranks
=
[
0
])
get_tflops_func
=
partial
(
get_tflops
,
global_numel
,
BATCH_SIZE
,
SEQ_LENGTH
)
torch
.
cuda
.
synchronize
()
model
.
train
()
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
# schedule=schedule(wait=1, warmup=2, active=2),
# on_trace_ready=tensorboard_trace_handler(f'log/dummy_data/bs128_seq128_new'),
# record_shapes=True,
# profile_memory=True) as prof:
# with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]) as prof:
for
n
in
range
(
10
):
# we just use randomly generated data here
input_ids
,
attn_mask
=
get_data
(
BATCH_SIZE
,
SEQ_LENGTH
,
VOCAB_SIZE
)
optimizer
.
zero_grad
()
start
=
time
()
outputs
=
gm
(
input_ids
,
attn_mask
)
loss
=
criterion
(
outputs
,
input_ids
)
loss
.
backward
()
optimizer
.
step
()
# prof.step()
torch
.
cuda
.
synchronize
()
step_time
=
time
()
-
start
logger
.
info
(
f
'[
{
n
+
1
}
/
{
NUM_STEPS
}
] Loss:
{
loss
.
item
():.
3
f
}
, Step time:
{
step_time
:.
3
f
}
s, TFLOPS:
{
get_tflops_func
(
step_time
):.
3
f
}
'
,
ranks
=
[
0
])
# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=10))
torch
.
cuda
.
synchronize
()
if
__name__
==
'__main__'
:
main
()
tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py
View file @
cb3d1bef
import
copy
import
copy
import
random
import
random
from
functools
import
partial
from
functools
import
partial
from
typing
import
Dict
,
Optional
,
Tuple
,
Union
from
typing
import
Dict
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
import
torch
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
import
transformers
import
transformers
from
torch.fx
import
GraphModule
from
torch.fx
import
GraphModule
from
colossalai.auto_parallel.passes.runtime_apply_pass
import
runtime_apply_pass
from
colossalai.auto_parallel.tensor_shard.initialize
import
(
from
colossalai.auto_parallel.passes.runtime_preparation_pass
import
runtime_preparation_pass
ModuleWrapper
,
from
colossalai.auto_parallel.tensor_shard.constants
import
BATCHNORM_MODULE_OP
build_strategy_constructor
,
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
ShardingSpec
solve_solution
,
from
colossalai.auto_parallel.tensor_shard.solver
import
(
transform_to_sharded_model
,
CostGraph
,
GraphAnalyser
,
Solver
,
SolverOptions
,
StrategiesConstructor
,
)
)
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
ShardingSpec
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.initialize
import
launch
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
,
to_global
from
colossalai.tensor.shape_consistency
import
to_global
from
colossalai.testing
import
assert_close
,
assert_close_loose
,
parameterize
,
rerun_if_address_is_in_use
from
colossalai.testing
import
assert_close
,
assert_close_loose
,
parameterize
,
rerun_if_address_is_in_use
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
...
@@ -49,6 +44,7 @@ def _check_module_grad(module: torch.nn.Module, origin_param_dict: Dict[str, tor
...
@@ -49,6 +44,7 @@ def _check_module_grad(module: torch.nn.Module, origin_param_dict: Dict[str, tor
best_sharding_spec_dict
:
Dict
[
str
,
ShardingSpec
]):
best_sharding_spec_dict
:
Dict
[
str
,
ShardingSpec
]):
for
name
,
param
in
module
.
named_parameters
():
for
name
,
param
in
module
.
named_parameters
():
param_grad
=
param
.
grad
param_grad
=
param
.
grad
name
=
name
.
replace
(
'module.'
,
''
)
origin_param_grad
=
origin_param_dict
[
name
].
grad
origin_param_grad
=
origin_param_dict
[
name
].
grad
atoms
=
name
.
split
(
'.'
)
atoms
=
name
.
split
(
'.'
)
new_name
=
'_'
.
join
(
atoms
)
new_name
=
'_'
.
join
(
atoms
)
...
@@ -115,30 +111,17 @@ def check_attention_layer(rank, model_cls, world_size, port):
...
@@ -115,30 +111,17 @@ def check_attention_layer(rank, model_cls, world_size, port):
# [[0, 1]
# [[0, 1]
# [2, 3]]
# [2, 3]]
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
shape_consistency_manager
=
ShapeConsistencyManager
()
tracer
=
ColoTracer
()
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
meta_input_sample
)
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
meta_input_sample
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
gm
.
recompile
()
graph_analyser
=
GraphAnalyser
(
gm
)
strategies_constructor
=
build_strategy_constructor
(
graph
,
device_mesh
)
liveness_list
=
graph_analyser
.
liveness_analysis
()
solution
=
solve_solution
(
gm
,
strategies_constructor
,
memory_budget
=-
1
)
solver_options
=
SolverOptions
()
gm
,
sharding_spec_dicts
=
transform_to_sharded_model
(
gm
,
solution
,
device_mesh
,
strategies_constructor
)
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
solver_options
)
gm
=
ModuleWrapper
(
gm
,
*
sharding_spec_dicts
)
strategies_constructor
.
build_strategies_and_cost
()
cost_graph
=
CostGraph
(
strategies_constructor
.
leaf_strategies
)
cost_graph
.
simplify_graph
()
solver
=
Solver
(
gm
.
graph
,
strategies_constructor
,
cost_graph
,
graph_analyser
,
memory_budget
=-
1
)
ret
=
solver
.
call_solver_serialized_args
()
solution
=
list
(
ret
[
0
])
gm
,
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
=
runtime_preparation_pass
(
gm
,
solution
,
device_mesh
,
strategies_constructor
)
gm
=
runtime_apply_pass
(
gm
)
gm
.
recompile
()
nodes
=
[
strategies_vector
.
node
for
strategies_vector
in
strategies_constructor
.
leaf_strategies
]
nodes
=
[
strategies_vector
.
node
for
strategies_vector
in
strategies_constructor
.
leaf_strategies
]
best_sharding_spec_dict
=
{}
best_sharding_spec_dict
=
{}
for
index
,
node
in
enumerate
(
nodes
):
for
index
,
node
in
enumerate
(
nodes
):
...
@@ -149,7 +132,7 @@ def check_attention_layer(rank, model_cls, world_size, port):
...
@@ -149,7 +132,7 @@ def check_attention_layer(rank, model_cls, world_size, port):
origin_output
=
test_model
(
*
test_input_sample
)
origin_output
=
test_model
(
*
test_input_sample
)
torch
.
cuda
.
set_rng_state
(
cuda_rng_state
)
torch
.
cuda
.
set_rng_state
(
cuda_rng_state
)
torch
.
set_rng_state
(
cpu_rng_state
)
torch
.
set_rng_state
(
cpu_rng_state
)
output
=
gm
(
*
input_sample
,
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
)
output
=
gm
(
*
input_sample
)
assert_close
(
output
,
origin_output
,
rtol
=
1e-03
,
atol
=
1e-03
)
assert_close
(
output
,
origin_output
,
rtol
=
1e-03
,
atol
=
1e-03
)
#*******************backward starting*******************
#*******************backward starting*******************
...
@@ -174,16 +157,15 @@ def check_attention_layer(rank, model_cls, world_size, port):
...
@@ -174,16 +157,15 @@ def check_attention_layer(rank, model_cls, world_size, port):
#*******************strategy selected*******************
#*******************strategy selected*******************
if
rank
==
0
:
if
rank
==
0
:
print
(
"*******************strategy selected*******************"
)
print
(
"*******************strategy selected*******************"
)
strategies_list
=
solver
.
last_s_val
nodes
=
[
strategies_vector
.
node
for
strategies_vector
in
strategies_constructor
.
leaf_strategies
]
nodes
=
[
strategies_vector
.
node
for
strategies_vector
in
strategies_constructor
.
leaf_strategies
]
computation_cost
=
0
computation_cost
=
0
communication_cost
=
0
communication_cost
=
0
memory_cost
=
0
memory_cost
=
0
for
index
,
node
in
enumerate
(
nodes
):
for
index
,
node
in
enumerate
(
nodes
):
print
(
node
.
name
,
node
.
strategies_vector
[
s
trategies_list
[
index
]].
name
)
print
(
node
.
name
,
node
.
strategies_vector
[
s
olution
[
index
]].
name
)
computation_cost
+=
node
.
strategies_vector
[
s
trategies_list
[
index
]].
compute_cost
.
total
computation_cost
+=
node
.
strategies_vector
[
s
olution
[
index
]].
compute_cost
.
total
communication_cost
+=
node
.
strategies_vector
[
s
trategies_list
[
index
]].
communication_cost
.
total
communication_cost
+=
node
.
strategies_vector
[
s
olution
[
index
]].
communication_cost
.
total
node_memory_cost
=
node
.
strategies_vector
[
s
trategies_list
[
index
]].
memory_cost
.
total
node_memory_cost
=
node
.
strategies_vector
[
s
olution
[
index
]].
memory_cost
.
total
if
isinstance
(
node_memory_cost
,
tuple
):
if
isinstance
(
node_memory_cost
,
tuple
):
node_memory_cost
=
node_memory_cost
[
0
]
node_memory_cost
=
node_memory_cost
[
0
]
memory_cost
+=
node_memory_cost
.
activation
+
node_memory_cost
.
parameter
memory_cost
+=
node_memory_cost
.
activation
+
node_memory_cost
.
parameter
...
...
tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py
View file @
cb3d1bef
...
@@ -57,7 +57,7 @@ def mem_test_for_node_strategy(rank: int,
...
@@ -57,7 +57,7 @@ def mem_test_for_node_strategy(rank: int,
output_key
]
output_key
]
gm
,
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
=
runtime_preparation_pass
(
gm
,
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
=
runtime_preparation_pass
(
gm
,
solution
,
device_mesh
)
gm
,
solution
,
device_mesh
,
strategies_constructor
)
gm
=
runtime_apply_pass
(
gm
)
gm
=
runtime_apply_pass
(
gm
)
gm
.
recompile
()
gm
.
recompile
()
gm
:
GraphModule
gm
:
GraphModule
...
...
tests/test_auto_parallel/test_tensor_shard/test_resnet_block_runtime.py
deleted
100644 → 0
View file @
c3755636
import
copy
from
copy
import
deepcopy
from
functools
import
partial
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
from
torch.fx
import
GraphModule
from
torchvision.models
import
resnet34
,
resnet50
from
colossalai
import
device
from
colossalai.auto_parallel.passes.runtime_apply_pass
import
runtime_apply_pass
from
colossalai.auto_parallel.passes.runtime_preparation_pass
import
runtime_preparation_pass
from
colossalai.auto_parallel.tensor_shard.constants
import
*
from
colossalai.auto_parallel.tensor_shard.solver.cost_graph
import
CostGraph
from
colossalai.auto_parallel.tensor_shard.solver.graph_analysis
import
GraphAnalyser
from
colossalai.auto_parallel.tensor_shard.solver.options
import
SolverOptions
from
colossalai.auto_parallel.tensor_shard.solver.solver
import
Solver
from
colossalai.auto_parallel.tensor_shard.solver.strategies_constructor
import
StrategiesConstructor
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.testing
import
assert_close
,
assert_close_loose
,
rerun_if_address_is_in_use
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.utils
import
free_port
seed
=
128
cudnn_benchmark
=
False
cudnn_deterministic
=
True
def
conv3x3
(
in_planes
:
int
,
out_planes
:
int
,
stride
:
int
=
1
,
groups
:
int
=
1
,
dilation
:
int
=
1
)
->
nn
.
Conv2d
:
"""3x3 convolution with padding"""
return
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
dilation
,
groups
=
groups
,
bias
=
False
,
dilation
=
dilation
,
)
def
conv1x1
(
in_planes
:
int
,
out_planes
:
int
,
stride
:
int
=
1
)
->
nn
.
Conv2d
:
"""1x1 convolution"""
return
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
stride
,
bias
=
False
)
class
Bottleneck
(
nn
.
Module
):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
expansion
:
int
=
4
def
__init__
(
self
,
inplanes
:
int
,
planes
:
int
,
stride
:
int
=
1
,
downsample
=
None
,
groups
:
int
=
1
,
base_width
:
int
=
64
,
dilation
:
int
=
1
,
norm_layer
=
None
,
)
->
None
:
super
().
__init__
()
if
norm_layer
is
None
:
norm_layer
=
nn
.
BatchNorm2d
width
=
int
(
planes
*
(
base_width
/
64.0
))
*
groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self
.
conv1
=
conv1x1
(
inplanes
,
width
)
self
.
bn1
=
norm_layer
(
width
)
self
.
conv2
=
conv3x3
(
width
,
width
,
stride
,
groups
,
dilation
)
self
.
bn2
=
norm_layer
(
width
)
self
.
conv3
=
conv1x1
(
width
,
planes
*
self
.
expansion
)
self
.
bn3
=
norm_layer
(
planes
*
self
.
expansion
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
downsample
=
downsample
self
.
stride
=
stride
def
forward
(
self
,
x
):
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv3
(
out
)
out
=
self
.
bn3
(
out
)
if
self
.
downsample
is
not
None
:
identity
=
self
.
downsample
(
x
)
out
=
self
.
relu
(
out
)
return
out
def
check_apply_bottleneck
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
input
=
torch
.
rand
(
4
,
4
,
4
,
4
).
cuda
()
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
# [[0, 1]
# [2, 3]]
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
tracer
=
ColoTracer
()
model
=
Bottleneck
(
4
,
4
,
1
,
norm_layer
=
torch
.
nn
.
modules
.
batchnorm
.
BatchNorm2d
).
cuda
()
test_model
=
copy
.
deepcopy
(
model
)
test_input
=
copy
.
deepcopy
(
input
)
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %conv1 : [#users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
# %bn1 : [#users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {})
# %relu : [#users=1] = call_module[target=relu](args = (%bn1,), kwargs = {})
# %conv2 : [#users=1] = call_module[target=conv2](args = (%relu,), kwargs = {})
# %bn2 : [#users=1] = call_module[target=bn2](args = (%conv2,), kwargs = {})
# %relu_1 : [#users=1] = call_module[target=relu](args = (%bn2,), kwargs = {})
# %conv3 : [#users=1] = call_module[target=conv3](args = (%relu_1,), kwargs = {})
# %bn3 : [#users=1] = call_module[target=bn3](args = (%conv3,), kwargs = {})
# %relu_2 : [#users=1] = call_module[target=relu](args = (%bn3,), kwargs = {})
# return relu_2
input_sample
=
{
'x'
:
torch
.
rand
(
4
,
4
,
4
,
4
).
to
(
'meta'
)}
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
input_sample
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
solver_options
=
SolverOptions
()
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
solver_options
)
strategies_constructor
.
build_strategies_and_cost
()
cost_graph
=
CostGraph
(
strategies_constructor
.
leaf_strategies
)
cost_graph
.
simplify_graph
()
graph_analyser
=
GraphAnalyser
(
gm
)
solver
=
Solver
(
gm
.
graph
,
strategies_constructor
,
cost_graph
,
graph_analyser
)
ret
=
solver
.
call_solver_serialized_args
()
solution
=
list
(
ret
[
0
])
print
(
solution
)
for
index
,
node
in
enumerate
(
graph
.
nodes
):
print
(
node
.
name
,
node
.
strategies_vector
[
solution
[
index
]].
name
)
gm
,
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
=
runtime_preparation_pass
(
gm
,
solution
,
device_mesh
)
gm
=
runtime_apply_pass
(
gm
)
gm
.
recompile
()
nodes
=
[
node
for
node
in
gm
.
graph
.
nodes
]
# TODO: wrap the gm to avoid the influence of the user training code
cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
origin_output
=
test_model
(
test_input
)
torch
.
cuda
.
set_rng_state
(
cuda_rng_state
)
output
=
gm
(
input
,
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
)
assert
output
.
shape
==
origin_output
.
shape
assert_close
(
output
,
origin_output
,
rtol
=
1e-03
,
atol
=
1e-05
)
print
(
"*******************backward starting*******************"
)
cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
output
.
sum
().
backward
()
torch
.
cuda
.
set_rng_state
(
cuda_rng_state
)
origin_output
.
sum
().
backward
()
if
rank
==
0
:
print
(
f
"bn3 diff sum in rank
{
rank
}
:
{
(
gm
.
bn3
.
weight
.
grad
-
test_model
.
bn3
.
weight
.
grad
.
narrow
(
0
,
0
,
4
)).
abs
().
sum
()
}
"
)
print
(
f
"conv3 diff sum in rank
{
rank
}
:
{
(
gm
.
conv3
.
weight
.
grad
-
test_model
.
conv3
.
weight
.
grad
.
narrow
(
0
,
0
,
8
)).
abs
().
sum
()
}
"
)
print
(
f
"bn2 diff sum in rank
{
rank
}
:
{
(
gm
.
bn2
.
weight
.
grad
-
test_model
.
bn2
.
weight
.
grad
.
narrow
(
0
,
0
,
2
)).
abs
().
sum
()
}
"
)
print
(
f
"conv2 diff sum in rank
{
rank
}
:
{
(
gm
.
conv2
.
weight
.
grad
-
test_model
.
conv2
.
weight
.
grad
.
narrow
(
0
,
0
,
2
)).
abs
().
sum
()
}
"
)
print
(
f
"bn1 diff sum in rank
{
rank
}
:
{
(
gm
.
bn1
.
weight
.
grad
-
test_model
.
bn1
.
weight
.
grad
.
narrow
(
0
,
0
,
1
)).
abs
().
sum
()
}
"
)
print
(
f
"conv1 diff sum in rank
{
rank
}
:
{
(
gm
.
conv1
.
weight
.
grad
-
test_model
.
conv1
.
weight
.
grad
).
sum
()
}
"
)
assert_close_loose
(
gm
.
conv3
.
weight
.
grad
.
sum
(),
test_model
.
conv3
.
weight
.
grad
.
narrow
(
0
,
0
,
8
).
sum
())
assert_close_loose
(
gm
.
conv2
.
weight
.
grad
.
sum
(),
test_model
.
conv2
.
weight
.
grad
.
narrow
(
0
,
0
,
2
).
sum
())
assert_close_loose
(
gm
.
conv1
.
weight
.
grad
.
sum
(),
test_model
.
conv1
.
weight
.
grad
.
sum
())
if
rank
==
1
:
print
(
f
"bn3 diff sum in rank
{
rank
}
:
{
(
gm
.
bn3
.
weight
.
grad
-
test_model
.
bn3
.
weight
.
grad
.
narrow
(
0
,
4
,
4
)).
abs
().
sum
()
}
"
)
print
(
f
"conv3 diff sum in rank
{
rank
}
:
{
(
gm
.
conv3
.
weight
.
grad
-
test_model
.
conv3
.
weight
.
grad
.
narrow
(
0
,
0
,
8
)).
abs
().
sum
()
}
"
)
print
(
f
"bn2 diff sum in rank
{
rank
}
:
{
(
gm
.
bn2
.
weight
.
grad
-
test_model
.
bn2
.
weight
.
grad
.
narrow
(
0
,
2
,
2
)).
abs
().
sum
()
}
"
)
print
(
f
"conv2 diff sum in rank
{
rank
}
:
{
(
gm
.
conv2
.
weight
.
grad
-
test_model
.
conv2
.
weight
.
grad
.
narrow
(
0
,
2
,
2
)).
abs
().
sum
()
}
"
)
print
(
f
"bn1 diff sum in rank
{
rank
}
:
{
(
gm
.
bn1
.
weight
.
grad
-
test_model
.
bn1
.
weight
.
grad
.
narrow
(
0
,
1
,
1
)).
abs
().
sum
()
}
"
)
print
(
f
"conv1 diff sum in rank
{
rank
}
:
{
(
gm
.
conv1
.
weight
.
grad
-
test_model
.
conv1
.
weight
.
grad
).
sum
()
}
"
)
assert_close_loose
(
gm
.
conv3
.
weight
.
grad
.
sum
(),
test_model
.
conv3
.
weight
.
grad
.
narrow
(
0
,
0
,
8
).
sum
())
assert_close_loose
(
gm
.
conv2
.
weight
.
grad
.
sum
(),
test_model
.
conv2
.
weight
.
grad
.
narrow
(
0
,
2
,
2
).
sum
())
assert_close_loose
(
gm
.
conv1
.
weight
.
grad
.
sum
(),
test_model
.
conv1
.
weight
.
grad
.
sum
())
if
rank
==
2
:
print
(
f
"bn3 diff sum in rank
{
rank
}
:
{
(
gm
.
bn3
.
weight
.
grad
-
test_model
.
bn3
.
weight
.
grad
.
narrow
(
0
,
8
,
4
)).
abs
().
sum
()
}
"
)
print
(
f
"conv3 diff sum in rank
{
rank
}
:
{
(
gm
.
conv3
.
weight
.
grad
-
test_model
.
conv3
.
weight
.
grad
.
narrow
(
0
,
8
,
8
)).
abs
().
sum
()
}
"
)
print
(
f
"bn2 diff sum in rank
{
rank
}
:
{
(
gm
.
bn2
.
weight
.
grad
-
test_model
.
bn2
.
weight
.
grad
.
narrow
(
0
,
0
,
2
)).
abs
().
sum
()
}
"
)
print
(
f
"conv2 diff sum in rank
{
rank
}
:
{
(
gm
.
conv2
.
weight
.
grad
-
test_model
.
conv2
.
weight
.
grad
.
narrow
(
0
,
0
,
2
)).
abs
().
sum
()
}
"
)
print
(
f
"bn1 diff sum in rank
{
rank
}
:
{
(
gm
.
bn1
.
weight
.
grad
-
test_model
.
bn1
.
weight
.
grad
.
narrow
(
0
,
2
,
1
)).
abs
().
sum
()
}
"
)
print
(
f
"conv1 diff sum in rank
{
rank
}
:
{
(
gm
.
conv1
.
weight
.
grad
-
test_model
.
conv1
.
weight
.
grad
).
sum
()
}
"
)
assert_close_loose
(
gm
.
conv3
.
weight
.
grad
.
sum
(),
test_model
.
conv3
.
weight
.
grad
.
narrow
(
0
,
8
,
8
).
sum
())
assert_close_loose
(
gm
.
conv2
.
weight
.
grad
.
sum
(),
test_model
.
conv2
.
weight
.
grad
.
narrow
(
0
,
0
,
2
).
sum
())
assert_close_loose
(
gm
.
conv1
.
weight
.
grad
.
sum
(),
test_model
.
conv1
.
weight
.
grad
.
sum
())
if
rank
==
3
:
print
(
f
"bn3 diff sum in rank
{
rank
}
:
{
(
gm
.
bn3
.
weight
.
grad
-
test_model
.
bn3
.
weight
.
grad
.
narrow
(
0
,
12
,
4
)).
abs
().
sum
()
}
"
)
print
(
f
"conv3 diff sum in rank
{
rank
}
:
{
(
gm
.
conv3
.
weight
.
grad
-
test_model
.
conv3
.
weight
.
grad
.
narrow
(
0
,
8
,
8
)).
abs
().
sum
()
}
"
)
print
(
f
"bn2 diff sum in rank
{
rank
}
:
{
(
gm
.
bn2
.
weight
.
grad
-
test_model
.
bn2
.
weight
.
grad
.
narrow
(
0
,
2
,
2
)).
abs
().
sum
()
}
"
)
print
(
f
"conv2 diff sum in rank
{
rank
}
:
{
(
gm
.
conv2
.
weight
.
grad
-
test_model
.
conv2
.
weight
.
grad
.
narrow
(
0
,
2
,
2
)).
abs
().
sum
()
}
"
)
print
(
f
"bn1 diff sum in rank
{
rank
}
:
{
(
gm
.
bn1
.
weight
.
grad
-
test_model
.
bn1
.
weight
.
grad
.
narrow
(
0
,
3
,
1
)).
abs
().
sum
()
}
"
)
print
(
f
"conv1 diff sum in rank
{
rank
}
:
{
(
gm
.
conv1
.
weight
.
grad
-
test_model
.
conv1
.
weight
.
grad
).
sum
()
}
"
)
assert_close_loose
(
gm
.
conv3
.
weight
.
grad
.
sum
(),
test_model
.
conv3
.
weight
.
grad
.
narrow
(
0
,
8
,
8
).
sum
())
assert_close_loose
(
gm
.
conv2
.
weight
.
grad
.
sum
(),
test_model
.
conv2
.
weight
.
grad
.
narrow
(
0
,
2
,
2
).
sum
())
assert_close_loose
(
gm
.
conv1
.
weight
.
grad
.
sum
(),
test_model
.
conv1
.
weight
.
grad
.
sum
())
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_apply
():
world_size
=
4
run_func
=
partial
(
check_apply_bottleneck
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_apply
()
tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py
View file @
cb3d1bef
...
@@ -5,19 +5,9 @@ import pytest
...
@@ -5,19 +5,9 @@ import pytest
import
torch
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.fx
import
GraphModule
from
colossalai.auto_parallel.tensor_shard.initialize
import
initialize_model
from
colossalai.auto_parallel.passes.runtime_apply_pass
import
runtime_apply_pass
from
colossalai.auto_parallel.passes.runtime_preparation_pass
import
runtime_preparation_pass
from
colossalai.auto_parallel.tensor_shard.solver
import
(
CostGraph
,
GraphAnalyser
,
Solver
,
SolverOptions
,
StrategiesConstructor
,
)
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.initialize
import
launch
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.testing
import
assert_close
,
rerun_if_address_is_in_use
from
colossalai.testing
import
assert_close
,
rerun_if_address_is_in_use
...
@@ -41,41 +31,22 @@ def check_apply(rank, world_size, port):
...
@@ -41,41 +31,22 @@ def check_apply(rank, world_size, port):
disable_existing_loggers
()
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
input
=
torch
.
rand
(
4
,
4
,
4
,
4
).
cuda
()
input
=
torch
.
rand
(
4
,
4
,
4
,
4
).
cuda
()
test_input
=
copy
.
deepcopy
(
input
)
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
# return conv
model
=
ConvModel
(
4
,
4
).
cuda
()
test_model
=
copy
.
deepcopy
(
model
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
mesh_shape
=
(
2
,
2
)
# [[0, 1]
# [[0, 1]
# [2, 3]]
# [2, 3]]
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
meta_args
=
{
'x'
:
torch
.
rand
(
4
,
4
,
4
,
4
).
to
(
'meta'
)}
gm
=
initialize_model
(
model
,
meta_args
,
device_mesh
)
tracer
=
ColoTracer
()
output
=
gm
(
input
)
model
=
ConvModel
(
4
,
4
).
cuda
()
test_model
=
copy
.
deepcopy
(
model
)
test_input
=
copy
.
deepcopy
(
input
)
input_sample
=
{
'x'
:
torch
.
rand
(
4
,
4
,
4
,
4
).
to
(
'meta'
)}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
# return conv
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
input_sample
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
solver_options
=
SolverOptions
()
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
solver_options
)
strategies_constructor
.
build_strategies_and_cost
()
cost_graph
=
CostGraph
(
strategies_constructor
.
leaf_strategies
)
cost_graph
.
simplify_graph
()
graph_analyser
=
GraphAnalyser
(
gm
)
solver
=
Solver
(
gm
.
graph
,
strategies_constructor
,
cost_graph
,
graph_analyser
)
ret
=
solver
.
call_solver_serialized_args
()
solution
=
list
(
ret
[
0
])
gm
,
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
=
runtime_preparation_pass
(
gm
,
solution
,
device_mesh
)
gm
=
runtime_apply_pass
(
gm
)
gm
.
recompile
()
nodes
=
[
node
for
node
in
gm
.
graph
.
nodes
]
# TODO: wrap the gm to avoid the influence of the user training code
output
=
gm
(
input
,
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
)
origin_output
=
test_model
(
test_input
)
origin_output
=
test_model
(
test_input
)
assert
output
.
equal
(
origin_output
)
assert
output
.
equal
(
origin_output
)
origin_loss
=
origin_output
.
sum
()
origin_loss
=
origin_output
.
sum
()
...
@@ -84,13 +55,21 @@ def check_apply(rank, world_size, port):
...
@@ -84,13 +55,21 @@ def check_apply(rank, world_size, port):
origin_loss
.
backward
()
origin_loss
.
backward
()
loss
.
backward
()
loss
.
backward
()
grad_0
=
test_model
.
conv
.
weight
.
grad
.
narrow
(
0
,
0
,
2
)
grad_0
=
test_model
.
conv
.
weight
.
grad
.
narrow
(
0
,
0
,
1
)
grad_1
=
test_model
.
conv
.
weight
.
grad
.
narrow
(
0
,
2
,
2
)
grad_1
=
test_model
.
conv
.
weight
.
grad
.
narrow
(
0
,
1
,
1
)
grad_2
=
test_model
.
conv
.
weight
.
grad
.
narrow
(
0
,
2
,
1
)
if
rank
in
(
0
,
1
):
grad_3
=
test_model
.
conv
.
weight
.
grad
.
narrow
(
0
,
3
,
1
)
assert_close
(
gm
.
conv
.
weight
.
grad
.
data
,
grad_0
.
data
)
elif
rank
in
(
2
,
3
):
if
rank
==
0
:
assert_close
(
gm
.
conv
.
weight
.
grad
.
data
,
grad_1
.
data
)
assert_close
(
gm
.
module
.
conv
.
weight
.
grad
.
data
,
grad_0
.
data
)
elif
rank
==
1
:
assert_close
(
gm
.
module
.
conv
.
weight
.
grad
.
data
,
grad_1
.
data
)
elif
rank
==
2
:
assert_close
(
gm
.
module
.
conv
.
weight
.
grad
.
data
,
grad_2
.
data
)
elif
rank
==
3
:
assert_close
(
gm
.
module
.
conv
.
weight
.
grad
.
data
,
grad_3
.
data
)
else
:
raise
ValueError
(
f
'rank
{
rank
}
does not exist.'
)
# skip this test due to pulp not installed in CI environment
# skip this test due to pulp not installed in CI environment
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment