Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e532679c
Commit
e532679c
authored
Jan 10, 2023
by
oahzxl
Browse files
Merge branch 'main' of
https://github.com/oahzxl/ColossalAI
into chunk
parents
c1492e50
7d5640b9
Changes
441
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
921 additions
and
137 deletions
+921
-137
colossalai/device/__init__.py
colossalai/device/__init__.py
+4
-0
colossalai/device/alpha_beta_profiler.py
colossalai/device/alpha_beta_profiler.py
+386
-0
colossalai/device/calc_pipeline_strategy.py
colossalai/device/calc_pipeline_strategy.py
+127
-0
colossalai/device/device_mesh.py
colossalai/device/device_mesh.py
+56
-4
colossalai/fx/__init__.py
colossalai/fx/__init__.py
+4
-4
colossalai/fx/_meta_registrations.py
colossalai/fx/_meta_registrations.py
+104
-3
colossalai/fx/codegen/activation_checkpoint_codegen.py
colossalai/fx/codegen/activation_checkpoint_codegen.py
+59
-48
colossalai/fx/passes/__init__.py
colossalai/fx/passes/__init__.py
+2
-2
colossalai/fx/passes/adding_split_node_pass.py
colossalai/fx/passes/adding_split_node_pass.py
+42
-3
colossalai/fx/passes/algorithms/ckpt_solver_chen.py
colossalai/fx/passes/algorithms/ckpt_solver_chen.py
+3
-1
colossalai/fx/passes/algorithms/ckpt_solver_rotor.py
colossalai/fx/passes/algorithms/ckpt_solver_rotor.py
+11
-9
colossalai/fx/passes/concrete_info_prop.py
colossalai/fx/passes/concrete_info_prop.py
+10
-9
colossalai/fx/passes/meta_info_prop.py
colossalai/fx/passes/meta_info_prop.py
+53
-9
colossalai/fx/passes/split_module.py
colossalai/fx/passes/split_module.py
+36
-20
colossalai/fx/passes/utils.py
colossalai/fx/passes/utils.py
+2
-9
colossalai/fx/profiler/__init__.py
colossalai/fx/profiler/__init__.py
+8
-2
colossalai/fx/profiler/dataflow.py
colossalai/fx/profiler/dataflow.py
+5
-5
colossalai/fx/profiler/experimental/__init__.py
colossalai/fx/profiler/experimental/__init__.py
+1
-1
colossalai/fx/profiler/experimental/profiler.py
colossalai/fx/profiler/experimental/profiler.py
+8
-8
colossalai/fx/profiler/experimental/shard_utils.py
colossalai/fx/profiler/experimental/shard_utils.py
+0
-0
No files found.
Too many changes to show.
To preserve performance only
441 of 441+
files are displayed.
Plain diff
Email patch
colossalai/device/__init__.py
View file @
e532679c
from
.alpha_beta_profiler
import
AlphaBetaProfiler
from
.calc_pipeline_strategy
import
alpa_dp
__all__
=
[
'AlphaBetaProfiler'
,
'alpa_dp'
]
colossalai/device/alpha_beta_profiler.py
0 → 100644
View file @
e532679c
import
math
import
time
from
typing
import
Dict
,
List
,
Tuple
import
torch
import
torch.distributed
as
dist
from
colossalai.logging
import
get_dist_logger
GB
=
int
((
1
<<
30
))
BYTE
=
4
FRAMEWORK_LATENCY
=
0
class
AlphaBetaProfiler
:
'''
Profile alpha and beta value for a given device list.
Usage:
# Note: the environment of execution is supposed to be
# multi-process with multi-gpu in mpi style.
>>> physical_devices = [0, 1, 4, 5]
>>> ab_profiler = AlphaBetaProfiler(physical_devices)
>>> ab_dict = profiler.alpha_beta_dict
>>> print(ab_dict)
{(0, 1): (1.9641406834125518e-05, 4.74049549614719e-12), (0, 4): (1.9506998360157013e-05, 6.97421973297474e-11), (0, 5): (2.293858677148819e-05, 7.129930361393644e-11),
(1, 4): (1.9010603427886962e-05, 7.077968863788975e-11), (1, 5): (1.9807778298854827e-05, 6.928845708992215e-11), (4, 5): (1.8681809306144713e-05, 4.7522367291330524e-12),
(1, 0): (1.9641406834125518e-05, 4.74049549614719e-12), (4, 0): (1.9506998360157013e-05, 6.97421973297474e-11), (5, 0): (2.293858677148819e-05, 7.129930361393644e-11),
(4, 1): (1.9010603427886962e-05, 7.077968863788975e-11), (5, 1): (1.9807778298854827e-05, 6.928845708992215e-11), (5, 4): (1.8681809306144713e-05, 4.7522367291330524e-12)}
'''
def
__init__
(
self
,
physical_devices
:
List
[
int
],
alpha_beta_dict
:
Dict
[
Tuple
[
int
,
int
],
Tuple
[
float
,
float
]]
=
None
,
ctype
:
str
=
'a'
,
warmup
:
int
=
5
,
repeat
:
int
=
25
,
latency_iters
:
int
=
5
,
homogeneous_tolerance
:
float
=
0.1
):
'''
Args:
physical_devices: A list of device id, each element inside it is the global rank of that device.
alpha_beta_dict: A dict which maps a process group to alpha-beta value pairs.
ctype: 'a' for all-reduce, 'b' for broadcast.
warmup: Number of warmup iterations.
repeat: Number of iterations to measure.
latency_iters: Number of iterations to measure latency.
'''
self
.
physical_devices
=
physical_devices
self
.
ctype
=
ctype
self
.
world_size
=
len
(
physical_devices
)
self
.
warmup
=
warmup
self
.
repeat
=
repeat
self
.
latency_iters
=
latency_iters
self
.
homogeneous_tolerance
=
homogeneous_tolerance
self
.
process_group_dict
=
None
self
.
_init_profiling
()
if
alpha_beta_dict
is
None
:
self
.
alpha_beta_dict
=
self
.
profile_ab
()
else
:
self
.
alpha_beta_dict
=
alpha_beta_dict
def
_init_profiling
(
self
):
# Create process group list based on its global rank
process_group_list
=
[]
for
f_index
in
range
(
self
.
world_size
-
1
):
for
b_index
in
range
(
f_index
+
1
,
self
.
world_size
):
process_group_list
.
append
((
self
.
physical_devices
[
f_index
],
self
.
physical_devices
[
b_index
]))
# Create process group dict which maps process group to its handler
process_group_dict
=
{}
for
process_group
in
process_group_list
:
pg_handler
=
dist
.
new_group
(
process_group
)
process_group_dict
[
process_group
]
=
pg_handler
self
.
process_group_dict
=
process_group_dict
def
_profile
(
self
,
process_group
,
pg_handler
,
nbytes
):
logger
=
get_dist_logger
()
rank
=
dist
.
get_rank
()
src_device_num
=
process_group
[
0
]
world_size
=
len
(
process_group
)
device
=
torch
.
cuda
.
current_device
()
buf
=
torch
.
randn
(
nbytes
//
4
).
to
(
device
)
torch
.
cuda
.
synchronize
()
# warmup
for
_
in
range
(
self
.
warmup
):
if
self
.
ctype
==
"a"
:
dist
.
all_reduce
(
buf
,
op
=
dist
.
ReduceOp
.
SUM
,
group
=
pg_handler
)
elif
self
.
ctype
==
"b"
:
dist
.
broadcast
(
buf
,
src
=
src_device_num
,
group
=
pg_handler
)
torch
.
cuda
.
synchronize
()
dist
.
barrier
(
group
=
pg_handler
)
begin
=
time
.
perf_counter
()
for
_
in
range
(
self
.
repeat
):
if
self
.
ctype
==
"a"
:
dist
.
all_reduce
(
buf
,
op
=
dist
.
ReduceOp
.
SUM
,
group
=
pg_handler
)
elif
self
.
ctype
==
"b"
:
dist
.
broadcast
(
buf
,
src
=
src_device_num
,
group
=
pg_handler
)
torch
.
cuda
.
synchronize
()
end
=
time
.
perf_counter
()
dist
.
barrier
(
group
=
pg_handler
)
if
rank
==
src_device_num
:
avg_time_s
=
(
end
-
begin
)
/
self
.
repeat
-
FRAMEWORK_LATENCY
alg_band
=
nbytes
/
avg_time_s
if
self
.
ctype
==
"a"
:
# convert the bandwidth of all-reduce algorithm to the bandwidth of the hardware.
bus_band
=
2
*
(
world_size
-
1
)
/
world_size
*
alg_band
bus_band
=
alg_band
elif
self
.
ctype
==
"b"
:
bus_band
=
alg_band
logger
.
info
(
f
"GPU:
{
rank
}
, Bytes:
{
nbytes
}
B,Time:
{
round
(
avg_time_s
*
1e6
,
2
)
}
us, Bus bandwidth:
{
round
(
bus_band
/
GB
,
2
)
}
GB/s"
)
return
(
avg_time_s
,
alg_band
)
else
:
# Just a placeholder
return
(
None
,
None
)
def
profile_latency
(
self
,
process_group
,
pg_handler
):
'''
This function is used to profile the latency of the given process group with a series of bytes.
Args:
process_group: A tuple of global rank of the process group.
pg_handler: The handler of the process group.
Returns:
latency: None if the latency is not measured, otherwise the median of the latency_list.
'''
latency_list
=
[]
for
i
in
range
(
self
.
latency_iters
):
nbytes
=
int
(
BYTE
<<
i
)
(
t
,
_
)
=
self
.
_profile
(
process_group
,
pg_handler
,
nbytes
)
latency_list
.
append
(
t
)
if
latency_list
[
0
]
is
None
:
latency
=
None
else
:
median_index
=
math
.
floor
(
self
.
latency_iters
/
2
)
latency
=
latency_list
[
median_index
]
return
latency
def
profile_bandwidth
(
self
,
process_group
,
pg_handler
,
maxbytes
=
(
1
*
GB
)):
'''
This function is used to profile the bandwidth of the given process group.
Args:
process_group: A tuple of global rank of the process group.
pg_handler: The handler of the process group.
'''
(
_
,
bandwidth
)
=
self
.
_profile
(
process_group
,
pg_handler
,
maxbytes
)
return
bandwidth
def
profile_ab
(
self
):
'''
This method is used to profiling the alpha and beta value for a given device list.
Returns:
alpha_beta_dict: A dict which maps process group to its alpha and beta value.
'''
alpha_beta_dict
:
Dict
[
Tuple
[
int
],
Tuple
[
float
]]
=
{}
rank
=
dist
.
get_rank
()
global_pg_handler
=
dist
.
new_group
(
self
.
physical_devices
)
def
get_max_nbytes
(
process_group
:
Tuple
[
int
],
pg_handler
:
dist
.
ProcessGroup
):
assert
rank
in
process_group
device
=
torch
.
cuda
.
current_device
()
rank_max_nbytes
=
torch
.
cuda
.
mem_get_info
(
device
)[
0
]
rank_max_nbytes
=
torch
.
tensor
(
rank_max_nbytes
,
device
=
device
)
dist
.
all_reduce
(
rank_max_nbytes
,
op
=
dist
.
ReduceOp
.
MIN
,
group
=
pg_handler
)
max_nbytes
=
min
(
int
(
1
*
GB
),
int
(
GB
<<
int
(
math
.
log2
(
rank_max_nbytes
.
item
()
/
GB
))))
return
max_nbytes
for
process_group
,
pg_handler
in
self
.
process_group_dict
.
items
():
if
rank
not
in
process_group
:
max_nbytes
=
None
alpha
=
None
bandwidth
=
None
else
:
max_nbytes
=
get_max_nbytes
(
process_group
,
pg_handler
)
alpha
=
self
.
profile_latency
(
process_group
,
pg_handler
)
bandwidth
=
self
.
profile_bandwidth
(
process_group
,
pg_handler
,
maxbytes
=
max_nbytes
)
if
bandwidth
is
None
:
beta
=
None
else
:
beta
=
1
/
bandwidth
broadcast_list
=
[
alpha
,
beta
]
dist
.
broadcast_object_list
(
broadcast_list
,
src
=
process_group
[
0
])
alpha_beta_dict
[
process_group
]
=
tuple
(
broadcast_list
)
# add symmetry pair to the apha_beta_dict
symmetry_ab_dict
=
{}
for
process_group
,
alpha_beta_pair
in
alpha_beta_dict
.
items
():
symmetry_process_group
=
(
process_group
[
1
],
process_group
[
0
])
symmetry_ab_dict
[
symmetry_process_group
]
=
alpha_beta_pair
alpha_beta_dict
.
update
(
symmetry_ab_dict
)
return
alpha_beta_dict
def
search_best_logical_mesh
(
self
):
'''
This method is used to search the best logical mesh for the given device list.
The best logical mesh is searched in following steps:
1. detect homogeneous device groups, we assume that the devices in the alpha_beta_dict
are homogeneous if the beta value is close enough.
2. Find the best homogeneous device group contains all the physical devices. The best homogeneous
device group means the lowest beta value in the groups which contains all the physical devices.
And the reason we require the group contains all the physical devices is that the devices not in
the group will decrease the bandwidth of the group.
3. If the best homogeneous device group is found, we will construct the largest ring for each device
based on the best homogeneous device group, and the best logical mesh will be the union of all the
rings. Otherwise, the best logical mesh will be the balanced logical mesh, such as shape (2, 2) for
4 devices.
Returns:
best_logical_mesh: The best logical mesh for the given device list.
Usage:
>>> physical_devices = [0, 1, 2, 3]
>>> ab_profiler = AlphaBetaProfiler(physical_devices)
>>> best_logical_mesh = profiler.search_best_logical_mesh()
>>> print(best_logical_mesh)
[[0, 1], [2, 3]]
'''
def
_power_of_two
(
integer
):
return
integer
&
(
integer
-
1
)
==
0
def
_detect_homogeneous_device
(
alpha_beta_dict
):
'''
This function is used to detect whether the devices in the alpha_beta_dict are homogeneous.
Note: we assume that the devices in the alpha_beta_dict are homogeneous if the beta value
of the devices are in range of [(1 - self.homogeneous_tolerance), (1 + self.homogeneous_tolerance)]
* base_beta.
'''
homogeneous_device_dict
:
Dict
[
float
,
List
[
Tuple
[
int
]]]
=
{}
for
process_group
,
(
_
,
beta
)
in
alpha_beta_dict
.
items
():
if
homogeneous_device_dict
is
None
:
homogeneous_device_dict
[
beta
]
=
[]
homogeneous_device_dict
[
beta
].
append
(
process_group
)
match_beta
=
None
for
beta_value
in
homogeneous_device_dict
.
keys
():
if
beta
<=
beta_value
*
(
1
+
self
.
homogeneous_tolerance
)
and
beta
>=
beta_value
*
(
1
-
self
.
homogeneous_tolerance
):
match_beta
=
beta_value
break
if
match_beta
is
not
None
:
homogeneous_device_dict
[
match_beta
].
append
(
process_group
)
else
:
homogeneous_device_dict
[
beta
]
=
[]
homogeneous_device_dict
[
beta
].
append
(
process_group
)
return
homogeneous_device_dict
def
_check_contain_all_devices
(
homogeneous_group
:
List
[
Tuple
[
int
]]):
'''
This function is used to check whether the homogeneous_group contains all physical devices.
'''
flatten_mesh
=
[]
for
process_group
in
homogeneous_group
:
flatten_mesh
.
extend
(
process_group
)
non_duplicated_flatten_mesh
=
set
(
flatten_mesh
)
return
len
(
non_duplicated_flatten_mesh
)
==
len
(
self
.
physical_devices
)
def
_construct_largest_ring
(
homogeneous_group
:
List
[
Tuple
[
int
]]):
'''
This function is used to construct the largest ring in the homogeneous_group for each rank.
'''
# Construct the ring
ring
=
[]
ranks_in_ring
=
[]
for
rank
in
self
.
physical_devices
:
if
rank
in
ranks_in_ring
:
continue
stable_status
=
False
ring_for_rank
=
[]
ring_for_rank
.
append
(
rank
)
check_rank_list
=
[
rank
]
rank_to_check_list
=
[]
while
not
stable_status
:
stable_status
=
True
check_rank_list
.
extend
(
rank_to_check_list
)
rank_to_check_list
=
[]
for
i
in
range
(
len
(
check_rank_list
)):
check_rank
=
check_rank_list
.
pop
()
for
process_group
in
homogeneous_group
:
if
check_rank
in
process_group
:
rank_to_append
=
process_group
[
0
]
if
process_group
[
1
]
==
check_rank
else
process_group
[
1
]
if
rank_to_append
not
in
ring_for_rank
:
stable_status
=
False
rank_to_check_list
.
append
(
rank_to_append
)
ring_for_rank
.
append
(
rank_to_append
)
ring
.
append
(
ring_for_rank
)
ranks_in_ring
.
extend
(
ring_for_rank
)
return
ring
assert
_power_of_two
(
self
.
world_size
)
power_of_two
=
int
(
math
.
log2
(
self
.
world_size
))
median
=
power_of_two
//
2
balanced_logical_mesh_shape
=
(
2
**
median
,
2
**
(
power_of_two
-
median
))
row_size
,
column_size
=
balanced_logical_mesh_shape
[
0
],
balanced_logical_mesh_shape
[
1
]
balanced_logical_mesh
=
[]
for
row_index
in
range
(
row_size
):
balanced_logical_mesh
.
append
([])
for
column_index
in
range
(
column_size
):
balanced_logical_mesh
[
row_index
].
append
(
self
.
physical_devices
[
row_index
*
column_size
+
column_index
])
homogeneous_device_dict
=
_detect_homogeneous_device
(
self
.
alpha_beta_dict
)
beta_list
=
[
b
for
b
in
homogeneous_device_dict
.
keys
()]
beta_list
.
sort
()
beta_list
.
reverse
()
homogeneous_types
=
len
(
beta_list
)
best_logical_mesh
=
None
if
homogeneous_types
>=
2
:
for
_
in
range
(
homogeneous_types
-
1
):
lowest_beta
=
beta_list
.
pop
()
best_homogeneous_group
=
homogeneous_device_dict
[
lowest_beta
]
# if the best homogeneous group contains all physical devices,
# we will build the logical device mesh based on it. Otherwise,
# we will check next level homogeneous group.
if
_check_contain_all_devices
(
best_homogeneous_group
):
# We choose the largest ring for each rank to maximum the best bus utilization.
best_logical_mesh
=
_construct_largest_ring
(
best_homogeneous_group
)
break
if
homogeneous_types
==
1
or
best_logical_mesh
is
None
:
# in this case, we use balanced logical mesh as the best
# logical mesh.
best_logical_mesh
=
balanced_logical_mesh
return
best_logical_mesh
def
extract_alpha_beta_for_device_mesh
(
self
):
'''
Extract the mesh_alpha list and mesh_beta list based on the
best logical mesh, which will be used to initialize the device mesh.
Usage:
>>> physical_devices = [0, 1, 2, 3]
>>> ab_profiler = AlphaBetaProfiler(physical_devices)
>>> mesh_alpha, mesh_beta = profiler.extract_alpha_beta_for_device_mesh()
>>> print(mesh_alpha)
[2.5917552411556242e-05, 0.00010312341153621673]
>>> print(mesh_beta)
[5.875573704655635e-11, 4.7361584445959614e-12]
'''
best_logical_mesh
=
self
.
search_best_logical_mesh
()
first_axis
=
[
row
[
0
]
for
row
in
best_logical_mesh
]
second_axis
=
best_logical_mesh
[
0
]
# init process group for both axes
first_axis_process_group
=
dist
.
new_group
(
first_axis
)
second_axis_process_group
=
dist
.
new_group
(
second_axis
)
# extract alpha and beta for both axes
def
_extract_alpha_beta
(
pg
,
pg_handler
):
latency
=
self
.
profile_latency
(
pg
,
pg_handler
)
bandwidth
=
self
.
profile_bandwidth
(
pg
,
pg_handler
)
broadcast_object
=
[
latency
,
bandwidth
]
dist
.
broadcast_object_list
(
broadcast_object
,
src
=
pg
[
0
])
return
broadcast_object
first_latency
,
first_bandwidth
=
_extract_alpha_beta
(
first_axis
,
first_axis_process_group
)
second_latency
,
second_bandwidth
=
_extract_alpha_beta
(
second_axis
,
second_axis_process_group
)
mesh_alpha
=
[
first_latency
,
second_latency
]
mesh_beta
=
[
1
/
first_bandwidth
,
1
/
second_bandwidth
]
return
mesh_alpha
,
mesh_beta
colossalai/device/calc_pipeline_strategy.py
0 → 100644
View file @
e532679c
from
math
import
pow
import
numpy
as
np
def
get_submesh_choices
(
num_hosts
,
num_devices_per_host
,
mode
=
"new"
):
submesh_choices
=
[]
i
=
1
p
=
-
1
while
i
<=
num_devices_per_host
:
i
*=
2
p
+=
1
assert
pow
(
2
,
p
)
==
num_devices_per_host
,
(
"Only supports the cases where num_devices_per_host is power of two, "
f
"while now num_devices_per_host =
{
num_devices_per_host
}
"
)
if
mode
==
"alpa"
:
for
i
in
range
(
p
+
1
):
submesh_choices
.
append
((
1
,
pow
(
2
,
i
)))
for
i
in
range
(
2
,
num_hosts
+
1
):
submesh_choices
.
append
((
i
,
num_devices_per_host
))
elif
mode
==
"new"
:
for
i
in
range
(
p
//
2
+
1
):
for
j
in
range
(
i
,
p
-
i
+
1
):
submesh_choices
.
append
((
pow
(
2
,
i
),
pow
(
2
,
j
)))
return
submesh_choices
def
alpa_dp_impl
(
num_layers
,
num_devices
,
num_microbatches
,
submesh_choices
,
compute_cost
,
max_stage_cost
,
best_configs
):
"""Implementation of Alpa DP for pipeline strategy
Paper reference: https://www.usenix.org/system/files/osdi22-zheng-lianmin.pdf
Arguments:
num_layers: K
num_devices: N*M
num_microbatches: B
submesh_choices: List[(n_i,m_i)]
compute_cost: t_intra
"""
# For f, layer ID start from 0
# f[#pipeline stages, layer id that is currently being considered, number of devices used]
f
=
np
.
full
((
num_layers
+
1
,
num_layers
+
1
,
num_devices
+
1
),
np
.
inf
,
dtype
=
np
.
float32
)
f_stage_max
=
np
.
full
((
num_layers
+
1
,
num_layers
+
1
,
num_devices
+
1
),
0.0
,
dtype
=
np
.
float32
)
f_argmin
=
np
.
full
((
num_layers
+
1
,
num_layers
+
1
,
num_devices
+
1
,
3
),
-
1
,
dtype
=
np
.
int32
)
f
[
0
,
num_layers
,
0
]
=
0
for
s
in
range
(
1
,
num_layers
+
1
):
for
k
in
range
(
num_layers
-
1
,
-
1
,
-
1
):
for
d
in
range
(
1
,
num_devices
+
1
):
for
m
,
submesh
in
enumerate
(
submesh_choices
):
n_submesh_devices
=
np
.
prod
(
np
.
array
(
submesh
))
if
n_submesh_devices
<=
d
:
# TODO: [luzgh]: Why alpa needs max_n_succ_stages? Delete.
# if s - 1 <= max_n_succ_stages[i, k - 1, m, n_config]:
# ...
for
i
in
range
(
num_layers
,
k
,
-
1
):
stage_cost
=
compute_cost
[
k
,
i
,
m
]
new_cost
=
f
[
s
-
1
,
k
,
d
-
n_submesh_devices
]
+
stage_cost
if
(
stage_cost
<=
max_stage_cost
and
new_cost
<
f
[
s
,
k
,
d
]):
f
[
s
,
k
,
d
]
=
new_cost
f_stage_max
[
s
,
k
,
d
]
=
max
(
stage_cost
,
f_stage_max
[
s
-
1
,
i
,
d
-
n_submesh_devices
])
f_argmin
[
s
,
k
,
d
]
=
(
i
,
m
,
best_configs
[
k
,
i
,
m
])
best_s
=
-
1
best_total_cost
=
np
.
inf
for
s
in
range
(
1
,
num_layers
+
1
):
if
f
[
s
,
0
,
num_devices
]
<
best_total_cost
:
best_s
=
s
best_total_cost
=
f
[
s
,
0
,
num_devices
]
if
np
.
isinf
(
best_total_cost
):
return
np
.
inf
,
None
total_cost
=
f
[
best_s
,
0
,
num_devices
]
+
(
num_microbatches
-
1
)
*
f_stage_max
[
best_s
,
0
,
num_devices
]
current_s
=
best_s
current_layer
=
0
current_devices
=
num_devices
res
=
[]
while
current_s
>
0
and
current_layer
<
num_layers
and
current_devices
>
0
:
next_start_layer
,
submesh_choice
,
autosharding_choice
=
(
f_argmin
[
current_s
,
current_layer
,
current_devices
])
assert
next_start_layer
!=
-
1
and
current_devices
!=
-
1
res
.
append
(((
current_layer
,
next_start_layer
),
submesh_choice
,
autosharding_choice
))
current_s
-=
1
current_layer
=
next_start_layer
current_devices
-=
np
.
prod
(
np
.
array
(
submesh_choices
[
submesh_choice
]))
assert
(
current_s
==
0
and
current_layer
==
num_layers
and
current_devices
==
0
)
return
total_cost
,
res
def
alpa_dp
(
num_layers
,
num_devices
,
num_microbatches
,
submesh_choices
,
num_autosharding_configs
,
compute_cost
,
gap
=
1e-6
):
"""Alpa auto stage dynamic programming.
Code reference: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/stage_construction.py
Arguments:
submesh_choices: List[(int,int)]
num_autosharding_configs: Max number of t_intra(start_layer, end_layer, LogicalMesh)
compute_cost: np.array(num_layers,num_layers,num_submesh_choices,num_autosharding_configs)
"""
assert
np
.
shape
(
compute_cost
)
==
(
num_layers
,
num_layers
,
len
(
submesh_choices
),
num_autosharding_configs
),
"Cost shape wrong."
all_possible_stage_costs
=
np
.
sort
(
np
.
unique
(
compute_cost
))
best_cost
=
np
.
inf
best_solution
=
None
last_max_stage_cost
=
0.0
# TODO: [luzgh]: Why alpa needs the num_autosharding_configs dimension in compute_cost?
# In dp_impl it seems the argmin n_config will be chosen. Just amin here.
best_configs
=
np
.
argmin
(
compute_cost
,
axis
=
3
)
best_compute_cost
=
np
.
amin
(
compute_cost
,
axis
=
3
)
assert
len
(
all_possible_stage_costs
),
"no solution in auto stage construction."
for
max_stage_cost
in
all_possible_stage_costs
:
if
max_stage_cost
*
num_microbatches
>=
best_cost
:
break
if
max_stage_cost
-
last_max_stage_cost
<
gap
:
continue
cost
,
solution
=
alpa_dp_impl
(
num_layers
,
num_devices
,
num_microbatches
,
submesh_choices
,
best_compute_cost
,
max_stage_cost
,
best_configs
)
if
cost
<
best_cost
:
best_cost
=
cost
best_solution
=
solution
last_max_stage_cost
=
max_stage_cost
return
best_cost
,
best_solution
colossalai/device/device_mesh.py
View file @
e532679c
from
functools
import
reduce
import
operator
import
operator
from
functools
import
reduce
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -11,7 +12,7 @@ class DeviceMesh:
...
@@ -11,7 +12,7 @@ class DeviceMesh:
can be viewed as a 1x16 or a 4x4 logical mesh). Each mesh dimension has its
can be viewed as a 1x16 or a 4x4 logical mesh). Each mesh dimension has its
own latency and bandwidth. We use alpha-beta model to model the
own latency and bandwidth. We use alpha-beta model to model the
communication cost.
communication cost.
Arguments:
Arguments:
physical_mesh_id (torch.Tensor): physical view of the devices in global rank.
physical_mesh_id (torch.Tensor): physical view of the devices in global rank.
mesh_shape (torch.Size): shape of logical view.
mesh_shape (torch.Size): shape of logical view.
...
@@ -23,6 +24,7 @@ class DeviceMesh:
...
@@ -23,6 +24,7 @@ class DeviceMesh:
during initializing the DeviceMesh instance if the init_process_group set to True.
during initializing the DeviceMesh instance if the init_process_group set to True.
Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group.
Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group.
(default: False)
(default: False)
need_flatten(bool, optional): initialize flatten_device_mesh during initializing the DeviceMesh instance if the need_flatten set to True.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -49,8 +51,11 @@ class DeviceMesh:
...
@@ -49,8 +51,11 @@ class DeviceMesh:
self
.
need_flatten
=
need_flatten
self
.
need_flatten
=
need_flatten
if
self
.
init_process_group
:
if
self
.
init_process_group
:
self
.
process_groups_dict
=
self
.
create_process_groups_for_logical_mesh
()
self
.
process_groups_dict
=
self
.
create_process_groups_for_logical_mesh
()
if
self
.
need_flatten
:
if
self
.
need_flatten
and
self
.
_logical_mesh_id
.
dim
()
>
1
:
self
.
flatten_device_mesh
=
self
.
flatten
()
self
.
flatten_device_mesh
=
self
.
flatten
()
# Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten())
self
.
flatten_device_meshes
=
FlattenDeviceMesh
(
self
.
physical_mesh_id
,
self
.
mesh_shape
,
self
.
mesh_alpha
,
self
.
mesh_beta
)
@
property
@
property
def
shape
(
self
):
def
shape
(
self
):
...
@@ -64,6 +69,18 @@ class DeviceMesh:
...
@@ -64,6 +69,18 @@ class DeviceMesh:
def
logical_mesh_id
(
self
):
def
logical_mesh_id
(
self
):
return
self
.
_logical_mesh_id
return
self
.
_logical_mesh_id
def
__deepcopy__
(
self
,
memo
):
cls
=
self
.
__class__
result
=
cls
.
__new__
(
cls
)
memo
[
id
(
self
)]
=
result
for
k
,
v
in
self
.
__dict__
.
items
():
if
k
!=
'process_groups_dict'
:
setattr
(
result
,
k
,
__import__
(
"copy"
).
deepcopy
(
v
,
memo
))
else
:
setattr
(
result
,
k
,
v
)
return
result
def
flatten
(
self
):
def
flatten
(
self
):
"""
"""
Flatten the logical mesh into an effective 1d logical mesh,
Flatten the logical mesh into an effective 1d logical mesh,
...
@@ -90,7 +107,7 @@ class DeviceMesh:
...
@@ -90,7 +107,7 @@ class DeviceMesh:
def
create_process_groups_for_logical_mesh
(
self
):
def
create_process_groups_for_logical_mesh
(
self
):
'''
'''
This method is used to initialize the logical process groups which will be used in communications
This method is used to initialize the logical process groups which will be used in communications
among logical device mesh.
among logical device mesh.
Note: if init_process_group set to False, you have to call this method manually. Otherwise,
Note: if init_process_group set to False, you have to call this method manually. Otherwise,
the communication related function, such as ShapeConsistencyManager.apply will raise errors.
the communication related function, such as ShapeConsistencyManager.apply will raise errors.
'''
'''
...
@@ -186,3 +203,38 @@ class DeviceMesh:
...
@@ -186,3 +203,38 @@ class DeviceMesh:
penalty_factor
=
num_devices
/
2.0
penalty_factor
=
num_devices
/
2.0
return
(
self
.
mesh_alpha
[
mesh_dim
]
+
self
.
mesh_beta
[
mesh_dim
]
*
return
(
self
.
mesh_alpha
[
mesh_dim
]
+
self
.
mesh_beta
[
mesh_dim
]
*
(
num_devices
-
1
)
/
num_devices
/
num_devices
*
num_bytes
*
penalty_factor
+
0.001
)
(
num_devices
-
1
)
/
num_devices
/
num_devices
*
num_bytes
*
penalty_factor
+
0.001
)
class
FlattenDeviceMesh
(
DeviceMesh
):
def
__init__
(
self
,
physical_mesh_id
,
mesh_shape
,
mesh_alpha
=
None
,
mesh_beta
=
None
):
super
().
__init__
(
physical_mesh_id
,
mesh_shape
,
mesh_alpha
,
mesh_beta
,
init_process_group
=
False
,
need_flatten
=
False
)
# Different from flatten(), mesh_shape leaves unchanged, mesh_alpha and mesh_beta are scalars
self
.
mesh_alpha
=
max
(
self
.
mesh_alpha
)
self
.
mesh_beta
=
min
(
self
.
mesh_beta
)
# Different from original process_groups_dict, rank_list is not stored
self
.
process_number_dict
=
self
.
create_process_numbers_for_logical_mesh
()
def
create_process_numbers_for_logical_mesh
(
self
):
'''
Build 1d DeviceMesh in column-major(0) and row-major(1)
for example:
mesh_shape = (2,4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7]]
# return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]}
'''
num_devices
=
reduce
(
operator
.
mul
,
self
.
mesh_shape
,
1
)
process_numbers_dict
=
{}
process_numbers_dict
[
0
]
=
torch
.
arange
(
num_devices
).
reshape
(
self
.
mesh_shape
).
transpose
(
1
,
0
).
flatten
().
tolist
()
process_numbers_dict
[
1
]
=
torch
.
arange
(
num_devices
).
reshape
(
self
.
mesh_shape
).
flatten
().
tolist
()
return
process_numbers_dict
def
mix_gather_cost
(
self
,
num_bytes
):
num_devices
=
reduce
(
operator
.
mul
,
self
.
mesh_shape
,
1
)
return
(
self
.
mesh_alpha
+
self
.
mesh_beta
*
(
num_devices
-
1
)
/
num_devices
*
num_bytes
+
0.1
)
colossalai/fx/__init__.py
View file @
e532679c
from
._compatibility
import
compatibility
,
is_compatible_with_meta
from
._compatibility
import
compatibility
,
is_compatible_with_meta
from
.graph_module
import
ColoGraphModule
from
.graph_module
import
ColoGraphModule
from
.passes
import
MetaInfoProp
from
.passes
import
MetaInfoProp
,
metainfo_trace
from
.tracer
import
ColoTracer
,
meta_trace
from
.tracer
import
ColoTracer
,
meta_trace
,
symbolic_trace
colossalai/fx/_meta_registrations.py
View file @
e532679c
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
# for more meta_registrations
# for more meta_registrations
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
Callable
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch.utils._pytree
import
tree_map
from
torch.utils._pytree
import
tree_map
...
@@ -163,6 +163,23 @@ def meta_conv(
...
@@ -163,6 +163,23 @@ def meta_conv(
return
out
return
out
@
register_meta
(
aten
.
_convolution
.
default
)
def
meta_conv_1
(
input_tensor
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
stride
:
List
[
int
],
padding
:
List
[
int
],
dilation
:
List
[
int
],
is_transposed
:
bool
,
output_padding
:
List
[
int
],
groups
:
int
,
*
extra_args
):
out
=
meta_conv
(
input_tensor
,
weight
,
bias
,
stride
,
padding
,
dilation
,
is_transposed
,
output_padding
,
groups
)
return
out
@
register_meta
(
aten
.
convolution_backward
.
default
)
@
register_meta
(
aten
.
convolution_backward
.
default
)
def
meta_conv_backward
(
grad_output
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_sizes
,
stride
,
def
meta_conv_backward
(
grad_output
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_sizes
,
stride
,
padding
,
dilation
,
transposed
,
output_padding
,
groups
,
output_mask
):
padding
,
dilation
,
transposed
,
output_padding
,
groups
,
output_mask
):
...
@@ -179,6 +196,79 @@ def meta_adaptive_avg_pool2d_backward(
...
@@ -179,6 +196,79 @@ def meta_adaptive_avg_pool2d_backward(
return
grad_input
return
grad_input
# ================================ RNN =============================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
@
register_meta
(
aten
.
_cudnn_rnn
.
default
)
def
meta_cuda_rnn
(
input
,
weight
,
weight_stride0
,
weight_buf
,
hx
,
cx
,
mode
,
hidden_size
,
proj_size
,
num_layers
,
batch_first
,
dropout
,
train
,
bidirectional
,
batch_sizes
,
dropout_state
,
):
is_input_packed
=
len
(
batch_sizes
)
!=
0
if
is_input_packed
:
seq_length
=
len
(
batch_sizes
)
mini_batch
=
batch_sizes
[
0
]
batch_sizes_sum
=
input
.
shape
[
0
]
else
:
seq_length
=
input
.
shape
[
1
]
if
batch_first
else
input
.
shape
[
0
]
mini_batch
=
input
.
shape
[
0
]
if
batch_first
else
input
.
shape
[
1
]
batch_sizes_sum
=
-
1
num_directions
=
2
if
bidirectional
else
1
out_size
=
proj_size
if
proj_size
!=
0
else
hidden_size
if
is_input_packed
:
out_shape
=
[
batch_sizes_sum
,
out_size
*
num_directions
]
else
:
out_shape
=
(
[
mini_batch
,
seq_length
,
out_size
*
num_directions
]
if
batch_first
else
[
seq_length
,
mini_batch
,
out_size
*
num_directions
]
)
output
=
input
.
new_empty
(
out_shape
)
cell_shape
=
[
num_layers
*
num_directions
,
mini_batch
,
hidden_size
]
cy
=
torch
.
empty
(
0
)
if
cx
is
None
else
cx
.
new_empty
(
cell_shape
)
hy
=
hx
.
new_empty
([
num_layers
*
num_directions
,
mini_batch
,
out_size
])
# TODO: Query cudnnGetRNNTrainingReserveSize (expose to python)
reserve_shape
=
0
if
train
else
0
reserve
=
input
.
new_empty
(
reserve_shape
,
dtype
=
torch
.
uint8
)
return
output
,
hy
,
cy
,
reserve
,
weight_buf
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
@
register_meta
(
aten
.
_cudnn_rnn_backward
.
default
)
def
meta_cudnn_rnn_backward
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight_stride0
:
int
,
hx
:
torch
.
Tensor
,
cx
:
Optional
[
torch
.
Tensor
]
=
None
,
*
args
,
**
kwargs
):
print
(
input
,
weight
,
hx
,
cx
)
grad_input
=
torch
.
empty_like
(
input
)
grad_weight
=
torch
.
empty_like
(
weight
)
grad_hx
=
torch
.
empty_like
(
hx
)
grad_cx
=
torch
.
empty_like
(
cx
)
if
cx
is
not
None
else
torch
.
empty
((),
device
=
'meta'
)
return
grad_input
,
grad_weight
,
grad_hx
,
grad_cx
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
# ============================== Activations =======================================
# ============================== Activations =======================================
@
register_meta
(
aten
.
relu
.
default
)
@
register_meta
(
aten
.
relu
.
default
)
...
@@ -186,6 +276,11 @@ def meta_relu(input: torch.Tensor):
...
@@ -186,6 +276,11 @@ def meta_relu(input: torch.Tensor):
return
torch
.
empty_like
(
input
)
return
torch
.
empty_like
(
input
)
@
register_meta
(
aten
.
prelu
.
default
)
def
meta_prelu
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
):
return
torch
.
empty_like
(
input
)
@
register_meta
(
aten
.
hardswish
.
default
)
@
register_meta
(
aten
.
hardswish
.
default
)
def
meta_hardswish
(
input
:
torch
.
Tensor
):
def
meta_hardswish
(
input
:
torch
.
Tensor
):
return
torch
.
empty_like
(
input
)
return
torch
.
empty_like
(
input
)
...
@@ -278,12 +373,18 @@ def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, me
...
@@ -278,12 +373,18 @@ def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, me
# ================================== Misc ==========================================
# ================================== Misc ==========================================
#https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
#
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
@
register_meta
(
aten
.
roll
.
default
)
@
register_meta
(
aten
.
roll
.
default
)
def
meta_roll
(
input
:
torch
.
Tensor
,
shifts
,
dims
):
def
meta_roll
(
input
:
torch
.
Tensor
,
shifts
,
dims
):
return
input
return
input
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Scalar.cpp
@
register_meta
(
aten
.
_local_scalar_dense
.
default
)
def
meta_local_scalar_dense
(
self
:
torch
.
Tensor
):
return
0
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp
@
register_meta
(
aten
.
where
.
self
)
@
register_meta
(
aten
.
where
.
self
)
def
meta_where_self
(
condition
:
torch
.
Tensor
,
self
:
torch
.
Tensor
,
other
:
torch
.
Tensor
):
def
meta_where_self
(
condition
:
torch
.
Tensor
,
self
:
torch
.
Tensor
,
other
:
torch
.
Tensor
):
...
@@ -317,7 +418,7 @@ def meta_index_Tensor(self, indices):
...
@@ -317,7 +418,7 @@ def meta_index_Tensor(self, indices):
indices
=
result
indices
=
result
assert
len
(
indices
)
<=
self
.
ndim
,
f
"too many indices for tensor of dimension
{
self
.
ndim
}
(got
{
len
(
indices
)
}
)"
assert
len
(
indices
)
<=
self
.
ndim
,
f
"too many indices for tensor of dimension
{
self
.
ndim
}
(got
{
len
(
indices
)
}
)"
# expand_outplace
# expand_outplace
import
torch._refs
as
refs
# avoid import cycle in mypy
import
torch._refs
as
refs
indices
=
list
(
refs
.
_maybe_broadcast
(
*
indices
))
indices
=
list
(
refs
.
_maybe_broadcast
(
*
indices
))
# add missing null tensors
# add missing null tensors
...
...
colossalai/fx/codegen/activation_checkpoint_codegen.py
View file @
e532679c
import
colossalai
from
typing
import
Any
,
Callable
,
Dict
,
Iterable
,
List
,
Tuple
import
torch
import
torch
from
typing
import
List
,
Callable
,
Any
,
Tuple
,
Dict
,
Iterable
import
colossalai
try
:
try
:
from
torch.fx.node
import
Node
,
Argument
,
map_arg
,
_type_repr
,
_get_qualified_name
from
torch.fx.graph
import
(
from
torch.fx.graph
import
_Namespace
,
PythonCode
,
_custom_builtins
,
_is_from_torch
,
_format_target
,
magic_methods
,
CodeGen
,
_origin_type_map
,
inplace_methods
,
_CustomBuiltin
CodeGen
,
PythonCode
,
_custom_builtins
,
_CustomBuiltin
,
_format_target
,
_is_from_torch
,
_Namespace
,
_origin_type_map
,
inplace_methods
,
magic_methods
,
)
from
torch.fx.node
import
Argument
,
Node
,
_get_qualified_name
,
_type_repr
,
map_arg
CODEGEN_AVAILABLE
=
True
CODEGEN_AVAILABLE
=
True
except
:
except
:
from
torch.fx.graph
import
_Namespace
,
PythonCode
,
_custom_builtins
,
_is_from_torch
,
_format_target
,
magic_methods
,
_origin_type_map
,
_format_args
,
_CustomBuiltin
from
torch.fx.graph
import
(
from
torch.fx.node
import
Node
,
Argument
,
map_arg
,
_type_repr
,
_get_qualified_name
PythonCode
,
_custom_builtins
,
_CustomBuiltin
,
_format_args
,
_format_target
,
_is_from_torch
,
_Namespace
,
_origin_type_map
,
magic_methods
,
)
from
torch.fx.node
import
Argument
,
Node
,
_get_qualified_name
,
_type_repr
,
map_arg
CODEGEN_AVAILABLE
=
False
CODEGEN_AVAILABLE
=
False
if
CODEGEN_AVAILABLE
:
if
CODEGEN_AVAILABLE
:
...
@@ -27,7 +50,7 @@ def _gen_saved_tensors_hooks():
...
@@ -27,7 +50,7 @@ def _gen_saved_tensors_hooks():
return (x.device, x.cpu())
return (x.device, x.cpu())
else:
else:
return x
return x
def pack_hook_no_input(self, x):
def pack_hook_no_input(self, x):
if getattr(x, "offload", True):
if getattr(x, "offload", True):
return (x.device, x.cpu())
return (x.device, x.cpu())
...
@@ -48,11 +71,9 @@ def pack_hook_no_input(self, x):
...
@@ -48,11 +71,9 @@ def pack_hook_no_input(self, x):
def
_gen_save_tensors_hooks_context
(
offload_input
=
True
)
->
str
:
def
_gen_save_tensors_hooks_context
(
offload_input
=
True
)
->
str
:
"""Generate customized saved_tensors_hooks
"""Generate customized saved_tensors_hooks
Args:
Args:
offload_input (bool, optional): whether we need offload input, if offload_input=False,
offload_input (bool, optional): whether we need offload input, if offload_input=False,
we will use self.pack_hook_no_input instead. Defaults to True.
we will use self.pack_hook_no_input instead. Defaults to True.
Returns:
Returns:
str: generated context
str: generated context
"""
"""
...
@@ -111,8 +132,8 @@ def _find_ckpt_regions(nodes: List[Node]):
...
@@ -111,8 +132,8 @@ def _find_ckpt_regions(nodes: List[Node]):
current_region
=
None
current_region
=
None
for
idx
,
node
in
enumerate
(
nodes
):
for
idx
,
node
in
enumerate
(
nodes
):
if
hasattr
(
node
,
'activation_checkpoint'
)
:
if
'activation_checkpoint'
in
node
.
meta
:
act_ckpt_label
=
node
.
activation_checkpoint
act_ckpt_label
=
node
.
meta
[
'
activation_checkpoint
'
]
# this activation checkpoint label is not set yet
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
# meaning this is the first node of the activation ckpt region
...
@@ -129,7 +150,7 @@ def _find_ckpt_regions(nodes: List[Node]):
...
@@ -129,7 +150,7 @@ def _find_ckpt_regions(nodes: List[Node]):
current_region
=
act_ckpt_label
current_region
=
act_ckpt_label
start
=
idx
start
=
idx
end
=
-
1
end
=
-
1
elif
current_region
is
not
None
and
not
hasattr
(
node
,
'activation_checkpoint'
)
:
elif
current_region
is
not
None
and
not
'activation_checkpoint'
in
node
.
meta
:
# used to check the case below
# used to check the case below
# node ckpt states = [ckpt, ckpt, non-ckpt]
# node ckpt states = [ckpt, ckpt, non-ckpt]
end
=
idx
-
1
end
=
idx
-
1
...
@@ -144,7 +165,7 @@ def _find_ckpt_regions(nodes: List[Node]):
...
@@ -144,7 +165,7 @@ def _find_ckpt_regions(nodes: List[Node]):
def
_find_offload_regions
(
nodes
:
List
[
Node
]):
def
_find_offload_regions
(
nodes
:
List
[
Node
]):
"""This function is to find the offload regions
"""This function is to find the offload regions
In pofo algorithm, during annotation, we will annotate the offload region with the
In pofo algorithm, during annotation, we will annotate the offload region with the
list in the form of [idx, offload_input, offload_bar]. idx indicates the offload
list in the form of [idx, offload_input, offload_bar]. idx indicates the offload
region's index, offload_input is a bool type indicates whether we need to offload
region's index, offload_input is a bool type indicates whether we need to offload
the input, offload_bar is a bool type indicates whether we need to offload all the
the input, offload_bar is a bool type indicates whether we need to offload all the
...
@@ -157,8 +178,8 @@ def _find_offload_regions(nodes: List[Node]):
...
@@ -157,8 +178,8 @@ def _find_offload_regions(nodes: List[Node]):
current_region
=
None
current_region
=
None
for
idx
,
node
in
enumerate
(
nodes
):
for
idx
,
node
in
enumerate
(
nodes
):
if
hasattr
(
node
,
'activation_offload'
)
and
isinstance
(
getattr
(
node
,
'activation_offload'
,
None
)
,
Iterable
):
if
'activation_offload'
in
node
.
meta
and
isinstance
(
node
.
meta
[
'activation_offload'
]
,
Iterable
):
act_offload_label
=
node
.
activation_offload
act_offload_label
=
node
.
meta
[
'
activation_offload
'
]
if
current_region
==
None
:
if
current_region
==
None
:
current_region
=
act_offload_label
current_region
=
act_offload_label
...
@@ -212,18 +233,16 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reen
...
@@ -212,18 +233,16 @@ def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reen
def
_end_of_ckpt
(
node
:
Node
,
check_idx
:
int
)
->
bool
:
def
_end_of_ckpt
(
node
:
Node
,
check_idx
:
int
)
->
bool
:
"""Check if the node could end the ckpt region
"""Check if the node could end the ckpt region
Args:
Args:
node (Node): torch.fx.Node
node (Node): torch.fx.Node
check_idx (int): the index of checkpoint level for
check_idx (int): the index of checkpoint level for
nested checkpoint
nested checkpoint
Returns:
Returns:
bool
bool
"""
"""
if
hasattr
(
node
,
"
activation_checkpoint
"
)
:
if
'
activation_checkpoint
'
in
node
.
meta
:
if
isinstance
(
node
.
activation_checkpoint
,
list
):
if
isinstance
(
node
.
meta
[
'
activation_checkpoint
'
]
,
list
):
return
node
.
activation_checkpoint
[
check_idx
]
==
None
return
node
.
meta
[
'
activation_checkpoint
'
]
[
check_idx
]
==
None
else
:
else
:
return
False
return
False
else
:
else
:
...
@@ -232,7 +251,7 @@ def _end_of_ckpt(node: Node, check_idx: int) -> bool:
...
@@ -232,7 +251,7 @@ def _end_of_ckpt(node: Node, check_idx: int) -> bool:
def
_find_nested_ckpt_regions
(
nodes
,
check_idx
=
0
):
def
_find_nested_ckpt_regions
(
nodes
,
check_idx
=
0
):
"""
"""
Find the nested checkpoint regions given a list of consecutive nodes. The outputs
Find the nested checkpoint regions given a list of consecutive nodes. The outputs
will be list of tuples, each tuple is in the form of (start_index, end_index).
will be list of tuples, each tuple is in the form of (start_index, end_index).
"""
"""
ckpt_regions
=
[]
ckpt_regions
=
[]
...
@@ -241,11 +260,11 @@ def _find_nested_ckpt_regions(nodes, check_idx=0):
...
@@ -241,11 +260,11 @@ def _find_nested_ckpt_regions(nodes, check_idx=0):
current_region
=
None
current_region
=
None
for
idx
,
node
in
enumerate
(
nodes
):
for
idx
,
node
in
enumerate
(
nodes
):
if
hasattr
(
node
,
'activation_checkpoint'
)
:
if
'activation_checkpoint'
in
node
.
meta
:
if
isinstance
(
getattr
(
node
,
'activation_checkpoint'
)
,
int
):
if
isinstance
(
node
.
meta
[
'activation_checkpoint'
]
,
int
):
act_ckpt_label
=
node
.
activation_checkpoint
act_ckpt_label
=
node
.
meta
[
'
activation_checkpoint
'
]
else
:
else
:
act_ckpt_label
=
node
.
activation_checkpoint
[
check_idx
]
act_ckpt_label
=
node
.
meta
[
'
activation_checkpoint
'
]
[
check_idx
]
# this activation checkpoint label is not set yet
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
# meaning this is the first node of the activation ckpt region
...
@@ -287,7 +306,6 @@ def emit_ckpt_func(body,
...
@@ -287,7 +306,6 @@ def emit_ckpt_func(body,
level
=
0
,
level
=
0
,
in_ckpt
=
False
):
in_ckpt
=
False
):
"""Emit ckpt fuction in nested way
"""Emit ckpt fuction in nested way
Args:
Args:
body: forward code, in recursive calls, this part will be checkpoint
body: forward code, in recursive calls, this part will be checkpoint
functions code
functions code
...
@@ -303,8 +321,8 @@ def emit_ckpt_func(body,
...
@@ -303,8 +321,8 @@ def emit_ckpt_func(body,
inputs
,
outputs
=
_find_input_and_output_nodes
(
node_list
)
inputs
,
outputs
=
_find_input_and_output_nodes
(
node_list
)
# if the current checkpoint function use int as label, using old generation method
# if the current checkpoint function use int as label, using old generation method
if
isinstance
(
node_list
[
0
].
activation_checkpoint
,
int
):
if
isinstance
(
node_list
[
0
].
meta
[
'
activation_checkpoint
'
]
,
int
):
label
=
node_list
[
0
].
activation_checkpoint
label
=
node_list
[
0
].
meta
[
'
activation_checkpoint
'
]
ckpt_fn_def
=
_gen_ckpt_fn_def
(
label
,
inputs
)
ckpt_fn_def
=
_gen_ckpt_fn_def
(
label
,
inputs
)
ckpt_func
.
append
(
f
'
{
ckpt_fn_def
}
\n
'
)
ckpt_func
.
append
(
f
'
{
ckpt_fn_def
}
\n
'
)
for
node
in
node_list
:
for
node
in
node_list
:
...
@@ -313,7 +331,7 @@ def emit_ckpt_func(body,
...
@@ -313,7 +331,7 @@ def emit_ckpt_func(body,
delete_unused_value_func
(
node
,
ckpt_func
)
delete_unused_value_func
(
node
,
ckpt_func
)
ckpt_func
.
append
(
' '
+
_gen_ckpt_output
(
outputs
)
+
'
\n\n
'
)
ckpt_func
.
append
(
' '
+
_gen_ckpt_output
(
outputs
)
+
'
\n\n
'
)
activation_offload
=
getattr
(
node_list
[
0
]
,
"
activation_offload
"
,
False
)
activation_offload
=
node_list
[
0
]
.
meta
.
get
(
'
activation_offload
'
,
False
)
usage
=
_gen_ckpt_usage
(
label
,
activation_offload
,
inputs
,
outputs
,
False
)
usage
=
_gen_ckpt_usage
(
label
,
activation_offload
,
inputs
,
outputs
,
False
)
usage
+=
"
\n
"
usage
+=
"
\n
"
body
.
append
(
usage
)
body
.
append
(
usage
)
...
@@ -322,12 +340,12 @@ def emit_ckpt_func(body,
...
@@ -322,12 +340,12 @@ def emit_ckpt_func(body,
else
:
else
:
# label given by each layer, e.g. if you are currently at level [0, 1, 1]
# label given by each layer, e.g. if you are currently at level [0, 1, 1]
# the label will be '0_1_1'
# the label will be '0_1_1'
label
=
"_"
.
join
([
str
(
idx
)
for
idx
in
node_list
[
0
].
activation_checkpoint
[:
level
+
1
]])
label
=
"_"
.
join
([
str
(
idx
)
for
idx
in
node_list
[
0
].
meta
[
'
activation_checkpoint
'
]
[:
level
+
1
]])
ckpt_fn_def
=
_gen_ckpt_fn_def
(
label
,
inputs
)
ckpt_fn_def
=
_gen_ckpt_fn_def
(
label
,
inputs
)
ckpt_func
.
append
(
f
'
{
ckpt_fn_def
}
\n
'
)
ckpt_func
.
append
(
f
'
{
ckpt_fn_def
}
\n
'
)
# if there is more level to fetch
# if there is more level to fetch
if
level
+
1
<
len
(
node_list
[
0
].
activation_checkpoint
):
if
level
+
1
<
len
(
node_list
[
0
].
meta
[
'
activation_checkpoint
'
]
):
ckpt_regions
=
_find_nested_ckpt_regions
(
node_list
,
level
+
1
)
ckpt_regions
=
_find_nested_ckpt_regions
(
node_list
,
level
+
1
)
start_idx
=
[
item
[
0
]
for
item
in
ckpt_regions
]
start_idx
=
[
item
[
0
]
for
item
in
ckpt_regions
]
end_idx
=
[
item
[
1
]
for
item
in
ckpt_regions
]
end_idx
=
[
item
[
1
]
for
item
in
ckpt_regions
]
...
@@ -354,7 +372,7 @@ def emit_ckpt_func(body,
...
@@ -354,7 +372,7 @@ def emit_ckpt_func(body,
ckpt_func
.
append
(
' '
+
_gen_ckpt_output
(
outputs
)
+
'
\n\n
'
)
ckpt_func
.
append
(
' '
+
_gen_ckpt_output
(
outputs
)
+
'
\n\n
'
)
ckpt_func
+=
ckpt_func_buffer
ckpt_func
+=
ckpt_func_buffer
activation_offload
=
getattr
(
node_list
[
0
]
,
"
activation_offload
"
,
False
)
activation_offload
=
node_list
[
0
]
.
meta
.
get
(
'
activation_offload
'
,
False
)
usage
=
_gen_ckpt_usage
(
label
,
activation_offload
,
inputs
,
outputs
,
False
)
+
'
\n
'
usage
=
_gen_ckpt_usage
(
label
,
activation_offload
,
inputs
,
outputs
,
False
)
+
'
\n
'
if
in_ckpt
:
if
in_ckpt
:
usage
=
' '
+
usage
usage
=
' '
+
usage
...
@@ -368,7 +386,7 @@ def emit_ckpt_func(body,
...
@@ -368,7 +386,7 @@ def emit_ckpt_func(body,
delete_unused_value_func
(
node
,
ckpt_func
)
delete_unused_value_func
(
node
,
ckpt_func
)
ckpt_func
.
append
(
' '
+
_gen_ckpt_output
(
outputs
)
+
'
\n\n
'
)
ckpt_func
.
append
(
' '
+
_gen_ckpt_output
(
outputs
)
+
'
\n\n
'
)
activation_offload
=
getattr
(
node_list
[
0
]
,
"
activation_offload
"
,
False
)
activation_offload
=
node_list
[
0
]
.
meta
.
get
(
'
activation_offload
'
,
False
)
usage
=
_gen_ckpt_usage
(
label
,
activation_offload
,
inputs
,
outputs
,
False
)
+
'
\n
'
usage
=
_gen_ckpt_usage
(
label
,
activation_offload
,
inputs
,
outputs
,
False
)
+
'
\n
'
if
in_ckpt
:
if
in_ckpt
:
usage
=
' '
+
usage
usage
=
' '
+
usage
...
@@ -379,7 +397,6 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
...
@@ -379,7 +397,6 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
"""Emit code with nested activation checkpoint
"""Emit code with nested activation checkpoint
When we detect some of the node.activation_checkpoint is a List, we will use
When we detect some of the node.activation_checkpoint is a List, we will use
this function to emit the activation checkpoint codes.
this function to emit the activation checkpoint codes.
Args:
Args:
body: forward code
body: forward code
ckpt_func: checkpoint functions code
ckpt_func: checkpoint functions code
...
@@ -564,8 +581,8 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
...
@@ -564,8 +581,8 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
# we need to check if the checkpoint need to offload the input
# we need to check if the checkpoint need to offload the input
start_node_idx
=
start_idx
[
label
]
start_node_idx
=
start_idx
[
label
]
if
hasattr
(
node_list
[
start_node_idx
]
,
'activation_offload'
)
:
if
'activation_offload'
in
node_list
[
start_node_idx
]
.
meta
:
activation_offload
=
node_list
[
start_node_idx
].
activation_offload
activation_offload
=
node_list
[
start_node_idx
].
meta
[
'
activation_offload
'
]
else
:
else
:
activation_offload
=
False
activation_offload
=
False
...
@@ -577,8 +594,8 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
...
@@ -577,8 +594,8 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
if
input_node
.
op
!=
"placeholder"
:
if
input_node
.
op
!=
"placeholder"
:
non_leaf_input
=
1
non_leaf_input
=
1
for
user
in
input_node
.
users
:
for
user
in
input_node
.
users
:
if
hasattr
(
user
,
"
activation_checkpoint
"
)
:
if
'
activation_checkpoint
'
in
user
.
meta
:
if
user
.
activation_checkpoint
==
label
:
if
user
.
meta
[
'
activation_checkpoint
'
]
==
label
:
if
user
.
op
==
"call_module"
:
if
user
.
op
==
"call_module"
:
if
hasattr
(
user
.
graph
.
owning_module
.
get_submodule
(
user
.
target
),
"inplace"
):
if
hasattr
(
user
.
graph
.
owning_module
.
get_submodule
(
user
.
target
),
"inplace"
):
use_reentrant
=
not
user
.
graph
.
owning_module
.
get_submodule
(
user
.
target
).
inplace
use_reentrant
=
not
user
.
graph
.
owning_module
.
get_submodule
(
user
.
target
).
inplace
...
@@ -616,10 +633,8 @@ if CODEGEN_AVAILABLE:
...
@@ -616,10 +633,8 @@ if CODEGEN_AVAILABLE:
def
add_global
(
name_hint
:
str
,
obj
:
Any
):
def
add_global
(
name_hint
:
str
,
obj
:
Any
):
"""Add an obj to be tracked as a global.
"""Add an obj to be tracked as a global.
We call this for names that reference objects external to the
We call this for names that reference objects external to the
Graph, like functions or types.
Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
"""
if
_is_from_torch
(
obj
)
and
obj
!=
torch
.
device
:
# to support registering torch.device
if
_is_from_torch
(
obj
)
and
obj
!=
torch
.
device
:
# to support registering torch.device
...
@@ -796,7 +811,7 @@ if CODEGEN_AVAILABLE:
...
@@ -796,7 +811,7 @@ if CODEGEN_AVAILABLE:
# if any node has a list of labels for activation_checkpoint, we
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
# will use nested type of activation checkpoint codegen
if
any
(
isinstance
(
getattr
(
node
,
"
activation_checkpoint
"
,
None
),
Iterable
)
for
node
in
nodes
):
if
any
(
isinstance
(
node
.
meta
.
get
(
'
activation_checkpoint
'
,
None
),
Iterable
)
for
node
in
nodes
):
emit_code_with_nested_activation_checkpoint
(
body
,
ckpt_func
,
nodes
,
emit_node
,
delete_unused_values
)
emit_code_with_nested_activation_checkpoint
(
body
,
ckpt_func
,
nodes
,
emit_node
,
delete_unused_values
)
else
:
else
:
emit_code_with_activation_checkpoint
(
body
,
ckpt_func
,
nodes
,
emit_node
,
delete_unused_values
)
emit_code_with_activation_checkpoint
(
body
,
ckpt_func
,
nodes
,
emit_node
,
delete_unused_values
)
...
@@ -829,7 +844,6 @@ if CODEGEN_AVAILABLE:
...
@@ -829,7 +844,6 @@ if CODEGEN_AVAILABLE:
code
=
'
\n
'
.
join
(
' '
+
line
for
line
in
code
.
split
(
'
\n
'
))
code
=
'
\n
'
.
join
(
' '
+
line
for
line
in
code
.
split
(
'
\n
'
))
fn_code
=
f
"""
fn_code
=
f
"""
{
wrap_stmts
}
{
wrap_stmts
}
{
prologue
}
{
prologue
}
{
code
}
"""
{
code
}
"""
return
PythonCode
(
fn_code
,
globals_
)
return
PythonCode
(
fn_code
,
globals_
)
...
@@ -851,10 +865,8 @@ else:
...
@@ -851,10 +865,8 @@ else:
def
add_global
(
name_hint
:
str
,
obj
:
Any
):
def
add_global
(
name_hint
:
str
,
obj
:
Any
):
"""Add an obj to be tracked as a global.
"""Add an obj to be tracked as a global.
We call this for names that reference objects external to the
We call this for names that reference objects external to the
Graph, like functions or types.
Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
"""
if
_is_from_torch
(
obj
)
and
obj
!=
torch
.
device
:
# to support registering torch.device
if
_is_from_torch
(
obj
)
and
obj
!=
torch
.
device
:
# to support registering torch.device
...
@@ -999,7 +1011,7 @@ else:
...
@@ -999,7 +1011,7 @@ else:
# if any node has a list of labels for activation_checkpoint, we
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
# will use nested type of activation checkpoint codegen
if
any
(
isinstance
(
getattr
(
node
,
"
activation_checkpoint
"
,
None
),
Iterable
)
for
node
in
self
.
nodes
):
if
any
(
isinstance
(
node
.
meta
.
get
(
'
activation_checkpoint
'
,
None
),
Iterable
)
for
node
in
self
.
nodes
):
emit_code_with_nested_activation_checkpoint
(
body
,
ckpt_func
,
self
.
nodes
,
emit_node
,
delete_unused_values
)
emit_code_with_nested_activation_checkpoint
(
body
,
ckpt_func
,
self
.
nodes
,
emit_node
,
delete_unused_values
)
else
:
else
:
emit_code_with_activation_checkpoint
(
body
,
ckpt_func
,
self
.
nodes
,
emit_node
,
delete_unused_values
)
emit_code_with_activation_checkpoint
(
body
,
ckpt_func
,
self
.
nodes
,
emit_node
,
delete_unused_values
)
...
@@ -1040,7 +1052,6 @@ else:
...
@@ -1040,7 +1052,6 @@ else:
# in forward function
# in forward function
fn_code
=
f
"""
fn_code
=
f
"""
{
wrap_stmts
}
{
wrap_stmts
}
{
ckpt_func
}
{
ckpt_func
}
def forward(
{
', '
.
join
(
orig_args
)
}
)
{
maybe_return_annotation
[
0
]
}
:
def forward(
{
', '
.
join
(
orig_args
)
}
)
{
maybe_return_annotation
[
0
]
}
:
{
code
}
"""
{
code
}
"""
...
...
colossalai/fx/passes/__init__.py
View file @
e532679c
from
.adding_split_node_pass
import
balanced_split_pass
,
split_with_split_nodes_pass
from
.adding_split_node_pass
import
balanced_split_pass
,
split_with_split_nodes_pass
from
.shard_1d_pass
import
column_shard_linear_pass
,
row_shard_linear_pass
from
.meta_info_prop
import
MetaInfoProp
from
.concrete_info_prop
import
ConcreteInfoProp
from
.concrete_info_prop
import
ConcreteInfoProp
from
.meta_info_prop
import
MetaInfoProp
,
metainfo_trace
from
.shard_1d_pass
import
column_shard_linear_pass
,
row_shard_linear_pass
colossalai/fx/passes/adding_split_node_pass.py
View file @
e532679c
import
torch
import
torch
from
torch.fx
import
symbolic_trace
from
torch.fx
import
symbolic_trace
from
torch.fx.node
import
Node
from
torch.fx.node
import
Node
from
colossalai.fx.passes.split_module
import
split_module
from
colossalai.fx.passes.split_module
import
split_module
...
@@ -9,6 +9,30 @@ def pipe_split():
...
@@ -9,6 +9,30 @@ def pipe_split():
pass
pass
def
avgnode_split_pass
(
gm
:
torch
.
fx
.
GraphModule
,
pp_size
:
int
):
"""
In avgnode_split_pass, simpliy split graph by node number.
"""
mod_graph
=
gm
.
graph
avg_num_node
=
len
(
mod_graph
.
nodes
)
//
pp_size
accumulate_num_node
=
0
for
node
in
mod_graph
.
nodes
:
if
pp_size
<=
1
:
break
accumulate_num_node
+=
1
if
accumulate_num_node
>=
avg_num_node
:
accumulate_num_node
=
0
pp_size
-=
1
if
node
.
next
.
op
==
'output'
:
with
mod_graph
.
inserting_before
(
node
):
split_node
=
mod_graph
.
create_node
(
'call_function'
,
pipe_split
)
else
:
with
mod_graph
.
inserting_after
(
node
):
split_node
=
mod_graph
.
create_node
(
'call_function'
,
pipe_split
)
gm
.
recompile
()
return
gm
def
balanced_split_pass
(
gm
:
torch
.
fx
.
GraphModule
,
pp_size
:
int
):
def
balanced_split_pass
(
gm
:
torch
.
fx
.
GraphModule
,
pp_size
:
int
):
"""
"""
In balanced_split_pass, we split module by the size of parameters(weights+bias).
In balanced_split_pass, we split module by the size of parameters(weights+bias).
...
@@ -37,6 +61,21 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
...
@@ -37,6 +61,21 @@ def balanced_split_pass(gm: torch.fx.GraphModule, pp_size: int):
else
:
else
:
with
mod_graph
.
inserting_after
(
node
):
with
mod_graph
.
inserting_after
(
node
):
split_node
=
mod_graph
.
create_node
(
'call_function'
,
pipe_split
)
split_node
=
mod_graph
.
create_node
(
'call_function'
,
pipe_split
)
if
pp_size
>
1
:
node_counter
=
0
for
node
in
mod_graph
.
nodes
:
if
pp_size
<=
1
:
break
if
node
.
op
==
'placeholder'
:
continue
elif
node_counter
==
0
:
node_counter
+=
1
else
:
pp_size
-=
1
node_counter
=
0
with
mod_graph
.
inserting_before
(
node
):
split_node
=
mod_graph
.
create_node
(
'call_function'
,
pipe_split
)
gm
.
recompile
()
gm
.
recompile
()
return
gm
return
gm
...
@@ -102,7 +141,7 @@ def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int):
...
@@ -102,7 +141,7 @@ def uniform_split_pass(gm: torch.fx.GraphModule, pp_size: int):
return
gm
return
gm
def
split_with_split_nodes_pass
(
annotated_gm
:
torch
.
fx
.
GraphModule
):
def
split_with_split_nodes_pass
(
annotated_gm
:
torch
.
fx
.
GraphModule
,
merge_output
=
False
):
# TODO(lyl): use partition IR to assign partition ID to each node.
# TODO(lyl): use partition IR to assign partition ID to each node.
# Currently: analyzing graph -> annotate graph by inserting split node -> use split module pass to split graph
# Currently: analyzing graph -> annotate graph by inserting split node -> use split module pass to split graph
# In future: graph to partitions -> analyzing partition IR -> recombining partitions to get best performance -> assign partition ID to each node
# In future: graph to partitions -> analyzing partition IR -> recombining partitions to get best performance -> assign partition ID to each node
...
@@ -114,7 +153,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule):
...
@@ -114,7 +153,7 @@ def split_with_split_nodes_pass(annotated_gm: torch.fx.GraphModule):
part_idx
+=
1
part_idx
+=
1
return
part_idx
return
part_idx
split_mod
=
split_module
(
annotated_gm
,
None
,
split_callback
)
split_mod
=
split_module
(
annotated_gm
,
None
,
split_callback
,
merge_output
)
split_submodules
=
[]
split_submodules
=
[]
for
name
,
submodule
in
split_mod
.
named_modules
():
for
name
,
submodule
in
split_mod
.
named_modules
():
if
isinstance
(
submodule
,
torch
.
fx
.
GraphModule
):
if
isinstance
(
submodule
,
torch
.
fx
.
GraphModule
):
...
...
colossalai/fx/passes/algorithms/ckpt_solver_chen.py
View file @
e532679c
import
math
from
typing
import
List
,
Set
,
Tuple
from
typing
import
List
,
Set
,
Tuple
import
torch
import
torch
from
torch.fx
import
GraphModule
,
Node
from
torch.fx
import
GraphModule
,
Node
import
math
from
colossalai.fx.profiler
import
calculate_fwd_in
,
calculate_fwd_tmp
from
colossalai.fx.profiler
import
calculate_fwd_in
,
calculate_fwd_tmp
__all__
=
[
'chen_greedy'
]
__all__
=
[
'chen_greedy'
]
...
...
colossalai/fx/passes/algorithms/ckpt_solver_rotor.py
View file @
e532679c
import
math
import
sys
import
sys
from
typing
import
List
,
Tuple
from
typing
import
List
,
Tuple
from
colossalai.fx.profiler.memory
import
calculate_fwd_in
from
torch.fx
import
Node
from
torch.fx
import
Node
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.profiler
import
activation_size
,
parameter_size
,
calculate_fwd_out
,
calculate_fwd_tmp
import
math
from
.linearize
import
linearize
from
.operation
import
ForwardCheck
,
ForwardEnable
,
ForwardNograd
,
Backward
,
Loss
,
Chain
,
Sequence
,
Function
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
_find_nested_ckpt_regions
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
_find_nested_ckpt_regions
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.profiler
import
activation_size
,
calculate_fwd_out
,
calculate_fwd_tmp
,
parameter_size
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
.linearize
import
linearize
from
.operation
import
Backward
,
Chain
,
ForwardCheck
,
ForwardEnable
,
ForwardNograd
,
Function
,
Loss
,
Sequence
# global vairable to indicate whether the solver is failed
# global vairable to indicate whether the solver is failed
SOLVER_FAILED
=
False
SOLVER_FAILED
=
False
...
@@ -18,7 +20,7 @@ SOLVER_FAILED = False
...
@@ -18,7 +20,7 @@ SOLVER_FAILED = False
# https://gitlab.inria.fr/hiepacs/rotor
# https://gitlab.inria.fr/hiepacs/rotor
# paper link: https://hal.inria.fr/hal-02352969
# paper link: https://hal.inria.fr/hal-02352969
def
_compute_table
(
chain
:
Chain
,
mmax
)
->
Tuple
:
def
_compute_table
(
chain
:
Chain
,
mmax
)
->
Tuple
:
"""Returns the optimal table: a tuple containing:
"""Returns the optimal table: a tuple containing:
Opt[m][lmin][lmax] with lmin = 0...chain.length
Opt[m][lmin][lmax] with lmin = 0...chain.length
and lmax = lmin...chain.length (lmax is not included) and m = 0...mmax
and lmax = lmin...chain.length (lmax is not included) and m = 0...mmax
what[m][lmin][lmax] is (True,) if the optimal choice is a chain checkpoint
what[m][lmin][lmax] is (True,) if the optimal choice is a chain checkpoint
...
@@ -127,7 +129,7 @@ def _fwd_xbar(node: List[Node]) -> int:
...
@@ -127,7 +129,7 @@ def _fwd_xbar(node: List[Node]) -> int:
"""Get the forward xbar of a node
"""Get the forward xbar of a node
Args:
Args:
node (List[Node]): List of torch.fx Node,
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph
indicates a node in linearized graph
Returns:
Returns:
...
@@ -372,8 +374,8 @@ def solver_rotor(gm: ColoGraphModule,
...
@@ -372,8 +374,8 @@ def solver_rotor(gm: ColoGraphModule,
# build module if module not found
# build module if module not found
except
ModuleNotFoundError
:
except
ModuleNotFoundError
:
import
subprocess
import
os
import
os
import
subprocess
logger
.
info
(
"dynamic_programs_C_version hasn't been built! Building library..."
,
ranks
=
[
0
])
logger
.
info
(
"dynamic_programs_C_version hasn't been built! Building library..."
,
ranks
=
[
0
])
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
result
=
subprocess
.
Popen
(
result
=
subprocess
.
Popen
(
...
...
colossalai/fx/passes/concrete_info_prop.py
View file @
e532679c
...
@@ -3,11 +3,12 @@ from typing import Any, Dict, List, NamedTuple, Optional, Tuple
...
@@ -3,11 +3,12 @@ from typing import Any, Dict, List, NamedTuple, Optional, Tuple
import
torch
import
torch
import
torch.fx
import
torch.fx
from
colossalai.fx._compatibility
import
compatibility
from
colossalai.fx.profiler
import
(
GraphInfo
,
profile_function
,
profile_method
,
profile_module
)
from
torch.fx.node
import
Argument
,
Node
,
Target
from
torch.fx.node
import
Argument
,
Node
,
Target
from
torch.utils._pytree
import
tree_flatten
from
torch.utils._pytree
import
tree_flatten
from
colossalai.fx._compatibility
import
compatibility
from
colossalai.fx.profiler
import
GraphInfo
,
profile_function
,
profile_method
,
profile_module
@
compatibility
(
is_backward_compatible
=
True
)
@
compatibility
(
is_backward_compatible
=
True
)
class
ConcreteInfoProp
(
torch
.
fx
.
Interpreter
):
class
ConcreteInfoProp
(
torch
.
fx
.
Interpreter
):
...
@@ -22,17 +23,17 @@ class ConcreteInfoProp(torch.fx.Interpreter):
...
@@ -22,17 +23,17 @@ class ConcreteInfoProp(torch.fx.Interpreter):
DIM_HIDDEN = 16
DIM_HIDDEN = 16
DIM_OUT = 16
DIM_OUT = 16
model = torch.nn.Sequential(
model = torch.nn.Sequential(
torch.nn.Linear(DIM_IN, DIM_HIDDEN),
torch.nn.Linear(DIM_IN, DIM_HIDDEN),
torch.nn.Linear(DIM_HIDDEN, DIM_OUT),
torch.nn.Linear(DIM_HIDDEN, DIM_OUT),
).cuda()
).cuda()
input_sample = torch.rand(BATCH_SIZE, DIM_IN, device="cuda")
input_sample = torch.rand(BATCH_SIZE, DIM_IN, device="cuda")
gm = symbolic_trace(model)
gm = symbolic_trace(model)
interp = ConcreteInfoProp(gm)
interp = ConcreteInfoProp(gm)
interp.run(input_sample)
interp.run(input_sample)
print(interp.summary(unit='kb'))
print(interp.summary(unit='kb'))
output of above code is
output of above code is
Op type Op Forward time Backward time SAVE_FWD_IN FWD_OUT FWD_TMP BWD_OUT BWD_TMP
Op type Op Forward time Backward time SAVE_FWD_IN FWD_OUT FWD_TMP BWD_OUT BWD_TMP
----------- ------- ----------------------- ------------------------ ------------- --------- --------- --------- ---------
----------- ------- ----------------------- ------------------------ ------------- --------- --------- --------- ---------
placeholder input_1 0.0 s 0.0 s False 0.00 KB 0.00 KB 0.00 KB 0.00 KB
placeholder input_1 0.0 s 0.0 s False 0.00 KB 0.00 KB 0.00 KB 0.00 KB
...
@@ -229,8 +230,8 @@ class ConcreteInfoProp(torch.fx.Interpreter):
...
@@ -229,8 +230,8 @@ class ConcreteInfoProp(torch.fx.Interpreter):
def
summary
(
self
,
unit
:
str
=
'MB'
)
->
str
:
def
summary
(
self
,
unit
:
str
=
'MB'
)
->
str
:
"""
"""
Summarizes the memory and FLOPs statistics of the `GraphModule` in
Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module
tabular format. Note that this API requires the ``tabulate`` module
to be installed.
to be installed.
"""
"""
# https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py
# https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py
...
...
colossalai/fx/passes/meta_info_prop.py
View file @
e532679c
...
@@ -3,12 +3,21 @@ from typing import Any, Dict, List, NamedTuple, Tuple
...
@@ -3,12 +3,21 @@ from typing import Any, Dict, List, NamedTuple, Tuple
import
torch
import
torch
import
torch.fx
import
torch.fx
from
colossalai.fx._compatibility
import
compatibility
from
colossalai.fx.profiler
import
(
GraphInfo
,
activation_size
,
calculate_fwd_in
,
calculate_fwd_out
,
calculate_fwd_tmp
,
profile_function
,
profile_method
,
profile_module
)
from
torch.fx.node
import
Argument
,
Node
,
Target
from
torch.fx.node
import
Argument
,
Node
,
Target
from
torch.utils._pytree
import
tree_map
from
torch.utils._pytree
import
tree_map
from
colossalai.fx._compatibility
import
compatibility
,
is_compatible_with_meta
from
colossalai.fx.profiler
import
(
GraphInfo
,
activation_size
,
calculate_fwd_in
,
calculate_fwd_out
,
calculate_fwd_tmp
,
profile_function
,
profile_method
,
profile_module
,
)
@
compatibility
(
is_backward_compatible
=
True
)
@
compatibility
(
is_backward_compatible
=
True
)
class
TensorMetadata
(
NamedTuple
):
class
TensorMetadata
(
NamedTuple
):
...
@@ -52,7 +61,7 @@ class MetaInfoProp(torch.fx.Interpreter):
...
@@ -52,7 +61,7 @@ class MetaInfoProp(torch.fx.Interpreter):
DIM_HIDDEN = 16
DIM_HIDDEN = 16
DIM_OUT = 16
DIM_OUT = 16
model = torch.nn.Sequential(
model = torch.nn.Sequential(
torch.nn.Linear(DIM_IN, DIM_HIDDEN),
torch.nn.Linear(DIM_IN, DIM_HIDDEN),
torch.nn.Linear(DIM_HIDDEN, DIM_OUT),
torch.nn.Linear(DIM_HIDDEN, DIM_OUT),
)
)
input_sample = torch.rand(BATCH_SIZE, DIM_IN)
input_sample = torch.rand(BATCH_SIZE, DIM_IN)
...
@@ -60,9 +69,9 @@ class MetaInfoProp(torch.fx.Interpreter):
...
@@ -60,9 +69,9 @@ class MetaInfoProp(torch.fx.Interpreter):
interp = MetaInfoProp(gm)
interp = MetaInfoProp(gm)
interp.run(input_sample)
interp.run(input_sample)
print(interp.summary(format='kb')) # don't panic if some statistics are 0.00 MB
print(interp.summary(format='kb')) # don't panic if some statistics are 0.00 MB
# output of above code is
# output of above code is
Op type Op Forward FLOPs Backward FLOPs FWD_OUT FWD_TMP BWD_OUT BWD_TMP
Op type Op Forward FLOPs Backward FLOPs FWD_OUT FWD_TMP BWD_OUT BWD_TMP
----------- ------- --------------- ---------------- --------- --------- --------- ---------
----------- ------- --------------- ---------------- --------- --------- --------- ---------
placeholder input_1 0 FLOPs 0 FLOPs 0.00 KB 0.00 KB 0.00 KB 0.00 KB
placeholder input_1 0 FLOPs 0 FLOPs 0.00 KB 0.00 KB 0.00 KB 0.00 KB
...
@@ -248,8 +257,8 @@ class MetaInfoProp(torch.fx.Interpreter):
...
@@ -248,8 +257,8 @@ class MetaInfoProp(torch.fx.Interpreter):
def
summary
(
self
,
unit
:
str
=
'MB'
)
->
str
:
def
summary
(
self
,
unit
:
str
=
'MB'
)
->
str
:
"""
"""
Summarizes the memory and FLOPs statistics of the `GraphModule` in
Summarizes the memory and FLOPs statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module
tabular format. Note that this API requires the ``tabulate`` module
to be installed.
to be installed.
"""
"""
# https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py
# https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py
...
@@ -306,3 +315,38 @@ class MetaInfoProp(torch.fx.Interpreter):
...
@@ -306,3 +315,38 @@ class MetaInfoProp(torch.fx.Interpreter):
]
]
return
tabulate
(
node_summaries
,
headers
=
headers
,
stralign
=
'right'
)
return
tabulate
(
node_summaries
,
headers
=
headers
,
stralign
=
'right'
)
def
metainfo_trace
(
gm
:
torch
.
fx
.
GraphModule
,
*
args
,
verbose
:
bool
=
False
,
unit
:
str
=
"MB"
,
**
kwargs
)
->
None
:
"""
MetaInfo tracing API
Given a ``GraphModule`` and a sample input, this API will trace the MetaInfo of a single training cycle,
and annotate them on ``gm.graph``.
Uses:
>>> model = ...
>>> gm = symbolic_trace(model)
>>> args = ... # sample input to the ``GraphModule``
>>> metainfo_trace(gm, *args)
Args:
gm (torch.fx.GraphModule): The ``GraphModule`` to be annotated with MetaInfo.
verbose (bool, optional): Whether to show ``MetaInfoProp.summary()`. Defaults to False.
unit (str, optional): The unit of memory. Defaults to "MB".
Returns:
torch.fx.GraphModule: The ``GraphModule`` annotated with MetaInfo.
"""
device
=
torch
.
device
(
'cuda:0'
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
'cpu'
)
interp
=
MetaInfoProp
(
gm
.
to
(
device
))
if
is_compatible_with_meta
():
from
colossalai.fx.profiler
import
MetaTensor
args
=
tree_map
(
lambda
x
:
MetaTensor
(
x
,
fake_device
=
device
),
args
)
kwargs
=
tree_map
(
lambda
x
:
MetaTensor
(
x
,
fake_device
=
device
),
kwargs
)
interp
.
propagate
(
*
args
,
**
kwargs
)
if
verbose
:
interp
.
summary
(
unit
)
gm
.
to
(
'cpu'
)
del
interp
return
gm
colossalai/fx/passes/split_module.py
View file @
e532679c
...
@@ -38,11 +38,11 @@ def split_module(
...
@@ -38,11 +38,11 @@ def split_module(
m
:
GraphModule
,
m
:
GraphModule
,
root_m
:
torch
.
nn
.
Module
,
root_m
:
torch
.
nn
.
Module
,
split_callback
:
Callable
[[
torch
.
fx
.
node
.
Node
],
int
],
split_callback
:
Callable
[[
torch
.
fx
.
node
.
Node
],
int
],
merge_output
=
False
,
):
):
"""
"""
Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py
Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py
Creates subgraphs out of main graph
Creates subgraphs out of main graph
Args:
Args:
m (GraphModule): Graph module to split
m (GraphModule): Graph module to split
root_m (torch.nn.Module): root nn module. Not currently used. Included
root_m (torch.nn.Module): root nn module. Not currently used. Included
...
@@ -52,52 +52,40 @@ def split_module(
...
@@ -52,52 +52,40 @@ def split_module(
that maps a given Node instance to a numeric partition identifier.
that maps a given Node instance to a numeric partition identifier.
split_module will use this function as the policy for which operations
split_module will use this function as the policy for which operations
appear in which partitions in the output Module.
appear in which partitions in the output Module.
Returns:
Returns:
GraphModule: the module after split.
GraphModule: the module after split.
Example:
Example:
This is a sample setup:
This is a sample setup:
import torch
import torch
from torch.fx.symbolic_trace import symbolic_trace
from torch.fx.symbolic_trace import symbolic_trace
from torch.fx.graph_module import GraphModule
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node
from torch.fx.node import Node
from colossalai.fx.passes.split_module import split_module
from colossalai.fx.passes.split_module import split_module
class MyModule(torch.nn.Module):
class MyModule(torch.nn.Module):
def __init__(self):
def __init__(self):
super().__init__()
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
self.linear = torch.nn.Linear(4, 5)
def forward(self, x, y):
def forward(self, x, y):
z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
w = self.linear(y).clamp(min=0.0, max=1.0)
w = self.linear(y).clamp(min=0.0, max=1.0)
return z + w
return z + w
# symbolically trace model
# symbolically trace model
my_module = MyModule()
my_module = MyModule()
my_module_traced = symbolic_trace(my_module)
my_module_traced = symbolic_trace(my_module)
# random mod partitioning
# random mod partitioning
partition_counter = 0
partition_counter = 0
NPARTITIONS = 3
NPARTITIONS = 3
def mod_partition(node: Node):
def mod_partition(node: Node):
global partition_counter
global partition_counter
partition = partition_counter % NPARTITIONS
partition = partition_counter % NPARTITIONS
partition_counter = (partition_counter + 1) % NPARTITIONS
partition_counter = (partition_counter + 1) % NPARTITIONS
return partition
return partition
# split module in module with submodules
# split module in module with submodules
module_with_submodules = split_module(
module_with_submodules = split_module(
my_module_traced, my_module, mod_partition
my_module_traced, my_module, mod_partition
)
)
Output looks like this. Original graph is broken into partitions
Output looks like this. Original graph is broken into partitions
> print(module_with_submodules)
> print(module_with_submodules)
GraphModule(
GraphModule(
(submod_0): GraphModule(
(submod_0): GraphModule(
...
@@ -108,7 +96,6 @@ def split_module(
...
@@ -108,7 +96,6 @@ def split_module(
)
)
(submod_2): GraphModule()
(submod_2): GraphModule()
)
)
def forward(self, x, y):
def forward(self, x, y):
param = self.param
param = self.param
submod_0 = self.submod_0(x, param, y); x = param = y = None
submod_0 = self.submod_0(x, param, y); x = param = y = None
...
@@ -119,10 +106,8 @@ def split_module(
...
@@ -119,10 +106,8 @@ def split_module(
getitem_3 = submod_1[1]; submod_1 = None
getitem_3 = submod_1[1]; submod_1 = None
submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None
submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None
return submod_2
return submod_2
Output of split module is the same as output of input traced module.
Output of split module is the same as output of input traced module.
This is an example within a test setting:
This is an example within a test setting:
> orig_out = my_module_traced(x, y)
> orig_out = my_module_traced(x, y)
> submodules_out = module_with_submodules(x, y)
> submodules_out = module_with_submodules(x, y)
> self.assertEqual(orig_out, submodules_out)
> self.assertEqual(orig_out, submodules_out)
...
@@ -147,6 +132,29 @@ def split_module(
...
@@ -147,6 +132,29 @@ def split_module(
use_partition
.
inputs
.
setdefault
(
def_node
.
name
)
use_partition
.
inputs
.
setdefault
(
def_node
.
name
)
if
def_partition_name
is
not
None
:
if
def_partition_name
is
not
None
:
use_partition
.
partitions_dependent_on
.
setdefault
(
def_partition_name
)
use_partition
.
partitions_dependent_on
.
setdefault
(
def_partition_name
)
def
record_output
(
def_node
:
torch
.
fx
.
node
.
Node
,
use_node
:
Optional
[
torch
.
fx
.
node
.
Node
]
):
# noqa: B950
def_partition_name
=
getattr
(
def_node
,
"_fx_partition"
,
None
)
use_partition_name
=
getattr
(
use_node
,
"_fx_partition"
,
None
)
if
def_partition_name
!=
use_partition_name
:
if
def_partition_name
is
not
None
:
def_partition
=
partitions
[
def_partition_name
]
def_partition
.
outputs
.
setdefault
(
def_node
.
name
)
if
use_partition_name
is
not
None
:
def_partition
.
partition_dependents
.
setdefault
(
use_partition_name
)
if
use_partition_name
is
not
None
:
use_partition
=
partitions
[
use_partition_name
]
use_partition
.
inputs
.
setdefault
(
def_node
.
name
)
if
def_partition_name
is
not
None
:
use_partition
.
partitions_dependent_on
.
setdefault
(
def_partition_name
)
use_partition
.
outputs
.
setdefault
(
def_node
.
name
)
else
:
if
use_partition_name
is
not
None
:
use_partition
=
partitions
[
use_partition_name
]
use_partition
.
outputs
.
setdefault
(
def_node
.
name
)
# split nodes into parititons
# split nodes into parititons
for
node
in
m
.
graph
.
nodes
:
for
node
in
m
.
graph
.
nodes
:
...
@@ -155,7 +163,10 @@ def split_module(
...
@@ -155,7 +163,10 @@ def split_module(
if
node
.
op
in
[
"placeholder"
]:
if
node
.
op
in
[
"placeholder"
]:
continue
continue
if
node
.
op
==
'output'
:
if
node
.
op
==
'output'
:
torch
.
fx
.
graph
.
map_arg
(
node
.
args
[
0
],
lambda
n
:
record_cross_partition_use
(
n
,
None
))
if
merge_output
:
torch
.
fx
.
graph
.
map_arg
(
node
.
args
[
0
],
lambda
n
:
record_output
(
n
,
node
.
prev
))
else
:
torch
.
fx
.
graph
.
map_arg
(
node
.
args
[
0
],
lambda
n
:
record_cross_partition_use
(
n
,
None
))
continue
continue
partition_name
=
str
(
split_callback
(
node
))
partition_name
=
str
(
split_callback
(
node
))
...
@@ -235,10 +246,10 @@ def split_module(
...
@@ -235,10 +246,10 @@ def split_module(
for
node
in
m
.
graph
.
nodes
:
for
node
in
m
.
graph
.
nodes
:
if
node
.
op
==
'placeholder'
:
if
node
.
op
==
'placeholder'
:
if
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'1.11.0'
):
if
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'1.11.0'
):
base_mod_env
[
node
.
name
]
=
base_mod_graph
.
placeholder
(
node
.
name
,
type_expr
=
node
.
type
)
base_mod_env
[
node
.
name
]
=
base_mod_graph
.
placeholder
(
node
.
target
,
type_expr
=
node
.
type
)
else
:
else
:
default_value
=
node
.
args
[
0
]
if
len
(
node
.
args
)
>
0
else
inspect
.
Signature
.
empty
default_value
=
node
.
args
[
0
]
if
len
(
node
.
args
)
>
0
else
inspect
.
Signature
.
empty
base_mod_env
[
node
.
name
]
=
base_mod_graph
.
placeholder
(
node
.
name
,
base_mod_env
[
node
.
name
]
=
base_mod_graph
.
placeholder
(
node
.
target
,
type_expr
=
node
.
type
,
type_expr
=
node
.
type
,
default_value
=
default_value
)
default_value
=
default_value
)
base_mod_env
[
node
.
name
].
meta
=
node
.
meta
.
copy
()
base_mod_env
[
node
.
name
].
meta
=
node
.
meta
.
copy
()
...
@@ -278,4 +289,9 @@ def split_module(
...
@@ -278,4 +289,9 @@ def split_module(
if
node
.
op
==
'output'
:
if
node
.
op
==
'output'
:
base_mod_graph
.
output
(
torch
.
fx
.
graph
.
map_arg
(
node
.
args
[
0
],
lambda
n
:
base_mod_env
[
n
.
name
]))
# noqa: B950
base_mod_graph
.
output
(
torch
.
fx
.
graph
.
map_arg
(
node
.
args
[
0
],
lambda
n
:
base_mod_env
[
n
.
name
]))
# noqa: B950
return
torch
.
fx
.
graph_module
.
GraphModule
(
base_mod_attrs
,
base_mod_graph
)
for
partition_name
in
sorted_partitions
:
partition
=
partitions
[
partition_name
]
new_gm
=
torch
.
fx
.
graph_module
.
GraphModule
(
base_mod_attrs
,
base_mod_graph
)
return
new_gm
colossalai/fx/passes/utils.py
View file @
e532679c
import
torch
import
torch
from
typing
import
Dict
,
Set
from
typing
import
Dict
from
torch.fx.node
import
Node
,
map_arg
from
torch.fx.node
import
Node
,
map_arg
from
torch.fx.graph
import
Graph
from
torch.fx.graph
import
Graph
def
get_comm_size
(
prev_partition
,
next_partition
):
def
get_comm_size
(
prev_partition
,
next_partition
):
"""
"""
Given two partitions (parent and child),
Given two partitions (parent and child),
...
@@ -32,7 +31,6 @@ def get_comm_size(prev_partition, next_partition):
...
@@ -32,7 +31,6 @@ def get_comm_size(prev_partition, next_partition):
def
get_leaf
(
graph
:
Graph
):
def
get_leaf
(
graph
:
Graph
):
"""
"""
Given a graph, return leaf nodes of this graph.
Given a graph, return leaf nodes of this graph.
Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph,
Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph,
we will get a normal DAG. Leaf nodes in this context means leaf nodes in that DAG.
we will get a normal DAG. Leaf nodes in this context means leaf nodes in that DAG.
"""
"""
...
@@ -57,7 +55,6 @@ def is_leaf(graph: Graph, node: Node):
...
@@ -57,7 +55,6 @@ def is_leaf(graph: Graph, node: Node):
def
get_top
(
graph
:
Graph
):
def
get_top
(
graph
:
Graph
):
"""
"""
Given a graph, return top nodes of this graph.
Given a graph, return top nodes of this graph.
Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph,
Note: If we remove ``root`` nodes, ``placeholder`` nodes, and ``output`` nodes from fx graph,
we will get a normal DAG. Top nodes in this context means nodes with BFS level 0 in that DAG.
we will get a normal DAG. Top nodes in this context means nodes with BFS level 0 in that DAG.
"""
"""
...
@@ -100,7 +97,6 @@ def get_all_consumers(graph: Graph, node: Node):
...
@@ -100,7 +97,6 @@ def get_all_consumers(graph: Graph, node: Node):
def
assign_bfs_level_to_nodes
(
graph
:
Graph
):
def
assign_bfs_level_to_nodes
(
graph
:
Graph
):
"""
"""
Give a graph, assign bfs level to each node of this graph excluding ``placeholder`` and ``output`` nodes.
Give a graph, assign bfs level to each node of this graph excluding ``placeholder`` and ``output`` nodes.
Example:
Example:
class MLP(torch.nn.Module):
class MLP(torch.nn.Module):
def __init__(self, dim: int):
def __init__(self, dim: int):
...
@@ -110,8 +106,6 @@ def assign_bfs_level_to_nodes(graph: Graph):
...
@@ -110,8 +106,6 @@ def assign_bfs_level_to_nodes(graph: Graph):
self.linear3 = torch.nn.Linear(dim, dim)
self.linear3 = torch.nn.Linear(dim, dim)
self.linear4 = torch.nn.Linear(dim, dim)
self.linear4 = torch.nn.Linear(dim, dim)
self.linear5 = torch.nn.Linear(dim, dim)
self.linear5 = torch.nn.Linear(dim, dim)
def forward(self, x):
def forward(self, x):
l1 = self.linear1(x)
l1 = self.linear1(x)
l2 = self.linear2(x)
l2 = self.linear2(x)
...
@@ -165,10 +159,8 @@ def assign_bfs_level_to_nodes(graph: Graph):
...
@@ -165,10 +159,8 @@ def assign_bfs_level_to_nodes(graph: Graph):
def
get_node_module
(
node
)
->
torch
.
nn
.
Module
:
def
get_node_module
(
node
)
->
torch
.
nn
.
Module
:
"""
"""
Find the module associated with the given node.
Find the module associated with the given node.
Args:
Args:
node (torch.fx.Node): a torch.fx.Node object in the fx computation graph
node (torch.fx.Node): a torch.fx.Node object in the fx computation graph
Returns:
Returns:
torch.nn.Module: the module associated with the given node
torch.nn.Module: the module associated with the given node
"""
"""
...
@@ -177,3 +169,4 @@ def get_node_module(node) -> torch.nn.Module:
...
@@ -177,3 +169,4 @@ def get_node_module(node) -> torch.nn.Module:
assert
node
.
op
==
'call_module'
,
f
'Expected node.op to be call_module, but found
{
node
.
op
}
'
assert
node
.
op
==
'call_module'
,
f
'Expected node.op to be call_module, but found
{
node
.
op
}
'
module
=
node
.
graph
.
owning_module
.
get_submodule
(
node
.
target
)
module
=
node
.
graph
.
owning_module
.
get_submodule
(
node
.
target
)
return
module
return
module
colossalai/fx/profiler/__init__.py
View file @
e532679c
from
.._compatibility
import
is_compatible_with_meta
from
.._compatibility
import
is_compatible_with_meta
if
is_compatible_with_meta
():
if
is_compatible_with_meta
():
from
.memory
import
calculate_fwd_in
,
calculate_fwd_out
,
calculate_fwd_tmp
from
.opcount
import
flop_mapping
from
.opcount
import
flop_mapping
from
.profiler
import
profile_function
,
profile_method
,
profile_module
from
.profiler
import
profile_function
,
profile_method
,
profile_module
from
.shard_utils
import
(
calculate_bwd_time
,
calculate_fwd_in
,
calculate_fwd_out
,
calculate_fwd_time
,
calculate_fwd_tmp
,
)
from
.tensor
import
MetaTensor
from
.tensor
import
MetaTensor
else
:
else
:
from
.experimental
import
meta_profiler_function
,
meta_profiler_module
,
profile_function
,
profile_method
,
profile_module
,
calculate_fwd_in
,
calculate_fwd_tmp
,
calculate_fwd_out
from
.experimental
import
meta_profiler_function
,
meta_profiler_module
,
profile_function
,
profile_method
,
profile_module
,
calculate_fwd_in
,
calculate_fwd_tmp
,
calculate_fwd_out
from
.dataflow
import
GraphInfo
from
.dataflow
import
GraphInfo
from
.memory
import
activation_size
,
is_inplace
,
parameter_size
from
.memory
_utils
import
activation_size
,
is_inplace
,
parameter_size
colossalai/fx/profiler/dataflow.py
View file @
e532679c
...
@@ -6,7 +6,7 @@ from typing import Dict, List
...
@@ -6,7 +6,7 @@ from typing import Dict, List
from
torch.fx
import
Graph
,
Node
from
torch.fx
import
Graph
,
Node
from
.._compatibility
import
compatibility
from
.._compatibility
import
compatibility
from
.memory
import
activation_size
,
is_inplace
from
.memory
_utils
import
activation_size
,
is_inplace
class
Phase
(
Enum
):
class
Phase
(
Enum
):
...
@@ -29,7 +29,7 @@ class GraphInfo:
...
@@ -29,7 +29,7 @@ class GraphInfo:
placeholders saved for | | \__________ | |
placeholders saved for | | \__________ | |
backward. | | \ | |
backward. | | \ | |
| [fwd_tmp] ------> [bwd_tmp] | <-----
| [fwd_tmp] ------> [bwd_tmp] | <-----
| | \_________ | | [bwd_tmp] marks the peak memory
| | \_________ | | [bwd_tmp] marks the peak memory
| / \ \ | | in backward pass.
| / \ \ | | in backward pass.
[x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <-----
[x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <-----
in [fwd_tmp] because | | \_____ | |
in [fwd_tmp] because | | \_____ | |
...
@@ -80,18 +80,18 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
...
@@ -80,18 +80,18 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
Nodes should have attribute `out` indicating the output of each node.
Nodes should have attribute `out` indicating the output of each node.
============================================================================
============================================================================
Placeholder ----> p o <---- We need to keep track of grad out
Placeholder ----> p o <---- We need to keep track of grad out
|\________ |
|\________ |
↓ ↘|
↓ ↘|
f --------> b
f --------> b
|\ \_____ ↑
|\ \_____ ↑
| \ ↘ /
| \ ↘ /
f f ----> b <---- Not every forward result needs to be saved for backward
f f ----> b <---- Not every forward result needs to be saved for backward
| \____ ↑
| \____ ↑
↘ ↘|
↘ ↘|
f ----> b <---- Backward can be freed as soon as it is required no more.
f ----> b <---- Backward can be freed as soon as it is required no more.
↘ ↗
↘ ↗
l
l
=============================================================================
=============================================================================
Args:
Args:
graph (Graph): The autograd graph with nodes marked for keyword `phase`.
graph (Graph): The autograd graph with nodes marked for keyword `phase`.
...
...
colossalai/fx/profiler/experimental/__init__.py
View file @
e532679c
from
.memory
import
calculate_fwd_in
,
calculate_fwd_out
,
calculate_fwd_tmp
from
.profiler
import
profile_function
,
profile_method
,
profile_module
from
.profiler
import
profile_function
,
profile_method
,
profile_module
from
.profiler_function
import
*
from
.profiler_function
import
*
from
.profiler_module
import
*
from
.profiler_module
import
*
from
.registry
import
meta_profiler_function
,
meta_profiler_module
from
.registry
import
meta_profiler_function
,
meta_profiler_module
from
.shard_utils
import
calculate_fwd_in
,
calculate_fwd_out
,
calculate_fwd_tmp
colossalai/fx/profiler/experimental/profiler.py
View file @
e532679c
...
@@ -5,7 +5,7 @@ import torch
...
@@ -5,7 +5,7 @@ import torch
from
torch.fx.node
import
Argument
,
Target
from
torch.fx.node
import
Argument
,
Target
from
..._compatibility
import
compatibility
from
..._compatibility
import
compatibility
from
..memory
import
activation_size
from
..memory
_utils
import
activation_size
from
.constants
import
INPLACE_METHOD
,
INPLACE_OPS
,
NON_INPLACE_METHOD
from
.constants
import
INPLACE_METHOD
,
INPLACE_OPS
,
NON_INPLACE_METHOD
from
.registry
import
meta_profiler_function
,
meta_profiler_module
from
.registry
import
meta_profiler_function
,
meta_profiler_module
...
@@ -27,7 +27,7 @@ class GraphInfo:
...
@@ -27,7 +27,7 @@ class GraphInfo:
placeholders saved for | | \__________ | |
placeholders saved for | | \__________ | |
backward. | | \ | |
backward. | | \ | |
| [fwd_tmp] ------> [bwd_tmp] | <-----
| [fwd_tmp] ------> [bwd_tmp] | <-----
| | \_________ | | [bwd_tmp] marks the peak memory
| | \_________ | | [bwd_tmp] marks the peak memory
| / \ \ | | in backward pass.
| / \ \ | | in backward pass.
[x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <-----
[x] is not counted ---> | [x] [fwd_tmp] -> [bwd_tmp] | <-----
in [fwd_tmp] because | | | \_____ | |
in [fwd_tmp] because | | | \_____ | |
...
@@ -76,14 +76,14 @@ def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int
...
@@ -76,14 +76,14 @@ def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int
@
compatibility
(
is_backward_compatible
=
True
)
@
compatibility
(
is_backward_compatible
=
True
)
def
profile_function
(
target
:
'Target'
)
->
Callable
:
def
profile_function
(
target
:
'Target'
)
->
Callable
:
"""
"""
Wrap a `call_function` node or `torch.nn.functional` in order to
Wrap a `call_function` node or `torch.nn.functional` in order to
record the memory cost and FLOPs of the execution.
record the memory cost and FLOPs of the execution.
Unfortunately, backward memory cost and FLOPs are estimated results.
Unfortunately, backward memory cost and FLOPs are estimated results.
Warnings:
Warnings:
You may only use tensors with `device=meta` for this wrapped function.
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn.functional` are available.
Only original `torch.nn.functional` are available.
Examples:
Examples:
>>> input = torch.rand(100, 100, 100, 100, device='meta')
>>> input = torch.rand(100, 100, 100, 100, device='meta')
>>> func = torch.nn.functional.relu
>>> func = torch.nn.functional.relu
...
@@ -142,13 +142,13 @@ def profile_method(target: 'Target') -> Callable:
...
@@ -142,13 +142,13 @@ def profile_method(target: 'Target') -> Callable:
@
compatibility
(
is_backward_compatible
=
True
)
@
compatibility
(
is_backward_compatible
=
True
)
def
profile_module
(
module
:
torch
.
nn
.
Module
)
->
Callable
:
def
profile_module
(
module
:
torch
.
nn
.
Module
)
->
Callable
:
"""
"""
Wrap a `call_module` node or `torch.nn` in order to
Wrap a `call_module` node or `torch.nn` in order to
record the memory cost and FLOPs of the execution.
record the memory cost and FLOPs of the execution.
Warnings:
Warnings:
You may only use tensors with `device=meta` for this wrapped function.
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn` are available.
Only original `torch.nn` are available.
Example:
Example:
>>> input = torch.rand(4, 3, 224, 224, device='meta')
>>> input = torch.rand(4, 3, 224, 224, device='meta')
>>> mod = torch.nn.Conv2d(3, 128, 3)
>>> mod = torch.nn.Conv2d(3, 128, 3)
...
...
colossalai/fx/profiler/experimental/
memory
.py
→
colossalai/fx/profiler/experimental/
shard_utils
.py
View file @
e532679c
File moved
Prev
1
…
4
5
6
7
8
9
10
11
12
…
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment