Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
9e768b59
Commit
9e768b59
authored
Oct 10, 2023
by
zhuwenwen
Browse files
Merge branch 'main' of
https://github.com/hpcaitech/ColossalAI
parents
7bc5a8e3
8aed02b9
Changes
442
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
581 additions
and
607 deletions
+581
-607
colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py
...lai/auto_parallel/meta_profiler/meta_registry/non_spmd.py
+1
-1
colossalai/auto_parallel/meta_profiler/meta_registry/norm.py
colossalai/auto_parallel/meta_profiler/meta_registry/norm.py
+56
-46
colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py
...alai/auto_parallel/meta_profiler/meta_registry/pooling.py
+8
-6
colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py
...salai/auto_parallel/meta_profiler/meta_registry/tensor.py
+28
-15
colossalai/auto_parallel/meta_profiler/meta_registry/where.py
...ssalai/auto_parallel/meta_profiler/meta_registry/where.py
+16
-11
colossalai/auto_parallel/meta_profiler/registry.py
colossalai/auto_parallel/meta_profiler/registry.py
+3
-5
colossalai/auto_parallel/meta_profiler/shard_metainfo.py
colossalai/auto_parallel/meta_profiler/shard_metainfo.py
+14
-18
colossalai/auto_parallel/offload/amp_optimizer.py
colossalai/auto_parallel/offload/amp_optimizer.py
+38
-35
colossalai/auto_parallel/offload/base_offload_module.py
colossalai/auto_parallel/offload/base_offload_module.py
+6
-8
colossalai/auto_parallel/offload/mem_optimize.py
colossalai/auto_parallel/offload/mem_optimize.py
+6
-8
colossalai/auto_parallel/offload/region.py
colossalai/auto_parallel/offload/region.py
+5
-5
colossalai/auto_parallel/offload/region_manager.py
colossalai/auto_parallel/offload/region_manager.py
+62
-75
colossalai/auto_parallel/offload/runtime.py
colossalai/auto_parallel/offload/runtime.py
+34
-34
colossalai/auto_parallel/offload/solver.py
colossalai/auto_parallel/offload/solver.py
+40
-65
colossalai/auto_parallel/offload/training_simulator.py
colossalai/auto_parallel/offload/training_simulator.py
+48
-82
colossalai/auto_parallel/offload/util.py
colossalai/auto_parallel/offload/util.py
+10
-12
colossalai/auto_parallel/passes/comm_metainfo_pass.py
colossalai/auto_parallel/passes/comm_metainfo_pass.py
+34
-23
colossalai/auto_parallel/passes/meta_info_prop.py
colossalai/auto_parallel/passes/meta_info_prop.py
+12
-12
colossalai/auto_parallel/passes/runtime_apply_pass.py
colossalai/auto_parallel/passes/runtime_apply_pass.py
+67
-64
colossalai/auto_parallel/passes/runtime_preparation_pass.py
colossalai/auto_parallel/passes/runtime_preparation_pass.py
+93
-82
No files found.
Too many changes to show.
To preserve performance only
442 of 442+
files are displayed.
Plain diff
Email patch
colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py
View file @
9e768b59
...
...
@@ -3,7 +3,7 @@ from typing import List, Tuple
import
torch
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
MemoryCost
,
OperationDataType
,
TrainCycleItem
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
MemoryCost
,
TrainCycleItem
from
..registry
import
meta_register
...
...
colossalai/auto_parallel/meta_profiler/meta_registry/norm.py
View file @
9e768b59
from
typing
import
Callable
,
Dict
,
List
,
Tuple
,
Union
from
typing
import
List
,
Tuple
import
torch
from
colossalai._analyzer._subclasses.flop_tensor
import
flop_mapping
from
colossalai._analyzer.fx.node_util
import
compute_size_in_bytes
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
MemoryCost
,
OperationData
,
OperationDataType
,
ShardingStrategy
,
StrategiesVector
,
TrainCycleItem
,
)
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
MemoryCost
,
OperationDataType
,
TrainCycleItem
from
..registry
import
meta_register
__all__
=
[
'
batchnormnd_meta_info
'
,
'
layernorm_meta_info
'
]
__all__
=
[
"
batchnormnd_meta_info
"
,
"
layernorm_meta_info
"
]
@
meta_register
.
register
(
torch
.
nn
.
BatchNorm1d
)
...
...
@@ -65,7 +57,15 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
# saved inv std and some other args indicating the status of the module
# the bwd outputs are input grad, weight grad and bias grad
bwd_in_args
=
[
output_tensor
,
output_tensor
,
weight_tensor
,
mean_tensor
,
var_tensor
,
mean_tensor
,
var_tensor
,
1e-5
,
num_batch
output_tensor
,
output_tensor
,
weight_tensor
,
mean_tensor
,
var_tensor
,
mean_tensor
,
var_tensor
,
1e-5
,
num_batch
,
]
bwd_out_args
=
[
input_tensor
,
weight_tensor
,
bias_tensor
]
...
...
@@ -77,29 +77,34 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
# calculate memory cost
# the fwd activation cost is output plus saved mean and saved inv std
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost
=
MemoryCost
(
activation
=
compute_size_in_bytes
(
[
input_tensor
,
output_tensor
,
mean_tensor
,
var_tensor
]),
fwd_memory_cost
=
MemoryCost
(
activation
=
compute_size_in_bytes
(
[
input_tensor
,
output_tensor
,
mean_tensor
,
var_tensor
]),
parameter
=
compute_size_in_bytes
([
weight_tensor
,
bias_tensor
]),
temp
=
0
,
buffer
=
compute_size_in_bytes
([
mean_tensor
,
var_tensor
]))
buffer
=
compute_size_in_bytes
([
mean_tensor
,
var_tensor
]),
)
# the bwd memory cost is quite tricky here, BatchNorm will remove saved mean
# and saved inv std during backward phase
bwd_memory_cost
=
MemoryCost
(
activation
=
compute_size_in_bytes
([
input_tensor
]),
bwd_memory_cost
=
MemoryCost
(
activation
=
compute_size_in_bytes
([
input_tensor
]),
parameter
=
compute_size_in_bytes
([
weight_tensor
,
bias_tensor
]),
temp
=
compute_size_in_bytes
([
mean_tensor
,
var_tensor
]),
buffer
=
compute_size_in_bytes
([
mean_tensor
,
var_tensor
]))
buffer
=
compute_size_in_bytes
([
mean_tensor
,
var_tensor
]),
)
# total cost is the sum of forward and backward cost
total_cost
=
MemoryCost
(
activation
=
fwd_memory_cost
.
activation
+
bwd_memory_cost
.
activation
,
parameter
=
fwd_memory_cost
.
parameter
+
bwd_memory_cost
.
parameter
)
total_cost
=
MemoryCost
(
activation
=
fwd_memory_cost
.
activation
+
bwd_memory_cost
.
activation
,
parameter
=
fwd_memory_cost
.
parameter
+
bwd_memory_cost
.
parameter
,
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_memory_cost
,
bwd
=
bwd_memory_cost
,
total
=
total_cost
)
# store fwd_in, fwd_buffer, fwd_out
fwd_in
=
[
torch
.
zeros_like
(
input_tensor
,
device
=
'
meta
'
)]
fwd_buffer
=
[
torch
.
zeros_like
(
mean_tensor
,
device
=
'
meta
'
),
torch
.
zeros_like
(
var_tensor
,
device
=
'
meta
'
)]
fwd_out
=
[
torch
.
zeros_like
(
output_tensor
,
device
=
'
meta
'
)]
fwd_in
=
[
torch
.
zeros_like
(
input_tensor
,
device
=
"
meta
"
)]
fwd_buffer
=
[
torch
.
zeros_like
(
mean_tensor
,
device
=
"
meta
"
),
torch
.
zeros_like
(
var_tensor
,
device
=
"
meta
"
)]
fwd_out
=
[
torch
.
zeros_like
(
output_tensor
,
device
=
"
meta
"
)]
return
compute_cost
,
memory_cost
,
fwd_in
,
fwd_buffer
,
fwd_out
...
...
@@ -116,8 +121,8 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
output_tensor
=
next
(
filter
(
lambda
x
:
x
.
type
==
OperationDataType
.
OUTPUT
,
args
)).
data
weight_tensor
=
next
(
filter
(
lambda
x
:
x
.
name
==
"weight"
,
args
)).
data
bias_tensor
=
next
(
filter
(
lambda
x
:
x
.
name
==
"bias"
,
args
)).
data
running_mean
=
torch
.
rand
(
input_tensor
.
shape
[
0
],
1
,
device
=
'
meta
'
)
running_var
=
torch
.
rand
(
input_tensor
.
shape
[
0
],
1
,
device
=
'
meta
'
)
running_mean
=
torch
.
rand
(
input_tensor
.
shape
[
0
],
1
,
device
=
"
meta
"
)
running_var
=
torch
.
rand
(
input_tensor
.
shape
[
0
],
1
,
device
=
"
meta
"
)
# construct args
fwd_in_args
=
[
input_tensor
,
[
input_tensor
.
shape
[
0
]],
weight_tensor
]
...
...
@@ -132,27 +137,32 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
# memory cost
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost
=
MemoryCost
(
activation
=
compute_size_in_bytes
(
[
input_tensor
,
output_tensor
,
weight_tensor
,
bias_tensor
]),
fwd_memory_cost
=
MemoryCost
(
activation
=
compute_size_in_bytes
(
[
input_tensor
,
output_tensor
,
weight_tensor
,
bias_tensor
]),
parameter
=
compute_size_in_bytes
([
weight_tensor
,
bias_tensor
]),
temp
=
0
,
buffer
=
compute_size_in_bytes
([
running_mean
,
running_var
]))
buffer
=
compute_size_in_bytes
([
running_mean
,
running_var
]),
)
bwd_memory_cost
=
MemoryCost
(
activation
=
compute_size_in_bytes
([
input_tensor
,
weight_tensor
,
bias_tensor
]),
bwd_memory_cost
=
MemoryCost
(
activation
=
compute_size_in_bytes
([
input_tensor
,
weight_tensor
,
bias_tensor
]),
parameter
=
compute_size_in_bytes
([
weight_tensor
,
bias_tensor
]),
temp
=
compute_size_in_bytes
([
running_mean
,
running_var
]),
buffer
=
compute_size_in_bytes
([
running_mean
,
running_var
]))
buffer
=
compute_size_in_bytes
([
running_mean
,
running_var
]),
)
total_cost
=
MemoryCost
(
activation
=
fwd_memory_cost
.
activation
+
bwd_memory_cost
.
activation
,
total_cost
=
MemoryCost
(
activation
=
fwd_memory_cost
.
activation
+
bwd_memory_cost
.
activation
,
parameter
=
fwd_memory_cost
.
parameter
+
bwd_memory_cost
.
parameter
,
temp
=
fwd_memory_cost
.
temp
+
bwd_memory_cost
.
temp
,
buffer
=
fwd_memory_cost
.
buffer
+
bwd_memory_cost
.
buffer
)
buffer
=
fwd_memory_cost
.
buffer
+
bwd_memory_cost
.
buffer
,
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_memory_cost
,
bwd
=
bwd_memory_cost
,
total
=
total_cost
)
# store fwd_in, fwd_buffer, fwd_out
fwd_in
=
[
torch
.
zeros_like
(
input_tensor
,
device
=
'
meta
'
)]
fwd_buffer
=
[
torch
.
zeros_like
(
running_mean
,
device
=
'
meta
'
),
torch
.
zeros_like
(
running_var
,
device
=
'
meta
'
)]
fwd_out
=
[
torch
.
zeros_like
(
output_tensor
,
device
=
'
meta
'
)]
fwd_in
=
[
torch
.
zeros_like
(
input_tensor
,
device
=
"
meta
"
)]
fwd_buffer
=
[
torch
.
zeros_like
(
running_mean
,
device
=
"
meta
"
),
torch
.
zeros_like
(
running_var
,
device
=
"
meta
"
)]
fwd_out
=
[
torch
.
zeros_like
(
output_tensor
,
device
=
"
meta
"
)]
return
compute_cost
,
memory_cost
,
fwd_in
,
fwd_buffer
,
fwd_out
colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py
View file @
9e768b59
...
...
@@ -63,7 +63,7 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
# store fwd_in, fwd_buffer, fwd_out
fwd_in
=
[]
fwd_buffer
=
[]
fwd_out
=
[
torch
.
zeros_like
(
output_tensor
,
device
=
'
meta
'
)]
fwd_out
=
[
torch
.
zeros_like
(
output_tensor
,
device
=
"
meta
"
)]
return
compute_cost
,
mem_cost
,
fwd_in
,
fwd_buffer
,
fwd_out
...
...
@@ -117,8 +117,10 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
fwd_mem_cost
=
MemoryCost
(
activation
=
compute_size_in_bytes
([
input_tensor
,
output_tensor
,
index_matrix
]))
# temp memory for backward is the index matrix to be discarded
bwd_mem_cost
=
MemoryCost
(
activation
=
compute_size_in_bytes
(
input_tensor
)
-
compute_size_in_bytes
(
index_matrix
),
temp
=
compute_size_in_bytes
(
index_matrix
))
bwd_mem_cost
=
MemoryCost
(
activation
=
compute_size_in_bytes
(
input_tensor
)
-
compute_size_in_bytes
(
index_matrix
),
temp
=
compute_size_in_bytes
(
index_matrix
),
)
# total cost
total_mem_cost
=
MemoryCost
(
activation
=
fwd_mem_cost
.
activation
+
bwd_mem_cost
.
activation
,
temp
=
bwd_mem_cost
.
temp
)
...
...
@@ -126,8 +128,8 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
mem_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
# store fwd_in, fwd_buffer, fwd_out
fwd_in
=
[
torch
.
zeros_like
(
input_tensor
,
device
=
'
meta
'
)]
fwd_buffer
=
[
torch
.
zeros_like
(
index_matrix
,
device
=
'
meta
'
)]
fwd_out
=
[
torch
.
zeros_like
(
output_tensor
,
device
=
'
meta
'
)]
fwd_in
=
[
torch
.
zeros_like
(
input_tensor
,
device
=
"
meta
"
)]
fwd_buffer
=
[
torch
.
zeros_like
(
index_matrix
,
device
=
"
meta
"
)]
fwd_out
=
[
torch
.
zeros_like
(
output_tensor
,
device
=
"
meta
"
)]
return
compute_cost
,
mem_cost
,
fwd_in
,
fwd_buffer
,
fwd_out
colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py
View file @
9e768b59
...
...
@@ -2,7 +2,6 @@ from typing import Callable, List, Tuple
import
torch
from
colossalai._analyzer._subclasses.flop_tensor
import
flop_mapping
from
colossalai._analyzer.fx.node_util
import
compute_size_in_bytes
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
MemoryCost
,
OperationDataType
,
TrainCycleItem
...
...
@@ -37,15 +36,19 @@ def tensor_related_metainfo(bwd_mem_out_factor: float = 1, bwd_mem_tmp_factor: f
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_mem_cost
=
MemoryCost
(
activation
=
compute_size_in_bytes
(
outputs
)
*
2
,
parameter
=
0
,
temp
=
0
,
buffer
=
0
)
bwd_mem_cost
=
MemoryCost
(
activation
=
compute_size_in_bytes
(
outputs
)
*
bwd_mem_out_factor
,
bwd_mem_cost
=
MemoryCost
(
activation
=
compute_size_in_bytes
(
outputs
)
*
bwd_mem_out_factor
,
parameter
=
0
,
temp
=
compute_size_in_bytes
(
outputs
)
*
bwd_mem_tmp_factor
,
buffer
=
0
)
buffer
=
0
,
)
total_mem_cost
=
MemoryCost
(
activation
=
fwd_mem_cost
.
activation
+
bwd_mem_cost
.
activation
,
total_mem_cost
=
MemoryCost
(
activation
=
fwd_mem_cost
.
activation
+
bwd_mem_cost
.
activation
,
parameter
=
fwd_mem_cost
.
parameter
+
bwd_mem_cost
.
parameter
,
temp
=
fwd_mem_cost
.
temp
+
bwd_mem_cost
.
temp
,
buffer
=
fwd_mem_cost
.
buffer
+
bwd_mem_cost
.
buffer
)
buffer
=
fwd_mem_cost
.
buffer
+
bwd_mem_cost
.
buffer
,
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
...
...
@@ -66,14 +69,24 @@ def tensor_related_metainfo(bwd_mem_out_factor: float = 1, bwd_mem_tmp_factor: f
# register torch.Tensor related metainfo
# (0, 0)
meta_register
.
register
([
torch
.
tensor
,
torch
.
Tensor
.
to
,
torch
.
Tensor
.
unsqueeze
,
torch
.
unsqueeze
,
torch
.
arange
])(
tensor_related_metainfo
(
0
,
0
))
meta_register
.
register
([
torch
.
tensor
,
torch
.
Tensor
.
to
,
torch
.
Tensor
.
unsqueeze
,
torch
.
unsqueeze
,
torch
.
arange
])(
tensor_related_metainfo
(
0
,
0
)
)
# (1, 0)
meta_register
.
register
([
torch
.
Tensor
.
flatten
,
torch
.
flatten
,
torch
.
Tensor
.
transpose
,
torch
.
transpose
,
torch
.
Tensor
.
permute
,
torch
.
permute
,
torch
.
Tensor
.
split
,
torch
.
split
,
torch
.
Tensor
.
view
])(
tensor_related_metainfo
(
1
,
0
))
meta_register
.
register
(
[
torch
.
Tensor
.
flatten
,
torch
.
flatten
,
torch
.
Tensor
.
transpose
,
torch
.
transpose
,
torch
.
Tensor
.
permute
,
torch
.
permute
,
torch
.
Tensor
.
split
,
torch
.
split
,
torch
.
Tensor
.
view
,
]
)(
tensor_related_metainfo
(
1
,
0
))
# (1, 1)
meta_register
.
register
([
torch
.
Tensor
.
type
,
torch
.
Tensor
.
contiguous
])(
tensor_related_metainfo
(
1
,
1
))
colossalai/auto_parallel/meta_profiler/meta_registry/where.py
View file @
9e768b59
...
...
@@ -4,7 +4,7 @@ import torch
from
colossalai._analyzer._subclasses.flop_tensor
import
flop_mapping
from
colossalai._analyzer.fx.node_util
import
compute_size_in_bytes
as
activation_size
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
MemoryCost
,
OperationDataType
,
TrainCycleItem
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
MemoryCost
,
TrainCycleItem
from
..registry
import
meta_register
...
...
@@ -39,16 +39,21 @@ def where_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Li
# gradient matrix for input x and input y, remove the temp memory and condition tensor generated in forward phase
# NOTE: currently in SPMD solver we always believe that there will be a new input tensor created in forward
fwd_mem_cost
=
MemoryCost
(
activation
=
activation_size
([
condition_tensor
,
x_tensor
,
y_tensor
,
output_tensor
]))
bwd_mem_cost
=
MemoryCost
(
activation
=
activation_size
([
x_tensor
,
y_tensor
])
-
activation_size
([
condition_tensor
]),
bwd_mem_cost
=
MemoryCost
(
activation
=
activation_size
([
x_tensor
,
y_tensor
])
-
activation_size
([
condition_tensor
]),
parameter
=
0
,
temp
=
activation_size
([
output_tensor
])
*
3
+
activation_size
([
condition_tensor
])
-
activation_size
([
x_tensor
,
y_tensor
]),
buffer
=
0
)
total_mem_cost
=
MemoryCost
(
activation
=
fwd_mem_cost
.
activation
+
bwd_mem_cost
.
activation
,
temp
=
activation_size
([
output_tensor
])
*
3
+
activation_size
([
condition_tensor
])
-
activation_size
([
x_tensor
,
y_tensor
]),
buffer
=
0
,
)
total_mem_cost
=
MemoryCost
(
activation
=
fwd_mem_cost
.
activation
+
bwd_mem_cost
.
activation
,
parameter
=
fwd_mem_cost
.
parameter
+
bwd_mem_cost
.
parameter
,
temp
=
fwd_mem_cost
.
temp
+
bwd_mem_cost
.
temp
,
buffer
=
fwd_mem_cost
.
buffer
+
bwd_mem_cost
.
buffer
)
buffer
=
fwd_mem_cost
.
buffer
+
bwd_mem_cost
.
buffer
,
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
...
...
colossalai/auto_parallel/meta_profiler/registry.py
View file @
9e768b59
__all__
=
[
'
Registry
'
]
__all__
=
[
"
Registry
"
]
class
Registry
:
def
__init__
(
self
,
name
):
self
.
name
=
name
self
.
store
=
{}
def
register
(
self
,
source
):
def
wrapper
(
func
):
if
isinstance
(
source
,
(
list
,
tuple
)):
# support register a list of items for this func
...
...
@@ -21,7 +19,7 @@ class Registry:
return
wrapper
def
get
(
self
,
source
):
assert
source
in
self
.
store
,
f
'
{
source
}
not found in the
{
self
.
name
}
registry
'
assert
source
in
self
.
store
,
f
"
{
source
}
not found in the
{
self
.
name
}
registry
"
target
=
self
.
store
[
source
]
return
target
...
...
@@ -29,4 +27,4 @@ class Registry:
return
source
in
self
.
store
meta_register
=
Registry
(
'
meta
'
)
meta_register
=
Registry
(
"
meta
"
)
colossalai/auto_parallel/meta_profiler/shard_metainfo.py
View file @
9e768b59
...
...
@@ -2,20 +2,13 @@ from typing import Callable, List
import
torch
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
MemoryCost
,
OperationData
,
OperationDataType
,
ShardingStrategy
,
StrategiesVector
,
TrainCycleItem
,
)
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
OperationData
,
ShardingStrategy
,
TrainCycleItem
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
.constants
import
INPLACE_MODULE
,
INPLACE_OPS
,
NO_SAVE_ACTIVATION
from
.registry
import
meta_register
__all__
=
[
'
ShardMetaInfo
'
]
__all__
=
[
"
ShardMetaInfo
"
]
class
ShardMetaInfo
:
...
...
@@ -76,10 +69,12 @@ class ShardMetaInfo:
"""
if
isinstance
(
sharding_spec
,
ShardingSpec
):
op_data
=
OperationData
(
name
=
operation_data
.
name
,
op_data
=
OperationData
(
name
=
operation_data
.
name
,
data
=
torch
.
zeros
(
sharding_spec
.
get_sharded_shape_per_device
(),
device
=
"meta"
),
type
=
operation_data
.
type
,
logical_shape
=
operation_data
.
logical_shape
)
logical_shape
=
operation_data
.
logical_shape
,
)
elif
isinstance
(
sharding_spec
,
(
list
,
tuple
)):
data
=
operation_data
.
data
assert
isinstance
(
data
,
(
list
,
tuple
)),
f
"Data Should be list or tuple, but got
{
type
(
data
)
}
."
...
...
@@ -97,8 +92,9 @@ class ShardMetaInfo:
"""
Compute meta info based on sharding strategy and the given target function.
"""
assert
meta_register
.
has
(
self
.
_target
.
__class__
)
or
meta_register
.
has
(
self
.
_target
),
\
f
"Meta info for
{
self
.
_target
}
is not registered."
assert
meta_register
.
has
(
self
.
_target
.
__class__
)
or
meta_register
.
has
(
self
.
_target
),
f
"Meta info for
{
self
.
_target
}
is not registered."
if
meta_register
.
has
(
self
.
_target
.
__class__
):
# module
meta_func
=
meta_register
.
get
(
self
.
_target
.
__class__
)
...
...
@@ -117,11 +113,11 @@ class ShardMetaInfo:
# construct kwargs
if
self
.
target
in
INPLACE_MODULE
:
kwargs
=
{
'
inplace
'
:
self
.
target
.
inplace
}
kwargs
=
{
"
inplace
"
:
self
.
target
.
inplace
}
elif
self
.
target
in
INPLACE_OPS
:
kwargs
=
{
'
inplace
'
:
True
}
kwargs
=
{
"
inplace
"
:
True
}
else
:
kwargs
=
{
'
inplace
'
:
False
}
kwargs
=
{
"
inplace
"
:
False
}
# compute metainfo with meta_func
self
.
compute_cost
,
self
.
memory_cost
,
self
.
fwd_in
,
self
.
fwd_buffer
,
self
.
fwd_out
=
meta_func
(
*
args
,
**
kwargs
)
...
...
colossalai/auto_parallel/offload/amp_optimizer.py
View file @
9e768b59
from
typing
import
Dict
,
Tuple
from
enum
import
Enum
from
typing
import
Dict
,
Tuple
import
torch
from
torch.optim
import
Optimizer
from
colossalai.logging
import
get_dist_logger
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.amp.naive_amp.grad_scaler
import
DynamicGradScaler
from
colossalai.interface
import
OptimizerWrapper
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
get_current_device
from
.base_offload_module
import
BaseOffloadModule
from
.region_manager
import
RegionManager
from
.region
import
Region
from
.region_manager
import
RegionManager
class
OptimState
(
Enum
):
SCALED
=
0
UNSCALED
=
1
class
AMPOptimizer
(
ColossalaiOptimizer
):
class
AMPOptimizer
(
OptimizerWrapper
):
"""
A wrapper for Optimizer.
Code reference: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/optimizer/zero_optimizer.py
...
...
@@ -36,7 +37,8 @@ class AMPOptimizer(ColossalaiOptimizer):
norm_type (float, optional): norm_type used for `clip_grad_norm`.
"""
def
__init__
(
self
,
def
__init__
(
self
,
optimizer
:
Optimizer
,
module
:
BaseOffloadModule
,
initial_scale
:
float
=
2
**
16
,
...
...
@@ -47,8 +49,8 @@ class AMPOptimizer(ColossalaiOptimizer):
min_scale
:
float
=
1
,
max_scale
:
float
=
2
**
32
,
clipping_norm
:
float
=
0.0
,
norm_type
:
float
=
2.0
):
norm_type
:
float
=
2.0
,
):
super
().
__init__
(
optimizer
)
self
.
module
=
module
...
...
@@ -68,19 +70,21 @@ class AMPOptimizer(ColossalaiOptimizer):
self
.
__init__optimizer
()
# Grad scaler
self
.
grad_scaler
=
DynamicGradScaler
(
initial_scale
=
initial_scale
,
self
.
grad_scaler
=
DynamicGradScaler
(
initial_scale
=
initial_scale
,
min_scale
=
min_scale
,
growth_factor
=
growth_factor
,
backoff_factor
=
backoff_factor
,
growth_interval
=
growth_interval
,
hysteresis
=
hysteresis
,
max_scale
=
max_scale
)
max_scale
=
max_scale
,
)
self
.
_found_overflow
:
torch
.
Tensor
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int64
,
device
=
get_current_device
())
self
.
_logger
=
get_dist_logger
()
def
_set_grad_ptr
(
self
):
for
group
in
self
.
param_groups
:
for
fake_param
in
group
[
'
params
'
]:
for
fake_param
in
group
[
"
params
"
]:
region
=
self
.
param_to_region
[
fake_param
]
begin
,
end
=
self
.
param_to_range
[
fake_param
]
...
...
@@ -91,7 +95,7 @@ class AMPOptimizer(ColossalaiOptimizer):
def
_update_fp16_params
(
self
):
none_tensor
=
torch
.
empty
([
0
])
for
group
in
self
.
param_groups
:
for
fake_param
in
group
[
'
params
'
]:
for
fake_param
in
group
[
"
params
"
]:
assert
fake_param
.
grad
is
None
fake_param
.
data
=
none_tensor
self
.
param_to_region
[
fake_param
].
cpu_grad
=
None
...
...
@@ -131,7 +135,7 @@ class AMPOptimizer(ColossalaiOptimizer):
if
found_inf
:
self
.
optim_state
=
OptimState
.
UNSCALED
# no need to unscale grad
self
.
grad_scaler
.
update
(
found_inf
)
# update gradient scaler
self
.
_logger
.
info
(
f
'
Found overflow. Skip step
'
)
self
.
_logger
.
info
(
f
"
Found overflow. Skip step
"
)
self
.
zero_grad
()
# reset all gradients
self
.
_update_fp16_params
()
return
...
...
@@ -155,11 +159,10 @@ class AMPOptimizer(ColossalaiOptimizer):
self
.
module
.
backward
(
loss
)
def
__init__optimizer
(
self
):
for
group
in
self
.
optim
.
param_groups
:
fake_params_list
=
list
()
for
param
in
group
[
'
params
'
]:
for
param
in
group
[
"
params
"
]:
region
=
self
.
region_manager
.
get_region
(
param
)
fake_param
=
torch
.
nn
.
Parameter
(
torch
.
empty
([
0
]))
self
.
param_to_range
[
fake_param
]
=
region
.
param_to_range
[
param
]
...
...
@@ -170,7 +173,7 @@ class AMPOptimizer(ColossalaiOptimizer):
if
param
in
self
.
optim
.
state
:
self
.
optim
.
state
[
fake_param
]
=
self
.
optim
.
state
.
pop
(
param
)
group
[
'
params
'
]
=
fake_params_list
group
[
"
params
"
]
=
fake_params_list
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
...
...
colossalai/auto_parallel/offload/base_offload_module.py
View file @
9e768b59
...
...
@@ -4,7 +4,7 @@ from typing import Optional, Set
import
torch
import
torch.nn
as
nn
from
colossalai.
nn.parallel.data_parallel
import
_cast_float
from
colossalai.
utils
import
_cast_float
from
colossalai.zero.legacy.gemini.tensor_utils
import
free_storage
from
.region_manager
import
RegionManager
...
...
@@ -22,7 +22,6 @@ class BaseOffloadModule:
"""
def
__init__
(
self
,
model
:
nn
.
Module
,
region_manager
:
RegionManager
,
is_sync
=
True
):
self
.
model
=
model
self
.
region_manager
=
region_manager
self
.
grad_hook_list
=
[]
...
...
@@ -91,17 +90,16 @@ class BaseOffloadModule:
def
parameters
(
self
,
recurse
:
bool
=
True
):
return
self
.
model
.
parameters
(
recurse
)
def
named_parameters
(
self
,
prefix
:
str
=
''
,
recurse
:
bool
=
True
):
def
named_parameters
(
self
,
prefix
:
str
=
""
,
recurse
:
bool
=
True
):
return
self
.
model
.
named_parameters
(
prefix
,
recurse
)
def
named_buffers
(
self
,
prefix
:
str
=
''
,
recurse
:
bool
=
True
):
def
named_buffers
(
self
,
prefix
:
str
=
""
,
recurse
:
bool
=
True
):
return
self
.
model
.
named_buffers
(
prefix
,
recurse
)
def
named_children
(
self
):
return
self
.
model
.
named_children
()
def
named_modules
(
self
,
memo
:
Optional
[
Set
[
torch
.
nn
.
Module
]]
=
None
,
prefix
:
str
=
''
,
remove_duplicate
:
bool
=
True
):
def
named_modules
(
self
,
memo
:
Optional
[
Set
[
torch
.
nn
.
Module
]]
=
None
,
prefix
:
str
=
""
,
remove_duplicate
:
bool
=
True
):
return
self
.
model
.
named_modules
(
memo
,
prefix
,
remove_duplicate
)
colossalai/auto_parallel/offload/mem_optimize.py
View file @
9e768b59
...
...
@@ -14,11 +14,9 @@ from .runtime import runtime_asyn_offload_apply_pass, runtime_syn_offload_apply_
from
.util
import
GlobalRuntimeInfo
,
compute_act_peak_mem
,
compute_max_param_mem
,
compute_total_param_mem
def
memory_optimize
(
model
:
torch
.
nn
.
Module
,
inps
:
Dict
[
str
,
torch
.
Tensor
],
memory_budget
:
float
=
-
1.0
,
solver_name
:
str
=
'asyn'
):
def
memory_optimize
(
model
:
torch
.
nn
.
Module
,
inps
:
Dict
[
str
,
torch
.
Tensor
],
memory_budget
:
float
=
-
1.0
,
solver_name
:
str
=
"asyn"
):
model
=
model
.
cpu
().
half
()
tracer
=
ColoTracer
()
assert
is_compatible_with_meta
()
...
...
@@ -40,13 +38,13 @@ def memory_optimize(model: torch.nn.Module,
f
"act_peak_mem=
{
act_peak_mem
:.
3
f
}
MB | max_param_mem=
{
max_param_mem
:.
3
f
}
MB | total_param_mem=
{
total_param_mem
:.
3
f
}
"
)
if
solver_name
==
'
syn
'
:
if
solver_name
==
"
syn
"
:
gm
=
runtime_syn_offload_apply_pass
(
gm
,
region_manager
.
region_list
)
elif
solver_name
==
'
asyn
'
:
elif
solver_name
==
"
asyn
"
:
gm
=
runtime_asyn_offload_apply_pass
(
gm
,
region_manager
.
region_list
)
else
:
raise
TypeError
(
f
"Unknown solver name
{
solver_name
}
!"
)
gm
.
recompile
()
optimized_model
=
BaseOffloadModule
(
gm
,
region_manager
,
solver_name
==
'
syn
'
)
optimized_model
=
BaseOffloadModule
(
gm
,
region_manager
,
solver_name
==
"
syn
"
)
return
optimized_model
colossalai/auto_parallel/offload/region.py
View file @
9e768b59
...
...
@@ -55,13 +55,13 @@ class Region:
Map the parameters in the region to a contiguous memory space.
"""
self
.
fp16_data
=
torch
.
zeros
(
self
.
param_num
,
dtype
=
torch
.
half
,
device
=
'
cuda
'
)
self
.
fp16_data
=
torch
.
zeros
(
self
.
param_num
,
dtype
=
torch
.
half
,
device
=
"
cuda
"
)
offset
=
0
for
param
in
self
.
fp16_params
:
param
.
data
=
param
.
data
.
cuda
()
p_num
=
param
.
data
.
numel
()
self
.
fp16_data
[
offset
:
offset
+
p_num
].
copy_
(
param
.
data
.
flatten
())
param
.
data
=
self
.
fp16_data
[
offset
:
offset
+
p_num
].
view
(
param
.
data
.
shape
)
self
.
fp16_data
[
offset
:
offset
+
p_num
].
copy_
(
param
.
data
.
flatten
())
param
.
data
=
self
.
fp16_data
[
offset
:
offset
+
p_num
].
view
(
param
.
data
.
shape
)
self
.
param_to_range
[
param
]
=
(
offset
,
offset
+
p_num
)
offset
+=
p_num
...
...
@@ -83,7 +83,7 @@ class Region:
self
.
temp_fp32_data
.
record_stream
(
torch
.
cuda
.
current_stream
())
if
not
self
.
in_mem_pool_flag
:
alloc_storage
(
self
.
fp16_data
)
self
.
fp16_data
[:
self
.
param_num
].
copy_
(
self
.
temp_fp32_data
)
self
.
fp16_data
[:
self
.
param_num
].
copy_
(
self
.
temp_fp32_data
)
self
.
fp16_data
.
record_stream
(
torch
.
cuda
.
current_stream
())
self
.
__update_params_ptr
()
...
...
@@ -94,7 +94,7 @@ class Region:
"""
self
.
cpu_grad
=
torch
.
empty
(
self
.
param_num
,
dtype
=
torch
.
half
,
pin_memory
=
True
)
self
.
cpu_grad
.
copy_
(
self
.
fp16_data
[:
self
.
param_num
],
non_blocking
=
True
)
self
.
cpu_grad
.
copy_
(
self
.
fp16_data
[:
self
.
param_num
],
non_blocking
=
True
)
self
.
fp16_data
.
record_stream
(
torch
.
cuda
.
current_stream
())
if
not
self
.
in_mem_pool_flag
:
self
.
free_cuda_data
()
...
...
colossalai/auto_parallel/offload/region_manager.py
View file @
9e768b59
from
typing
import
List
,
Any
,
Dict
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Tuple
import
torch
from
torch.fx
import
Graph
,
Node
from
.region
import
Region
from
.solver
import
SolverFactory
from
.training_simulator
import
TrainingSimulator
from
.region
import
Region
from
.util
import
NodeInfo
...
...
@@ -19,14 +20,9 @@ class RegionManager:
cnode (List[str], optional): Common node List, should be the subset of input.
"""
def
__init__
(
self
,
graph
:
Graph
,
solver_name
:
str
=
'asyn'
,
memory_budget
:
float
=
-
1.0
,
cnode
:
List
[
str
]
=
None
):
def
__init__
(
self
,
graph
:
Graph
,
solver_name
:
str
=
"asyn"
,
memory_budget
:
float
=
-
1.0
,
cnode
:
List
[
str
]
=
None
):
self
.
graph
=
graph
assert
graph
.
owning_module
is
not
None
,
'
The given graph is not associated with a owning_module
'
assert
graph
.
owning_module
is
not
None
,
"
The given graph is not associated with a owning_module
"
self
.
root_module
=
self
.
graph
.
owning_module
self
.
nodes
=
list
(
graph
.
nodes
)
self
.
cnode
=
cnode
...
...
@@ -39,7 +35,7 @@ class RegionManager:
self
.
memory_budget
=
memory_budget
self
.
solver_name
=
solver_name
self
.
require_pool
:
bool
=
solver_name
==
'
asyn
'
self
.
require_pool
:
bool
=
solver_name
==
"
asyn
"
self
.
reg_to_block
:
Dict
[
int
,
int
]
=
dict
()
...
...
@@ -61,22 +57,19 @@ class RegionManager:
self
.
_post_process
(
solver
.
best_ts
)
def
_pre_process
(
self
):
init_region_list
=
self
.
_linearize_graph
()
if
len
(
self
.
shared_region_pairs
)
>
1
:
raise
NotImplementedError
(
'The current version only considers at most one pair of parameter sharing.'
)
raise
NotImplementedError
(
"The current version only considers at most one pair of parameter sharing."
)
elif
len
(
self
.
shared_region_pairs
)
==
1
:
shared_regs
=
self
.
shared_region_pairs
[
0
]
assert
shared_regs
[
0
].
shared_rid
==
shared_regs
[
1
].
r_id
\
and
shared_regs
[
1
].
shared_rid
==
shared_regs
[
0
].
r_id
assert
shared_regs
[
0
].
shared_rid
==
shared_regs
[
1
].
r_id
and
shared_regs
[
1
].
shared_rid
==
shared_regs
[
0
].
r_id
fst_id
=
shared_regs
[
0
].
r_id
lst_id
=
shared_regs
[
1
].
r_id
regs_left_out
=
init_region_list
[:
fst_id
+
1
]
regs_left_out
=
init_region_list
[:
fst_id
+
1
]
regs_right_out
=
init_region_list
[
lst_id
:]
hold_regs
=
init_region_list
[
fst_id
+
1
:
lst_id
]
hold_regs
=
init_region_list
[
fst_id
+
1
:
lst_id
]
else
:
regs_left_out
=
[]
regs_right_out
=
[]
...
...
@@ -122,12 +115,9 @@ class RegionManager:
it may not find a suitable region placement strategy for the given execution flow.
"""
reg_flow
=
torch
.
cat
(
[
ts
.
fwd_reg_flow
,
ts
.
bwd_reg_flow
],
dim
=
0
)
mem_block_num
=
torch
.
max
(
torch
.
sum
(
reg_flow
[:,
self
.
rid_in_pool
],
dim
=
1
))
coexist_matrix
=
torch
.
logical_or
(
ts
.
fwd_reg_flow
,
ts
.
bwd_reg_flow
)
reg_flow
=
torch
.
cat
([
ts
.
fwd_reg_flow
,
ts
.
bwd_reg_flow
],
dim
=
0
)
mem_block_num
=
torch
.
max
(
torch
.
sum
(
reg_flow
[:,
self
.
rid_in_pool
],
dim
=
1
))
coexist_matrix
=
torch
.
logical_or
(
ts
.
fwd_reg_flow
,
ts
.
bwd_reg_flow
)
block_to_regs
=
{}
for
block_idx
in
range
(
mem_block_num
):
...
...
@@ -135,8 +125,7 @@ class RegionManager:
for
reg
in
self
.
region_list
:
if
reg
.
r_id
in
self
.
rid_in_pool
:
cur_reg_appears
=
coexist_matrix
[:,
reg
.
r_id
]
cur_reg_coexists
=
torch
.
sum
(
coexist_matrix
[
cur_reg_appears
],
dim
=
0
).
bool
()
cur_reg_coexists
=
torch
.
sum
(
coexist_matrix
[
cur_reg_appears
],
dim
=
0
).
bool
()
for
block_idx
in
range
(
mem_block_num
):
if
not
any
(
cur_reg_coexists
[
block_to_regs
[
block_idx
]]):
block_to_regs
[
block_idx
].
append
(
reg
.
r_id
)
...
...
@@ -145,9 +134,12 @@ class RegionManager:
if
reg
.
r_id
not
in
self
.
reg_to_block
:
raise
NotImplementedError
(
f
'can not find a block from the memory pool to store parameters of the region'
)
self
.
memory_pool
=
torch
.
chunk
(
torch
.
zeros
(
int
(
mem_block_num
*
self
.
mem_block_size
/
2
),
dtype
=
torch
.
half
,
device
=
'cuda'
),
chunks
=
int
(
mem_block_num
))
f
"can not find a block from the memory pool to store parameters of the region"
)
self
.
memory_pool
=
torch
.
chunk
(
torch
.
zeros
(
int
(
mem_block_num
*
self
.
mem_block_size
/
2
),
dtype
=
torch
.
half
,
device
=
"cuda"
),
chunks
=
int
(
mem_block_num
),
)
def
_merge_small_regions
(
self
,
orig_reg_list
:
List
[
Region
])
->
List
[
Region
]:
"""
...
...
@@ -178,10 +170,9 @@ class RegionManager:
return
region_list
def
_search_block_size
(
self
,
region_list
:
List
[
Region
],
search_interval_byte
:
int
=
1024
,
search_range_byte
:
int
=
128
*
1024
**
2
)
->
int
:
def
_search_block_size
(
self
,
region_list
:
List
[
Region
],
search_interval_byte
:
int
=
1024
,
search_range_byte
:
int
=
128
*
1024
**
2
)
->
int
:
"""
Search for a suitable memory block size.
...
...
@@ -208,11 +199,10 @@ class RegionManager:
acc_wasted
+=
blk_size
-
left
return
acc_wasted
param_size_list
=
[
region
.
param_size
for
region
in
region_list
if
region
.
r_id
==
region
.
shared_rid
]
param_size_list
=
[
region
.
param_size
for
region
in
region_list
if
region
.
r_id
==
region
.
shared_rid
]
start_size
=
max
(
param_size_list
)
min_mem_waste
=
float
(
'
+inf
'
)
min_mem_waste
=
float
(
"
+inf
"
)
best_block_size
=
start_size
for
block_size
in
range
(
start_size
,
start_size
+
search_range_byte
+
1
,
search_interval_byte
):
...
...
@@ -229,7 +219,7 @@ class RegionManager:
Initialize region data, which maps the parameters in the region to a contiguous memory space.
"""
self
.
temp_fp32_data
=
torch
.
zeros
(
self
.
max_param_num
,
device
=
'
cuda
'
,
dtype
=
torch
.
float32
)
self
.
temp_fp32_data
=
torch
.
zeros
(
self
.
max_param_num
,
device
=
"
cuda
"
,
dtype
=
torch
.
float32
)
for
region
in
self
.
region_list
:
pre_alloc_tensor
=
None
...
...
@@ -244,8 +234,7 @@ class RegionManager:
region
.
fp16_data
=
shared_region
.
fp16_data
region
.
fp32_data
=
shared_region
.
fp32_data
region
.
param_to_range
=
shared_region
.
param_to_range
region
.
temp_fp32_data
=
self
.
temp_fp32_data
[:
region
.
param_num
].
detach
(
)
region
.
temp_fp32_data
=
self
.
temp_fp32_data
[:
region
.
param_num
].
detach
()
torch
.
cuda
.
empty_cache
()
...
...
@@ -259,13 +248,14 @@ class RegionManager:
former_reg
,
latter_reg
=
self
.
shared_region_pairs
[
0
]
assert
latter_reg
.
param_num
>=
former_reg
.
param_num
embedding_node
=
former_reg
.
nodes
[
-
1
]
assert
embedding_node
.
op
==
'call_module'
and
isinstance
(
self
.
root_module
.
get_submodule
(
embedding_node
.
target
),
torch
.
nn
.
Embedding
)
assert
embedding_node
.
op
==
"call_module"
and
isinstance
(
self
.
root_module
.
get_submodule
(
embedding_node
.
target
),
torch
.
nn
.
Embedding
)
if
latter_reg
.
param_num
>
former_reg
.
param_num
:
for
idx
,
n
in
enumerate
(
latter_reg
.
nodes
):
if
(
n
.
op
==
'call_module'
and
isinstance
(
self
.
root_module
.
get_submodule
(
n
.
target
),
torch
.
nn
.
Linear
)
)
or
\
(
n
.
op
==
'
call_function
'
and
n
.
target
is
torch
.
nn
.
functional
.
linear
):
if
(
n
.
op
==
"call_module"
and
isinstance
(
self
.
root_module
.
get_submodule
(
n
.
target
),
torch
.
nn
.
Linear
)
)
or
(
n
.
op
==
"
call_function
"
and
n
.
target
is
torch
.
nn
.
functional
.
linear
):
cut_node_idx
=
idx
+
1
break
assert
len
(
latter_reg
.
fp16_params
)
==
2
...
...
@@ -273,7 +263,7 @@ class RegionManager:
for
p
in
new_reg
.
fp16_params
:
self
.
param_region_map
[
p
]
=
new_reg
self
.
region_list
.
insert
(
new_reg
.
r_id
,
new_reg
)
for
reg
in
self
.
region_list
[
new_reg
.
r_id
+
1
:]:
for
reg
in
self
.
region_list
[
new_reg
.
r_id
+
1
:]:
reg
.
r_id
+=
1
latter_reg
.
shared_rid
=
former_reg
.
r_id
former_reg
.
shared_rid
=
latter_reg
.
r_id
...
...
@@ -362,14 +352,12 @@ class RegionManager:
"""
def
_is_inplace
(
n
:
Node
):
"""Get the inplace argument from ``torch.fx.Node``
"""
"""Get the inplace argument from ``torch.fx.Node``"""
inplace
=
False
if
n
.
op
==
"call_function"
:
inplace
=
n
.
kwargs
.
get
(
"inplace"
,
False
)
elif
n
.
op
==
"call_module"
:
inplace
=
getattr
(
n
.
graph
.
owning_module
.
get_submodule
(
n
.
target
),
"inplace"
,
False
)
inplace
=
getattr
(
n
.
graph
.
owning_module
.
get_submodule
(
n
.
target
),
"inplace"
,
False
)
return
inplace
label
=
False
...
...
@@ -385,21 +373,23 @@ class RegionManager:
elif
n
.
op
==
"call_function"
:
label
=
any
(
map
(
lambda
x
:
x
.
name
in
self
.
only_param_ops
,
n
.
all_input_nodes
))
and
any
(
map
(
lambda
x
:
x
.
name
not
in
self
.
only_param_ops
and
not
_is_cop
(
n
.
target
),
n
.
all_input_nodes
))
map
(
lambda
x
:
x
.
name
not
in
self
.
only_param_ops
and
not
_is_cop
(
n
.
target
),
n
.
all_input_nodes
)
)
return
label
and
not
sum
([
v
for
_
,
v
in
param_op_deps
.
items
()])
and
not
any
(
map
(
_is_inplace
,
n
.
users
))
def
_exception_node_handling
():
# TODO meta info prop bug
if
n
.
name
.
__contains__
(
"transpose"
)
and
n
.
meta
[
'
fwd_out
'
][
0
].
dim
()
<=
2
:
n
.
meta
[
'
fwd_out
'
]
=
[]
if
n
.
name
.
__contains__
(
"transpose"
)
and
n
.
meta
[
"
fwd_out
"
][
0
].
dim
()
<=
2
:
n
.
meta
[
"
fwd_out
"
]
=
[]
# make sure that item in cnode is valid
if
self
.
cnode
:
for
name
in
self
.
cnode
:
try
:
assert
next
(
node
for
node
in
self
.
graph
.
nodes
if
node
.
name
==
name
).
op
==
"placeholder"
,
\
f
"Common node
{
name
}
is not an input of the model."
assert
(
next
(
node
for
node
in
self
.
graph
.
nodes
if
node
.
name
==
name
).
op
==
"placeholder"
),
f
"Common node
{
name
}
is not an input of the model."
except
StopIteration
:
raise
ValueError
(
f
"Common node name
{
name
}
not in graph."
)
else
:
...
...
@@ -428,8 +418,8 @@ class RegionManager:
ns
=
[]
border_n_idx
=
region
.
nodes
.
index
(
act_n
)
if
border_n_idx
<
len
(
region
.
nodes
):
ns
=
region
.
nodes
[
border_n_idx
+
1
:]
region
.
nodes
=
region
.
nodes
[:
border_n_idx
+
1
]
ns
=
region
.
nodes
[
border_n_idx
+
1
:]
region
.
nodes
=
region
.
nodes
[:
border_n_idx
+
1
]
region_list
.
append
(
region
)
region_id
+=
1
region
=
Region
(
r_id
=
region_id
)
...
...
@@ -448,19 +438,21 @@ class RegionManager:
region
=
Region
(
r_id
=
region_id
)
# propagate common node attr if possible
if
len
(
n
.
all_input_nodes
)
==
len
([
node
for
node
in
n
.
all_input_nodes
if
node
.
name
in
self
.
cnode
])
or
_is_cop
(
n
.
target
):
if
len
(
n
.
all_input_nodes
)
==
len
(
[
node
for
node
in
n
.
all_input_nodes
if
node
.
name
in
self
.
cnode
]
)
or
_is_cop
(
n
.
target
):
self
.
cnode
.
append
(
n
.
name
)
else
:
deps
[
n
]
=
len
(
[
user
for
user
in
n
.
users
if
user
.
op
!=
"output"
])
deps
[
n
]
=
len
([
user
for
user
in
n
.
users
if
user
.
op
!=
"output"
])
# propagate param node attr if possible
if
len
(
n
.
all_input_nodes
)
==
len
([
node
for
node
in
n
.
all_input_nodes
if
node
.
name
in
self
.
only_param_ops
])
or
n
.
op
==
"get_attr"
:
if
(
len
(
n
.
all_input_nodes
)
==
len
([
node
for
node
in
n
.
all_input_nodes
if
node
.
name
in
self
.
only_param_ops
])
or
n
.
op
==
"get_attr"
):
self
.
only_param_ops
.
append
(
n
.
name
)
param_op_deps
[
n
]
=
len
(
[
user
for
user
in
n
.
users
if
user
.
op
!=
"output"
])
param_op_deps
[
n
]
=
len
([
user
for
user
in
n
.
users
if
user
.
op
!=
"output"
])
# record last activation node
if
_is_act
(
n
.
_meta_data
):
...
...
@@ -472,19 +464,16 @@ class RegionManager:
return
region_list
def
_set_node_and_region_info
(
self
,
node_id
:
int
,
cur_n
:
Node
,
cur_reg
:
Region
):
cur_n
.
node_info
=
NodeInfo
(
node_id
)
if
cur_n
.
op
==
'
call_module
'
:
if
cur_n
.
op
==
"
call_module
"
:
target
=
cur_n
.
target
submod
=
self
.
root_module
.
get_submodule
(
target
)
for
p
in
list
(
submod
.
parameters
(
recurse
=
False
)):
if
p
in
self
.
param_region_map
:
cur_reg
.
shared_rid
=
self
.
param_region_map
[
p
].
r_id
self
.
param_region_map
[
p
].
shared_rid
=
cur_reg
.
r_id
self
.
shared_region_pairs
.
append
(
(
self
.
param_region_map
[
p
],
cur_reg
))
self
.
shared_region_pairs
.
append
((
self
.
param_region_map
[
p
],
cur_reg
))
else
:
self
.
param_region_map
[
p
]
=
cur_reg
...
...
@@ -499,12 +488,10 @@ class RegionManager:
attr_itr
=
getattr
(
attr_itr
,
atom
)
if
isinstance
(
attr_itr
,
torch
.
nn
.
Parameter
):
if
attr_itr
in
self
.
param_region_map
:
cur_reg
.
shared_rid
=
self
.
param_region_map
[
attr_itr
].
r_id
self
.
param_region_map
[
attr_itr
].
shared_rid
=
cur_reg
.
r_id
self
.
shared_region_pairs
.
append
(
(
self
.
param_region_map
[
attr_itr
],
cur_reg
))
self
.
shared_region_pairs
.
append
((
self
.
param_region_map
[
attr_itr
],
cur_reg
))
else
:
self
.
param_region_map
[
attr_itr
]
=
cur_reg
...
...
colossalai/auto_parallel/offload/runtime.py
View file @
9e768b59
...
...
@@ -22,13 +22,13 @@ class SynPreFwdPostBwdOP(torch.autograd.Function):
@
staticmethod
def
forward
(
ctx
,
input_
,
fwd_info
,
bwd_info
):
ctx
.
bwd_info
=
bwd_info
d2h_rid
=
fwd_info
.
get
(
'
d2h_rid
'
,
None
)
d2h_rid
=
fwd_info
.
get
(
"
d2h_rid
"
,
None
)
if
d2h_rid
is
not
None
:
free_region
=
GlobalRuntimeInfo
().
region_list
[
d2h_rid
]
assert
isinstance
(
free_region
,
Region
)
free_region
.
free_cuda_data
()
h2d_rid
=
fwd_info
.
get
(
'
h2d_rid
'
,
None
)
h2d_rid
=
fwd_info
.
get
(
"
h2d_rid
"
,
None
)
if
h2d_rid
is
not
None
:
h2d_region
=
GlobalRuntimeInfo
().
region_list
[
h2d_rid
]
assert
isinstance
(
h2d_region
,
Region
)
...
...
@@ -38,8 +38,7 @@ class SynPreFwdPostBwdOP(torch.autograd.Function):
@
staticmethod
def
backward
(
ctx
,
grad_output
):
h2d_rid
=
ctx
.
bwd_info
.
get
(
'h2d_rid'
,
None
)
h2d_rid
=
ctx
.
bwd_info
.
get
(
"h2d_rid"
,
None
)
if
h2d_rid
is
not
None
:
pref_region
=
GlobalRuntimeInfo
().
region_list
[
h2d_rid
]
assert
isinstance
(
pref_region
,
Region
)
...
...
@@ -64,13 +63,13 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
def
forward
(
ctx
,
input_
,
fwd_info
,
bwd_info
):
ctx
.
bwd_info
=
bwd_info
sync_rid
=
fwd_info
.
get
(
'
sync_rid
'
,
None
)
sync_rid
=
fwd_info
.
get
(
"
sync_rid
"
,
None
)
if
sync_rid
is
not
None
:
prefetch_event
=
GlobalRuntimeInfo
().
fwd_prefetch_event_map
.
get
(
sync_rid
,
None
)
if
prefetch_event
:
prefetch_event
.
wait
()
h2d_rid
=
fwd_info
.
get
(
'
h2d_rid
'
,
None
)
h2d_rid
=
fwd_info
.
get
(
"
h2d_rid
"
,
None
)
if
h2d_rid
is
not
None
:
pref_region
=
GlobalRuntimeInfo
().
region_list
[
h2d_rid
]
assert
isinstance
(
pref_region
,
Region
)
...
...
@@ -87,8 +86,7 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
@
staticmethod
def
backward
(
ctx
,
grad_output
):
sync_rid
=
ctx
.
bwd_info
.
get
(
'sync_rid'
,
None
)
sync_rid
=
ctx
.
bwd_info
.
get
(
"sync_rid"
,
None
)
if
sync_rid
is
not
None
:
wait_region
=
GlobalRuntimeInfo
().
region_list
[
sync_rid
]
assert
isinstance
(
wait_region
,
Region
)
...
...
@@ -98,7 +96,7 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
else
:
wait_region
.
move_param_to_cuda
()
h2d_rid
=
ctx
.
bwd_info
.
get
(
'
h2d_rid
'
,
None
)
h2d_rid
=
ctx
.
bwd_info
.
get
(
"
h2d_rid
"
,
None
)
if
h2d_rid
is
not
None
:
pref_region
=
GlobalRuntimeInfo
().
region_list
[
h2d_rid
]
assert
isinstance
(
pref_region
,
Region
)
...
...
@@ -114,7 +112,7 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
def
convert_fwd_upload_bwd_offload_to_action
(
tensor
,
fwd_info
,
bwd_info
):
'''
"""
Convert Upload and Offload operation into runtime action.
Argument:
...
...
@@ -123,14 +121,14 @@ def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info):
that need to be uploaded, or freed during forward pass.
bwd_info(dict): information dict, which contains region indices
that need to be uploaded during backward pass.
'''
"""
with
torch
.
_C
.
DisableTorchFunction
():
ret
=
SynPreFwdPostBwdOP
.
apply
(
tensor
,
fwd_info
,
bwd_info
)
return
ret
def
convert_fwd_prefetch_bwd_offload_to_action
(
tensor
,
fwd_info
,
bwd_info
):
'''
"""
Convert Prefetch and Offload operation into runtime action.
Argument:
...
...
@@ -139,7 +137,7 @@ def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info):
that need to be prefetched, waited, or freed during forward pass.
bwd_info(dict): information dict, which contains region indices
that need to be prefetched or waited during backward pass.
'''
"""
with
torch
.
_C
.
DisableTorchFunction
():
ret
=
AsynPreFwdPostBwdOP
.
apply
(
tensor
,
fwd_info
,
bwd_info
)
return
ret
...
...
@@ -176,22 +174,22 @@ def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[R
# forward upload
fwd_info
=
{}
if
requires_upload_p_in_fwd
(
region_list
[
region
.
shared_rid
]):
fwd_info
[
'
h2d_rid
'
]
=
region
.
r_id
fwd_info
[
"
h2d_rid
"
]
=
region
.
r_id
# forward offload
if
r_idx
>
0
and
region_list
[
r_idx
-
1
].
need_offload
:
fwd_info
[
'
d2h_rid
'
]
=
r_idx
-
1
fwd_info
[
"
d2h_rid
"
]
=
r_idx
-
1
bwd_info
=
{}
# backward upload
if
r_idx
>
0
and
region_list
[
r_idx
-
1
].
need_offload
:
bwd_info
[
'
h2d_rid
'
]
=
region_list
[
r_idx
-
1
].
r_id
bwd_info
[
"
h2d_rid
"
]
=
region_list
[
r_idx
-
1
].
r_id
if
fwd_info
or
bwd_info
:
with
mod_graph
.
inserting_after
(
last_inp_node
):
new_node
=
mod_graph
.
create_node
(
'call_function'
,
convert_fwd_upload_bwd_offload_to_action
,
args
=
(
last_inp_node
,
fwd_info
,
bwd_info
)
)
new_node
=
mod_graph
.
create_node
(
"call_function"
,
convert_fwd_upload_bwd_offload_to_action
,
args
=
(
last_inp_node
,
fwd_info
,
bwd_info
)
)
replace_node_users
(
last_inp_node
,
new_node
)
last_inp_node
=
region
.
nodes
[
-
1
]
...
...
@@ -210,9 +208,9 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
first_region_with_p
=
[
region
for
region
in
region_list
if
region
.
param_size
][
0
]
fwd_info
=
{
"h2d_rid"
:
first_region_with_p
.
r_id
}
with
mod_graph
.
inserting_after
(
last_inp_node
):
upload_apply_node
=
mod_graph
.
create_node
(
'call_function'
,
convert_fwd_upload_bwd_offload_to_action
,
args
=
(
last_inp_node
,
fwd_info
,
{})
)
upload_apply_node
=
mod_graph
.
create_node
(
"call_function"
,
convert_fwd_upload_bwd_offload_to_action
,
args
=
(
last_inp_node
,
fwd_info
,
{})
)
replace_node_users
(
last_inp_node
,
upload_apply_node
)
last_inp_node
=
upload_apply_node
...
...
@@ -220,37 +218,39 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
# forward prefetch
fwd_info
=
{}
if
region
.
param_size
:
fwd_info
[
'
sync_rid
'
]
=
region
.
r_id
fwd_info
[
"
sync_rid
"
]
=
region
.
r_id
fwd_prefetch_region
=
region
.
fwd_prefetch_region
if
fwd_prefetch_region
and
requires_upload_p_in_fwd
(
region_list
[
fwd_prefetch_region
.
shared_rid
]):
fwd_info
[
'
h2d_rid
'
]
=
fwd_prefetch_region
.
r_id
fwd_info
[
"
h2d_rid
"
]
=
fwd_prefetch_region
.
r_id
# forward offload
if
r_idx
>
0
and
region_list
[
r_idx
-
1
].
need_offload
:
fwd_info
[
'
d2h_rid
'
]
=
r_idx
-
1
fwd_info
[
"
d2h_rid
"
]
=
r_idx
-
1
bwd_info
=
{}
# backward prefetch
if
r_idx
>
0
and
region_list
[
r_idx
-
1
].
need_offload
:
bwd_info
[
'
sync_rid
'
]
=
r_idx
-
1
bwd_info
[
"
sync_rid
"
]
=
r_idx
-
1
if
r_idx
>
0
and
region_list
[
r_idx
-
1
].
bwd_prefetch_region
:
bwd_info
[
'
h2d_rid
'
]
=
region_list
[
r_idx
-
1
].
bwd_prefetch_region
.
r_id
bwd_info
[
"
h2d_rid
"
]
=
region_list
[
r_idx
-
1
].
bwd_prefetch_region
.
r_id
if
fwd_info
or
bwd_info
:
with
mod_graph
.
inserting_after
(
last_inp_node
):
new_node
=
mod_graph
.
create_node
(
'call_function'
,
new_node
=
mod_graph
.
create_node
(
"call_function"
,
convert_fwd_prefetch_bwd_offload_to_action
,
args
=
(
last_inp_node
,
fwd_info
,
bwd_info
))
args
=
(
last_inp_node
,
fwd_info
,
bwd_info
),
)
replace_node_users
(
last_inp_node
,
new_node
)
last_inp_node
=
region
.
nodes
[
-
1
]
if
region
.
bwd_prefetch_region
:
bwd_info
=
{
'
h2d_rid
'
:
region
.
bwd_prefetch_region
.
r_id
}
bwd_info
=
{
"
h2d_rid
"
:
region
.
bwd_prefetch_region
.
r_id
}
with
mod_graph
.
inserting_after
(
last_inp_node
):
new_node
=
mod_graph
.
create_node
(
'call_function'
,
convert_fwd_prefetch_bwd_offload_to_action
,
args
=
(
last_inp_node
,
{},
bwd_info
)
)
new_node
=
mod_graph
.
create_node
(
"call_function"
,
convert_fwd_prefetch_bwd_offload_to_action
,
args
=
(
last_inp_node
,
{},
bwd_info
)
)
replace_node_users
(
last_inp_node
,
new_node
)
# gm.graph.print_tabular()
return
gm
colossalai/auto_parallel/offload/solver.py
View file @
9e768b59
import
time
from
typing
import
List
,
Dict
,
Type
from
abc
import
ABC
,
abstractmethod
from
typing
import
Dict
,
List
,
Type
NOT_NVML
=
False
try
:
...
...
@@ -10,10 +10,11 @@ except:
import
torch
from
torch.fx.node
import
Node
from
colossalai.utils.cuda
import
get_current_device
from
.training_simulator
import
TrainingSimulator
,
SynTrainingSimulator
,
AsynTrainingSimulator
from
.region
import
Region
from
.training_simulator
import
AsynTrainingSimulator
,
SynTrainingSimulator
,
TrainingSimulator
from
.util
import
NodeInfo
,
NvDevicePower
...
...
@@ -49,19 +50,14 @@ class Solver(ABC):
It is used to reduce the memory budget. Due to some errors in the estimation of peak memory and execution time.
"""
def
__init__
(
self
,
region_list
:
List
[
Region
],
memory_budget
:
float
=
-
1.0
,
error_factor
:
float
=
0.95
)
->
None
:
def
__init__
(
self
,
region_list
:
List
[
Region
],
memory_budget
:
float
=
-
1.0
,
error_factor
:
float
=
0.95
)
->
None
:
self
.
region_list
=
region_list
self
.
error_factor
:
float
=
error_factor
if
memory_budget
>
0
:
self
.
memory_budget
=
memory_budget
*
self
.
error_factor
else
:
self
.
memory_budget
=
torch
.
cuda
.
get_device_properties
(
get_current_device
()).
total_memory
*
self
.
error_factor
self
.
memory_budget
=
torch
.
cuda
.
get_device_properties
(
get_current_device
()).
total_memory
*
self
.
error_factor
self
.
link_to_bandwidth
:
Dict
[
str
,
Dict
[
float
,
float
]]
=
self
.
_profile_bandwidth
()
self
.
comp_power
:
float
=
self
.
_extract_computing_power
()
...
...
@@ -94,7 +90,7 @@ class Solver(ABC):
if
extra_cost
==
0
:
# means data transfer overhead can be completely overlapped
return
(
float
(
'
inf
'
),
total_mem_saving
,
peak_mem_saving
)
return
(
float
(
"
inf
"
),
total_mem_saving
,
peak_mem_saving
)
return
(
total_mem_saving
/
extra_cost
,
total_mem_saving
,
peak_mem_saving
)
def
_compare_profit
(
self
,
profit_a
:
tuple
,
profit_b
:
tuple
)
->
bool
:
...
...
@@ -122,9 +118,7 @@ class Solver(ABC):
self
.
best_ts
=
best_ts
self
.
_update_node_mem_info
(
best_ts
.
fwd_node_mem
,
best_ts
.
bwd_node_mem
)
def
_update_node_mem_info
(
self
,
fwd_mem_info
:
Dict
[
Node
,
float
],
bwd_mem_info
:
Dict
[
Node
,
float
]):
def
_update_node_mem_info
(
self
,
fwd_mem_info
:
Dict
[
Node
,
float
],
bwd_mem_info
:
Dict
[
Node
,
float
]):
"""
Update the runtime memory information of the node.
...
...
@@ -134,12 +128,10 @@ class Solver(ABC):
"""
for
node
,
mem
in
fwd_mem_info
.
items
():
assert
hasattr
(
node
,
'node_info'
)
and
isinstance
(
node
.
node_info
,
NodeInfo
)
assert
hasattr
(
node
,
"node_info"
)
and
isinstance
(
node
.
node_info
,
NodeInfo
)
node
.
node_info
.
runtime_fwd_mem
=
mem
for
node
,
mem
in
bwd_mem_info
.
items
():
assert
hasattr
(
node
,
'node_info'
)
and
isinstance
(
node
.
node_info
,
NodeInfo
)
assert
hasattr
(
node
,
"node_info"
)
and
isinstance
(
node
.
node_info
,
NodeInfo
)
node
.
node_info
.
runtime_bwd_mem
=
mem
def
_extract_computing_power
(
self
):
...
...
@@ -159,12 +151,12 @@ class Solver(ABC):
return
NvDevicePower
.
RTX3080_FP16
*
units
elif
device_name
.
__contains__
(
"RTX 3090"
):
return
NvDevicePower
.
RTX3090_FP16
*
units
elif
device_name
.
__contains__
(
'
V100
'
):
elif
device_name
.
__contains__
(
"
V100
"
):
return
NvDevicePower
.
V100_FP16
*
units
elif
device_name
.
__contains__
(
"A100"
):
return
NvDevicePower
.
A100_FP16
*
units
else
:
raise
TypeError
(
f
'
Unknown NVIDIA GPU device name
{
device_name
}
'
)
raise
TypeError
(
f
"
Unknown NVIDIA GPU device name
{
device_name
}
"
)
def
_profile_bandwidth
(
self
):
"""
...
...
@@ -172,9 +164,9 @@ class Solver(ABC):
using data volumes ranging from 1KB to 1GB.
"""
print
(
'
profiling bandwidth ......
'
)
print
(
"
profiling bandwidth ......
"
)
link_to_bandwidth
=
{}
links
=
[
'
h2d
'
,
'
d2h
'
]
links
=
[
"
h2d
"
,
"
d2h
"
]
for
link
in
links
:
t_size
=
1024
...
...
@@ -182,24 +174,22 @@ class Solver(ABC):
# from 1KB to 1GB
for
i
in
range
(
21
):
if
link
==
'h2d'
:
src_tensor
=
torch
.
ones
(
int
(
t_size
),
dtype
=
torch
.
int8
,
pin_memory
=
True
)
dst_tensor
=
torch
.
ones
(
(
int
(
t_size
)),
dtype
=
torch
.
int8
,
device
=
'cuda'
)
elif
link
==
'd2h'
:
src_tensor
=
torch
.
ones
(
int
(
t_size
),
dtype
=
torch
.
int8
,
device
=
'cuda'
)
dst_tensor
=
torch
.
ones
(
(
int
(
t_size
)),
dtype
=
torch
.
int8
,
pin_memory
=
True
)
if
link
==
"h2d"
:
src_tensor
=
torch
.
ones
(
int
(
t_size
),
dtype
=
torch
.
int8
,
pin_memory
=
True
)
dst_tensor
=
torch
.
ones
((
int
(
t_size
)),
dtype
=
torch
.
int8
,
device
=
"cuda"
)
elif
link
==
"d2h"
:
src_tensor
=
torch
.
ones
(
int
(
t_size
),
dtype
=
torch
.
int8
,
device
=
"cuda"
)
dst_tensor
=
torch
.
ones
((
int
(
t_size
)),
dtype
=
torch
.
int8
,
pin_memory
=
True
)
def
func
():
dst_tensor
.
copy_
(
src_tensor
)
size_to_bandwidth
[
t_size
]
=
t_size
/
benchmark_func
(
func
,
number
=
5
,
repeat
=
3
)
print
(
f
'size:
{
t_size
/
1024
**
2
:.
3
f
}
MB, '
f
'
{
src_tensor
.
device
.
type
}
-to-
{
dst_tensor
.
device
.
type
}
'
f
'bandwidth:
{
size_to_bandwidth
[
t_size
]
/
1024
**
3
:.
3
f
}
GB/s'
)
print
(
f
"size:
{
t_size
/
1024
**
2
:.
3
f
}
MB, "
f
"
{
src_tensor
.
device
.
type
}
-to-
{
dst_tensor
.
device
.
type
}
"
f
"bandwidth:
{
size_to_bandwidth
[
t_size
]
/
1024
**
3
:.
3
f
}
GB/s"
)
t_size
*=
2
...
...
@@ -208,10 +198,7 @@ class Solver(ABC):
class
SynGreedySolver
(
Solver
):
def
__init__
(
self
,
region_list
:
List
[
Region
],
memory_budget
:
float
=
-
1.0
)
->
None
:
def
__init__
(
self
,
region_list
:
List
[
Region
],
memory_budget
:
float
=
-
1.0
)
->
None
:
super
().
__init__
(
region_list
,
memory_budget
)
self
.
best_ts
:
SynTrainingSimulator
=
None
...
...
@@ -258,7 +245,8 @@ class SynGreedySolver(Solver):
else
:
raise
NotImplementedError
(
f
"can't find the offload strategy met the memory budget
{
self
.
memory_budget
/
1024
**
2
}
MB, "
f
"it needs
{
self
.
best_ts
.
peak_mem
/
1024
**
2
:.
3
f
}
MB at least!"
)
f
"it needs
{
self
.
best_ts
.
peak_mem
/
1024
**
2
:.
3
f
}
MB at least!"
)
def
_call_solver_l2l
(
self
):
"""
...
...
@@ -270,7 +258,6 @@ class SynGreedySolver(Solver):
region
.
is_syn
=
True
def
_try_to_offload
(
self
,
offload_region
:
Region
):
# record previous information
orig_need_offload
=
offload_region
.
need_offload
assert
not
orig_need_offload
...
...
@@ -297,23 +284,17 @@ class SynGreedySolver(Solver):
ts
=
SynTrainingSimulator
(
self
.
region_list
,
self
.
comp_power
,
self
.
link_to_bandwidth
)
ts
.
execute
()
extra_comm_cost
=
2.0
*
\
ts
.
_get_communication_overhead
(
'h2d'
,
offload_region
.
param_size
)
extra_comm_cost
=
2.0
*
ts
.
_get_communication_overhead
(
"h2d"
,
offload_region
.
param_size
)
# the shared region needs to be moved twice
if
offload_region
.
r_id
<
offload_region
.
shared_rid
:
extra_comm_cost
*=
2.0
profit
=
self
.
_compute_offload_profit
(
ts
.
total_mem_saving
,
self
.
best_ts
.
peak_mem
-
ts
.
peak_mem
,
extra_comm_cost
)
profit
=
self
.
_compute_offload_profit
(
ts
.
total_mem_saving
,
self
.
best_ts
.
peak_mem
-
ts
.
peak_mem
,
extra_comm_cost
)
return
ts
,
profit
class
AsynGreedySolver
(
Solver
):
def
__init__
(
self
,
region_list
:
List
[
Region
],
memory_budget
:
float
=
-
1.0
,
search_window_size
:
int
=
3
):
def
__init__
(
self
,
region_list
:
List
[
Region
],
memory_budget
:
float
=
-
1.0
,
search_window_size
:
int
=
3
):
super
().
__init__
(
region_list
,
memory_budget
)
self
.
search_window_size
=
search_window_size
...
...
@@ -331,7 +312,7 @@ class AsynGreedySolver(Solver):
ts
=
AsynTrainingSimulator
(
self
.
region_list
,
self
.
comp_power
,
self
.
link_to_bandwidth
)
ts
.
execute
()
self
.
_update_state
(
ts
)
print
(
"init peak memory"
,
self
.
best_ts
.
peak_mem
/
1024
**
2
,
"MB"
)
print
(
"init peak memory"
,
self
.
best_ts
.
peak_mem
/
1024
**
2
,
"MB"
)
def
_call_solver
(
self
):
"""
...
...
@@ -358,18 +339,17 @@ class AsynGreedySolver(Solver):
best_pref_ts
=
None
# search when to prefetch the region offloaded
for
host_region
in
self
.
region_list
[
region
.
r_id
+
1
:
region
.
r_id
+
1
+
self
.
search_window_size
]:
for
host_region
in
self
.
region_list
[
region
.
r_id
+
1
:
region
.
r_id
+
1
+
self
.
search_window_size
]:
if
host_region
.
bwd_prefetch_region
is
not
None
:
continue
temp_ts
,
profit
=
self
.
_try_to_offload
(
host_region
,
region
)
temp_ts
,
profit
=
self
.
_try_to_offload
(
host_region
,
region
)
if
self
.
_compare_profit
(
profit
,
max_prefetch_profit
):
region_to_region_map
[
region
.
r_id
]
=
host_region
max_prefetch_profit
=
profit
best_pref_ts
=
temp_ts
if
profit
[
0
]
==
float
(
'
inf
'
):
if
profit
[
0
]
==
float
(
"
inf
"
):
break
if
self
.
_compare_profit
(
max_prefetch_profit
,
max_offload_profit
):
...
...
@@ -392,7 +372,8 @@ class AsynGreedySolver(Solver):
else
:
raise
NotImplementedError
(
f
"can't find the offload strategy met the memory budget
{
self
.
memory_budget
/
1024
**
2
}
MB, "
f
"it needs
{
self
.
best_ts
.
peak_mem
/
1024
**
2
:.
3
f
}
MB at least!"
)
f
"it needs
{
self
.
best_ts
.
peak_mem
/
1024
**
2
:.
3
f
}
MB at least!"
)
region_to_region_map
.
clear
()
...
...
@@ -452,7 +433,6 @@ class AsynGreedySolver(Solver):
peak_mem_saving
=
0
while
len
(
self
.
region_to_region_map
)
and
peak_mem_saving
<=
0
:
max_profit
=
(
0
,)
best_ts
=
None
undo_host_region
=
None
...
...
@@ -464,8 +444,7 @@ class AsynGreedySolver(Solver):
assert
offload_region
.
need_offload
assert
not
offload_region
.
is_syn
ts
,
profit
=
self
.
_try_convert_to_syn_upload
(
host_region
,
offload_region
)
ts
,
profit
=
self
.
_try_convert_to_syn_upload
(
host_region
,
offload_region
)
if
self
.
_compare_profit
(
profit
,
max_profit
):
undo_host_region
=
host_region
...
...
@@ -474,7 +453,7 @@ class AsynGreedySolver(Solver):
best_ts
=
ts
if
best_ts
is
None
:
raise
NotImplementedError
(
'
repair error!
'
)
raise
NotImplementedError
(
"
repair error!
"
)
assert
not
undo_offload_region
.
is_syn
undo_offload_region
.
is_syn
=
True
...
...
@@ -500,17 +479,13 @@ class AsynGreedySolver(Solver):
ts
.
execute
()
extra_comm_cost
=
max
(
ts
.
iter_end_time
-
self
.
best_ts
.
iter_end_time
,
0
)
profit
=
self
.
_compute_offload_profit
(
ts
.
total_mem_saving
,
self
.
best_ts
.
peak_mem
-
ts
.
peak_mem
,
extra_comm_cost
)
profit
=
self
.
_compute_offload_profit
(
ts
.
total_mem_saving
,
self
.
best_ts
.
peak_mem
-
ts
.
peak_mem
,
extra_comm_cost
)
return
ts
,
profit
class
SolverFactory
:
solvers
:
Dict
[
str
,
Type
[
Solver
]]
=
{
'syn'
:
SynGreedySolver
,
'asyn'
:
AsynGreedySolver
}
solvers
:
Dict
[
str
,
Type
[
Solver
]]
=
{
"syn"
:
SynGreedySolver
,
"asyn"
:
AsynGreedySolver
}
@
staticmethod
def
create
(
solver_name
:
str
)
->
Type
[
Solver
]:
...
...
colossalai/auto_parallel/offload/training_simulator.py
View file @
9e768b59
import
bisect
from
typing
import
List
,
Dict
from
collections
import
OrderedDict
from
abc
import
ABC
,
abstractmethod
from
collections
import
OrderedDict
from
typing
import
Dict
,
List
from
torch.fx.node
import
Node
...
...
@@ -26,10 +26,7 @@ class TrainingSimulator(ABC):
link_to_bw (Dict[str, Dict[float, float]]): communication links and the corresponding bandwidth.
"""
def
__init__
(
self
,
region_list
:
List
[
Region
],
comp_power
:
float
,
link_to_bw
:
Dict
[
str
,
Dict
[
float
,
float
]])
->
None
:
def
__init__
(
self
,
region_list
:
List
[
Region
],
comp_power
:
float
,
link_to_bw
:
Dict
[
str
,
Dict
[
float
,
float
]])
->
None
:
self
.
region_list
=
region_list
self
.
region_num
=
len
(
region_list
)
...
...
@@ -87,11 +84,7 @@ class TrainingSimulator(ABC):
class
SynTrainingSimulator
(
TrainingSimulator
):
def
__init__
(
self
,
region_list
:
List
[
Region
],
comp_power
:
float
,
link_to_bw
:
Dict
[
str
,
Dict
[
float
,
float
]])
->
None
:
def
__init__
(
self
,
region_list
:
List
[
Region
],
comp_power
:
float
,
link_to_bw
:
Dict
[
str
,
Dict
[
float
,
float
]])
->
None
:
super
().
__init__
(
region_list
,
comp_power
,
link_to_bw
)
def
execute
(
self
):
...
...
@@ -115,8 +108,7 @@ class SynTrainingSimulator(TrainingSimulator):
self
.
runtime_mem
+=
region
.
param_size
for
node
in
region
.
nodes
:
self
.
runtime_mem
+=
calculate_fwd_tmp
(
node
)
+
\
calculate_fwd_out
(
node
)
self
.
runtime_mem
+=
calculate_fwd_tmp
(
node
)
+
calculate_fwd_out
(
node
)
self
.
fwd_node_mem
[
node
]
=
self
.
runtime_mem
self
.
peak_mem
=
max
(
self
.
runtime_mem
,
self
.
peak_mem
)
self
.
total_mem_saving
+=
node
.
node_info
.
runtime_fwd_mem
-
self
.
runtime_mem
...
...
@@ -141,18 +133,15 @@ class SynTrainingSimulator(TrainingSimulator):
self
.
runtime_mem
+=
region
.
param_size
for
node
in
region
.
nodes
.
__reversed__
():
self
.
runtime_mem
-=
calculate_fwd_out
(
node
)
self
.
runtime_mem
+=
node
.
meta
[
'bwd_mem_tmp'
]
+
\
node
.
meta
[
'bwd_mem_out'
]
self
.
runtime_mem
+=
node
.
meta
[
"bwd_mem_tmp"
]
+
node
.
meta
[
"bwd_mem_out"
]
self
.
peak_mem
=
max
(
self
.
runtime_mem
,
self
.
peak_mem
)
# The memory savings of a node may be negative due to parameter prefetch.
self
.
total_mem_saving
+=
node
.
node_info
.
runtime_bwd_mem
-
self
.
runtime_mem
self
.
bwd_node_mem
[
node
]
=
self
.
runtime_mem
self
.
runtime_mem
-=
(
node
.
meta
[
'bwd_mem_tmp'
]
+
calculate_fwd_tmp
(
node
))
self
.
runtime_mem
-=
node
.
meta
[
"bwd_mem_tmp"
]
+
calculate_fwd_tmp
(
node
)
# free bwd_mem_out
self
.
bwd_node_deps
[
node
]
=
len
(
node
.
all_input_nodes
)
...
...
@@ -160,12 +149,14 @@ class SynTrainingSimulator(TrainingSimulator):
if
user_node
in
self
.
bwd_node_deps
:
self
.
bwd_node_deps
[
user_node
]
-=
1
if
self
.
bwd_node_deps
[
user_node
]
<=
0
:
self
.
runtime_mem
-=
user_node
.
meta
[
'
bwd_mem_out
'
]
self
.
runtime_mem
-=
user_node
.
meta
[
"
bwd_mem_out
"
]
if
self
.
runtime_mem
<
0
:
raise
ValueError
(
f
"region id:
{
region
.
r_id
}
, node name:
{
node
.
name
}
, "
raise
ValueError
(
f
"region id:
{
region
.
r_id
}
, node name:
{
node
.
name
}
, "
f
"runtime_mem:
{
self
.
runtime_mem
/
1024
**
2
:.
3
f
}
MB ---"
f
"runtime memory computed less than 0, which is miscalculated!"
)
f
"runtime memory computed less than 0, which is miscalculated!"
)
# release parameter and offload gradient in region
if
region
.
r_id
==
region
.
shared_rid
:
...
...
@@ -177,23 +168,16 @@ class SynTrainingSimulator(TrainingSimulator):
class
AsynTrainingSimulator
(
TrainingSimulator
):
def
__init__
(
self
,
region_list
:
List
[
Region
],
comp_power
:
float
,
link_to_bw
:
Dict
[
str
,
Dict
[
float
,
float
]])
->
None
:
def
__init__
(
self
,
region_list
:
List
[
Region
],
comp_power
:
float
,
link_to_bw
:
Dict
[
str
,
Dict
[
float
,
float
]])
->
None
:
super
().
__init__
(
region_list
,
comp_power
,
link_to_bw
)
self
.
iter_end_time
:
int
=
0
# the last computation execution period
self
.
last_comp
:
ExecutionPeriod
=
ExecutionPeriod
(
start_time
=
0
,
end_time
=
0
)
self
.
last_comp
:
ExecutionPeriod
=
ExecutionPeriod
(
start_time
=
0
,
end_time
=
0
)
# the last parameter prefetch execution period
self
.
last_h2d
:
ExecutionPeriod
=
ExecutionPeriod
(
start_time
=
0
,
end_time
=
0
)
self
.
last_h2d
:
ExecutionPeriod
=
ExecutionPeriod
(
start_time
=
0
,
end_time
=
0
)
# the last gradient offload execution period
self
.
last_d2h
:
ExecutionPeriod
=
ExecutionPeriod
(
start_time
=
0
,
end_time
=
0
)
self
.
last_d2h
:
ExecutionPeriod
=
ExecutionPeriod
(
start_time
=
0
,
end_time
=
0
)
# the forward computation execution period of the region
self
.
fwd_reg_to_comp
:
OrderedDict
[
int
,
ExecutionPeriod
]
=
OrderedDict
()
# the forward parameter prefetch execution period of the region
...
...
@@ -204,10 +188,8 @@ class AsynTrainingSimulator(TrainingSimulator):
self
.
bwd_reg_to_pref
:
OrderedDict
[
int
,
ExecutionPeriod
]
=
OrderedDict
()
# the gradient offload execution period of the region
# which is divided into those that are waiting and those that have been released
self
.
bwd_reg_to_offl_waiting
:
OrderedDict
[
int
,
ExecutionPeriod
]
=
OrderedDict
()
self
.
bwd_reg_to_offl_freed
:
OrderedDict
[
int
,
ExecutionPeriod
]
=
OrderedDict
()
self
.
bwd_reg_to_offl_waiting
:
OrderedDict
[
int
,
ExecutionPeriod
]
=
OrderedDict
()
self
.
bwd_reg_to_offl_freed
:
OrderedDict
[
int
,
ExecutionPeriod
]
=
OrderedDict
()
# the region buffer, which records regions that are offloaded but not released
self
.
reg_buffer_to_free
:
List
[
int
]
=
[]
...
...
@@ -217,10 +199,8 @@ class AsynTrainingSimulator(TrainingSimulator):
# the region execution flow,
# where fwd_reg_flow[i,j] denotes whether the parameters of j-th region are in the GPU
# when the execution reaches the i-th region.
self
.
fwd_reg_flow
=
torch
.
zeros
(
(
self
.
region_num
,
self
.
region_num
)).
bool
()
self
.
bwd_reg_flow
=
torch
.
zeros
(
(
self
.
region_num
,
self
.
region_num
)).
bool
()
self
.
fwd_reg_flow
=
torch
.
zeros
((
self
.
region_num
,
self
.
region_num
)).
bool
()
self
.
bwd_reg_flow
=
torch
.
zeros
((
self
.
region_num
,
self
.
region_num
)).
bool
()
def
execute
(
self
):
"""
...
...
@@ -232,7 +212,7 @@ class AsynTrainingSimulator(TrainingSimulator):
for
reg
in
self
.
region_list
:
if
reg
.
param_size
and
reg
.
r_id
<
self
.
region_num
-
1
:
for
nr
in
self
.
region_list
[
reg
.
r_id
+
1
:]:
for
nr
in
self
.
region_list
[
reg
.
r_id
+
1
:]:
if
nr
.
param_size
and
requires_upload_p_in_fwd
(
self
.
region_list
[
nr
.
shared_rid
]):
reg
.
fwd_prefetch_region
=
nr
break
...
...
@@ -249,8 +229,7 @@ class AsynTrainingSimulator(TrainingSimulator):
self
.
runtime_mem
-=
self
.
region_list
[
reg_id
].
param_size
self
.
bwd_reg_to_offl_waiting
.
clear
()
self
.
iter_end_time
=
max
(
self
.
last_comp
.
end_time
,
self
.
last_d2h
.
end_time
)
self
.
iter_end_time
=
max
(
self
.
last_comp
.
end_time
,
self
.
last_d2h
.
end_time
)
def
_insert_h2d_exec
(
self
,
region
:
Region
,
is_fwd
:
bool
=
True
):
"""
...
...
@@ -258,10 +237,8 @@ class AsynTrainingSimulator(TrainingSimulator):
"""
pref_start_time
=
max
(
self
.
last_h2d
.
end_time
,
self
.
last_comp
.
end_time
)
pref_end_time
=
pref_start_time
+
\
2.0
*
self
.
_get_communication_overhead
(
'h2d'
,
region
.
param_size
)
pref_ep
=
ExecutionPeriod
(
start_time
=
pref_start_time
,
end_time
=
pref_end_time
)
pref_end_time
=
pref_start_time
+
2.0
*
self
.
_get_communication_overhead
(
"h2d"
,
region
.
param_size
)
pref_ep
=
ExecutionPeriod
(
start_time
=
pref_start_time
,
end_time
=
pref_end_time
)
if
is_fwd
:
self
.
fwd_reg_to_pref
[
region
.
r_id
]
=
pref_ep
else
:
...
...
@@ -276,18 +253,16 @@ class AsynTrainingSimulator(TrainingSimulator):
if
is_fwd
:
reg_to_comp
=
self
.
fwd_reg_to_comp
reg_to_pref
=
self
.
fwd_reg_to_pref
flop_key
=
'
fwd_flop
'
flop_key
=
"
fwd_flop
"
else
:
reg_to_comp
=
self
.
bwd_reg_to_comp
reg_to_pref
=
self
.
bwd_reg_to_pref
flop_key
=
'bwd_flop'
comp_start_time
=
max
(
self
.
last_comp
.
end_time
,
reg_to_pref
.
get
(
region
.
r_id
,
ExecutionPeriod
(
0
,
0
)).
end_time
)
comp_end_time
=
comp_start_time
+
\
sum
([
self
.
_get_computing_overhead
(
node
.
meta
.
get
(
flop_key
,
0
))
for
node
in
region
.
nodes
])
comp_ep
=
ExecutionPeriod
(
start_time
=
comp_start_time
,
end_time
=
comp_end_time
)
flop_key
=
"bwd_flop"
comp_start_time
=
max
(
self
.
last_comp
.
end_time
,
reg_to_pref
.
get
(
region
.
r_id
,
ExecutionPeriod
(
0
,
0
)).
end_time
)
comp_end_time
=
comp_start_time
+
sum
(
[
self
.
_get_computing_overhead
(
node
.
meta
.
get
(
flop_key
,
0
))
for
node
in
region
.
nodes
]
)
comp_ep
=
ExecutionPeriod
(
start_time
=
comp_start_time
,
end_time
=
comp_end_time
)
reg_to_comp
[
region
.
r_id
]
=
comp_ep
self
.
last_comp
=
comp_ep
...
...
@@ -297,10 +272,8 @@ class AsynTrainingSimulator(TrainingSimulator):
"""
offl_start_time
=
max
(
self
.
last_d2h
.
end_time
,
self
.
last_comp
.
end_time
)
offl_end_time
=
offl_start_time
+
\
self
.
_get_communication_overhead
(
'd2h'
,
region
.
param_size
)
offl_ep
=
ExecutionPeriod
(
start_time
=
offl_start_time
,
end_time
=
offl_end_time
)
offl_end_time
=
offl_start_time
+
self
.
_get_communication_overhead
(
"d2h"
,
region
.
param_size
)
offl_ep
=
ExecutionPeriod
(
start_time
=
offl_start_time
,
end_time
=
offl_end_time
)
self
.
bwd_reg_to_offl_waiting
[
region
.
r_id
]
=
offl_ep
self
.
last_d2h
=
offl_ep
...
...
@@ -332,20 +305,17 @@ class AsynTrainingSimulator(TrainingSimulator):
self
.
fwd_reg_flow
[
region
.
r_id
,
region
.
r_id
]
=
True
else
:
self
.
fwd_reg_flow
[
region
.
r_id
]
=
self
.
fwd_reg_flow
[
region
.
r_id
-
1
]
self
.
fwd_reg_flow
[
region
.
r_id
,
self
.
reg_buffer_to_free
]
=
False
self
.
fwd_reg_flow
[
region
.
r_id
,
self
.
reg_buffer_to_free
]
=
False
self
.
reg_buffer_to_free
.
clear
()
# prefetch parameters of the next region
fwd_prefetch_region
=
region
.
fwd_prefetch_region
if
fwd_prefetch_region
and
requires_upload_p_in_fwd
(
self
.
region_list
[
fwd_prefetch_region
.
shared_rid
]):
self
.
runtime_mem
+=
fwd_prefetch_region
.
param_size
self
.
fwd_reg_flow
[
region
.
r_id
,
fwd_prefetch_region
.
r_id
]
=
True
self
.
fwd_reg_flow
[
region
.
r_id
,
fwd_prefetch_region
.
r_id
]
=
True
for
node
in
region
.
nodes
:
self
.
runtime_mem
+=
calculate_fwd_tmp
(
node
)
+
\
calculate_fwd_out
(
node
)
self
.
runtime_mem
+=
calculate_fwd_tmp
(
node
)
+
calculate_fwd_out
(
node
)
self
.
peak_mem
=
max
(
self
.
runtime_mem
,
self
.
peak_mem
)
self
.
total_mem_saving
+=
node
.
node_info
.
runtime_fwd_mem
-
self
.
runtime_mem
...
...
@@ -354,8 +324,7 @@ class AsynTrainingSimulator(TrainingSimulator):
if
region
.
need_offload
:
self
.
runtime_mem
-=
region
.
param_size
assert
len
(
self
.
reg_buffer_to_free
)
<=
1
,
f
'
{
len
(
self
.
reg_buffer_to_free
)
}
'
assert
len
(
self
.
reg_buffer_to_free
)
<=
1
,
f
"
{
len
(
self
.
reg_buffer_to_free
)
}
"
self
.
reg_buffer_to_free
.
append
(
region
.
r_id
)
def
_eval_bwd_cost_per_region
(
self
,
region
:
Region
):
...
...
@@ -398,8 +367,7 @@ class AsynTrainingSimulator(TrainingSimulator):
self
.
bwd_reg_flow
[
region
.
r_id
]
=
self
.
bwd_reg_flow
[
region
.
r_id
+
1
]
else
:
self
.
bwd_reg_flow
[
region
.
r_id
]
=
self
.
fwd_reg_flow
[
-
1
]
self
.
bwd_reg_flow
[
region
.
r_id
,
self
.
reg_buffer_to_free
]
=
False
self
.
bwd_reg_flow
[
region
.
r_id
,
self
.
reg_buffer_to_free
]
=
False
# free gradients in the buffer
while
len
(
self
.
reg_buffer_to_free
):
...
...
@@ -415,8 +383,7 @@ class AsynTrainingSimulator(TrainingSimulator):
bwd_prefetch_region
=
region
.
bwd_prefetch_region
if
bwd_prefetch_region
:
self
.
runtime_mem
+=
bwd_prefetch_region
.
param_size
self
.
bwd_reg_flow
[
region
.
r_id
,
bwd_prefetch_region
.
r_id
]
=
True
self
.
bwd_reg_flow
[
region
.
r_id
,
bwd_prefetch_region
.
r_id
]
=
True
# add the gradient of the parameter
if
region
.
r_id
<
region
.
shared_rid
:
...
...
@@ -426,10 +393,8 @@ class AsynTrainingSimulator(TrainingSimulator):
self
.
runtime_mem
+=
region
.
param_size
for
node
in
region
.
nodes
.
__reversed__
():
self
.
runtime_mem
-=
calculate_fwd_out
(
node
)
self
.
runtime_mem
+=
node
.
meta
[
'bwd_mem_tmp'
]
+
\
node
.
meta
[
'bwd_mem_out'
]
self
.
runtime_mem
+=
node
.
meta
[
"bwd_mem_tmp"
]
+
node
.
meta
[
"bwd_mem_out"
]
self
.
peak_mem
=
max
(
self
.
runtime_mem
,
self
.
peak_mem
)
# The memory savings of a node may be negative due to parameter prefetch.
...
...
@@ -437,8 +402,7 @@ class AsynTrainingSimulator(TrainingSimulator):
self
.
bwd_node_mem
[
node
]
=
self
.
runtime_mem
self
.
runtime_mem
-=
(
node
.
meta
[
'bwd_mem_tmp'
]
+
calculate_fwd_tmp
(
node
))
self
.
runtime_mem
-=
node
.
meta
[
"bwd_mem_tmp"
]
+
calculate_fwd_tmp
(
node
)
# free bwd_mem_out
self
.
bwd_node_deps
[
node
]
=
len
(
node
.
all_input_nodes
)
...
...
@@ -446,12 +410,14 @@ class AsynTrainingSimulator(TrainingSimulator):
if
user_node
in
self
.
bwd_node_deps
:
self
.
bwd_node_deps
[
user_node
]
-=
1
if
self
.
bwd_node_deps
[
user_node
]
<=
0
:
self
.
runtime_mem
-=
user_node
.
meta
[
'
bwd_mem_out
'
]
self
.
runtime_mem
-=
user_node
.
meta
[
"
bwd_mem_out
"
]
if
self
.
runtime_mem
<
0
:
raise
ValueError
(
f
"region id:
{
region
.
r_id
}
, node name:
{
node
.
name
}
, "
raise
ValueError
(
f
"region id:
{
region
.
r_id
}
, node name:
{
node
.
name
}
, "
f
"runtime_mem:
{
self
.
runtime_mem
/
1024
**
2
:.
3
f
}
MB ---"
f
"runtime memory computed less than 0, which is miscalculated!"
)
f
"runtime memory computed less than 0, which is miscalculated!"
)
# release parameters of the region
if
requires_release_p_in_bwd
(
self
.
region_list
[
region
.
shared_rid
]):
...
...
colossalai/auto_parallel/offload/util.py
View file @
9e768b59
...
...
@@ -35,7 +35,6 @@ class NvDevicePower:
class
GlobalRuntimeInfo
(
metaclass
=
SingletonMeta
):
def
__init__
(
self
):
self
.
h2d_stream
=
torch
.
cuda
.
Stream
()
self
.
d2h_stream
=
torch
.
cuda
.
Stream
()
...
...
@@ -50,21 +49,18 @@ def compute_act_peak_mem(region_list: List[Region]) -> float:
# forward
for
region
in
region_list
:
for
node
in
region
.
nodes
:
runtime_mem
=
runtime_mem
+
\
calculate_fwd_tmp
(
node
)
+
calculate_fwd_out
(
node
)
runtime_mem
=
runtime_mem
+
calculate_fwd_tmp
(
node
)
+
calculate_fwd_out
(
node
)
act_peak_mem
=
max
(
runtime_mem
,
act_peak_mem
)
# backward
bwd_deps
=
{}
for
region
in
region_list
.
__reversed__
():
for
node
in
region
.
nodes
.
__reversed__
():
runtime_mem
-=
calculate_fwd_out
(
node
)
runtime_mem
=
runtime_mem
+
\
node
.
meta
[
'bwd_mem_tmp'
]
+
node
.
meta
[
'bwd_mem_out'
]
runtime_mem
=
runtime_mem
+
node
.
meta
[
"bwd_mem_tmp"
]
+
node
.
meta
[
"bwd_mem_out"
]
act_peak_mem
=
max
(
runtime_mem
,
act_peak_mem
)
runtime_mem
=
runtime_mem
-
\
node
.
meta
[
'bwd_mem_tmp'
]
-
calculate_fwd_tmp
(
node
)
runtime_mem
=
runtime_mem
-
node
.
meta
[
"bwd_mem_tmp"
]
-
calculate_fwd_tmp
(
node
)
# free bwd_mem_out
bwd_deps
[
node
]
=
len
(
node
.
all_input_nodes
)
...
...
@@ -72,7 +68,7 @@ def compute_act_peak_mem(region_list: List[Region]) -> float:
if
user_node
in
bwd_deps
:
bwd_deps
[
user_node
]
-=
1
if
bwd_deps
[
user_node
]
<=
0
:
runtime_mem
-=
user_node
.
meta
[
'
bwd_mem_out
'
]
runtime_mem
-=
user_node
.
meta
[
"
bwd_mem_out
"
]
return
act_peak_mem
...
...
@@ -86,13 +82,15 @@ def compute_total_param_mem(region_list: List[Region]) -> float:
def
requires_upload_p_in_fwd
(
shared_reg
:
Region
):
return
(
shared_reg
.
r_id
>=
shared_reg
.
shared_rid
)
or
(
shared_reg
.
r_id
<
shared_reg
.
shared_rid
and
shared_reg
.
need_offload
)
return
(
shared_reg
.
r_id
>=
shared_reg
.
shared_rid
)
or
(
shared_reg
.
r_id
<
shared_reg
.
shared_rid
and
shared_reg
.
need_offload
)
def
requires_release_p_in_bwd
(
shared_reg
:
Region
):
return
(
shared_reg
.
r_id
>=
shared_reg
.
shared_rid
)
or
(
shared_reg
.
r_id
<
shared_reg
.
shared_rid
and
shared_reg
.
need_offload
)
return
(
shared_reg
.
r_id
>=
shared_reg
.
shared_rid
)
or
(
shared_reg
.
r_id
<
shared_reg
.
shared_rid
and
shared_reg
.
need_offload
)
def
requires_offload_g_in_bwd
(
region
:
Region
):
...
...
colossalai/auto_parallel/passes/comm_metainfo_pass.py
View file @
9e768b59
...
...
@@ -14,18 +14,20 @@ from colossalai.tensor.sharding_spec import ShardingSpec
shape_consistency_manager
=
ShapeConsistencyManager
()
def
_construct_shard_meta_info
(
node
:
Node
,
origin_sharding_spec
:
ShardingSpec
,
target_sharding_spec
:
ShardingSpec
)
->
ShardMetaInfo
:
def
_construct_shard_meta_info
(
node
:
Node
,
origin_sharding_spec
:
ShardingSpec
,
target_sharding_spec
:
ShardingSpec
)
->
ShardMetaInfo
:
# get comm_action_sequence and total_cost from shape_consistency_manager
_
,
comm_action_sequence
,
total_cost
=
shape_consistency_manager
.
shape_consistency
(
origin_sharding_spec
,
target_sharding_spec
)
origin_sharding_spec
,
target_sharding_spec
)
meta_info
=
ShardMetaInfo
()
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
# get mem cost for ShardMetaInfo
mem_cost
=
shape_consistency_manager
.
mem_cost
(
comm_action_sequence
)
# extract user that has _meta_data and extract element length
input_node
=
next
(
n
for
n
in
node
.
_input_nodes
if
hasattr
(
n
,
'
_meta_data
'
))
input_node
=
next
(
n
for
n
in
node
.
_input_nodes
if
hasattr
(
n
,
"
_meta_data
"
))
element_length
=
input_node
.
_meta_data
.
element_size
()
mem_cost
.
fwd
.
activation
*=
element_length
...
...
@@ -37,9 +39,11 @@ def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
meta_info
.
memory_cost
=
mem_cost
# get computation cost for ShardMetaInfo
meta_info
.
compute_cost
=
TrainCycleItem
(
total_cost
[
'forward'
]
*
element_length
,
total_cost
[
'backward'
]
*
element_length
,
total_cost
[
'total'
]
*
element_length
)
meta_info
.
compute_cost
=
TrainCycleItem
(
total_cost
[
"forward"
]
*
element_length
,
total_cost
[
"backward"
]
*
element_length
,
total_cost
[
"total"
]
*
element_length
,
)
# get tensor shape for ShardMetaInfo
origin_sharding_spec
:
ShardingSpec
...
...
@@ -47,9 +51,9 @@ def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
input_shape
=
origin_sharding_spec
.
get_sharded_shape_per_device
()
output_shape
=
target_sharding_spec
.
get_sharded_shape_per_device
()
meta_info
.
fwd_in
=
[
torch
.
rand
(
input_shape
,
device
=
'
meta
'
)]
meta_info
.
fwd_in
=
[
torch
.
rand
(
input_shape
,
device
=
"
meta
"
)]
meta_info
.
fwd_buffer
=
[]
meta_info
.
fwd_out
=
[
torch
.
rand
(
output_shape
,
device
=
'
meta
'
)]
meta_info
.
fwd_out
=
[
torch
.
rand
(
output_shape
,
device
=
"
meta
"
)]
return
meta_info
...
...
@@ -62,8 +66,10 @@ def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -
# extract node index and user node index
args
=
node
.
args
node_index
,
user_node_index
=
args
[
3
],
args
[
4
]
origin_sharding_spec
,
target_sharding_spec
=
origin_spec_dict
[
node_index
],
sharding_spec_dict
[
node_index
][
user_node_index
]
origin_sharding_spec
,
target_sharding_spec
=
(
origin_spec_dict
[
node_index
],
sharding_spec_dict
[
node_index
][
user_node_index
],
)
return
_construct_shard_meta_info
(
node
,
origin_sharding_spec
,
target_sharding_spec
)
...
...
@@ -77,37 +83,42 @@ def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> S
# this case is for all_reduce, there will be no memory cost
meta_info
=
ShardMetaInfo
()
meta_info
.
memory_cost
=
TrainCycleItem
(
MemoryCost
(),
MemoryCost
(),
MemoryCost
)
output_node
=
next
(
n
for
n
in
node
.
users
if
hasattr
(
n
,
'
_meta_data
'
))
output_node
=
next
(
n
for
n
in
node
.
users
if
hasattr
(
n
,
"
_meta_data
"
))
element_length
=
output_node
.
_meta_data
.
element_size
()
total_cost
=
comm_action
.
comm_spec
.
get_comm_cost
()
meta_info
.
compute_cost
=
TrainCycleItem
(
total_cost
[
'forward'
]
*
element_length
,
total_cost
[
'backward'
]
*
element_length
,
total_cost
[
'total'
]
*
element_length
)
meta_info
.
compute_cost
=
TrainCycleItem
(
total_cost
[
"forward"
]
*
element_length
,
total_cost
[
"backward"
]
*
element_length
,
total_cost
[
"total"
]
*
element_length
,
)
input_shape
=
output_shape
=
comm_action
.
comm_spec
.
sharding_spec
.
get_sharded_shape_per_device
()
meta_info
.
fwd_in
=
[
torch
.
rand
(
input_shape
,
device
=
'
meta
'
)]
meta_info
.
fwd_in
=
[
torch
.
rand
(
input_shape
,
device
=
"
meta
"
)]
meta_info
.
fwd_buffer
=
[]
meta_info
.
fwd_out
=
[
torch
.
rand
(
output_shape
,
device
=
'
meta
'
)]
meta_info
.
fwd_out
=
[
torch
.
rand
(
output_shape
,
device
=
"
meta
"
)]
else
:
# this case will be handled by shape consistency manager
origin_sharding_spec
,
target_sharding_spec
=
comm_action
.
comm_spec
[
'src_spec'
],
comm_action
.
comm_spec
[
'tgt_spec'
]
origin_sharding_spec
,
target_sharding_spec
=
(
comm_action
.
comm_spec
[
"src_spec"
],
comm_action
.
comm_spec
[
"tgt_spec"
],
)
meta_info
=
_construct_shard_meta_info
(
node
,
origin_sharding_spec
,
target_sharding_spec
)
return
meta_info
def
comm_metainfo_pass
(
gm
:
GraphModule
,
sharding_spec_dict
:
Dict
,
origin_spec_dict
:
Dict
,
comm_actions_dict
:
Dict
)
->
GraphModule
:
def
comm_metainfo_pass
(
gm
:
GraphModule
,
sharding_spec_dict
:
Dict
,
origin_spec_dict
:
Dict
,
comm_actions_dict
:
Dict
)
->
GraphModule
:
"""
The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph.
"""
for
node
in
gm
.
graph
.
nodes
:
if
node
.
target
==
runtime_apply
:
setattr
(
node
,
'
best_strategy_info
'
,
_runtime_apply_meta_info
(
node
,
origin_spec_dict
,
sharding_spec_dict
))
setattr
(
node
,
"
best_strategy_info
"
,
_runtime_apply_meta_info
(
node
,
origin_spec_dict
,
sharding_spec_dict
))
elif
node
.
target
==
runtime_comm_spec_apply
:
setattr
(
node
,
'
best_strategy_info
'
,
_runtime_comm_spec_apply_meta_info
(
node
,
comm_actions_dict
))
setattr
(
node
,
"
best_strategy_info
"
,
_runtime_comm_spec_apply_meta_info
(
node
,
comm_actions_dict
))
else
:
pass
return
gm
colossalai/auto_parallel/passes/meta_info_prop.py
View file @
9e768b59
...
...
@@ -21,16 +21,15 @@ def _normalize_tuple(x):
@
compatibility
(
is_backward_compatible
=
False
)
class
MetaInfoProp
:
def
__init__
(
self
,
module
:
GraphModule
)
->
None
:
self
.
module
=
module
self
.
func_dict
=
{
'
placeholder
'
:
self
.
placeholder_handler
,
'
get_attr
'
:
self
.
get_attr_handler
,
'
output
'
:
self
.
output_handler
,
'
call_function
'
:
self
.
node_handler
,
'
call_module
'
:
self
.
node_handler
,
'
call_method
'
:
self
.
node_handler
,
"
placeholder
"
:
self
.
placeholder_handler
,
"
get_attr
"
:
self
.
get_attr_handler
,
"
output
"
:
self
.
output_handler
,
"
call_function
"
:
self
.
node_handler
,
"
call_module
"
:
self
.
node_handler
,
"
call_method
"
:
self
.
node_handler
,
}
def
_set_data_ptr
(
self
,
x
):
...
...
@@ -46,7 +45,7 @@ class MetaInfoProp:
"""
Check if the node is inplace operation.
"""
if
node
.
op
==
'
call_module
'
:
if
node
.
op
==
"
call_module
"
:
return
node
.
graph
.
owning_module
.
get_submodule
(
node
.
target
).
__class__
in
OUTPUT_SAVED_MOD
elif
node
.
op
==
"call_function"
:
return
node
.
target
in
OUTPUT_SAVED_OPS
...
...
@@ -66,7 +65,7 @@ class MetaInfoProp:
Handle the placeholder node.
"""
graph_info
=
GraphInfo
()
out
=
_normalize_tuple
(
getattr
(
node
,
'
_meta_data
'
,
None
))
out
=
_normalize_tuple
(
getattr
(
node
,
"
_meta_data
"
,
None
))
graph_info
.
fwd_out
=
list
(
out
)
if
out
[
0
]
is
not
None
else
[]
node
.
meta
=
{
**
asdict
(
graph_info
)}
...
...
@@ -96,7 +95,7 @@ class MetaInfoProp:
"""
Handle other kind of nodes
"""
assert
hasattr
(
node
,
'
best_strategy_info
'
),
f
"Cannot find best_strategy_info in node
{
node
}
,
{
node
.
op
}
"
assert
hasattr
(
node
,
"
best_strategy_info
"
),
f
"Cannot find best_strategy_info in node
{
node
}
,
{
node
.
op
}
"
graph_info
=
GraphInfo
()
meta_info
=
node
.
best_strategy_info
meta_info
:
ShardMetaInfo
...
...
@@ -126,7 +125,8 @@ class MetaInfoProp:
for
tensor
in
par
.
meta
.
get
(
"fwd_out"
,
[]):
tensor
:
torch
.
Tensor
target_input_tensor
=
next
(
(
x
for
x
in
input_tensors
if
not
x
.
data_ptr
()
and
x
.
shape
==
tensor
.
shape
),
None
)
(
x
for
x
in
input_tensors
if
not
x
.
data_ptr
()
and
x
.
shape
==
tensor
.
shape
),
None
)
if
target_input_tensor
is
not
None
:
target_input_tensor
.
data_ptr
=
tensor
.
data_ptr
...
...
@@ -148,7 +148,7 @@ class MetaInfoProp:
graph_info
.
fwd_tmp
=
buffer_tensors
graph_info
.
fwd_out
=
output_tensors
# fetch other memory information
s
# fetch other memory information
memory_cost
=
meta_info
.
memory_cost
graph_info
.
fwd_mem_tmp
=
memory_cost
.
fwd
.
temp
graph_info
.
fwd_mem_out
=
memory_cost
.
fwd
.
activation
...
...
colossalai/auto_parallel/passes/runtime_apply_pass.py
View file @
9e768b59
from
copy
import
deepcopy
from
typing
import
Dict
,
List
import
torch
from
torch.fx.node
import
Node
from
colossalai._analyzer.fx.node_util
import
MetaInfo
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
CommAction
,
CommType
,
OperationData
,
OperationDataType
,
TrainCycleItem
,
)
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
CommType
,
OperationDataType
from
colossalai.tensor.comm_spec
import
CommSpec
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.tensor.sharding_spec
import
ShardingSpec
...
...
@@ -30,19 +22,22 @@ def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: i
return
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
node
,
origin_sharding_spec
,
target_sharding_spec
)
def
runtime_apply_for_iterable_object
(
node
:
Node
,
origin_dict
:
Dict
,
input_dict
:
Dict
,
node_index
:
int
,
user_node_index
:
int
):
def
runtime_apply_for_iterable_object
(
node
:
Node
,
origin_dict
:
Dict
,
input_dict
:
Dict
,
node_index
:
int
,
user_node_index
:
int
):
"""
This method will be invoked during runtime to do the shape consistency, which makes sure the activations in type of tuple or list
is converted into the user node expected form.
"""
rst
=
[]
for
index
,
(
origin_sharding_spec
,
target_sharding_spec
)
in
enumerate
(
zip
(
origin_dict
[
node_index
],
input_dict
[
node_index
][
user_node_index
])
):
for
index
,
(
origin_sharding_spec
,
target_sharding_spec
)
in
enumerate
(
zip
(
origin_dict
[
node_index
],
input_dict
[
node_index
][
user_node_index
])
):
rst
.
append
(
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
node
[
index
],
origin_sharding_spec
,
target_sharding_spec
))
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
node
[
index
],
origin_sharding_spec
,
target_sharding_spec
)
)
rst
=
type
(
node
)(
rst
)
return
rst
...
...
@@ -55,8 +50,8 @@ def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_
if
isinstance
(
comm_action
.
comm_spec
,
CommSpec
):
rst
=
comm_action
.
comm_spec
.
covert_spec_to_action
(
tensor
)
else
:
origin_sharding_spec
=
comm_action
.
comm_spec
[
'
src_spec
'
]
tgt_sharding_spec
=
comm_action
.
comm_spec
[
'
tgt_spec
'
]
origin_sharding_spec
=
comm_action
.
comm_spec
[
"
src_spec
"
]
tgt_sharding_spec
=
comm_action
.
comm_spec
[
"
tgt_spec
"
]
rst
=
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
tensor
,
origin_sharding_spec
,
tgt_sharding_spec
)
return
rst
...
...
@@ -70,16 +65,16 @@ def _preprocess_graph(nodes: List[Node]):
node_to_index_dict
=
{}
index
=
0
for
node
in
nodes
:
if
node
.
target
==
'
sharding_spec_convert_dict
'
:
if
node
.
target
==
"
sharding_spec_convert_dict
"
:
input_dict_node
=
node
continue
if
node
.
target
==
'
origin_node_sharding_spec_dict
'
:
if
node
.
target
==
"
origin_node_sharding_spec_dict
"
:
origin_dict_node
=
node
continue
if
node
.
target
==
'
comm_actions_dict
'
:
if
node
.
target
==
"
comm_actions_dict
"
:
comm_actions_dict_node
=
node
continue
if
not
hasattr
(
node
,
'
best_strategy
'
):
if
not
hasattr
(
node
,
"
best_strategy
"
):
continue
node_to_index_dict
[
node
]
=
index
index
+=
1
...
...
@@ -97,41 +92,46 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
input_dict_node
,
origin_dict_node
,
_
,
node_to_index_dict
=
_preprocess_graph
(
nodes
)
for
node
in
nodes
:
if
not
hasattr
(
node
,
'
best_strategy
'
)
or
node
.
op
==
'
output
'
:
if
not
hasattr
(
node
,
"
best_strategy
"
)
or
node
.
op
==
"
output
"
:
continue
for
user_node_index
,
user_node
in
enumerate
(
node
.
strategies_vector
.
successor_nodes
):
if
isinstance
(
node
.
sharding_spec
,
(
list
,
tuple
)):
assert
isinstance
(
node
.
target_sharding_specs
,
(
list
,
tuple
)),
'target sharding specs should be tuple or list when node.sharding_spec is tuple or list'
node
.
target_sharding_specs
,
(
list
,
tuple
)
),
"target sharding specs should be tuple or list when node.sharding_spec is tuple or list"
total_difference
=
0
for
sharding_spec
,
target_sharding_spec
in
zip
(
node
.
sharding_spec
,
node
.
target_sharding_specs
[
user_node_index
]):
for
sharding_spec
,
target_sharding_spec
in
zip
(
node
.
sharding_spec
,
node
.
target_sharding_specs
[
user_node_index
]
):
total_difference
+=
sharding_spec
.
sharding_sequence_difference
(
target_sharding_spec
)
if
total_difference
==
0
:
continue
with
mod_graph
.
inserting_before
(
user_node
):
shape_consistency_node
=
mod_graph
.
create_node
(
'call_function'
,
shape_consistency_node
=
mod_graph
.
create_node
(
"call_function"
,
runtime_apply_for_iterable_object
,
args
=
(
node
,
origin_dict_node
,
input_dict_node
,
node_to_index_dict
[
node
],
user_node_index
)
)
args
=
(
node
,
origin_dict_node
,
input_dict_node
,
node_to_index_dict
[
node
],
user_node_index
),
)
else
:
assert
isinstance
(
node
.
sharding_spec
,
ShardingSpec
),
'node.sharding_spec should be type of ShardingSpec, tuple or list.'
assert
isinstance
(
node
.
sharding_spec
,
ShardingSpec
),
"node.sharding_spec should be type of ShardingSpec, tuple or list."
if
node
.
sharding_spec
.
sharding_sequence_difference
(
node
.
target_sharding_specs
[
user_node_index
])
==
0
:
continue
with
mod_graph
.
inserting_before
(
user_node
):
shape_consistency_node
=
mod_graph
.
create_node
(
'call_function'
,
shape_consistency_node
=
mod_graph
.
create_node
(
"call_function"
,
runtime_apply
,
args
=
(
node
,
origin_dict_node
,
input_dict_node
,
node_to_index_dict
[
node
],
user_node_index
))
if
hasattr
(
user_node
.
meta
[
'info'
],
'activation_checkpoint'
):
MetaInfo
(
shape_consistency_node
,
mod_dir
=
user_node
.
meta
[
'info'
].
mod_dir
,
activation_checkpoint
=
tuple
(
user_node
.
meta
[
'info'
].
activation_checkpoint
))
args
=
(
node
,
origin_dict_node
,
input_dict_node
,
node_to_index_dict
[
node
],
user_node_index
),
)
if
hasattr
(
user_node
.
meta
[
"info"
],
"activation_checkpoint"
):
MetaInfo
(
shape_consistency_node
,
mod_dir
=
user_node
.
meta
[
"info"
].
mod_dir
,
activation_checkpoint
=
tuple
(
user_node
.
meta
[
"info"
].
activation_checkpoint
),
)
new_args
=
list
(
user_node
.
args
)
new_kwargs
=
dict
(
user_node
.
kwargs
)
# the origin node may be a positional argument or key word argument of user node
...
...
@@ -158,12 +158,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
_
,
_
,
comm_actions_dict_node
,
node_to_index_dict
=
_preprocess_graph
(
nodes
)
for
node
in
nodes
:
if
not
hasattr
(
node
,
'
best_strategy
'
)
or
node
.
op
==
'
output
'
:
if
not
hasattr
(
node
,
"
best_strategy
"
)
or
node
.
op
==
"
output
"
:
continue
comm_actions
=
node
.
best_strategy
.
communication_actions
for
op_data
,
comm_action
in
comm_actions
.
items
():
if
comm_action
.
comm_type
==
CommType
.
HOOK
:
continue
if
comm_action
.
comm_type
==
CommType
.
BEFORE
:
...
...
@@ -174,10 +173,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
else
:
comm_object
=
node
.
args
[
comm_action
.
arg_index
]
with
mod_graph
.
inserting_before
(
node
):
comm_spec_apply_node
=
mod_graph
.
create_node
(
'call_function'
,
comm_spec_apply_node
=
mod_graph
.
create_node
(
"call_function"
,
runtime_comm_spec_apply
,
args
=
(
comm_object
,
comm_actions_dict_node
,
node_to_index_dict
[
node
],
op_data
.
name
)
)
args
=
(
comm_object
,
comm_actions_dict_node
,
node_to_index_dict
[
node
],
op_data
.
name
),
)
# the origin node may be a positional argument or key word argument of user node
if
comm_action
.
key_for_kwarg
is
not
None
:
# substitute the origin node with comm_spec_apply_node
...
...
@@ -192,10 +192,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
elif
comm_action
.
comm_type
==
CommType
.
AFTER
:
with
mod_graph
.
inserting_after
(
node
):
comm_spec_apply_node
=
mod_graph
.
create_node
(
'call_function'
,
comm_spec_apply_node
=
mod_graph
.
create_node
(
"call_function"
,
runtime_comm_spec_apply
,
args
=
(
node
,
comm_actions_dict_node
,
node_to_index_dict
[
node
],
op_data
.
name
)
)
args
=
(
node
,
comm_actions_dict_node
,
node_to_index_dict
[
node
],
op_data
.
name
),
)
user_list
=
list
(
node
.
users
.
keys
())
for
user
in
user_list
:
if
user
==
comm_spec_apply_node
:
...
...
@@ -211,15 +212,17 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
# substitute the origin node with comm_spec_apply_node
new_kwargs
[
str
(
node
)]
=
comm_spec_apply_node
user
.
kwargs
=
new_kwargs
if
hasattr
(
node
.
meta
[
'info'
],
'activation_checkpoint'
):
MetaInfo
(
comm_spec_apply_node
,
mod_dir
=
node
.
meta
[
'info'
].
mod_dir
,
activation_checkpoint
=
tuple
(
node
.
meta
[
'info'
].
activation_checkpoint
))
if
hasattr
(
node
.
meta
[
"info"
],
"activation_checkpoint"
):
MetaInfo
(
comm_spec_apply_node
,
mod_dir
=
node
.
meta
[
"info"
].
mod_dir
,
activation_checkpoint
=
tuple
(
node
.
meta
[
"info"
].
activation_checkpoint
),
)
return
gm
def
_act_annotat
a
ion_pass
(
gm
:
torch
.
fx
.
GraphModule
):
def
_act_annotation_pass
(
gm
:
torch
.
fx
.
GraphModule
):
"""
This pass is used to add the act annotation to the new inserted nodes.
"""
...
...
@@ -227,21 +230,21 @@ def _act_annotataion_pass(gm: torch.fx.GraphModule):
nodes
=
tuple
(
mod_graph
.
nodes
)
for
node
in
nodes
:
if
not
hasattr
(
node
.
meta
,
'
activation_checkpoint
'
):
from
.runtime_preparation_pass
import
size_processing
if
not
hasattr
(
node
.
meta
,
"
activation_checkpoint
"
):
pass
user_act_annotation
=
-
1
input_act_annotation
=
-
1
for
user_node
in
node
.
users
.
keys
():
if
'
activation_checkpoint
'
in
user_node
.
meta
:
user_act_annotation
=
user_node
.
meta
[
'
activation_checkpoint
'
]
if
"
activation_checkpoint
"
in
user_node
.
meta
:
user_act_annotation
=
user_node
.
meta
[
"
activation_checkpoint
"
]
break
for
input_node
in
node
.
_input_nodes
.
keys
():
if
'
activation_checkpoint
'
in
input_node
.
meta
:
input_act_annotation
=
input_node
.
meta
[
'
activation_checkpoint
'
]
if
"
activation_checkpoint
"
in
input_node
.
meta
:
input_act_annotation
=
input_node
.
meta
[
"
activation_checkpoint
"
]
break
if
user_act_annotation
==
input_act_annotation
and
user_act_annotation
!=
-
1
:
node
.
meta
[
'
activation_checkpoint
'
]
=
user_act_annotation
node
.
meta
[
"
activation_checkpoint
"
]
=
user_act_annotation
return
gm
...
...
colossalai/auto_parallel/passes/runtime_preparation_pass.py
View file @
9e768b59
import
operator
from
copy
import
deepcopy
from
typing
import
Dict
,
List
,
Union
import
torch
from
torch.fx
import
symbolic_trace
from
torch.fx.node
import
Node
from
colossalai._analyzer.fx.node_util
import
MetaInfo
from
colossalai.auto_parallel.tensor_shard.constants
import
RESHAPE_FUNC_OP
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
CommAction
,
CommType
,
OperationDataType
,
ShardingStrategy
,
)
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
CommType
,
OperationDataType
from
colossalai.auto_parallel.tensor_shard.solver.strategies_constructor
import
StrategiesConstructor
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.comm_spec
import
_all_reduce
...
...
@@ -25,11 +18,13 @@ from .constants import SHAPE_ARGUMENT_OPS
shape_consistency_manager
=
ShapeConsistencyManager
()
def
size_processing
(
size
:
Union
[
int
,
torch
.
Size
],
def
size_processing
(
size
:
Union
[
int
,
torch
.
Size
],
dim_partition_dict
:
Dict
[
int
,
List
[
int
]],
device_mesh_info
:
Dict
[
int
,
int
],
target_dim
:
int
=
None
,
node_name
:
str
=
None
):
node_name
:
str
=
None
,
):
"""
This method will be invoked during runtime to convert size node value depending on distributed information.
"""
...
...
@@ -54,8 +49,9 @@ def size_processing(size: Union[int, torch.Size],
return
size
def
solution_annotatation_pass
(
gm
:
torch
.
fx
.
GraphModule
,
solution
:
List
[
int
],
strategies_constructor
:
StrategiesConstructor
):
def
solution_annotation_pass
(
gm
:
torch
.
fx
.
GraphModule
,
solution
:
List
[
int
],
strategies_constructor
:
StrategiesConstructor
):
"""
This method is used to stick the solution strategy to the nodes and add the information
required in runtime into graph as placeholder nodes.
...
...
@@ -70,14 +66,15 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
for
node_index
,
(
node
,
strategy_index
)
in
enumerate
(
zip
(
nodes
,
solution
)):
strategies_vector
=
node
.
strategies_vector
# stick the solution strategy to the corresponding node
setattr
(
node
,
'
best_strategy
'
,
strategies_vector
[
strategy_index
])
setattr
(
node
,
'
sharding_spec
'
,
strategies_vector
[
strategy_index
].
get_sharding_spec_by_name
(
str
(
node
)))
setattr
(
node
,
"
best_strategy
"
,
strategies_vector
[
strategy_index
])
setattr
(
node
,
"
sharding_spec
"
,
strategies_vector
[
strategy_index
].
get_sharding_spec_by_name
(
str
(
node
)))
origin_node_sharding_spec_dict
[
node_index
]
=
strategies_vector
[
strategy_index
].
get_sharding_spec_by_name
(
str
(
node
))
str
(
node
)
)
# attach the corresponding metainfo if node has the attribute `strategies_info`
if
hasattr
(
node
,
'
strategies_info
'
):
setattr
(
node
,
'
best_strategy_info
'
,
node
.
strategies_info
[
strategy_index
])
if
hasattr
(
node
,
"
strategies_info
"
):
setattr
(
node
,
"
best_strategy_info
"
,
node
.
strategies_info
[
strategy_index
])
# the dict to get input sharding specs of user node
sharding_spec_convert_dict
=
{}
...
...
@@ -92,15 +89,15 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
target_sharding_spec
=
user_node
.
best_strategy
.
get_sharding_spec_by_name
(
str
(
node
.
name
))
target_sharding_specs
.
append
(
target_sharding_spec
)
sharding_spec_convert_dict
[
index
]
=
target_sharding_specs
setattr
(
node
,
'
target_sharding_specs
'
,
target_sharding_specs
)
setattr
(
node
,
"
target_sharding_specs
"
,
target_sharding_specs
)
# the get_attr node strategy is kind of pending strategy, which means we will change it
# to the same strategy of the user node.
if
node
.
op
==
'
get_attr
'
:
assert
len
(
target_sharding_specs
)
==
1
,
f
'
sharing weight is not supported in current version.
'
if
node
.
op
==
"
get_attr
"
:
assert
len
(
target_sharding_specs
)
==
1
,
f
"
sharing weight is not supported in current version.
"
target_node
=
node
.
strategies_vector
.
successor_nodes
[
0
]
node_name
=
str
(
node
)
if
target_node
.
op
==
'
call_function
'
and
target_node
.
target
in
RESHAPE_FUNC_OP
:
if
target_node
.
op
==
"
call_function
"
and
target_node
.
target
in
RESHAPE_FUNC_OP
:
node_name
=
str
(
target_node
)
target_node
=
target_node
.
strategies_vector
.
successor_nodes
[
0
]
user_strategy
=
target_node
.
best_strategy
...
...
@@ -122,11 +119,11 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
# add above dicts into graph
for
node
in
nodes
:
if
node
.
op
!=
'
placeholder
'
:
if
node
.
op
!=
"
placeholder
"
:
with
mod_graph
.
inserting_before
(
node
):
input_specs_node
=
mod_graph
.
create_node
(
'
placeholder
'
,
target
=
'
sharding_spec_convert_dict
'
)
origin_specs_node
=
mod_graph
.
create_node
(
'
placeholder
'
,
target
=
'
origin_node_sharding_spec_dict
'
)
comm_actions_dict_node
=
mod_graph
.
create_node
(
'
placeholder
'
,
target
=
'
comm_actions_dict
'
)
input_specs_node
=
mod_graph
.
create_node
(
"
placeholder
"
,
target
=
"
sharding_spec_convert_dict
"
)
origin_specs_node
=
mod_graph
.
create_node
(
"
placeholder
"
,
target
=
"
origin_node_sharding_spec_dict
"
)
comm_actions_dict_node
=
mod_graph
.
create_node
(
"
placeholder
"
,
target
=
"
comm_actions_dict
"
)
break
return
gm
,
sharding_spec_convert_dict
,
origin_node_sharding_spec_dict
,
comm_actions_dict
...
...
@@ -144,11 +141,11 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
# DeviceMesh information instructs the scaling of the size value
device_mesh_info
=
{}
for
dim
,
dim_size
in
enumerate
(
device_mesh
.
mesh_
shape
):
for
dim
,
dim_size
in
enumerate
(
device_mesh
.
shape
):
device_mesh_info
[
dim
]
=
dim_size
def
_extract_target_dim
(
node
):
'''
"""
A helper function to extract the target dimension from size node.
There are two usages of torch.Tensor.size:
1. tensor.size()
...
...
@@ -156,7 +153,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
If a target_dim is assigned, then the output will be in type of int, instead of torch.Size.
Otherwise, the output will be in type of torch.Size and this function will return None.
'''
"""
target_dim
=
None
if
len
(
node
.
args
)
>
1
:
target_dim
=
node
.
args
[
1
]
...
...
@@ -165,19 +162,21 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
return
target_dim
def
_post_processing
(
node
,
size_processing_node
):
'''
"""
This function is used to process the dependency between the size node and its users after
inserting the size_process_node.
'''
# store original node and processing node pair in node_pairs dictio
a
nry
"""
# store original node and processing node pair in node_pairs diction
a
ry
# It will be used to replace the original node with processing node in slice object
node_pairs
[
node
]
=
size_processing_node
size_processing_node
.
_meta_data
=
node
.
_meta_data
if
hasattr
(
node
.
meta
[
'info'
],
'activation_checkpoint'
):
MetaInfo
(
size_processing_node
,
mod_dir
=
node
.
meta
[
'info'
].
mod_dir
,
activation_checkpoint
=
tuple
(
node
.
meta
[
'info'
].
activation_checkpoint
))
if
hasattr
(
node
.
meta
[
"info"
],
"activation_checkpoint"
):
MetaInfo
(
size_processing_node
,
mod_dir
=
node
.
meta
[
"info"
].
mod_dir
,
activation_checkpoint
=
tuple
(
node
.
meta
[
"info"
].
activation_checkpoint
),
)
user_list
=
list
(
node
.
users
.
keys
())
for
user
in
user_list
:
...
...
@@ -196,10 +195,10 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
user
.
kwargs
=
new_kwargs
def
_update_slice_object_args
(
slice_object
):
'''
"""
This function is used to update the slice object argument list.
If the slice object contains the Node argument, then the size node will be replaced with
'''
"""
if
isinstance
(
slice_object
,
slice
):
start
=
slice_object
.
start
stop
=
slice_object
.
stop
...
...
@@ -220,8 +219,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
raise
RuntimeError
(
f
"Unsupported slice object type:
{
type
(
slice_object
)
}
"
)
for
node
in
nodes
:
if
node
.
op
==
'call_method'
and
node
.
target
==
'size'
:
if
node
.
op
==
"call_method"
and
node
.
target
==
"size"
:
# extract useful information from size node
# dim_partition_dict will instruct the size value on which
# dimension should be enlarged.
...
...
@@ -232,14 +230,14 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
# insert size_processing node
with
mod_graph
.
inserting_after
(
node
):
size_processing_node
=
mod_graph
.
create_node
(
'call_function'
,
size_processing_node
=
mod_graph
.
create_node
(
"call_function"
,
size_processing
,
args
=
(
node
,
dim_partition_dict
,
device_mesh_info
,
target_dim
,
node
.
name
)
)
args
=
(
node
,
dim_partition_dict
,
device_mesh_info
,
target_dim
,
node
.
name
),
)
_post_processing
(
node
,
size_processing_node
)
if
node
.
op
==
'call_function'
and
node
.
target
==
operator
.
getitem
:
if
node
.
op
==
"call_function"
and
node
.
target
==
operator
.
getitem
:
getitem_index
=
node
.
args
[
1
]
# slice object is quite special in torch.fx graph,
# On one side, we treat slice object same as type of int,
...
...
@@ -287,18 +285,19 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
nodes
=
tuple
(
mod_graph
.
nodes
)
def
_extract_info_from_sharding_spec
(
sharding_spec
):
'''
"""
This function is used to extract the dim_partition_dict and device_mesh from
sharding spec instance or a list of sharding spec.
'''
"""
if
isinstance
(
sharding_spec
,
ShardingSpec
):
dim_partition_dict
=
sharding_spec
.
dim_partition_dict
device_mesh
=
sharding_spec
.
device_mesh
return
dim_partition_dict
,
device_mesh
if
sharding_spec
is
None
:
return
None
,
None
assert
isinstance
(
sharding_spec
,
(
tuple
,
list
)),
'sharding_spec should be type of ShardingSpec, tuple, list or None'
assert
isinstance
(
sharding_spec
,
(
tuple
,
list
)
),
"sharding_spec should be type of ShardingSpec, tuple, list or None"
device_mesh
=
sharding_spec
[
0
].
device_mesh
dim_partition_dict
=
[]
...
...
@@ -322,8 +321,9 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
else
:
new_args
.
append
(
arg
)
else
:
assert
isinstance
(
arg
,
(
int
,
tuple
,
list
)),
'The argument in view node should be either type of Node or int.'
assert
isinstance
(
arg
,
(
int
,
tuple
,
list
)
),
"The argument in view node should be either type of Node or int."
if
isinstance
(
arg
,
(
tuple
,
list
)):
new_args
.
extend
(
arg
)
else
:
...
...
@@ -332,7 +332,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
def
_scale_args_adapt_sharding_spec
(
dim_partition_dict
,
device_mesh
,
node
):
new_args
=
_process_node_arguments
(
node
)
if
node
.
op
==
'
call_method
'
:
if
node
.
op
==
"
call_method
"
:
args_to_process
=
list
(
new_args
[
1
:])
else
:
args_to_process
=
list
(
new_args
)
...
...
@@ -350,7 +350,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
args_to_process
=
tuple
(
args_to_process
)
if
node
.
op
==
'
call_method
'
:
if
node
.
op
==
"
call_method
"
:
new_args
=
(
new_args
[
0
],)
+
args_to_process
else
:
new_args
=
args_to_process
...
...
@@ -358,9 +358,9 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
node
.
args
=
new_args
def
_filter_node_with_shape_args
(
node
):
if
node
.
op
==
'
call_method
'
:
if
node
.
op
==
"
call_method
"
:
target
=
getattr
(
node
.
args
[
0
].
_meta_data
.
__class__
,
node
.
target
)
elif
node
.
op
==
'
call_function
'
:
elif
node
.
op
==
"
call_function
"
:
target
=
node
.
target
else
:
target
=
None
...
...
@@ -371,7 +371,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
for
node
in
nodes
:
# skip the placeholder node added in _solution_annotation pass
if
not
hasattr
(
node
,
'
sharding_spec
'
):
if
not
hasattr
(
node
,
"
sharding_spec
"
):
continue
output_dim_partition_dict
,
device_mesh
=
_extract_info_from_sharding_spec
(
node
.
sharding_spec
)
...
...
@@ -388,19 +388,25 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
"""
mod_graph
=
gm
.
graph
nodes
=
tuple
(
mod_graph
.
nodes
)
# This stream is created for overlaping the communication and computation.
# This stream is created for overlap
p
ing the communication and computation.
reduction_stream
=
torch
.
cuda
.
Stream
()
def
_add_hook_for_grad_communication
(
node
,
param
,
name
=
None
):
comm_actions
=
node
.
best_strategy
.
communication_actions
def
_filter_param_to_hook
(
node
,
op_data
,
comm_action
,
name
):
if
node
.
op
==
'call_module'
and
op_data
.
type
==
OperationDataType
.
PARAM
and
op_data
.
name
==
name
and
comm_action
.
comm_type
==
CommType
.
HOOK
:
if
(
node
.
op
==
"call_module"
and
op_data
.
type
==
OperationDataType
.
PARAM
and
op_data
.
name
==
name
and
comm_action
.
comm_type
==
CommType
.
HOOK
):
return
True
if
node
.
op
==
'get_attr'
and
isinstance
(
node
.
_meta_data
,
torch
.
nn
.
parameter
.
Parameter
)
and
comm_action
.
comm_type
==
CommType
.
HOOK
:
if
(
node
.
op
==
"get_attr"
and
isinstance
(
node
.
_meta_data
,
torch
.
nn
.
parameter
.
Parameter
)
and
comm_action
.
comm_type
==
CommType
.
HOOK
):
return
True
return
False
...
...
@@ -410,7 +416,6 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
if
_filter_param_to_hook
(
node
,
operation_data
,
comm_action
,
name
=
name
):
def
wrapper
(
param
,
comm_spec
,
stream
,
overlap
):
def
hook_fn
(
grad
):
if
overlap
:
with
torch
.
cuda
.
stream
(
stream
):
...
...
@@ -426,22 +431,26 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
# apply the sharding spec of parameters
if
target_sharding_spec
.
dim_partition_dict
!=
{}:
origin_sharding_spec
=
ShardingSpec
(
device_mesh
,
param
.
shape
,
{})
setattr
(
param
,
'
sharding_spec
'
,
origin_sharding_spec
)
setattr
(
param
,
"
sharding_spec
"
,
origin_sharding_spec
)
# TODO: build a ColoParameter class to manager the distributed parameters
# we could use .data here, because all the operations just happen before the real training
# loop, so we don't need to track these operations in the autograd graph.
param
=
torch
.
nn
.
Parameter
(
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
param
.
data
,
param
.
sharding_spec
,
target_sharding_spec
).
detach
().
clone
())
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
param
.
data
,
param
.
sharding_spec
,
target_sharding_spec
)
.
detach
()
.
clone
()
)
return
param
for
node
in
nodes
:
if
node
.
op
==
'
call_module
'
:
if
node
.
op
==
"
call_module
"
:
target_module
=
node
.
graph
.
owning_module
.
get_submodule
(
node
.
target
)
# TODO: we need to do more actions to take care of the shared parameters.
if
hasattr
(
target_module
,
'
processed
'
)
and
target_module
.
processed
:
if
hasattr
(
target_module
,
"
processed
"
)
and
target_module
.
processed
:
continue
setattr
(
target_module
,
'
processed
'
,
True
)
setattr
(
target_module
,
"
processed
"
,
True
)
for
name
,
param
in
target_module
.
named_parameters
():
target_sharding_spec
=
node
.
best_strategy
.
get_sharding_spec_by_name
(
name
)
param
=
_shard_param
(
param
,
target_sharding_spec
)
...
...
@@ -453,7 +462,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
# apply the sharding spec of buffers
for
name
,
buffer
in
target_module
.
named_buffers
():
origin_sharding_spec
=
ShardingSpec
(
device_mesh
,
buffer
.
shape
,
{})
setattr
(
buffer
,
'
sharding_spec
'
,
origin_sharding_spec
)
setattr
(
buffer
,
"
sharding_spec
"
,
origin_sharding_spec
)
target_sharding_spec
=
node
.
best_strategy
.
get_sharding_spec_by_name
(
name
)
buffer_sharded
=
shape_consistency_manager
.
apply
(
buffer
,
target_sharding_spec
)
sharded_buffer_dict
[
name
]
=
buffer_sharded
...
...
@@ -461,7 +470,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
for
name
,
buffer_sharded
in
sharded_buffer_dict
.
items
():
setattr
(
target_module
,
name
,
buffer_sharded
.
detach
().
clone
())
if
node
.
op
==
'
get_attr
'
:
if
node
.
op
==
"
get_attr
"
:
root
=
node
.
graph
.
owning_module
atoms
=
node
.
target
.
split
(
"."
)
attr_len
=
len
(
atoms
)
...
...
@@ -488,16 +497,18 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
"""
replace the origin kernel into kernel with implicit communication inside.
"""
pass
def
runtime_preparation_pass
(
gm
:
torch
.
fx
.
GraphModule
,
def
runtime_preparation_pass
(
gm
:
torch
.
fx
.
GraphModule
,
solution
:
List
[
int
],
device_mesh
:
DeviceMesh
,
strategies_constructor
:
StrategiesConstructor
,
overlap
=
False
):
gm
,
sharding_spec_convert_dict
,
origin_node_sharding_spec_dict
,
comm_actions_dict
=
solution_annotatation_pass
(
gm
,
solution
,
strategies_constructor
)
overlap
=
False
,
):
gm
,
sharding_spec_convert_dict
,
origin_node_sharding_spec_dict
,
comm_actions_dict
=
solution_annotation_pass
(
gm
,
solution
,
strategies_constructor
)
gm
=
size_value_converting_pass
(
gm
,
device_mesh
)
gm
=
node_args_converting_pass
(
gm
,
device_mesh
)
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
...
...
Prev
1
…
12
13
14
15
16
17
18
19
20
…
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment