Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
9e768b59
Commit
9e768b59
authored
Oct 10, 2023
by
zhuwenwen
Browse files
Merge branch 'main' of
https://github.com/hpcaitech/ColossalAI
parents
7bc5a8e3
8aed02b9
Changes
436
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
581 additions
and
607 deletions
+581
-607
colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py
...lai/auto_parallel/meta_profiler/meta_registry/non_spmd.py
+1
-1
colossalai/auto_parallel/meta_profiler/meta_registry/norm.py
colossalai/auto_parallel/meta_profiler/meta_registry/norm.py
+56
-46
colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py
...alai/auto_parallel/meta_profiler/meta_registry/pooling.py
+8
-6
colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py
...salai/auto_parallel/meta_profiler/meta_registry/tensor.py
+28
-15
colossalai/auto_parallel/meta_profiler/meta_registry/where.py
...ssalai/auto_parallel/meta_profiler/meta_registry/where.py
+16
-11
colossalai/auto_parallel/meta_profiler/registry.py
colossalai/auto_parallel/meta_profiler/registry.py
+3
-5
colossalai/auto_parallel/meta_profiler/shard_metainfo.py
colossalai/auto_parallel/meta_profiler/shard_metainfo.py
+14
-18
colossalai/auto_parallel/offload/amp_optimizer.py
colossalai/auto_parallel/offload/amp_optimizer.py
+38
-35
colossalai/auto_parallel/offload/base_offload_module.py
colossalai/auto_parallel/offload/base_offload_module.py
+6
-8
colossalai/auto_parallel/offload/mem_optimize.py
colossalai/auto_parallel/offload/mem_optimize.py
+6
-8
colossalai/auto_parallel/offload/region.py
colossalai/auto_parallel/offload/region.py
+5
-5
colossalai/auto_parallel/offload/region_manager.py
colossalai/auto_parallel/offload/region_manager.py
+62
-75
colossalai/auto_parallel/offload/runtime.py
colossalai/auto_parallel/offload/runtime.py
+34
-34
colossalai/auto_parallel/offload/solver.py
colossalai/auto_parallel/offload/solver.py
+40
-65
colossalai/auto_parallel/offload/training_simulator.py
colossalai/auto_parallel/offload/training_simulator.py
+48
-82
colossalai/auto_parallel/offload/util.py
colossalai/auto_parallel/offload/util.py
+10
-12
colossalai/auto_parallel/passes/comm_metainfo_pass.py
colossalai/auto_parallel/passes/comm_metainfo_pass.py
+34
-23
colossalai/auto_parallel/passes/meta_info_prop.py
colossalai/auto_parallel/passes/meta_info_prop.py
+12
-12
colossalai/auto_parallel/passes/runtime_apply_pass.py
colossalai/auto_parallel/passes/runtime_apply_pass.py
+67
-64
colossalai/auto_parallel/passes/runtime_preparation_pass.py
colossalai/auto_parallel/passes/runtime_preparation_pass.py
+93
-82
No files found.
Too many changes to show.
To preserve performance only
436 of 436+
files are displayed.
Plain diff
Email patch
colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py
View file @
9e768b59
...
@@ -3,7 +3,7 @@ from typing import List, Tuple
...
@@ -3,7 +3,7 @@ from typing import List, Tuple
import
torch
import
torch
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
MemoryCost
,
OperationDataType
,
TrainCycleItem
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
MemoryCost
,
TrainCycleItem
from
..registry
import
meta_register
from
..registry
import
meta_register
...
...
colossalai/auto_parallel/meta_profiler/meta_registry/norm.py
View file @
9e768b59
from
typing
import
Callable
,
Dict
,
List
,
Tuple
,
Union
from
typing
import
List
,
Tuple
import
torch
import
torch
from
colossalai._analyzer._subclasses.flop_tensor
import
flop_mapping
from
colossalai._analyzer._subclasses.flop_tensor
import
flop_mapping
from
colossalai._analyzer.fx.node_util
import
compute_size_in_bytes
from
colossalai._analyzer.fx.node_util
import
compute_size_in_bytes
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
MemoryCost
,
OperationDataType
,
TrainCycleItem
MemoryCost
,
OperationData
,
OperationDataType
,
ShardingStrategy
,
StrategiesVector
,
TrainCycleItem
,
)
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
..registry
import
meta_register
from
..registry
import
meta_register
__all__
=
[
'
batchnormnd_meta_info
'
,
'
layernorm_meta_info
'
]
__all__
=
[
"
batchnormnd_meta_info
"
,
"
layernorm_meta_info
"
]
@
meta_register
.
register
(
torch
.
nn
.
BatchNorm1d
)
@
meta_register
.
register
(
torch
.
nn
.
BatchNorm1d
)
...
@@ -65,7 +57,15 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
...
@@ -65,7 +57,15 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
# saved inv std and some other args indicating the status of the module
# saved inv std and some other args indicating the status of the module
# the bwd outputs are input grad, weight grad and bias grad
# the bwd outputs are input grad, weight grad and bias grad
bwd_in_args
=
[
bwd_in_args
=
[
output_tensor
,
output_tensor
,
weight_tensor
,
mean_tensor
,
var_tensor
,
mean_tensor
,
var_tensor
,
1e-5
,
num_batch
output_tensor
,
output_tensor
,
weight_tensor
,
mean_tensor
,
var_tensor
,
mean_tensor
,
var_tensor
,
1e-5
,
num_batch
,
]
]
bwd_out_args
=
[
input_tensor
,
weight_tensor
,
bias_tensor
]
bwd_out_args
=
[
input_tensor
,
weight_tensor
,
bias_tensor
]
...
@@ -77,29 +77,34 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
...
@@ -77,29 +77,34 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
# calculate memory cost
# calculate memory cost
# the fwd activation cost is output plus saved mean and saved inv std
# the fwd activation cost is output plus saved mean and saved inv std
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost
=
MemoryCost
(
activation
=
compute_size_in_bytes
(
fwd_memory_cost
=
MemoryCost
(
[
input_tensor
,
output_tensor
,
mean_tensor
,
var_tensor
]),
activation
=
compute_size_in_bytes
([
input_tensor
,
output_tensor
,
mean_tensor
,
var_tensor
]),
parameter
=
compute_size_in_bytes
([
weight_tensor
,
bias_tensor
]),
parameter
=
compute_size_in_bytes
([
weight_tensor
,
bias_tensor
]),
temp
=
0
,
temp
=
0
,
buffer
=
compute_size_in_bytes
([
mean_tensor
,
var_tensor
]))
buffer
=
compute_size_in_bytes
([
mean_tensor
,
var_tensor
]),
)
# the bwd memory cost is quite tricky here, BatchNorm will remove saved mean
# the bwd memory cost is quite tricky here, BatchNorm will remove saved mean
# and saved inv std during backward phase
# and saved inv std during backward phase
bwd_memory_cost
=
MemoryCost
(
activation
=
compute_size_in_bytes
([
input_tensor
]),
bwd_memory_cost
=
MemoryCost
(
parameter
=
compute_size_in_bytes
([
weight_tensor
,
bias_tensor
]),
activation
=
compute_size_in_bytes
([
input_tensor
]),
temp
=
compute_size_in_bytes
([
mean_tensor
,
var_tensor
]),
parameter
=
compute_size_in_bytes
([
weight_tensor
,
bias_tensor
]),
buffer
=
compute_size_in_bytes
([
mean_tensor
,
var_tensor
]))
temp
=
compute_size_in_bytes
([
mean_tensor
,
var_tensor
]),
buffer
=
compute_size_in_bytes
([
mean_tensor
,
var_tensor
]),
)
# total cost is the sum of forward and backward cost
# total cost is the sum of forward and backward cost
total_cost
=
MemoryCost
(
activation
=
fwd_memory_cost
.
activation
+
bwd_memory_cost
.
activation
,
total_cost
=
MemoryCost
(
parameter
=
fwd_memory_cost
.
parameter
+
bwd_memory_cost
.
parameter
)
activation
=
fwd_memory_cost
.
activation
+
bwd_memory_cost
.
activation
,
parameter
=
fwd_memory_cost
.
parameter
+
bwd_memory_cost
.
parameter
,
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_memory_cost
,
bwd
=
bwd_memory_cost
,
total
=
total_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_memory_cost
,
bwd
=
bwd_memory_cost
,
total
=
total_cost
)
# store fwd_in, fwd_buffer, fwd_out
# store fwd_in, fwd_buffer, fwd_out
fwd_in
=
[
torch
.
zeros_like
(
input_tensor
,
device
=
'
meta
'
)]
fwd_in
=
[
torch
.
zeros_like
(
input_tensor
,
device
=
"
meta
"
)]
fwd_buffer
=
[
torch
.
zeros_like
(
mean_tensor
,
device
=
'
meta
'
),
torch
.
zeros_like
(
var_tensor
,
device
=
'
meta
'
)]
fwd_buffer
=
[
torch
.
zeros_like
(
mean_tensor
,
device
=
"
meta
"
),
torch
.
zeros_like
(
var_tensor
,
device
=
"
meta
"
)]
fwd_out
=
[
torch
.
zeros_like
(
output_tensor
,
device
=
'
meta
'
)]
fwd_out
=
[
torch
.
zeros_like
(
output_tensor
,
device
=
"
meta
"
)]
return
compute_cost
,
memory_cost
,
fwd_in
,
fwd_buffer
,
fwd_out
return
compute_cost
,
memory_cost
,
fwd_in
,
fwd_buffer
,
fwd_out
...
@@ -116,8 +121,8 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
...
@@ -116,8 +121,8 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
output_tensor
=
next
(
filter
(
lambda
x
:
x
.
type
==
OperationDataType
.
OUTPUT
,
args
)).
data
output_tensor
=
next
(
filter
(
lambda
x
:
x
.
type
==
OperationDataType
.
OUTPUT
,
args
)).
data
weight_tensor
=
next
(
filter
(
lambda
x
:
x
.
name
==
"weight"
,
args
)).
data
weight_tensor
=
next
(
filter
(
lambda
x
:
x
.
name
==
"weight"
,
args
)).
data
bias_tensor
=
next
(
filter
(
lambda
x
:
x
.
name
==
"bias"
,
args
)).
data
bias_tensor
=
next
(
filter
(
lambda
x
:
x
.
name
==
"bias"
,
args
)).
data
running_mean
=
torch
.
rand
(
input_tensor
.
shape
[
0
],
1
,
device
=
'
meta
'
)
running_mean
=
torch
.
rand
(
input_tensor
.
shape
[
0
],
1
,
device
=
"
meta
"
)
running_var
=
torch
.
rand
(
input_tensor
.
shape
[
0
],
1
,
device
=
'
meta
'
)
running_var
=
torch
.
rand
(
input_tensor
.
shape
[
0
],
1
,
device
=
"
meta
"
)
# construct args
# construct args
fwd_in_args
=
[
input_tensor
,
[
input_tensor
.
shape
[
0
]],
weight_tensor
]
fwd_in_args
=
[
input_tensor
,
[
input_tensor
.
shape
[
0
]],
weight_tensor
]
...
@@ -132,27 +137,32 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
...
@@ -132,27 +137,32 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
# memory cost
# memory cost
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost
=
MemoryCost
(
activation
=
compute_size_in_bytes
(
fwd_memory_cost
=
MemoryCost
(
[
input_tensor
,
output_tensor
,
weight_tensor
,
bias_tensor
]),
activation
=
compute_size_in_bytes
([
input_tensor
,
output_tensor
,
weight_tensor
,
bias_tensor
]),
parameter
=
compute_size_in_bytes
([
weight_tensor
,
bias_tensor
]),
parameter
=
compute_size_in_bytes
([
weight_tensor
,
bias_tensor
]),
temp
=
0
,
temp
=
0
,
buffer
=
compute_size_in_bytes
([
running_mean
,
running_var
]))
buffer
=
compute_size_in_bytes
([
running_mean
,
running_var
]),
)
bwd_memory_cost
=
MemoryCost
(
activation
=
compute_size_in_bytes
([
input_tensor
,
weight_tensor
,
bias_tensor
]),
parameter
=
compute_size_in_bytes
([
weight_tensor
,
bias_tensor
]),
bwd_memory_cost
=
MemoryCost
(
temp
=
compute_size_in_bytes
([
running_mean
,
running_var
]),
activation
=
compute_size_in_bytes
([
input_tensor
,
weight_tensor
,
bias_tensor
]),
buffer
=
compute_size_in_bytes
([
running_mean
,
running_var
]))
parameter
=
compute_size_in_bytes
([
weight_tensor
,
bias_tensor
]),
temp
=
compute_size_in_bytes
([
running_mean
,
running_var
]),
total_cost
=
MemoryCost
(
activation
=
fwd_memory_cost
.
activation
+
bwd_memory_cost
.
activation
,
buffer
=
compute_size_in_bytes
([
running_mean
,
running_var
]),
parameter
=
fwd_memory_cost
.
parameter
+
bwd_memory_cost
.
parameter
,
)
temp
=
fwd_memory_cost
.
temp
+
bwd_memory_cost
.
temp
,
buffer
=
fwd_memory_cost
.
buffer
+
bwd_memory_cost
.
buffer
)
total_cost
=
MemoryCost
(
activation
=
fwd_memory_cost
.
activation
+
bwd_memory_cost
.
activation
,
parameter
=
fwd_memory_cost
.
parameter
+
bwd_memory_cost
.
parameter
,
temp
=
fwd_memory_cost
.
temp
+
bwd_memory_cost
.
temp
,
buffer
=
fwd_memory_cost
.
buffer
+
bwd_memory_cost
.
buffer
,
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_memory_cost
,
bwd
=
bwd_memory_cost
,
total
=
total_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_memory_cost
,
bwd
=
bwd_memory_cost
,
total
=
total_cost
)
# store fwd_in, fwd_buffer, fwd_out
# store fwd_in, fwd_buffer, fwd_out
fwd_in
=
[
torch
.
zeros_like
(
input_tensor
,
device
=
'
meta
'
)]
fwd_in
=
[
torch
.
zeros_like
(
input_tensor
,
device
=
"
meta
"
)]
fwd_buffer
=
[
torch
.
zeros_like
(
running_mean
,
device
=
'
meta
'
),
torch
.
zeros_like
(
running_var
,
device
=
'
meta
'
)]
fwd_buffer
=
[
torch
.
zeros_like
(
running_mean
,
device
=
"
meta
"
),
torch
.
zeros_like
(
running_var
,
device
=
"
meta
"
)]
fwd_out
=
[
torch
.
zeros_like
(
output_tensor
,
device
=
'
meta
'
)]
fwd_out
=
[
torch
.
zeros_like
(
output_tensor
,
device
=
"
meta
"
)]
return
compute_cost
,
memory_cost
,
fwd_in
,
fwd_buffer
,
fwd_out
return
compute_cost
,
memory_cost
,
fwd_in
,
fwd_buffer
,
fwd_out
colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py
View file @
9e768b59
...
@@ -63,7 +63,7 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
...
@@ -63,7 +63,7 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
# store fwd_in, fwd_buffer, fwd_out
# store fwd_in, fwd_buffer, fwd_out
fwd_in
=
[]
fwd_in
=
[]
fwd_buffer
=
[]
fwd_buffer
=
[]
fwd_out
=
[
torch
.
zeros_like
(
output_tensor
,
device
=
'
meta
'
)]
fwd_out
=
[
torch
.
zeros_like
(
output_tensor
,
device
=
"
meta
"
)]
return
compute_cost
,
mem_cost
,
fwd_in
,
fwd_buffer
,
fwd_out
return
compute_cost
,
mem_cost
,
fwd_in
,
fwd_buffer
,
fwd_out
...
@@ -117,8 +117,10 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
...
@@ -117,8 +117,10 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
fwd_mem_cost
=
MemoryCost
(
activation
=
compute_size_in_bytes
([
input_tensor
,
output_tensor
,
index_matrix
]))
fwd_mem_cost
=
MemoryCost
(
activation
=
compute_size_in_bytes
([
input_tensor
,
output_tensor
,
index_matrix
]))
# temp memory for backward is the index matrix to be discarded
# temp memory for backward is the index matrix to be discarded
bwd_mem_cost
=
MemoryCost
(
activation
=
compute_size_in_bytes
(
input_tensor
)
-
compute_size_in_bytes
(
index_matrix
),
bwd_mem_cost
=
MemoryCost
(
temp
=
compute_size_in_bytes
(
index_matrix
))
activation
=
compute_size_in_bytes
(
input_tensor
)
-
compute_size_in_bytes
(
index_matrix
),
temp
=
compute_size_in_bytes
(
index_matrix
),
)
# total cost
# total cost
total_mem_cost
=
MemoryCost
(
activation
=
fwd_mem_cost
.
activation
+
bwd_mem_cost
.
activation
,
temp
=
bwd_mem_cost
.
temp
)
total_mem_cost
=
MemoryCost
(
activation
=
fwd_mem_cost
.
activation
+
bwd_mem_cost
.
activation
,
temp
=
bwd_mem_cost
.
temp
)
...
@@ -126,8 +128,8 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
...
@@ -126,8 +128,8 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
mem_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
mem_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
# store fwd_in, fwd_buffer, fwd_out
# store fwd_in, fwd_buffer, fwd_out
fwd_in
=
[
torch
.
zeros_like
(
input_tensor
,
device
=
'
meta
'
)]
fwd_in
=
[
torch
.
zeros_like
(
input_tensor
,
device
=
"
meta
"
)]
fwd_buffer
=
[
torch
.
zeros_like
(
index_matrix
,
device
=
'
meta
'
)]
fwd_buffer
=
[
torch
.
zeros_like
(
index_matrix
,
device
=
"
meta
"
)]
fwd_out
=
[
torch
.
zeros_like
(
output_tensor
,
device
=
'
meta
'
)]
fwd_out
=
[
torch
.
zeros_like
(
output_tensor
,
device
=
"
meta
"
)]
return
compute_cost
,
mem_cost
,
fwd_in
,
fwd_buffer
,
fwd_out
return
compute_cost
,
mem_cost
,
fwd_in
,
fwd_buffer
,
fwd_out
colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py
View file @
9e768b59
...
@@ -2,7 +2,6 @@ from typing import Callable, List, Tuple
...
@@ -2,7 +2,6 @@ from typing import Callable, List, Tuple
import
torch
import
torch
from
colossalai._analyzer._subclasses.flop_tensor
import
flop_mapping
from
colossalai._analyzer.fx.node_util
import
compute_size_in_bytes
from
colossalai._analyzer.fx.node_util
import
compute_size_in_bytes
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
MemoryCost
,
OperationDataType
,
TrainCycleItem
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
MemoryCost
,
OperationDataType
,
TrainCycleItem
...
@@ -37,15 +36,19 @@ def tensor_related_metainfo(bwd_mem_out_factor: float = 1, bwd_mem_tmp_factor: f
...
@@ -37,15 +36,19 @@ def tensor_related_metainfo(bwd_mem_out_factor: float = 1, bwd_mem_tmp_factor: f
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_mem_cost
=
MemoryCost
(
activation
=
compute_size_in_bytes
(
outputs
)
*
2
,
parameter
=
0
,
temp
=
0
,
buffer
=
0
)
fwd_mem_cost
=
MemoryCost
(
activation
=
compute_size_in_bytes
(
outputs
)
*
2
,
parameter
=
0
,
temp
=
0
,
buffer
=
0
)
bwd_mem_cost
=
MemoryCost
(
activation
=
compute_size_in_bytes
(
outputs
)
*
bwd_mem_out_factor
,
bwd_mem_cost
=
MemoryCost
(
parameter
=
0
,
activation
=
compute_size_in_bytes
(
outputs
)
*
bwd_mem_out_factor
,
temp
=
compute_size_in_bytes
(
outputs
)
*
bwd_mem_tmp_factor
,
parameter
=
0
,
buffer
=
0
)
temp
=
compute_size_in_bytes
(
outputs
)
*
bwd_mem_tmp_factor
,
buffer
=
0
,
)
total_mem_cost
=
MemoryCost
(
activation
=
fwd_mem_cost
.
activation
+
bwd_mem_cost
.
activation
,
total_mem_cost
=
MemoryCost
(
parameter
=
fwd_mem_cost
.
parameter
+
bwd_mem_cost
.
parameter
,
activation
=
fwd_mem_cost
.
activation
+
bwd_mem_cost
.
activation
,
temp
=
fwd_mem_cost
.
temp
+
bwd_mem_cost
.
temp
,
parameter
=
fwd_mem_cost
.
parameter
+
bwd_mem_cost
.
parameter
,
buffer
=
fwd_mem_cost
.
buffer
+
bwd_mem_cost
.
buffer
)
temp
=
fwd_mem_cost
.
temp
+
bwd_mem_cost
.
temp
,
buffer
=
fwd_mem_cost
.
buffer
+
bwd_mem_cost
.
buffer
,
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
...
@@ -66,14 +69,24 @@ def tensor_related_metainfo(bwd_mem_out_factor: float = 1, bwd_mem_tmp_factor: f
...
@@ -66,14 +69,24 @@ def tensor_related_metainfo(bwd_mem_out_factor: float = 1, bwd_mem_tmp_factor: f
# register torch.Tensor related metainfo
# register torch.Tensor related metainfo
# (0, 0)
# (0, 0)
meta_register
.
register
([
torch
.
tensor
,
torch
.
Tensor
.
to
,
torch
.
Tensor
.
unsqueeze
,
torch
.
unsqueeze
,
meta_register
.
register
([
torch
.
tensor
,
torch
.
Tensor
.
to
,
torch
.
Tensor
.
unsqueeze
,
torch
.
unsqueeze
,
torch
.
arange
])(
torch
.
arange
])(
tensor_related_metainfo
(
0
,
0
))
tensor_related_metainfo
(
0
,
0
)
)
# (1, 0)
# (1, 0)
meta_register
.
register
([
meta_register
.
register
(
torch
.
Tensor
.
flatten
,
torch
.
flatten
,
torch
.
Tensor
.
transpose
,
torch
.
transpose
,
torch
.
Tensor
.
permute
,
torch
.
permute
,
[
torch
.
Tensor
.
split
,
torch
.
split
,
torch
.
Tensor
.
view
torch
.
Tensor
.
flatten
,
])(
tensor_related_metainfo
(
1
,
0
))
torch
.
flatten
,
torch
.
Tensor
.
transpose
,
torch
.
transpose
,
torch
.
Tensor
.
permute
,
torch
.
permute
,
torch
.
Tensor
.
split
,
torch
.
split
,
torch
.
Tensor
.
view
,
]
)(
tensor_related_metainfo
(
1
,
0
))
# (1, 1)
# (1, 1)
meta_register
.
register
([
torch
.
Tensor
.
type
,
torch
.
Tensor
.
contiguous
])(
tensor_related_metainfo
(
1
,
1
))
meta_register
.
register
([
torch
.
Tensor
.
type
,
torch
.
Tensor
.
contiguous
])(
tensor_related_metainfo
(
1
,
1
))
colossalai/auto_parallel/meta_profiler/meta_registry/where.py
View file @
9e768b59
...
@@ -4,7 +4,7 @@ import torch
...
@@ -4,7 +4,7 @@ import torch
from
colossalai._analyzer._subclasses.flop_tensor
import
flop_mapping
from
colossalai._analyzer._subclasses.flop_tensor
import
flop_mapping
from
colossalai._analyzer.fx.node_util
import
compute_size_in_bytes
as
activation_size
from
colossalai._analyzer.fx.node_util
import
compute_size_in_bytes
as
activation_size
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
MemoryCost
,
OperationDataType
,
TrainCycleItem
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
MemoryCost
,
TrainCycleItem
from
..registry
import
meta_register
from
..registry
import
meta_register
...
@@ -39,16 +39,21 @@ def where_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Li
...
@@ -39,16 +39,21 @@ def where_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Li
# gradient matrix for input x and input y, remove the temp memory and condition tensor generated in forward phase
# gradient matrix for input x and input y, remove the temp memory and condition tensor generated in forward phase
# NOTE: currently in SPMD solver we always believe that there will be a new input tensor created in forward
# NOTE: currently in SPMD solver we always believe that there will be a new input tensor created in forward
fwd_mem_cost
=
MemoryCost
(
activation
=
activation_size
([
condition_tensor
,
x_tensor
,
y_tensor
,
output_tensor
]))
fwd_mem_cost
=
MemoryCost
(
activation
=
activation_size
([
condition_tensor
,
x_tensor
,
y_tensor
,
output_tensor
]))
bwd_mem_cost
=
MemoryCost
(
activation
=
activation_size
([
x_tensor
,
y_tensor
])
-
activation_size
([
condition_tensor
]),
bwd_mem_cost
=
MemoryCost
(
parameter
=
0
,
activation
=
activation_size
([
x_tensor
,
y_tensor
])
-
activation_size
([
condition_tensor
]),
temp
=
activation_size
([
output_tensor
])
*
3
+
activation_size
([
condition_tensor
])
-
parameter
=
0
,
activation_size
([
x_tensor
,
y_tensor
]),
temp
=
activation_size
([
output_tensor
])
*
3
buffer
=
0
)
+
activation_size
([
condition_tensor
])
-
activation_size
([
x_tensor
,
y_tensor
]),
total_mem_cost
=
MemoryCost
(
activation
=
fwd_mem_cost
.
activation
+
bwd_mem_cost
.
activation
,
buffer
=
0
,
parameter
=
fwd_mem_cost
.
parameter
+
bwd_mem_cost
.
parameter
,
)
temp
=
fwd_mem_cost
.
temp
+
bwd_mem_cost
.
temp
,
buffer
=
fwd_mem_cost
.
buffer
+
bwd_mem_cost
.
buffer
)
total_mem_cost
=
MemoryCost
(
activation
=
fwd_mem_cost
.
activation
+
bwd_mem_cost
.
activation
,
parameter
=
fwd_mem_cost
.
parameter
+
bwd_mem_cost
.
parameter
,
temp
=
fwd_mem_cost
.
temp
+
bwd_mem_cost
.
temp
,
buffer
=
fwd_mem_cost
.
buffer
+
bwd_mem_cost
.
buffer
,
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
...
...
colossalai/auto_parallel/meta_profiler/registry.py
View file @
9e768b59
__all__
=
[
'
Registry
'
]
__all__
=
[
"
Registry
"
]
class
Registry
:
class
Registry
:
def
__init__
(
self
,
name
):
def
__init__
(
self
,
name
):
self
.
name
=
name
self
.
name
=
name
self
.
store
=
{}
self
.
store
=
{}
def
register
(
self
,
source
):
def
register
(
self
,
source
):
def
wrapper
(
func
):
def
wrapper
(
func
):
if
isinstance
(
source
,
(
list
,
tuple
)):
if
isinstance
(
source
,
(
list
,
tuple
)):
# support register a list of items for this func
# support register a list of items for this func
...
@@ -21,7 +19,7 @@ class Registry:
...
@@ -21,7 +19,7 @@ class Registry:
return
wrapper
return
wrapper
def
get
(
self
,
source
):
def
get
(
self
,
source
):
assert
source
in
self
.
store
,
f
'
{
source
}
not found in the
{
self
.
name
}
registry
'
assert
source
in
self
.
store
,
f
"
{
source
}
not found in the
{
self
.
name
}
registry
"
target
=
self
.
store
[
source
]
target
=
self
.
store
[
source
]
return
target
return
target
...
@@ -29,4 +27,4 @@ class Registry:
...
@@ -29,4 +27,4 @@ class Registry:
return
source
in
self
.
store
return
source
in
self
.
store
meta_register
=
Registry
(
'
meta
'
)
meta_register
=
Registry
(
"
meta
"
)
colossalai/auto_parallel/meta_profiler/shard_metainfo.py
View file @
9e768b59
...
@@ -2,20 +2,13 @@ from typing import Callable, List
...
@@ -2,20 +2,13 @@ from typing import Callable, List
import
torch
import
torch
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
OperationData
,
ShardingStrategy
,
TrainCycleItem
MemoryCost
,
OperationData
,
OperationDataType
,
ShardingStrategy
,
StrategiesVector
,
TrainCycleItem
,
)
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
.constants
import
INPLACE_MODULE
,
INPLACE_OPS
,
NO_SAVE_ACTIVATION
from
.constants
import
INPLACE_MODULE
,
INPLACE_OPS
,
NO_SAVE_ACTIVATION
from
.registry
import
meta_register
from
.registry
import
meta_register
__all__
=
[
'
ShardMetaInfo
'
]
__all__
=
[
"
ShardMetaInfo
"
]
class
ShardMetaInfo
:
class
ShardMetaInfo
:
...
@@ -76,10 +69,12 @@ class ShardMetaInfo:
...
@@ -76,10 +69,12 @@ class ShardMetaInfo:
"""
"""
if
isinstance
(
sharding_spec
,
ShardingSpec
):
if
isinstance
(
sharding_spec
,
ShardingSpec
):
op_data
=
OperationData
(
name
=
operation_data
.
name
,
op_data
=
OperationData
(
data
=
torch
.
zeros
(
sharding_spec
.
get_sharded_shape_per_device
(),
device
=
"meta"
),
name
=
operation_data
.
name
,
type
=
operation_data
.
type
,
data
=
torch
.
zeros
(
sharding_spec
.
get_sharded_shape_per_device
(),
device
=
"meta"
),
logical_shape
=
operation_data
.
logical_shape
)
type
=
operation_data
.
type
,
logical_shape
=
operation_data
.
logical_shape
,
)
elif
isinstance
(
sharding_spec
,
(
list
,
tuple
)):
elif
isinstance
(
sharding_spec
,
(
list
,
tuple
)):
data
=
operation_data
.
data
data
=
operation_data
.
data
assert
isinstance
(
data
,
(
list
,
tuple
)),
f
"Data Should be list or tuple, but got
{
type
(
data
)
}
."
assert
isinstance
(
data
,
(
list
,
tuple
)),
f
"Data Should be list or tuple, but got
{
type
(
data
)
}
."
...
@@ -97,8 +92,9 @@ class ShardMetaInfo:
...
@@ -97,8 +92,9 @@ class ShardMetaInfo:
"""
"""
Compute meta info based on sharding strategy and the given target function.
Compute meta info based on sharding strategy and the given target function.
"""
"""
assert
meta_register
.
has
(
self
.
_target
.
__class__
)
or
meta_register
.
has
(
self
.
_target
),
\
assert
meta_register
.
has
(
self
.
_target
.
__class__
)
or
meta_register
.
has
(
f
"Meta info for
{
self
.
_target
}
is not registered."
self
.
_target
),
f
"Meta info for
{
self
.
_target
}
is not registered."
if
meta_register
.
has
(
self
.
_target
.
__class__
):
if
meta_register
.
has
(
self
.
_target
.
__class__
):
# module
# module
meta_func
=
meta_register
.
get
(
self
.
_target
.
__class__
)
meta_func
=
meta_register
.
get
(
self
.
_target
.
__class__
)
...
@@ -117,11 +113,11 @@ class ShardMetaInfo:
...
@@ -117,11 +113,11 @@ class ShardMetaInfo:
# construct kwargs
# construct kwargs
if
self
.
target
in
INPLACE_MODULE
:
if
self
.
target
in
INPLACE_MODULE
:
kwargs
=
{
'
inplace
'
:
self
.
target
.
inplace
}
kwargs
=
{
"
inplace
"
:
self
.
target
.
inplace
}
elif
self
.
target
in
INPLACE_OPS
:
elif
self
.
target
in
INPLACE_OPS
:
kwargs
=
{
'
inplace
'
:
True
}
kwargs
=
{
"
inplace
"
:
True
}
else
:
else
:
kwargs
=
{
'
inplace
'
:
False
}
kwargs
=
{
"
inplace
"
:
False
}
# compute metainfo with meta_func
# compute metainfo with meta_func
self
.
compute_cost
,
self
.
memory_cost
,
self
.
fwd_in
,
self
.
fwd_buffer
,
self
.
fwd_out
=
meta_func
(
*
args
,
**
kwargs
)
self
.
compute_cost
,
self
.
memory_cost
,
self
.
fwd_in
,
self
.
fwd_buffer
,
self
.
fwd_out
=
meta_func
(
*
args
,
**
kwargs
)
...
...
colossalai/auto_parallel/offload/amp_optimizer.py
View file @
9e768b59
from
typing
import
Dict
,
Tuple
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Dict
,
Tuple
import
torch
import
torch
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
from
colossalai.logging
import
get_dist_logger
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.amp.naive_amp.grad_scaler
import
DynamicGradScaler
from
colossalai.amp.naive_amp.grad_scaler
import
DynamicGradScaler
from
colossalai.interface
import
OptimizerWrapper
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
.base_offload_module
import
BaseOffloadModule
from
.base_offload_module
import
BaseOffloadModule
from
.region_manager
import
RegionManager
from
.region
import
Region
from
.region
import
Region
from
.region_manager
import
RegionManager
class
OptimState
(
Enum
):
class
OptimState
(
Enum
):
SCALED
=
0
SCALED
=
0
UNSCALED
=
1
UNSCALED
=
1
class
AMPOptimizer
(
ColossalaiOptimizer
):
class
AMPOptimizer
(
OptimizerWrapper
):
"""
"""
A wrapper for Optimizer.
A wrapper for Optimizer.
Code reference: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/optimizer/zero_optimizer.py
Code reference: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/optimizer/zero_optimizer.py
...
@@ -36,19 +37,20 @@ class AMPOptimizer(ColossalaiOptimizer):
...
@@ -36,19 +37,20 @@ class AMPOptimizer(ColossalaiOptimizer):
norm_type (float, optional): norm_type used for `clip_grad_norm`.
norm_type (float, optional): norm_type used for `clip_grad_norm`.
"""
"""
def
__init__
(
self
,
def
__init__
(
optimizer
:
Optimizer
,
self
,
module
:
BaseOffloadModule
,
optimizer
:
Optimizer
,
initial_scale
:
float
=
2
**
16
,
module
:
BaseOffloadModule
,
growth_factor
:
float
=
2
,
initial_scale
:
float
=
2
**
16
,
backoff_factor
:
float
=
0.5
,
growth_factor
:
float
=
2
,
growth_interval
:
int
=
1000
,
backoff_factor
:
float
=
0.5
,
hysteresis
:
int
=
2
,
growth_interval
:
int
=
1000
,
min_scale
:
float
=
1
,
hysteresis
:
int
=
2
,
max_scale
:
float
=
2
**
32
,
min_scale
:
float
=
1
,
clipping_norm
:
float
=
0.0
,
max_scale
:
float
=
2
**
32
,
norm_type
:
float
=
2.0
):
clipping_norm
:
float
=
0.0
,
norm_type
:
float
=
2.0
,
):
super
().
__init__
(
optimizer
)
super
().
__init__
(
optimizer
)
self
.
module
=
module
self
.
module
=
module
...
@@ -68,19 +70,21 @@ class AMPOptimizer(ColossalaiOptimizer):
...
@@ -68,19 +70,21 @@ class AMPOptimizer(ColossalaiOptimizer):
self
.
__init__optimizer
()
self
.
__init__optimizer
()
# Grad scaler
# Grad scaler
self
.
grad_scaler
=
DynamicGradScaler
(
initial_scale
=
initial_scale
,
self
.
grad_scaler
=
DynamicGradScaler
(
min_scale
=
min_scale
,
initial_scale
=
initial_scale
,
growth_factor
=
growth_factor
,
min_scale
=
min_scale
,
backoff_factor
=
backoff_factor
,
growth_factor
=
growth_factor
,
growth_interval
=
growth_interval
,
backoff_factor
=
backoff_factor
,
hysteresis
=
hysteresis
,
growth_interval
=
growth_interval
,
max_scale
=
max_scale
)
hysteresis
=
hysteresis
,
max_scale
=
max_scale
,
)
self
.
_found_overflow
:
torch
.
Tensor
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int64
,
device
=
get_current_device
())
self
.
_found_overflow
:
torch
.
Tensor
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int64
,
device
=
get_current_device
())
self
.
_logger
=
get_dist_logger
()
self
.
_logger
=
get_dist_logger
()
def
_set_grad_ptr
(
self
):
def
_set_grad_ptr
(
self
):
for
group
in
self
.
param_groups
:
for
group
in
self
.
param_groups
:
for
fake_param
in
group
[
'
params
'
]:
for
fake_param
in
group
[
"
params
"
]:
region
=
self
.
param_to_region
[
fake_param
]
region
=
self
.
param_to_region
[
fake_param
]
begin
,
end
=
self
.
param_to_range
[
fake_param
]
begin
,
end
=
self
.
param_to_range
[
fake_param
]
...
@@ -91,7 +95,7 @@ class AMPOptimizer(ColossalaiOptimizer):
...
@@ -91,7 +95,7 @@ class AMPOptimizer(ColossalaiOptimizer):
def
_update_fp16_params
(
self
):
def
_update_fp16_params
(
self
):
none_tensor
=
torch
.
empty
([
0
])
none_tensor
=
torch
.
empty
([
0
])
for
group
in
self
.
param_groups
:
for
group
in
self
.
param_groups
:
for
fake_param
in
group
[
'
params
'
]:
for
fake_param
in
group
[
"
params
"
]:
assert
fake_param
.
grad
is
None
assert
fake_param
.
grad
is
None
fake_param
.
data
=
none_tensor
fake_param
.
data
=
none_tensor
self
.
param_to_region
[
fake_param
].
cpu_grad
=
None
self
.
param_to_region
[
fake_param
].
cpu_grad
=
None
...
@@ -129,10 +133,10 @@ class AMPOptimizer(ColossalaiOptimizer):
...
@@ -129,10 +133,10 @@ class AMPOptimizer(ColossalaiOptimizer):
found_inf
=
self
.
_check_overflow
()
found_inf
=
self
.
_check_overflow
()
if
found_inf
:
if
found_inf
:
self
.
optim_state
=
OptimState
.
UNSCALED
# no need to unscale grad
self
.
optim_state
=
OptimState
.
UNSCALED
# no need to unscale grad
self
.
grad_scaler
.
update
(
found_inf
)
# update gradient scaler
self
.
grad_scaler
.
update
(
found_inf
)
# update gradient scaler
self
.
_logger
.
info
(
f
'
Found overflow. Skip step
'
)
self
.
_logger
.
info
(
f
"
Found overflow. Skip step
"
)
self
.
zero_grad
()
# reset all gradients
self
.
zero_grad
()
# reset all gradients
self
.
_update_fp16_params
()
self
.
_update_fp16_params
()
return
return
...
@@ -155,11 +159,10 @@ class AMPOptimizer(ColossalaiOptimizer):
...
@@ -155,11 +159,10 @@ class AMPOptimizer(ColossalaiOptimizer):
self
.
module
.
backward
(
loss
)
self
.
module
.
backward
(
loss
)
def
__init__optimizer
(
self
):
def
__init__optimizer
(
self
):
for
group
in
self
.
optim
.
param_groups
:
for
group
in
self
.
optim
.
param_groups
:
fake_params_list
=
list
()
fake_params_list
=
list
()
for
param
in
group
[
'
params
'
]:
for
param
in
group
[
"
params
"
]:
region
=
self
.
region_manager
.
get_region
(
param
)
region
=
self
.
region_manager
.
get_region
(
param
)
fake_param
=
torch
.
nn
.
Parameter
(
torch
.
empty
([
0
]))
fake_param
=
torch
.
nn
.
Parameter
(
torch
.
empty
([
0
]))
self
.
param_to_range
[
fake_param
]
=
region
.
param_to_range
[
param
]
self
.
param_to_range
[
fake_param
]
=
region
.
param_to_range
[
param
]
...
@@ -170,8 +173,8 @@ class AMPOptimizer(ColossalaiOptimizer):
...
@@ -170,8 +173,8 @@ class AMPOptimizer(ColossalaiOptimizer):
if
param
in
self
.
optim
.
state
:
if
param
in
self
.
optim
.
state
:
self
.
optim
.
state
[
fake_param
]
=
self
.
optim
.
state
.
pop
(
param
)
self
.
optim
.
state
[
fake_param
]
=
self
.
optim
.
state
.
pop
(
param
)
group
[
'
params
'
]
=
fake_params_list
group
[
"
params
"
]
=
fake_params_list
# Leverage state_dict() and load_state_dict() to
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
# recast preexisting per-param state tensors
self
.
optim
.
load_state_dict
(
self
.
optim
.
state_dict
())
self
.
optim
.
load_state_dict
(
self
.
optim
.
state_dict
())
\ No newline at end of file
colossalai/auto_parallel/offload/base_offload_module.py
View file @
9e768b59
...
@@ -4,7 +4,7 @@ from typing import Optional, Set
...
@@ -4,7 +4,7 @@ from typing import Optional, Set
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
colossalai.
nn.parallel.data_parallel
import
_cast_float
from
colossalai.
utils
import
_cast_float
from
colossalai.zero.legacy.gemini.tensor_utils
import
free_storage
from
colossalai.zero.legacy.gemini.tensor_utils
import
free_storage
from
.region_manager
import
RegionManager
from
.region_manager
import
RegionManager
...
@@ -22,7 +22,6 @@ class BaseOffloadModule:
...
@@ -22,7 +22,6 @@ class BaseOffloadModule:
"""
"""
def
__init__
(
self
,
model
:
nn
.
Module
,
region_manager
:
RegionManager
,
is_sync
=
True
):
def
__init__
(
self
,
model
:
nn
.
Module
,
region_manager
:
RegionManager
,
is_sync
=
True
):
self
.
model
=
model
self
.
model
=
model
self
.
region_manager
=
region_manager
self
.
region_manager
=
region_manager
self
.
grad_hook_list
=
[]
self
.
grad_hook_list
=
[]
...
@@ -91,17 +90,16 @@ class BaseOffloadModule:
...
@@ -91,17 +90,16 @@ class BaseOffloadModule:
def
parameters
(
self
,
recurse
:
bool
=
True
):
def
parameters
(
self
,
recurse
:
bool
=
True
):
return
self
.
model
.
parameters
(
recurse
)
return
self
.
model
.
parameters
(
recurse
)
def
named_parameters
(
self
,
prefix
:
str
=
''
,
recurse
:
bool
=
True
):
def
named_parameters
(
self
,
prefix
:
str
=
""
,
recurse
:
bool
=
True
):
return
self
.
model
.
named_parameters
(
prefix
,
recurse
)
return
self
.
model
.
named_parameters
(
prefix
,
recurse
)
def
named_buffers
(
self
,
prefix
:
str
=
''
,
recurse
:
bool
=
True
):
def
named_buffers
(
self
,
prefix
:
str
=
""
,
recurse
:
bool
=
True
):
return
self
.
model
.
named_buffers
(
prefix
,
recurse
)
return
self
.
model
.
named_buffers
(
prefix
,
recurse
)
def
named_children
(
self
):
def
named_children
(
self
):
return
self
.
model
.
named_children
()
return
self
.
model
.
named_children
()
def
named_modules
(
self
,
def
named_modules
(
memo
:
Optional
[
Set
[
torch
.
nn
.
Module
]]
=
None
,
self
,
memo
:
Optional
[
Set
[
torch
.
nn
.
Module
]]
=
None
,
prefix
:
str
=
""
,
remove_duplicate
:
bool
=
True
prefix
:
str
=
''
,
):
remove_duplicate
:
bool
=
True
):
return
self
.
model
.
named_modules
(
memo
,
prefix
,
remove_duplicate
)
return
self
.
model
.
named_modules
(
memo
,
prefix
,
remove_duplicate
)
colossalai/auto_parallel/offload/mem_optimize.py
View file @
9e768b59
...
@@ -14,11 +14,9 @@ from .runtime import runtime_asyn_offload_apply_pass, runtime_syn_offload_apply_
...
@@ -14,11 +14,9 @@ from .runtime import runtime_asyn_offload_apply_pass, runtime_syn_offload_apply_
from
.util
import
GlobalRuntimeInfo
,
compute_act_peak_mem
,
compute_max_param_mem
,
compute_total_param_mem
from
.util
import
GlobalRuntimeInfo
,
compute_act_peak_mem
,
compute_max_param_mem
,
compute_total_param_mem
def
memory_optimize
(
model
:
torch
.
nn
.
Module
,
def
memory_optimize
(
inps
:
Dict
[
str
,
torch
.
Tensor
],
model
:
torch
.
nn
.
Module
,
inps
:
Dict
[
str
,
torch
.
Tensor
],
memory_budget
:
float
=
-
1.0
,
solver_name
:
str
=
"asyn"
memory_budget
:
float
=
-
1.0
,
):
solver_name
:
str
=
'asyn'
):
model
=
model
.
cpu
().
half
()
model
=
model
.
cpu
().
half
()
tracer
=
ColoTracer
()
tracer
=
ColoTracer
()
assert
is_compatible_with_meta
()
assert
is_compatible_with_meta
()
...
@@ -40,13 +38,13 @@ def memory_optimize(model: torch.nn.Module,
...
@@ -40,13 +38,13 @@ def memory_optimize(model: torch.nn.Module,
f
"act_peak_mem=
{
act_peak_mem
:.
3
f
}
MB | max_param_mem=
{
max_param_mem
:.
3
f
}
MB | total_param_mem=
{
total_param_mem
:.
3
f
}
"
f
"act_peak_mem=
{
act_peak_mem
:.
3
f
}
MB | max_param_mem=
{
max_param_mem
:.
3
f
}
MB | total_param_mem=
{
total_param_mem
:.
3
f
}
"
)
)
if
solver_name
==
'
syn
'
:
if
solver_name
==
"
syn
"
:
gm
=
runtime_syn_offload_apply_pass
(
gm
,
region_manager
.
region_list
)
gm
=
runtime_syn_offload_apply_pass
(
gm
,
region_manager
.
region_list
)
elif
solver_name
==
'
asyn
'
:
elif
solver_name
==
"
asyn
"
:
gm
=
runtime_asyn_offload_apply_pass
(
gm
,
region_manager
.
region_list
)
gm
=
runtime_asyn_offload_apply_pass
(
gm
,
region_manager
.
region_list
)
else
:
else
:
raise
TypeError
(
f
"Unknown solver name
{
solver_name
}
!"
)
raise
TypeError
(
f
"Unknown solver name
{
solver_name
}
!"
)
gm
.
recompile
()
gm
.
recompile
()
optimized_model
=
BaseOffloadModule
(
gm
,
region_manager
,
solver_name
==
'
syn
'
)
optimized_model
=
BaseOffloadModule
(
gm
,
region_manager
,
solver_name
==
"
syn
"
)
return
optimized_model
return
optimized_model
colossalai/auto_parallel/offload/region.py
View file @
9e768b59
...
@@ -55,13 +55,13 @@ class Region:
...
@@ -55,13 +55,13 @@ class Region:
Map the parameters in the region to a contiguous memory space.
Map the parameters in the region to a contiguous memory space.
"""
"""
self
.
fp16_data
=
torch
.
zeros
(
self
.
param_num
,
dtype
=
torch
.
half
,
device
=
'
cuda
'
)
self
.
fp16_data
=
torch
.
zeros
(
self
.
param_num
,
dtype
=
torch
.
half
,
device
=
"
cuda
"
)
offset
=
0
offset
=
0
for
param
in
self
.
fp16_params
:
for
param
in
self
.
fp16_params
:
param
.
data
=
param
.
data
.
cuda
()
param
.
data
=
param
.
data
.
cuda
()
p_num
=
param
.
data
.
numel
()
p_num
=
param
.
data
.
numel
()
self
.
fp16_data
[
offset
:
offset
+
p_num
].
copy_
(
param
.
data
.
flatten
())
self
.
fp16_data
[
offset
:
offset
+
p_num
].
copy_
(
param
.
data
.
flatten
())
param
.
data
=
self
.
fp16_data
[
offset
:
offset
+
p_num
].
view
(
param
.
data
.
shape
)
param
.
data
=
self
.
fp16_data
[
offset
:
offset
+
p_num
].
view
(
param
.
data
.
shape
)
self
.
param_to_range
[
param
]
=
(
offset
,
offset
+
p_num
)
self
.
param_to_range
[
param
]
=
(
offset
,
offset
+
p_num
)
offset
+=
p_num
offset
+=
p_num
...
@@ -83,7 +83,7 @@ class Region:
...
@@ -83,7 +83,7 @@ class Region:
self
.
temp_fp32_data
.
record_stream
(
torch
.
cuda
.
current_stream
())
self
.
temp_fp32_data
.
record_stream
(
torch
.
cuda
.
current_stream
())
if
not
self
.
in_mem_pool_flag
:
if
not
self
.
in_mem_pool_flag
:
alloc_storage
(
self
.
fp16_data
)
alloc_storage
(
self
.
fp16_data
)
self
.
fp16_data
[:
self
.
param_num
].
copy_
(
self
.
temp_fp32_data
)
self
.
fp16_data
[:
self
.
param_num
].
copy_
(
self
.
temp_fp32_data
)
self
.
fp16_data
.
record_stream
(
torch
.
cuda
.
current_stream
())
self
.
fp16_data
.
record_stream
(
torch
.
cuda
.
current_stream
())
self
.
__update_params_ptr
()
self
.
__update_params_ptr
()
...
@@ -94,7 +94,7 @@ class Region:
...
@@ -94,7 +94,7 @@ class Region:
"""
"""
self
.
cpu_grad
=
torch
.
empty
(
self
.
param_num
,
dtype
=
torch
.
half
,
pin_memory
=
True
)
self
.
cpu_grad
=
torch
.
empty
(
self
.
param_num
,
dtype
=
torch
.
half
,
pin_memory
=
True
)
self
.
cpu_grad
.
copy_
(
self
.
fp16_data
[:
self
.
param_num
],
non_blocking
=
True
)
self
.
cpu_grad
.
copy_
(
self
.
fp16_data
[:
self
.
param_num
],
non_blocking
=
True
)
self
.
fp16_data
.
record_stream
(
torch
.
cuda
.
current_stream
())
self
.
fp16_data
.
record_stream
(
torch
.
cuda
.
current_stream
())
if
not
self
.
in_mem_pool_flag
:
if
not
self
.
in_mem_pool_flag
:
self
.
free_cuda_data
()
self
.
free_cuda_data
()
...
...
colossalai/auto_parallel/offload/region_manager.py
View file @
9e768b59
from
typing
import
List
,
Any
,
Dict
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Tuple
import
torch
import
torch
from
torch.fx
import
Graph
,
Node
from
torch.fx
import
Graph
,
Node
from
.region
import
Region
from
.solver
import
SolverFactory
from
.solver
import
SolverFactory
from
.training_simulator
import
TrainingSimulator
from
.training_simulator
import
TrainingSimulator
from
.region
import
Region
from
.util
import
NodeInfo
from
.util
import
NodeInfo
...
@@ -19,14 +20,9 @@ class RegionManager:
...
@@ -19,14 +20,9 @@ class RegionManager:
cnode (List[str], optional): Common node List, should be the subset of input.
cnode (List[str], optional): Common node List, should be the subset of input.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
graph
:
Graph
,
solver_name
:
str
=
"asyn"
,
memory_budget
:
float
=
-
1.0
,
cnode
:
List
[
str
]
=
None
):
graph
:
Graph
,
solver_name
:
str
=
'asyn'
,
memory_budget
:
float
=
-
1.0
,
cnode
:
List
[
str
]
=
None
):
self
.
graph
=
graph
self
.
graph
=
graph
assert
graph
.
owning_module
is
not
None
,
'
The given graph is not associated with a owning_module
'
assert
graph
.
owning_module
is
not
None
,
"
The given graph is not associated with a owning_module
"
self
.
root_module
=
self
.
graph
.
owning_module
self
.
root_module
=
self
.
graph
.
owning_module
self
.
nodes
=
list
(
graph
.
nodes
)
self
.
nodes
=
list
(
graph
.
nodes
)
self
.
cnode
=
cnode
self
.
cnode
=
cnode
...
@@ -39,7 +35,7 @@ class RegionManager:
...
@@ -39,7 +35,7 @@ class RegionManager:
self
.
memory_budget
=
memory_budget
self
.
memory_budget
=
memory_budget
self
.
solver_name
=
solver_name
self
.
solver_name
=
solver_name
self
.
require_pool
:
bool
=
solver_name
==
'
asyn
'
self
.
require_pool
:
bool
=
solver_name
==
"
asyn
"
self
.
reg_to_block
:
Dict
[
int
,
int
]
=
dict
()
self
.
reg_to_block
:
Dict
[
int
,
int
]
=
dict
()
...
@@ -61,22 +57,19 @@ class RegionManager:
...
@@ -61,22 +57,19 @@ class RegionManager:
self
.
_post_process
(
solver
.
best_ts
)
self
.
_post_process
(
solver
.
best_ts
)
def
_pre_process
(
self
):
def
_pre_process
(
self
):
init_region_list
=
self
.
_linearize_graph
()
init_region_list
=
self
.
_linearize_graph
()
if
len
(
self
.
shared_region_pairs
)
>
1
:
if
len
(
self
.
shared_region_pairs
)
>
1
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"The current version only considers at most one pair of parameter sharing."
)
'The current version only considers at most one pair of parameter sharing.'
)
elif
len
(
self
.
shared_region_pairs
)
==
1
:
elif
len
(
self
.
shared_region_pairs
)
==
1
:
shared_regs
=
self
.
shared_region_pairs
[
0
]
shared_regs
=
self
.
shared_region_pairs
[
0
]
assert
shared_regs
[
0
].
shared_rid
==
shared_regs
[
1
].
r_id
\
assert
shared_regs
[
0
].
shared_rid
==
shared_regs
[
1
].
r_id
and
shared_regs
[
1
].
shared_rid
==
shared_regs
[
0
].
r_id
and
shared_regs
[
1
].
shared_rid
==
shared_regs
[
0
].
r_id
fst_id
=
shared_regs
[
0
].
r_id
fst_id
=
shared_regs
[
0
].
r_id
lst_id
=
shared_regs
[
1
].
r_id
lst_id
=
shared_regs
[
1
].
r_id
regs_left_out
=
init_region_list
[:
fst_id
+
1
]
regs_left_out
=
init_region_list
[:
fst_id
+
1
]
regs_right_out
=
init_region_list
[
lst_id
:]
regs_right_out
=
init_region_list
[
lst_id
:]
hold_regs
=
init_region_list
[
fst_id
+
1
:
lst_id
]
hold_regs
=
init_region_list
[
fst_id
+
1
:
lst_id
]
else
:
else
:
regs_left_out
=
[]
regs_left_out
=
[]
regs_right_out
=
[]
regs_right_out
=
[]
...
@@ -122,12 +115,9 @@ class RegionManager:
...
@@ -122,12 +115,9 @@ class RegionManager:
it may not find a suitable region placement strategy for the given execution flow.
it may not find a suitable region placement strategy for the given execution flow.
"""
"""
reg_flow
=
torch
.
cat
(
reg_flow
=
torch
.
cat
([
ts
.
fwd_reg_flow
,
ts
.
bwd_reg_flow
],
dim
=
0
)
[
ts
.
fwd_reg_flow
,
ts
.
bwd_reg_flow
],
dim
=
0
)
mem_block_num
=
torch
.
max
(
torch
.
sum
(
reg_flow
[:,
self
.
rid_in_pool
],
dim
=
1
))
mem_block_num
=
torch
.
max
(
coexist_matrix
=
torch
.
logical_or
(
ts
.
fwd_reg_flow
,
ts
.
bwd_reg_flow
)
torch
.
sum
(
reg_flow
[:,
self
.
rid_in_pool
],
dim
=
1
))
coexist_matrix
=
torch
.
logical_or
(
ts
.
fwd_reg_flow
,
ts
.
bwd_reg_flow
)
block_to_regs
=
{}
block_to_regs
=
{}
for
block_idx
in
range
(
mem_block_num
):
for
block_idx
in
range
(
mem_block_num
):
...
@@ -135,8 +125,7 @@ class RegionManager:
...
@@ -135,8 +125,7 @@ class RegionManager:
for
reg
in
self
.
region_list
:
for
reg
in
self
.
region_list
:
if
reg
.
r_id
in
self
.
rid_in_pool
:
if
reg
.
r_id
in
self
.
rid_in_pool
:
cur_reg_appears
=
coexist_matrix
[:,
reg
.
r_id
]
cur_reg_appears
=
coexist_matrix
[:,
reg
.
r_id
]
cur_reg_coexists
=
torch
.
sum
(
cur_reg_coexists
=
torch
.
sum
(
coexist_matrix
[
cur_reg_appears
],
dim
=
0
).
bool
()
coexist_matrix
[
cur_reg_appears
],
dim
=
0
).
bool
()
for
block_idx
in
range
(
mem_block_num
):
for
block_idx
in
range
(
mem_block_num
):
if
not
any
(
cur_reg_coexists
[
block_to_regs
[
block_idx
]]):
if
not
any
(
cur_reg_coexists
[
block_to_regs
[
block_idx
]]):
block_to_regs
[
block_idx
].
append
(
reg
.
r_id
)
block_to_regs
[
block_idx
].
append
(
reg
.
r_id
)
...
@@ -145,9 +134,12 @@ class RegionManager:
...
@@ -145,9 +134,12 @@ class RegionManager:
if
reg
.
r_id
not
in
self
.
reg_to_block
:
if
reg
.
r_id
not
in
self
.
reg_to_block
:
raise
NotImplementedError
(
raise
NotImplementedError
(
f
'can not find a block from the memory pool to store parameters of the region'
)
f
"can not find a block from the memory pool to store parameters of the region"
self
.
memory_pool
=
torch
.
chunk
(
torch
.
zeros
(
int
(
)
mem_block_num
*
self
.
mem_block_size
/
2
),
dtype
=
torch
.
half
,
device
=
'cuda'
),
chunks
=
int
(
mem_block_num
))
self
.
memory_pool
=
torch
.
chunk
(
torch
.
zeros
(
int
(
mem_block_num
*
self
.
mem_block_size
/
2
),
dtype
=
torch
.
half
,
device
=
"cuda"
),
chunks
=
int
(
mem_block_num
),
)
def
_merge_small_regions
(
self
,
orig_reg_list
:
List
[
Region
])
->
List
[
Region
]:
def
_merge_small_regions
(
self
,
orig_reg_list
:
List
[
Region
])
->
List
[
Region
]:
"""
"""
...
@@ -178,10 +170,9 @@ class RegionManager:
...
@@ -178,10 +170,9 @@ class RegionManager:
return
region_list
return
region_list
def
_search_block_size
(
self
,
def
_search_block_size
(
region_list
:
List
[
Region
],
self
,
region_list
:
List
[
Region
],
search_interval_byte
:
int
=
1024
,
search_range_byte
:
int
=
128
*
1024
**
2
search_interval_byte
:
int
=
1024
,
)
->
int
:
search_range_byte
:
int
=
128
*
1024
**
2
)
->
int
:
"""
"""
Search for a suitable memory block size.
Search for a suitable memory block size.
...
@@ -208,11 +199,10 @@ class RegionManager:
...
@@ -208,11 +199,10 @@ class RegionManager:
acc_wasted
+=
blk_size
-
left
acc_wasted
+=
blk_size
-
left
return
acc_wasted
return
acc_wasted
param_size_list
=
[
param_size_list
=
[
region
.
param_size
for
region
in
region_list
if
region
.
r_id
==
region
.
shared_rid
]
region
.
param_size
for
region
in
region_list
if
region
.
r_id
==
region
.
shared_rid
]
start_size
=
max
(
param_size_list
)
start_size
=
max
(
param_size_list
)
min_mem_waste
=
float
(
'
+inf
'
)
min_mem_waste
=
float
(
"
+inf
"
)
best_block_size
=
start_size
best_block_size
=
start_size
for
block_size
in
range
(
start_size
,
start_size
+
search_range_byte
+
1
,
search_interval_byte
):
for
block_size
in
range
(
start_size
,
start_size
+
search_range_byte
+
1
,
search_interval_byte
):
...
@@ -229,7 +219,7 @@ class RegionManager:
...
@@ -229,7 +219,7 @@ class RegionManager:
Initialize region data, which maps the parameters in the region to a contiguous memory space.
Initialize region data, which maps the parameters in the region to a contiguous memory space.
"""
"""
self
.
temp_fp32_data
=
torch
.
zeros
(
self
.
max_param_num
,
device
=
'
cuda
'
,
dtype
=
torch
.
float32
)
self
.
temp_fp32_data
=
torch
.
zeros
(
self
.
max_param_num
,
device
=
"
cuda
"
,
dtype
=
torch
.
float32
)
for
region
in
self
.
region_list
:
for
region
in
self
.
region_list
:
pre_alloc_tensor
=
None
pre_alloc_tensor
=
None
...
@@ -244,8 +234,7 @@ class RegionManager:
...
@@ -244,8 +234,7 @@ class RegionManager:
region
.
fp16_data
=
shared_region
.
fp16_data
region
.
fp16_data
=
shared_region
.
fp16_data
region
.
fp32_data
=
shared_region
.
fp32_data
region
.
fp32_data
=
shared_region
.
fp32_data
region
.
param_to_range
=
shared_region
.
param_to_range
region
.
param_to_range
=
shared_region
.
param_to_range
region
.
temp_fp32_data
=
self
.
temp_fp32_data
[:
region
.
param_num
].
detach
(
region
.
temp_fp32_data
=
self
.
temp_fp32_data
[:
region
.
param_num
].
detach
()
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
@@ -259,13 +248,14 @@ class RegionManager:
...
@@ -259,13 +248,14 @@ class RegionManager:
former_reg
,
latter_reg
=
self
.
shared_region_pairs
[
0
]
former_reg
,
latter_reg
=
self
.
shared_region_pairs
[
0
]
assert
latter_reg
.
param_num
>=
former_reg
.
param_num
assert
latter_reg
.
param_num
>=
former_reg
.
param_num
embedding_node
=
former_reg
.
nodes
[
-
1
]
embedding_node
=
former_reg
.
nodes
[
-
1
]
assert
embedding_node
.
op
==
'call_module'
and
isinstance
(
assert
embedding_node
.
op
==
"call_module"
and
isinstance
(
self
.
root_module
.
get_submodule
(
embedding_node
.
target
),
torch
.
nn
.
Embedding
)
self
.
root_module
.
get_submodule
(
embedding_node
.
target
),
torch
.
nn
.
Embedding
)
if
latter_reg
.
param_num
>
former_reg
.
param_num
:
if
latter_reg
.
param_num
>
former_reg
.
param_num
:
for
idx
,
n
in
enumerate
(
latter_reg
.
nodes
):
for
idx
,
n
in
enumerate
(
latter_reg
.
nodes
):
if
(
n
.
op
==
'call_module'
and
isinstance
(
self
.
root_module
.
get_submodule
(
n
.
target
),
if
(
torch
.
nn
.
Linear
)
)
or
\
n
.
op
==
"call_module"
and
isinstance
(
self
.
root_module
.
get_submodule
(
n
.
target
),
torch
.
nn
.
Linear
)
(
n
.
op
==
'
call_function
'
and
n
.
target
is
torch
.
nn
.
functional
.
linear
):
)
or
(
n
.
op
==
"
call_function
"
and
n
.
target
is
torch
.
nn
.
functional
.
linear
):
cut_node_idx
=
idx
+
1
cut_node_idx
=
idx
+
1
break
break
assert
len
(
latter_reg
.
fp16_params
)
==
2
assert
len
(
latter_reg
.
fp16_params
)
==
2
...
@@ -273,7 +263,7 @@ class RegionManager:
...
@@ -273,7 +263,7 @@ class RegionManager:
for
p
in
new_reg
.
fp16_params
:
for
p
in
new_reg
.
fp16_params
:
self
.
param_region_map
[
p
]
=
new_reg
self
.
param_region_map
[
p
]
=
new_reg
self
.
region_list
.
insert
(
new_reg
.
r_id
,
new_reg
)
self
.
region_list
.
insert
(
new_reg
.
r_id
,
new_reg
)
for
reg
in
self
.
region_list
[
new_reg
.
r_id
+
1
:]:
for
reg
in
self
.
region_list
[
new_reg
.
r_id
+
1
:]:
reg
.
r_id
+=
1
reg
.
r_id
+=
1
latter_reg
.
shared_rid
=
former_reg
.
r_id
latter_reg
.
shared_rid
=
former_reg
.
r_id
former_reg
.
shared_rid
=
latter_reg
.
r_id
former_reg
.
shared_rid
=
latter_reg
.
r_id
...
@@ -344,8 +334,8 @@ class RegionManager:
...
@@ -344,8 +334,8 @@ class RegionManager:
target
=
n
.
target
target
=
n
.
target
submod
=
self
.
root_module
.
get_submodule
(
target
)
submod
=
self
.
root_module
.
get_submodule
(
target
)
if
(
if
(
len
(
list
(
submod
.
named_parameters
(
recurse
=
False
)))
!=
0
len
(
list
(
submod
.
named_parameters
(
recurse
=
False
)))
!=
0
or
len
(
list
(
submod
.
named_buffers
(
recurse
=
False
)))
!=
0
or
len
(
list
(
submod
.
named_buffers
(
recurse
=
False
)))
!=
0
):
):
label
=
True
label
=
True
...
@@ -362,14 +352,12 @@ class RegionManager:
...
@@ -362,14 +352,12 @@ class RegionManager:
"""
"""
def
_is_inplace
(
n
:
Node
):
def
_is_inplace
(
n
:
Node
):
"""Get the inplace argument from ``torch.fx.Node``
"""Get the inplace argument from ``torch.fx.Node``"""
"""
inplace
=
False
inplace
=
False
if
n
.
op
==
"call_function"
:
if
n
.
op
==
"call_function"
:
inplace
=
n
.
kwargs
.
get
(
"inplace"
,
False
)
inplace
=
n
.
kwargs
.
get
(
"inplace"
,
False
)
elif
n
.
op
==
"call_module"
:
elif
n
.
op
==
"call_module"
:
inplace
=
getattr
(
n
.
graph
.
owning_module
.
get_submodule
(
inplace
=
getattr
(
n
.
graph
.
owning_module
.
get_submodule
(
n
.
target
),
"inplace"
,
False
)
n
.
target
),
"inplace"
,
False
)
return
inplace
return
inplace
label
=
False
label
=
False
...
@@ -378,28 +366,30 @@ class RegionManager:
...
@@ -378,28 +366,30 @@ class RegionManager:
target
=
n
.
target
target
=
n
.
target
submod
=
self
.
root_module
.
get_submodule
(
target
)
submod
=
self
.
root_module
.
get_submodule
(
target
)
if
(
if
(
len
(
list
(
submod
.
named_parameters
(
recurse
=
False
)))
!=
0
len
(
list
(
submod
.
named_parameters
(
recurse
=
False
)))
!=
0
or
len
(
list
(
submod
.
named_buffers
(
recurse
=
False
)))
!=
0
or
len
(
list
(
submod
.
named_buffers
(
recurse
=
False
)))
!=
0
):
):
label
=
True
label
=
True
elif
n
.
op
==
"call_function"
:
elif
n
.
op
==
"call_function"
:
label
=
any
(
map
(
lambda
x
:
x
.
name
in
self
.
only_param_ops
,
n
.
all_input_nodes
))
and
any
(
label
=
any
(
map
(
lambda
x
:
x
.
name
in
self
.
only_param_ops
,
n
.
all_input_nodes
))
and
any
(
map
(
lambda
x
:
x
.
name
not
in
self
.
only_param_ops
and
not
_is_cop
(
n
.
target
),
n
.
all_input_nodes
))
map
(
lambda
x
:
x
.
name
not
in
self
.
only_param_ops
and
not
_is_cop
(
n
.
target
),
n
.
all_input_nodes
)
)
return
label
and
not
sum
([
v
for
_
,
v
in
param_op_deps
.
items
()])
and
not
any
(
map
(
_is_inplace
,
n
.
users
))
return
label
and
not
sum
([
v
for
_
,
v
in
param_op_deps
.
items
()])
and
not
any
(
map
(
_is_inplace
,
n
.
users
))
def
_exception_node_handling
():
def
_exception_node_handling
():
# TODO meta info prop bug
# TODO meta info prop bug
if
n
.
name
.
__contains__
(
"transpose"
)
and
n
.
meta
[
'
fwd_out
'
][
0
].
dim
()
<=
2
:
if
n
.
name
.
__contains__
(
"transpose"
)
and
n
.
meta
[
"
fwd_out
"
][
0
].
dim
()
<=
2
:
n
.
meta
[
'
fwd_out
'
]
=
[]
n
.
meta
[
"
fwd_out
"
]
=
[]
# make sure that item in cnode is valid
# make sure that item in cnode is valid
if
self
.
cnode
:
if
self
.
cnode
:
for
name
in
self
.
cnode
:
for
name
in
self
.
cnode
:
try
:
try
:
assert
next
(
node
for
node
in
self
.
graph
.
nodes
if
node
.
name
==
name
).
op
==
"placeholder"
,
\
assert
(
f
"Common node
{
name
}
is not an input of the model."
next
(
node
for
node
in
self
.
graph
.
nodes
if
node
.
name
==
name
).
op
==
"placeholder"
),
f
"Common node
{
name
}
is not an input of the model."
except
StopIteration
:
except
StopIteration
:
raise
ValueError
(
f
"Common node name
{
name
}
not in graph."
)
raise
ValueError
(
f
"Common node name
{
name
}
not in graph."
)
else
:
else
:
...
@@ -428,8 +418,8 @@ class RegionManager:
...
@@ -428,8 +418,8 @@ class RegionManager:
ns
=
[]
ns
=
[]
border_n_idx
=
region
.
nodes
.
index
(
act_n
)
border_n_idx
=
region
.
nodes
.
index
(
act_n
)
if
border_n_idx
<
len
(
region
.
nodes
):
if
border_n_idx
<
len
(
region
.
nodes
):
ns
=
region
.
nodes
[
border_n_idx
+
1
:]
ns
=
region
.
nodes
[
border_n_idx
+
1
:]
region
.
nodes
=
region
.
nodes
[:
border_n_idx
+
1
]
region
.
nodes
=
region
.
nodes
[:
border_n_idx
+
1
]
region_list
.
append
(
region
)
region_list
.
append
(
region
)
region_id
+=
1
region_id
+=
1
region
=
Region
(
r_id
=
region_id
)
region
=
Region
(
r_id
=
region_id
)
...
@@ -448,19 +438,21 @@ class RegionManager:
...
@@ -448,19 +438,21 @@ class RegionManager:
region
=
Region
(
r_id
=
region_id
)
region
=
Region
(
r_id
=
region_id
)
# propagate common node attr if possible
# propagate common node attr if possible
if
len
(
n
.
all_input_nodes
)
==
len
([
node
for
node
in
n
.
all_input_nodes
if
node
.
name
in
self
.
cnode
if
len
(
n
.
all_input_nodes
)
==
len
(
])
or
_is_cop
(
n
.
target
):
[
node
for
node
in
n
.
all_input_nodes
if
node
.
name
in
self
.
cnode
]
)
or
_is_cop
(
n
.
target
):
self
.
cnode
.
append
(
n
.
name
)
self
.
cnode
.
append
(
n
.
name
)
else
:
else
:
deps
[
n
]
=
len
(
deps
[
n
]
=
len
([
user
for
user
in
n
.
users
if
user
.
op
!=
"output"
])
[
user
for
user
in
n
.
users
if
user
.
op
!=
"output"
])
# propagate param node attr if possible
# propagate param node attr if possible
if
len
(
n
.
all_input_nodes
)
==
len
([
node
for
node
in
n
.
all_input_nodes
if
node
.
name
in
self
.
only_param_ops
if
(
])
or
n
.
op
==
"get_attr"
:
len
(
n
.
all_input_nodes
)
==
len
([
node
for
node
in
n
.
all_input_nodes
if
node
.
name
in
self
.
only_param_ops
])
or
n
.
op
==
"get_attr"
):
self
.
only_param_ops
.
append
(
n
.
name
)
self
.
only_param_ops
.
append
(
n
.
name
)
param_op_deps
[
n
]
=
len
(
param_op_deps
[
n
]
=
len
([
user
for
user
in
n
.
users
if
user
.
op
!=
"output"
])
[
user
for
user
in
n
.
users
if
user
.
op
!=
"output"
])
# record last activation node
# record last activation node
if
_is_act
(
n
.
_meta_data
):
if
_is_act
(
n
.
_meta_data
):
...
@@ -472,19 +464,16 @@ class RegionManager:
...
@@ -472,19 +464,16 @@ class RegionManager:
return
region_list
return
region_list
def
_set_node_and_region_info
(
self
,
node_id
:
int
,
cur_n
:
Node
,
cur_reg
:
Region
):
def
_set_node_and_region_info
(
self
,
node_id
:
int
,
cur_n
:
Node
,
cur_reg
:
Region
):
cur_n
.
node_info
=
NodeInfo
(
node_id
)
cur_n
.
node_info
=
NodeInfo
(
node_id
)
if
cur_n
.
op
==
'
call_module
'
:
if
cur_n
.
op
==
"
call_module
"
:
target
=
cur_n
.
target
target
=
cur_n
.
target
submod
=
self
.
root_module
.
get_submodule
(
target
)
submod
=
self
.
root_module
.
get_submodule
(
target
)
for
p
in
list
(
submod
.
parameters
(
recurse
=
False
)):
for
p
in
list
(
submod
.
parameters
(
recurse
=
False
)):
if
p
in
self
.
param_region_map
:
if
p
in
self
.
param_region_map
:
cur_reg
.
shared_rid
=
self
.
param_region_map
[
p
].
r_id
cur_reg
.
shared_rid
=
self
.
param_region_map
[
p
].
r_id
self
.
param_region_map
[
p
].
shared_rid
=
cur_reg
.
r_id
self
.
param_region_map
[
p
].
shared_rid
=
cur_reg
.
r_id
self
.
shared_region_pairs
.
append
(
self
.
shared_region_pairs
.
append
((
self
.
param_region_map
[
p
],
cur_reg
))
(
self
.
param_region_map
[
p
],
cur_reg
))
else
:
else
:
self
.
param_region_map
[
p
]
=
cur_reg
self
.
param_region_map
[
p
]
=
cur_reg
...
@@ -499,12 +488,10 @@ class RegionManager:
...
@@ -499,12 +488,10 @@ class RegionManager:
attr_itr
=
getattr
(
attr_itr
,
atom
)
attr_itr
=
getattr
(
attr_itr
,
atom
)
if
isinstance
(
attr_itr
,
torch
.
nn
.
Parameter
):
if
isinstance
(
attr_itr
,
torch
.
nn
.
Parameter
):
if
attr_itr
in
self
.
param_region_map
:
if
attr_itr
in
self
.
param_region_map
:
cur_reg
.
shared_rid
=
self
.
param_region_map
[
attr_itr
].
r_id
cur_reg
.
shared_rid
=
self
.
param_region_map
[
attr_itr
].
r_id
self
.
param_region_map
[
attr_itr
].
shared_rid
=
cur_reg
.
r_id
self
.
param_region_map
[
attr_itr
].
shared_rid
=
cur_reg
.
r_id
self
.
shared_region_pairs
.
append
(
self
.
shared_region_pairs
.
append
((
self
.
param_region_map
[
attr_itr
],
cur_reg
))
(
self
.
param_region_map
[
attr_itr
],
cur_reg
))
else
:
else
:
self
.
param_region_map
[
attr_itr
]
=
cur_reg
self
.
param_region_map
[
attr_itr
]
=
cur_reg
...
...
colossalai/auto_parallel/offload/runtime.py
View file @
9e768b59
...
@@ -22,13 +22,13 @@ class SynPreFwdPostBwdOP(torch.autograd.Function):
...
@@ -22,13 +22,13 @@ class SynPreFwdPostBwdOP(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input_
,
fwd_info
,
bwd_info
):
def
forward
(
ctx
,
input_
,
fwd_info
,
bwd_info
):
ctx
.
bwd_info
=
bwd_info
ctx
.
bwd_info
=
bwd_info
d2h_rid
=
fwd_info
.
get
(
'
d2h_rid
'
,
None
)
d2h_rid
=
fwd_info
.
get
(
"
d2h_rid
"
,
None
)
if
d2h_rid
is
not
None
:
if
d2h_rid
is
not
None
:
free_region
=
GlobalRuntimeInfo
().
region_list
[
d2h_rid
]
free_region
=
GlobalRuntimeInfo
().
region_list
[
d2h_rid
]
assert
isinstance
(
free_region
,
Region
)
assert
isinstance
(
free_region
,
Region
)
free_region
.
free_cuda_data
()
free_region
.
free_cuda_data
()
h2d_rid
=
fwd_info
.
get
(
'
h2d_rid
'
,
None
)
h2d_rid
=
fwd_info
.
get
(
"
h2d_rid
"
,
None
)
if
h2d_rid
is
not
None
:
if
h2d_rid
is
not
None
:
h2d_region
=
GlobalRuntimeInfo
().
region_list
[
h2d_rid
]
h2d_region
=
GlobalRuntimeInfo
().
region_list
[
h2d_rid
]
assert
isinstance
(
h2d_region
,
Region
)
assert
isinstance
(
h2d_region
,
Region
)
...
@@ -38,8 +38,7 @@ class SynPreFwdPostBwdOP(torch.autograd.Function):
...
@@ -38,8 +38,7 @@ class SynPreFwdPostBwdOP(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
h2d_rid
=
ctx
.
bwd_info
.
get
(
"h2d_rid"
,
None
)
h2d_rid
=
ctx
.
bwd_info
.
get
(
'h2d_rid'
,
None
)
if
h2d_rid
is
not
None
:
if
h2d_rid
is
not
None
:
pref_region
=
GlobalRuntimeInfo
().
region_list
[
h2d_rid
]
pref_region
=
GlobalRuntimeInfo
().
region_list
[
h2d_rid
]
assert
isinstance
(
pref_region
,
Region
)
assert
isinstance
(
pref_region
,
Region
)
...
@@ -64,13 +63,13 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
...
@@ -64,13 +63,13 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
def
forward
(
ctx
,
input_
,
fwd_info
,
bwd_info
):
def
forward
(
ctx
,
input_
,
fwd_info
,
bwd_info
):
ctx
.
bwd_info
=
bwd_info
ctx
.
bwd_info
=
bwd_info
sync_rid
=
fwd_info
.
get
(
'
sync_rid
'
,
None
)
sync_rid
=
fwd_info
.
get
(
"
sync_rid
"
,
None
)
if
sync_rid
is
not
None
:
if
sync_rid
is
not
None
:
prefetch_event
=
GlobalRuntimeInfo
().
fwd_prefetch_event_map
.
get
(
sync_rid
,
None
)
prefetch_event
=
GlobalRuntimeInfo
().
fwd_prefetch_event_map
.
get
(
sync_rid
,
None
)
if
prefetch_event
:
if
prefetch_event
:
prefetch_event
.
wait
()
prefetch_event
.
wait
()
h2d_rid
=
fwd_info
.
get
(
'
h2d_rid
'
,
None
)
h2d_rid
=
fwd_info
.
get
(
"
h2d_rid
"
,
None
)
if
h2d_rid
is
not
None
:
if
h2d_rid
is
not
None
:
pref_region
=
GlobalRuntimeInfo
().
region_list
[
h2d_rid
]
pref_region
=
GlobalRuntimeInfo
().
region_list
[
h2d_rid
]
assert
isinstance
(
pref_region
,
Region
)
assert
isinstance
(
pref_region
,
Region
)
...
@@ -87,8 +86,7 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
...
@@ -87,8 +86,7 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
sync_rid
=
ctx
.
bwd_info
.
get
(
"sync_rid"
,
None
)
sync_rid
=
ctx
.
bwd_info
.
get
(
'sync_rid'
,
None
)
if
sync_rid
is
not
None
:
if
sync_rid
is
not
None
:
wait_region
=
GlobalRuntimeInfo
().
region_list
[
sync_rid
]
wait_region
=
GlobalRuntimeInfo
().
region_list
[
sync_rid
]
assert
isinstance
(
wait_region
,
Region
)
assert
isinstance
(
wait_region
,
Region
)
...
@@ -98,7 +96,7 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
...
@@ -98,7 +96,7 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
else
:
else
:
wait_region
.
move_param_to_cuda
()
wait_region
.
move_param_to_cuda
()
h2d_rid
=
ctx
.
bwd_info
.
get
(
'
h2d_rid
'
,
None
)
h2d_rid
=
ctx
.
bwd_info
.
get
(
"
h2d_rid
"
,
None
)
if
h2d_rid
is
not
None
:
if
h2d_rid
is
not
None
:
pref_region
=
GlobalRuntimeInfo
().
region_list
[
h2d_rid
]
pref_region
=
GlobalRuntimeInfo
().
region_list
[
h2d_rid
]
assert
isinstance
(
pref_region
,
Region
)
assert
isinstance
(
pref_region
,
Region
)
...
@@ -114,7 +112,7 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
...
@@ -114,7 +112,7 @@ class AsynPreFwdPostBwdOP(torch.autograd.Function):
def
convert_fwd_upload_bwd_offload_to_action
(
tensor
,
fwd_info
,
bwd_info
):
def
convert_fwd_upload_bwd_offload_to_action
(
tensor
,
fwd_info
,
bwd_info
):
'''
"""
Convert Upload and Offload operation into runtime action.
Convert Upload and Offload operation into runtime action.
Argument:
Argument:
...
@@ -123,14 +121,14 @@ def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info):
...
@@ -123,14 +121,14 @@ def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info):
that need to be uploaded, or freed during forward pass.
that need to be uploaded, or freed during forward pass.
bwd_info(dict): information dict, which contains region indices
bwd_info(dict): information dict, which contains region indices
that need to be uploaded during backward pass.
that need to be uploaded during backward pass.
'''
"""
with
torch
.
_C
.
DisableTorchFunction
():
with
torch
.
_C
.
DisableTorchFunction
():
ret
=
SynPreFwdPostBwdOP
.
apply
(
tensor
,
fwd_info
,
bwd_info
)
ret
=
SynPreFwdPostBwdOP
.
apply
(
tensor
,
fwd_info
,
bwd_info
)
return
ret
return
ret
def
convert_fwd_prefetch_bwd_offload_to_action
(
tensor
,
fwd_info
,
bwd_info
):
def
convert_fwd_prefetch_bwd_offload_to_action
(
tensor
,
fwd_info
,
bwd_info
):
'''
"""
Convert Prefetch and Offload operation into runtime action.
Convert Prefetch and Offload operation into runtime action.
Argument:
Argument:
...
@@ -139,7 +137,7 @@ def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info):
...
@@ -139,7 +137,7 @@ def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info):
that need to be prefetched, waited, or freed during forward pass.
that need to be prefetched, waited, or freed during forward pass.
bwd_info(dict): information dict, which contains region indices
bwd_info(dict): information dict, which contains region indices
that need to be prefetched or waited during backward pass.
that need to be prefetched or waited during backward pass.
'''
"""
with
torch
.
_C
.
DisableTorchFunction
():
with
torch
.
_C
.
DisableTorchFunction
():
ret
=
AsynPreFwdPostBwdOP
.
apply
(
tensor
,
fwd_info
,
bwd_info
)
ret
=
AsynPreFwdPostBwdOP
.
apply
(
tensor
,
fwd_info
,
bwd_info
)
return
ret
return
ret
...
@@ -176,22 +174,22 @@ def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[R
...
@@ -176,22 +174,22 @@ def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[R
# forward upload
# forward upload
fwd_info
=
{}
fwd_info
=
{}
if
requires_upload_p_in_fwd
(
region_list
[
region
.
shared_rid
]):
if
requires_upload_p_in_fwd
(
region_list
[
region
.
shared_rid
]):
fwd_info
[
'
h2d_rid
'
]
=
region
.
r_id
fwd_info
[
"
h2d_rid
"
]
=
region
.
r_id
# forward offload
# forward offload
if
r_idx
>
0
and
region_list
[
r_idx
-
1
].
need_offload
:
if
r_idx
>
0
and
region_list
[
r_idx
-
1
].
need_offload
:
fwd_info
[
'
d2h_rid
'
]
=
r_idx
-
1
fwd_info
[
"
d2h_rid
"
]
=
r_idx
-
1
bwd_info
=
{}
bwd_info
=
{}
# backward upload
# backward upload
if
r_idx
>
0
and
region_list
[
r_idx
-
1
].
need_offload
:
if
r_idx
>
0
and
region_list
[
r_idx
-
1
].
need_offload
:
bwd_info
[
'
h2d_rid
'
]
=
region_list
[
r_idx
-
1
].
r_id
bwd_info
[
"
h2d_rid
"
]
=
region_list
[
r_idx
-
1
].
r_id
if
fwd_info
or
bwd_info
:
if
fwd_info
or
bwd_info
:
with
mod_graph
.
inserting_after
(
last_inp_node
):
with
mod_graph
.
inserting_after
(
last_inp_node
):
new_node
=
mod_graph
.
create_node
(
'call_function'
,
new_node
=
mod_graph
.
create_node
(
convert_fwd_upload_bwd_offload_to_action
,
"call_function"
,
convert_fwd_upload_bwd_offload_to_action
,
args
=
(
last_inp_node
,
fwd_info
,
bwd_info
)
args
=
(
last_inp_node
,
fwd_info
,
bwd_info
)
)
)
replace_node_users
(
last_inp_node
,
new_node
)
replace_node_users
(
last_inp_node
,
new_node
)
last_inp_node
=
region
.
nodes
[
-
1
]
last_inp_node
=
region
.
nodes
[
-
1
]
...
@@ -210,9 +208,9 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
...
@@ -210,9 +208,9 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
first_region_with_p
=
[
region
for
region
in
region_list
if
region
.
param_size
][
0
]
first_region_with_p
=
[
region
for
region
in
region_list
if
region
.
param_size
][
0
]
fwd_info
=
{
"h2d_rid"
:
first_region_with_p
.
r_id
}
fwd_info
=
{
"h2d_rid"
:
first_region_with_p
.
r_id
}
with
mod_graph
.
inserting_after
(
last_inp_node
):
with
mod_graph
.
inserting_after
(
last_inp_node
):
upload_apply_node
=
mod_graph
.
create_node
(
'call_function'
,
upload_apply_node
=
mod_graph
.
create_node
(
convert_fwd_upload_bwd_offload_to_action
,
"call_function"
,
convert_fwd_upload_bwd_offload_to_action
,
args
=
(
last_inp_node
,
fwd_info
,
{})
args
=
(
last_inp_node
,
fwd_info
,
{})
)
)
replace_node_users
(
last_inp_node
,
upload_apply_node
)
replace_node_users
(
last_inp_node
,
upload_apply_node
)
last_inp_node
=
upload_apply_node
last_inp_node
=
upload_apply_node
...
@@ -220,37 +218,39 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
...
@@ -220,37 +218,39 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[
# forward prefetch
# forward prefetch
fwd_info
=
{}
fwd_info
=
{}
if
region
.
param_size
:
if
region
.
param_size
:
fwd_info
[
'
sync_rid
'
]
=
region
.
r_id
fwd_info
[
"
sync_rid
"
]
=
region
.
r_id
fwd_prefetch_region
=
region
.
fwd_prefetch_region
fwd_prefetch_region
=
region
.
fwd_prefetch_region
if
fwd_prefetch_region
and
requires_upload_p_in_fwd
(
region_list
[
fwd_prefetch_region
.
shared_rid
]):
if
fwd_prefetch_region
and
requires_upload_p_in_fwd
(
region_list
[
fwd_prefetch_region
.
shared_rid
]):
fwd_info
[
'
h2d_rid
'
]
=
fwd_prefetch_region
.
r_id
fwd_info
[
"
h2d_rid
"
]
=
fwd_prefetch_region
.
r_id
# forward offload
# forward offload
if
r_idx
>
0
and
region_list
[
r_idx
-
1
].
need_offload
:
if
r_idx
>
0
and
region_list
[
r_idx
-
1
].
need_offload
:
fwd_info
[
'
d2h_rid
'
]
=
r_idx
-
1
fwd_info
[
"
d2h_rid
"
]
=
r_idx
-
1
bwd_info
=
{}
bwd_info
=
{}
# backward prefetch
# backward prefetch
if
r_idx
>
0
and
region_list
[
r_idx
-
1
].
need_offload
:
if
r_idx
>
0
and
region_list
[
r_idx
-
1
].
need_offload
:
bwd_info
[
'
sync_rid
'
]
=
r_idx
-
1
bwd_info
[
"
sync_rid
"
]
=
r_idx
-
1
if
r_idx
>
0
and
region_list
[
r_idx
-
1
].
bwd_prefetch_region
:
if
r_idx
>
0
and
region_list
[
r_idx
-
1
].
bwd_prefetch_region
:
bwd_info
[
'
h2d_rid
'
]
=
region_list
[
r_idx
-
1
].
bwd_prefetch_region
.
r_id
bwd_info
[
"
h2d_rid
"
]
=
region_list
[
r_idx
-
1
].
bwd_prefetch_region
.
r_id
if
fwd_info
or
bwd_info
:
if
fwd_info
or
bwd_info
:
with
mod_graph
.
inserting_after
(
last_inp_node
):
with
mod_graph
.
inserting_after
(
last_inp_node
):
new_node
=
mod_graph
.
create_node
(
'call_function'
,
new_node
=
mod_graph
.
create_node
(
convert_fwd_prefetch_bwd_offload_to_action
,
"call_function"
,
args
=
(
last_inp_node
,
fwd_info
,
bwd_info
))
convert_fwd_prefetch_bwd_offload_to_action
,
args
=
(
last_inp_node
,
fwd_info
,
bwd_info
),
)
replace_node_users
(
last_inp_node
,
new_node
)
replace_node_users
(
last_inp_node
,
new_node
)
last_inp_node
=
region
.
nodes
[
-
1
]
last_inp_node
=
region
.
nodes
[
-
1
]
if
region
.
bwd_prefetch_region
:
if
region
.
bwd_prefetch_region
:
bwd_info
=
{
'
h2d_rid
'
:
region
.
bwd_prefetch_region
.
r_id
}
bwd_info
=
{
"
h2d_rid
"
:
region
.
bwd_prefetch_region
.
r_id
}
with
mod_graph
.
inserting_after
(
last_inp_node
):
with
mod_graph
.
inserting_after
(
last_inp_node
):
new_node
=
mod_graph
.
create_node
(
'call_function'
,
new_node
=
mod_graph
.
create_node
(
convert_fwd_prefetch_bwd_offload_to_action
,
"call_function"
,
convert_fwd_prefetch_bwd_offload_to_action
,
args
=
(
last_inp_node
,
{},
bwd_info
)
args
=
(
last_inp_node
,
{},
bwd_info
)
)
)
replace_node_users
(
last_inp_node
,
new_node
)
replace_node_users
(
last_inp_node
,
new_node
)
# gm.graph.print_tabular()
# gm.graph.print_tabular()
return
gm
return
gm
colossalai/auto_parallel/offload/solver.py
View file @
9e768b59
import
time
import
time
from
typing
import
List
,
Dict
,
Type
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Dict
,
List
,
Type
NOT_NVML
=
False
NOT_NVML
=
False
try
:
try
:
...
@@ -10,10 +10,11 @@ except:
...
@@ -10,10 +10,11 @@ except:
import
torch
import
torch
from
torch.fx.node
import
Node
from
torch.fx.node
import
Node
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
.training_simulator
import
TrainingSimulator
,
SynTrainingSimulator
,
AsynTrainingSimulator
from
.region
import
Region
from
.region
import
Region
from
.training_simulator
import
AsynTrainingSimulator
,
SynTrainingSimulator
,
TrainingSimulator
from
.util
import
NodeInfo
,
NvDevicePower
from
.util
import
NodeInfo
,
NvDevicePower
...
@@ -49,19 +50,14 @@ class Solver(ABC):
...
@@ -49,19 +50,14 @@ class Solver(ABC):
It is used to reduce the memory budget. Due to some errors in the estimation of peak memory and execution time.
It is used to reduce the memory budget. Due to some errors in the estimation of peak memory and execution time.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
region_list
:
List
[
Region
],
memory_budget
:
float
=
-
1.0
,
error_factor
:
float
=
0.95
)
->
None
:
region_list
:
List
[
Region
],
memory_budget
:
float
=
-
1.0
,
error_factor
:
float
=
0.95
)
->
None
:
self
.
region_list
=
region_list
self
.
region_list
=
region_list
self
.
error_factor
:
float
=
error_factor
self
.
error_factor
:
float
=
error_factor
if
memory_budget
>
0
:
if
memory_budget
>
0
:
self
.
memory_budget
=
memory_budget
*
self
.
error_factor
self
.
memory_budget
=
memory_budget
*
self
.
error_factor
else
:
else
:
self
.
memory_budget
=
torch
.
cuda
.
get_device_properties
(
self
.
memory_budget
=
torch
.
cuda
.
get_device_properties
(
get_current_device
()).
total_memory
*
self
.
error_factor
get_current_device
()).
total_memory
*
self
.
error_factor
self
.
link_to_bandwidth
:
Dict
[
str
,
Dict
[
float
,
float
]]
=
self
.
_profile_bandwidth
()
self
.
link_to_bandwidth
:
Dict
[
str
,
Dict
[
float
,
float
]]
=
self
.
_profile_bandwidth
()
self
.
comp_power
:
float
=
self
.
_extract_computing_power
()
self
.
comp_power
:
float
=
self
.
_extract_computing_power
()
...
@@ -94,7 +90,7 @@ class Solver(ABC):
...
@@ -94,7 +90,7 @@ class Solver(ABC):
if
extra_cost
==
0
:
if
extra_cost
==
0
:
# means data transfer overhead can be completely overlapped
# means data transfer overhead can be completely overlapped
return
(
float
(
'
inf
'
),
total_mem_saving
,
peak_mem_saving
)
return
(
float
(
"
inf
"
),
total_mem_saving
,
peak_mem_saving
)
return
(
total_mem_saving
/
extra_cost
,
total_mem_saving
,
peak_mem_saving
)
return
(
total_mem_saving
/
extra_cost
,
total_mem_saving
,
peak_mem_saving
)
def
_compare_profit
(
self
,
profit_a
:
tuple
,
profit_b
:
tuple
)
->
bool
:
def
_compare_profit
(
self
,
profit_a
:
tuple
,
profit_b
:
tuple
)
->
bool
:
...
@@ -122,9 +118,7 @@ class Solver(ABC):
...
@@ -122,9 +118,7 @@ class Solver(ABC):
self
.
best_ts
=
best_ts
self
.
best_ts
=
best_ts
self
.
_update_node_mem_info
(
best_ts
.
fwd_node_mem
,
best_ts
.
bwd_node_mem
)
self
.
_update_node_mem_info
(
best_ts
.
fwd_node_mem
,
best_ts
.
bwd_node_mem
)
def
_update_node_mem_info
(
self
,
def
_update_node_mem_info
(
self
,
fwd_mem_info
:
Dict
[
Node
,
float
],
bwd_mem_info
:
Dict
[
Node
,
float
]):
fwd_mem_info
:
Dict
[
Node
,
float
],
bwd_mem_info
:
Dict
[
Node
,
float
]):
"""
"""
Update the runtime memory information of the node.
Update the runtime memory information of the node.
...
@@ -134,12 +128,10 @@ class Solver(ABC):
...
@@ -134,12 +128,10 @@ class Solver(ABC):
"""
"""
for
node
,
mem
in
fwd_mem_info
.
items
():
for
node
,
mem
in
fwd_mem_info
.
items
():
assert
hasattr
(
node
,
'node_info'
)
and
isinstance
(
assert
hasattr
(
node
,
"node_info"
)
and
isinstance
(
node
.
node_info
,
NodeInfo
)
node
.
node_info
,
NodeInfo
)
node
.
node_info
.
runtime_fwd_mem
=
mem
node
.
node_info
.
runtime_fwd_mem
=
mem
for
node
,
mem
in
bwd_mem_info
.
items
():
for
node
,
mem
in
bwd_mem_info
.
items
():
assert
hasattr
(
node
,
'node_info'
)
and
isinstance
(
assert
hasattr
(
node
,
"node_info"
)
and
isinstance
(
node
.
node_info
,
NodeInfo
)
node
.
node_info
,
NodeInfo
)
node
.
node_info
.
runtime_bwd_mem
=
mem
node
.
node_info
.
runtime_bwd_mem
=
mem
def
_extract_computing_power
(
self
):
def
_extract_computing_power
(
self
):
...
@@ -159,12 +151,12 @@ class Solver(ABC):
...
@@ -159,12 +151,12 @@ class Solver(ABC):
return
NvDevicePower
.
RTX3080_FP16
*
units
return
NvDevicePower
.
RTX3080_FP16
*
units
elif
device_name
.
__contains__
(
"RTX 3090"
):
elif
device_name
.
__contains__
(
"RTX 3090"
):
return
NvDevicePower
.
RTX3090_FP16
*
units
return
NvDevicePower
.
RTX3090_FP16
*
units
elif
device_name
.
__contains__
(
'
V100
'
):
elif
device_name
.
__contains__
(
"
V100
"
):
return
NvDevicePower
.
V100_FP16
*
units
return
NvDevicePower
.
V100_FP16
*
units
elif
device_name
.
__contains__
(
"A100"
):
elif
device_name
.
__contains__
(
"A100"
):
return
NvDevicePower
.
A100_FP16
*
units
return
NvDevicePower
.
A100_FP16
*
units
else
:
else
:
raise
TypeError
(
f
'
Unknown NVIDIA GPU device name
{
device_name
}
'
)
raise
TypeError
(
f
"
Unknown NVIDIA GPU device name
{
device_name
}
"
)
def
_profile_bandwidth
(
self
):
def
_profile_bandwidth
(
self
):
"""
"""
...
@@ -172,9 +164,9 @@ class Solver(ABC):
...
@@ -172,9 +164,9 @@ class Solver(ABC):
using data volumes ranging from 1KB to 1GB.
using data volumes ranging from 1KB to 1GB.
"""
"""
print
(
'
profiling bandwidth ......
'
)
print
(
"
profiling bandwidth ......
"
)
link_to_bandwidth
=
{}
link_to_bandwidth
=
{}
links
=
[
'
h2d
'
,
'
d2h
'
]
links
=
[
"
h2d
"
,
"
d2h
"
]
for
link
in
links
:
for
link
in
links
:
t_size
=
1024
t_size
=
1024
...
@@ -182,24 +174,22 @@ class Solver(ABC):
...
@@ -182,24 +174,22 @@ class Solver(ABC):
# from 1KB to 1GB
# from 1KB to 1GB
for
i
in
range
(
21
):
for
i
in
range
(
21
):
if
link
==
'h2d'
:
if
link
==
"h2d"
:
src_tensor
=
torch
.
ones
(
src_tensor
=
torch
.
ones
(
int
(
t_size
),
dtype
=
torch
.
int8
,
pin_memory
=
True
)
int
(
t_size
),
dtype
=
torch
.
int8
,
pin_memory
=
True
)
dst_tensor
=
torch
.
ones
((
int
(
t_size
)),
dtype
=
torch
.
int8
,
device
=
"cuda"
)
dst_tensor
=
torch
.
ones
(
elif
link
==
"d2h"
:
(
int
(
t_size
)),
dtype
=
torch
.
int8
,
device
=
'cuda'
)
src_tensor
=
torch
.
ones
(
int
(
t_size
),
dtype
=
torch
.
int8
,
device
=
"cuda"
)
elif
link
==
'd2h'
:
dst_tensor
=
torch
.
ones
((
int
(
t_size
)),
dtype
=
torch
.
int8
,
pin_memory
=
True
)
src_tensor
=
torch
.
ones
(
int
(
t_size
),
dtype
=
torch
.
int8
,
device
=
'cuda'
)
dst_tensor
=
torch
.
ones
(
(
int
(
t_size
)),
dtype
=
torch
.
int8
,
pin_memory
=
True
)
def
func
():
def
func
():
dst_tensor
.
copy_
(
src_tensor
)
dst_tensor
.
copy_
(
src_tensor
)
size_to_bandwidth
[
t_size
]
=
t_size
/
benchmark_func
(
func
,
number
=
5
,
repeat
=
3
)
size_to_bandwidth
[
t_size
]
=
t_size
/
benchmark_func
(
func
,
number
=
5
,
repeat
=
3
)
print
(
f
'size:
{
t_size
/
1024
**
2
:.
3
f
}
MB, '
print
(
f
'
{
src_tensor
.
device
.
type
}
-to-
{
dst_tensor
.
device
.
type
}
'
f
"size:
{
t_size
/
1024
**
2
:.
3
f
}
MB, "
f
'bandwidth:
{
size_to_bandwidth
[
t_size
]
/
1024
**
3
:.
3
f
}
GB/s'
)
f
"
{
src_tensor
.
device
.
type
}
-to-
{
dst_tensor
.
device
.
type
}
"
f
"bandwidth:
{
size_to_bandwidth
[
t_size
]
/
1024
**
3
:.
3
f
}
GB/s"
)
t_size
*=
2
t_size
*=
2
...
@@ -208,10 +198,7 @@ class Solver(ABC):
...
@@ -208,10 +198,7 @@ class Solver(ABC):
class
SynGreedySolver
(
Solver
):
class
SynGreedySolver
(
Solver
):
def
__init__
(
self
,
region_list
:
List
[
Region
],
memory_budget
:
float
=
-
1.0
)
->
None
:
def
__init__
(
self
,
region_list
:
List
[
Region
],
memory_budget
:
float
=
-
1.0
)
->
None
:
super
().
__init__
(
region_list
,
memory_budget
)
super
().
__init__
(
region_list
,
memory_budget
)
self
.
best_ts
:
SynTrainingSimulator
=
None
self
.
best_ts
:
SynTrainingSimulator
=
None
...
@@ -258,7 +245,8 @@ class SynGreedySolver(Solver):
...
@@ -258,7 +245,8 @@ class SynGreedySolver(Solver):
else
:
else
:
raise
NotImplementedError
(
raise
NotImplementedError
(
f
"can't find the offload strategy met the memory budget
{
self
.
memory_budget
/
1024
**
2
}
MB, "
f
"can't find the offload strategy met the memory budget
{
self
.
memory_budget
/
1024
**
2
}
MB, "
f
"it needs
{
self
.
best_ts
.
peak_mem
/
1024
**
2
:.
3
f
}
MB at least!"
)
f
"it needs
{
self
.
best_ts
.
peak_mem
/
1024
**
2
:.
3
f
}
MB at least!"
)
def
_call_solver_l2l
(
self
):
def
_call_solver_l2l
(
self
):
"""
"""
...
@@ -270,7 +258,6 @@ class SynGreedySolver(Solver):
...
@@ -270,7 +258,6 @@ class SynGreedySolver(Solver):
region
.
is_syn
=
True
region
.
is_syn
=
True
def
_try_to_offload
(
self
,
offload_region
:
Region
):
def
_try_to_offload
(
self
,
offload_region
:
Region
):
# record previous information
# record previous information
orig_need_offload
=
offload_region
.
need_offload
orig_need_offload
=
offload_region
.
need_offload
assert
not
orig_need_offload
assert
not
orig_need_offload
...
@@ -297,23 +284,17 @@ class SynGreedySolver(Solver):
...
@@ -297,23 +284,17 @@ class SynGreedySolver(Solver):
ts
=
SynTrainingSimulator
(
self
.
region_list
,
self
.
comp_power
,
self
.
link_to_bandwidth
)
ts
=
SynTrainingSimulator
(
self
.
region_list
,
self
.
comp_power
,
self
.
link_to_bandwidth
)
ts
.
execute
()
ts
.
execute
()
extra_comm_cost
=
2.0
*
\
extra_comm_cost
=
2.0
*
ts
.
_get_communication_overhead
(
"h2d"
,
offload_region
.
param_size
)
ts
.
_get_communication_overhead
(
'h2d'
,
offload_region
.
param_size
)
# the shared region needs to be moved twice
# the shared region needs to be moved twice
if
offload_region
.
r_id
<
offload_region
.
shared_rid
:
if
offload_region
.
r_id
<
offload_region
.
shared_rid
:
extra_comm_cost
*=
2.0
extra_comm_cost
*=
2.0
profit
=
self
.
_compute_offload_profit
(
profit
=
self
.
_compute_offload_profit
(
ts
.
total_mem_saving
,
self
.
best_ts
.
peak_mem
-
ts
.
peak_mem
,
extra_comm_cost
)
ts
.
total_mem_saving
,
self
.
best_ts
.
peak_mem
-
ts
.
peak_mem
,
extra_comm_cost
)
return
ts
,
profit
return
ts
,
profit
class
AsynGreedySolver
(
Solver
):
class
AsynGreedySolver
(
Solver
):
def
__init__
(
self
,
region_list
:
List
[
Region
],
memory_budget
:
float
=
-
1.0
,
search_window_size
:
int
=
3
):
def
__init__
(
self
,
region_list
:
List
[
Region
],
memory_budget
:
float
=
-
1.0
,
search_window_size
:
int
=
3
):
super
().
__init__
(
region_list
,
memory_budget
)
super
().
__init__
(
region_list
,
memory_budget
)
self
.
search_window_size
=
search_window_size
self
.
search_window_size
=
search_window_size
...
@@ -331,7 +312,7 @@ class AsynGreedySolver(Solver):
...
@@ -331,7 +312,7 @@ class AsynGreedySolver(Solver):
ts
=
AsynTrainingSimulator
(
self
.
region_list
,
self
.
comp_power
,
self
.
link_to_bandwidth
)
ts
=
AsynTrainingSimulator
(
self
.
region_list
,
self
.
comp_power
,
self
.
link_to_bandwidth
)
ts
.
execute
()
ts
.
execute
()
self
.
_update_state
(
ts
)
self
.
_update_state
(
ts
)
print
(
"init peak memory"
,
self
.
best_ts
.
peak_mem
/
1024
**
2
,
"MB"
)
print
(
"init peak memory"
,
self
.
best_ts
.
peak_mem
/
1024
**
2
,
"MB"
)
def
_call_solver
(
self
):
def
_call_solver
(
self
):
"""
"""
...
@@ -358,18 +339,17 @@ class AsynGreedySolver(Solver):
...
@@ -358,18 +339,17 @@ class AsynGreedySolver(Solver):
best_pref_ts
=
None
best_pref_ts
=
None
# search when to prefetch the region offloaded
# search when to prefetch the region offloaded
for
host_region
in
self
.
region_list
[
region
.
r_id
+
1
:
region
.
r_id
+
1
+
self
.
search_window_size
]:
for
host_region
in
self
.
region_list
[
region
.
r_id
+
1
:
region
.
r_id
+
1
+
self
.
search_window_size
]:
if
host_region
.
bwd_prefetch_region
is
not
None
:
if
host_region
.
bwd_prefetch_region
is
not
None
:
continue
continue
temp_ts
,
profit
=
self
.
_try_to_offload
(
temp_ts
,
profit
=
self
.
_try_to_offload
(
host_region
,
region
)
host_region
,
region
)
if
self
.
_compare_profit
(
profit
,
max_prefetch_profit
):
if
self
.
_compare_profit
(
profit
,
max_prefetch_profit
):
region_to_region_map
[
region
.
r_id
]
=
host_region
region_to_region_map
[
region
.
r_id
]
=
host_region
max_prefetch_profit
=
profit
max_prefetch_profit
=
profit
best_pref_ts
=
temp_ts
best_pref_ts
=
temp_ts
if
profit
[
0
]
==
float
(
'
inf
'
):
if
profit
[
0
]
==
float
(
"
inf
"
):
break
break
if
self
.
_compare_profit
(
max_prefetch_profit
,
max_offload_profit
):
if
self
.
_compare_profit
(
max_prefetch_profit
,
max_offload_profit
):
...
@@ -392,7 +372,8 @@ class AsynGreedySolver(Solver):
...
@@ -392,7 +372,8 @@ class AsynGreedySolver(Solver):
else
:
else
:
raise
NotImplementedError
(
raise
NotImplementedError
(
f
"can't find the offload strategy met the memory budget
{
self
.
memory_budget
/
1024
**
2
}
MB, "
f
"can't find the offload strategy met the memory budget
{
self
.
memory_budget
/
1024
**
2
}
MB, "
f
"it needs
{
self
.
best_ts
.
peak_mem
/
1024
**
2
:.
3
f
}
MB at least!"
)
f
"it needs
{
self
.
best_ts
.
peak_mem
/
1024
**
2
:.
3
f
}
MB at least!"
)
region_to_region_map
.
clear
()
region_to_region_map
.
clear
()
...
@@ -452,7 +433,6 @@ class AsynGreedySolver(Solver):
...
@@ -452,7 +433,6 @@ class AsynGreedySolver(Solver):
peak_mem_saving
=
0
peak_mem_saving
=
0
while
len
(
self
.
region_to_region_map
)
and
peak_mem_saving
<=
0
:
while
len
(
self
.
region_to_region_map
)
and
peak_mem_saving
<=
0
:
max_profit
=
(
0
,)
max_profit
=
(
0
,)
best_ts
=
None
best_ts
=
None
undo_host_region
=
None
undo_host_region
=
None
...
@@ -464,8 +444,7 @@ class AsynGreedySolver(Solver):
...
@@ -464,8 +444,7 @@ class AsynGreedySolver(Solver):
assert
offload_region
.
need_offload
assert
offload_region
.
need_offload
assert
not
offload_region
.
is_syn
assert
not
offload_region
.
is_syn
ts
,
profit
=
self
.
_try_convert_to_syn_upload
(
host_region
,
ts
,
profit
=
self
.
_try_convert_to_syn_upload
(
host_region
,
offload_region
)
offload_region
)
if
self
.
_compare_profit
(
profit
,
max_profit
):
if
self
.
_compare_profit
(
profit
,
max_profit
):
undo_host_region
=
host_region
undo_host_region
=
host_region
...
@@ -474,7 +453,7 @@ class AsynGreedySolver(Solver):
...
@@ -474,7 +453,7 @@ class AsynGreedySolver(Solver):
best_ts
=
ts
best_ts
=
ts
if
best_ts
is
None
:
if
best_ts
is
None
:
raise
NotImplementedError
(
'
repair error!
'
)
raise
NotImplementedError
(
"
repair error!
"
)
assert
not
undo_offload_region
.
is_syn
assert
not
undo_offload_region
.
is_syn
undo_offload_region
.
is_syn
=
True
undo_offload_region
.
is_syn
=
True
...
@@ -500,17 +479,13 @@ class AsynGreedySolver(Solver):
...
@@ -500,17 +479,13 @@ class AsynGreedySolver(Solver):
ts
.
execute
()
ts
.
execute
()
extra_comm_cost
=
max
(
ts
.
iter_end_time
-
self
.
best_ts
.
iter_end_time
,
0
)
extra_comm_cost
=
max
(
ts
.
iter_end_time
-
self
.
best_ts
.
iter_end_time
,
0
)
profit
=
self
.
_compute_offload_profit
(
profit
=
self
.
_compute_offload_profit
(
ts
.
total_mem_saving
,
self
.
best_ts
.
peak_mem
-
ts
.
peak_mem
,
extra_comm_cost
)
ts
.
total_mem_saving
,
self
.
best_ts
.
peak_mem
-
ts
.
peak_mem
,
extra_comm_cost
)
return
ts
,
profit
return
ts
,
profit
class
SolverFactory
:
class
SolverFactory
:
solvers
:
Dict
[
str
,
Type
[
Solver
]]
=
{
solvers
:
Dict
[
str
,
Type
[
Solver
]]
=
{
"syn"
:
SynGreedySolver
,
"asyn"
:
AsynGreedySolver
}
'syn'
:
SynGreedySolver
,
'asyn'
:
AsynGreedySolver
}
@
staticmethod
@
staticmethod
def
create
(
solver_name
:
str
)
->
Type
[
Solver
]:
def
create
(
solver_name
:
str
)
->
Type
[
Solver
]:
...
...
colossalai/auto_parallel/offload/training_simulator.py
View file @
9e768b59
import
bisect
import
bisect
from
typing
import
List
,
Dict
from
collections
import
OrderedDict
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections
import
OrderedDict
from
typing
import
Dict
,
List
from
torch.fx.node
import
Node
from
torch.fx.node
import
Node
...
@@ -26,10 +26,7 @@ class TrainingSimulator(ABC):
...
@@ -26,10 +26,7 @@ class TrainingSimulator(ABC):
link_to_bw (Dict[str, Dict[float, float]]): communication links and the corresponding bandwidth.
link_to_bw (Dict[str, Dict[float, float]]): communication links and the corresponding bandwidth.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
region_list
:
List
[
Region
],
comp_power
:
float
,
link_to_bw
:
Dict
[
str
,
Dict
[
float
,
float
]])
->
None
:
region_list
:
List
[
Region
],
comp_power
:
float
,
link_to_bw
:
Dict
[
str
,
Dict
[
float
,
float
]])
->
None
:
self
.
region_list
=
region_list
self
.
region_list
=
region_list
self
.
region_num
=
len
(
region_list
)
self
.
region_num
=
len
(
region_list
)
...
@@ -87,11 +84,7 @@ class TrainingSimulator(ABC):
...
@@ -87,11 +84,7 @@ class TrainingSimulator(ABC):
class
SynTrainingSimulator
(
TrainingSimulator
):
class
SynTrainingSimulator
(
TrainingSimulator
):
def
__init__
(
self
,
region_list
:
List
[
Region
],
comp_power
:
float
,
link_to_bw
:
Dict
[
str
,
Dict
[
float
,
float
]])
->
None
:
def
__init__
(
self
,
region_list
:
List
[
Region
],
comp_power
:
float
,
link_to_bw
:
Dict
[
str
,
Dict
[
float
,
float
]])
->
None
:
super
().
__init__
(
region_list
,
comp_power
,
link_to_bw
)
super
().
__init__
(
region_list
,
comp_power
,
link_to_bw
)
def
execute
(
self
):
def
execute
(
self
):
...
@@ -115,8 +108,7 @@ class SynTrainingSimulator(TrainingSimulator):
...
@@ -115,8 +108,7 @@ class SynTrainingSimulator(TrainingSimulator):
self
.
runtime_mem
+=
region
.
param_size
self
.
runtime_mem
+=
region
.
param_size
for
node
in
region
.
nodes
:
for
node
in
region
.
nodes
:
self
.
runtime_mem
+=
calculate_fwd_tmp
(
node
)
+
\
self
.
runtime_mem
+=
calculate_fwd_tmp
(
node
)
+
calculate_fwd_out
(
node
)
calculate_fwd_out
(
node
)
self
.
fwd_node_mem
[
node
]
=
self
.
runtime_mem
self
.
fwd_node_mem
[
node
]
=
self
.
runtime_mem
self
.
peak_mem
=
max
(
self
.
runtime_mem
,
self
.
peak_mem
)
self
.
peak_mem
=
max
(
self
.
runtime_mem
,
self
.
peak_mem
)
self
.
total_mem_saving
+=
node
.
node_info
.
runtime_fwd_mem
-
self
.
runtime_mem
self
.
total_mem_saving
+=
node
.
node_info
.
runtime_fwd_mem
-
self
.
runtime_mem
...
@@ -141,18 +133,15 @@ class SynTrainingSimulator(TrainingSimulator):
...
@@ -141,18 +133,15 @@ class SynTrainingSimulator(TrainingSimulator):
self
.
runtime_mem
+=
region
.
param_size
self
.
runtime_mem
+=
region
.
param_size
for
node
in
region
.
nodes
.
__reversed__
():
for
node
in
region
.
nodes
.
__reversed__
():
self
.
runtime_mem
-=
calculate_fwd_out
(
node
)
self
.
runtime_mem
-=
calculate_fwd_out
(
node
)
self
.
runtime_mem
+=
node
.
meta
[
'bwd_mem_tmp'
]
+
\
self
.
runtime_mem
+=
node
.
meta
[
"bwd_mem_tmp"
]
+
node
.
meta
[
"bwd_mem_out"
]
node
.
meta
[
'bwd_mem_out'
]
self
.
peak_mem
=
max
(
self
.
runtime_mem
,
self
.
peak_mem
)
self
.
peak_mem
=
max
(
self
.
runtime_mem
,
self
.
peak_mem
)
# The memory savings of a node may be negative due to parameter prefetch.
# The memory savings of a node may be negative due to parameter prefetch.
self
.
total_mem_saving
+=
node
.
node_info
.
runtime_bwd_mem
-
self
.
runtime_mem
self
.
total_mem_saving
+=
node
.
node_info
.
runtime_bwd_mem
-
self
.
runtime_mem
self
.
bwd_node_mem
[
node
]
=
self
.
runtime_mem
self
.
bwd_node_mem
[
node
]
=
self
.
runtime_mem
self
.
runtime_mem
-=
(
node
.
meta
[
'bwd_mem_tmp'
]
+
self
.
runtime_mem
-=
node
.
meta
[
"bwd_mem_tmp"
]
+
calculate_fwd_tmp
(
node
)
calculate_fwd_tmp
(
node
))
# free bwd_mem_out
# free bwd_mem_out
self
.
bwd_node_deps
[
node
]
=
len
(
node
.
all_input_nodes
)
self
.
bwd_node_deps
[
node
]
=
len
(
node
.
all_input_nodes
)
...
@@ -160,12 +149,14 @@ class SynTrainingSimulator(TrainingSimulator):
...
@@ -160,12 +149,14 @@ class SynTrainingSimulator(TrainingSimulator):
if
user_node
in
self
.
bwd_node_deps
:
if
user_node
in
self
.
bwd_node_deps
:
self
.
bwd_node_deps
[
user_node
]
-=
1
self
.
bwd_node_deps
[
user_node
]
-=
1
if
self
.
bwd_node_deps
[
user_node
]
<=
0
:
if
self
.
bwd_node_deps
[
user_node
]
<=
0
:
self
.
runtime_mem
-=
user_node
.
meta
[
'
bwd_mem_out
'
]
self
.
runtime_mem
-=
user_node
.
meta
[
"
bwd_mem_out
"
]
if
self
.
runtime_mem
<
0
:
if
self
.
runtime_mem
<
0
:
raise
ValueError
(
f
"region id:
{
region
.
r_id
}
, node name:
{
node
.
name
}
, "
raise
ValueError
(
f
"runtime_mem:
{
self
.
runtime_mem
/
1024
**
2
:.
3
f
}
MB ---"
f
"region id:
{
region
.
r_id
}
, node name:
{
node
.
name
}
, "
f
"runtime memory computed less than 0, which is miscalculated!"
)
f
"runtime_mem:
{
self
.
runtime_mem
/
1024
**
2
:.
3
f
}
MB ---"
f
"runtime memory computed less than 0, which is miscalculated!"
)
# release parameter and offload gradient in region
# release parameter and offload gradient in region
if
region
.
r_id
==
region
.
shared_rid
:
if
region
.
r_id
==
region
.
shared_rid
:
...
@@ -177,23 +168,16 @@ class SynTrainingSimulator(TrainingSimulator):
...
@@ -177,23 +168,16 @@ class SynTrainingSimulator(TrainingSimulator):
class
AsynTrainingSimulator
(
TrainingSimulator
):
class
AsynTrainingSimulator
(
TrainingSimulator
):
def
__init__
(
self
,
region_list
:
List
[
Region
],
comp_power
:
float
,
link_to_bw
:
Dict
[
str
,
Dict
[
float
,
float
]])
->
None
:
def
__init__
(
self
,
region_list
:
List
[
Region
],
comp_power
:
float
,
link_to_bw
:
Dict
[
str
,
Dict
[
float
,
float
]])
->
None
:
super
().
__init__
(
region_list
,
comp_power
,
link_to_bw
)
super
().
__init__
(
region_list
,
comp_power
,
link_to_bw
)
self
.
iter_end_time
:
int
=
0
self
.
iter_end_time
:
int
=
0
# the last computation execution period
# the last computation execution period
self
.
last_comp
:
ExecutionPeriod
=
ExecutionPeriod
(
self
.
last_comp
:
ExecutionPeriod
=
ExecutionPeriod
(
start_time
=
0
,
end_time
=
0
)
start_time
=
0
,
end_time
=
0
)
# the last parameter prefetch execution period
# the last parameter prefetch execution period
self
.
last_h2d
:
ExecutionPeriod
=
ExecutionPeriod
(
self
.
last_h2d
:
ExecutionPeriod
=
ExecutionPeriod
(
start_time
=
0
,
end_time
=
0
)
start_time
=
0
,
end_time
=
0
)
# the last gradient offload execution period
# the last gradient offload execution period
self
.
last_d2h
:
ExecutionPeriod
=
ExecutionPeriod
(
self
.
last_d2h
:
ExecutionPeriod
=
ExecutionPeriod
(
start_time
=
0
,
end_time
=
0
)
start_time
=
0
,
end_time
=
0
)
# the forward computation execution period of the region
# the forward computation execution period of the region
self
.
fwd_reg_to_comp
:
OrderedDict
[
int
,
ExecutionPeriod
]
=
OrderedDict
()
self
.
fwd_reg_to_comp
:
OrderedDict
[
int
,
ExecutionPeriod
]
=
OrderedDict
()
# the forward parameter prefetch execution period of the region
# the forward parameter prefetch execution period of the region
...
@@ -204,10 +188,8 @@ class AsynTrainingSimulator(TrainingSimulator):
...
@@ -204,10 +188,8 @@ class AsynTrainingSimulator(TrainingSimulator):
self
.
bwd_reg_to_pref
:
OrderedDict
[
int
,
ExecutionPeriod
]
=
OrderedDict
()
self
.
bwd_reg_to_pref
:
OrderedDict
[
int
,
ExecutionPeriod
]
=
OrderedDict
()
# the gradient offload execution period of the region
# the gradient offload execution period of the region
# which is divided into those that are waiting and those that have been released
# which is divided into those that are waiting and those that have been released
self
.
bwd_reg_to_offl_waiting
:
OrderedDict
[
int
,
self
.
bwd_reg_to_offl_waiting
:
OrderedDict
[
int
,
ExecutionPeriod
]
=
OrderedDict
()
ExecutionPeriod
]
=
OrderedDict
()
self
.
bwd_reg_to_offl_freed
:
OrderedDict
[
int
,
ExecutionPeriod
]
=
OrderedDict
()
self
.
bwd_reg_to_offl_freed
:
OrderedDict
[
int
,
ExecutionPeriod
]
=
OrderedDict
()
# the region buffer, which records regions that are offloaded but not released
# the region buffer, which records regions that are offloaded but not released
self
.
reg_buffer_to_free
:
List
[
int
]
=
[]
self
.
reg_buffer_to_free
:
List
[
int
]
=
[]
...
@@ -217,10 +199,8 @@ class AsynTrainingSimulator(TrainingSimulator):
...
@@ -217,10 +199,8 @@ class AsynTrainingSimulator(TrainingSimulator):
# the region execution flow,
# the region execution flow,
# where fwd_reg_flow[i,j] denotes whether the parameters of j-th region are in the GPU
# where fwd_reg_flow[i,j] denotes whether the parameters of j-th region are in the GPU
# when the execution reaches the i-th region.
# when the execution reaches the i-th region.
self
.
fwd_reg_flow
=
torch
.
zeros
(
self
.
fwd_reg_flow
=
torch
.
zeros
((
self
.
region_num
,
self
.
region_num
)).
bool
()
(
self
.
region_num
,
self
.
region_num
)).
bool
()
self
.
bwd_reg_flow
=
torch
.
zeros
((
self
.
region_num
,
self
.
region_num
)).
bool
()
self
.
bwd_reg_flow
=
torch
.
zeros
(
(
self
.
region_num
,
self
.
region_num
)).
bool
()
def
execute
(
self
):
def
execute
(
self
):
"""
"""
...
@@ -232,7 +212,7 @@ class AsynTrainingSimulator(TrainingSimulator):
...
@@ -232,7 +212,7 @@ class AsynTrainingSimulator(TrainingSimulator):
for
reg
in
self
.
region_list
:
for
reg
in
self
.
region_list
:
if
reg
.
param_size
and
reg
.
r_id
<
self
.
region_num
-
1
:
if
reg
.
param_size
and
reg
.
r_id
<
self
.
region_num
-
1
:
for
nr
in
self
.
region_list
[
reg
.
r_id
+
1
:]:
for
nr
in
self
.
region_list
[
reg
.
r_id
+
1
:]:
if
nr
.
param_size
and
requires_upload_p_in_fwd
(
self
.
region_list
[
nr
.
shared_rid
]):
if
nr
.
param_size
and
requires_upload_p_in_fwd
(
self
.
region_list
[
nr
.
shared_rid
]):
reg
.
fwd_prefetch_region
=
nr
reg
.
fwd_prefetch_region
=
nr
break
break
...
@@ -249,8 +229,7 @@ class AsynTrainingSimulator(TrainingSimulator):
...
@@ -249,8 +229,7 @@ class AsynTrainingSimulator(TrainingSimulator):
self
.
runtime_mem
-=
self
.
region_list
[
reg_id
].
param_size
self
.
runtime_mem
-=
self
.
region_list
[
reg_id
].
param_size
self
.
bwd_reg_to_offl_waiting
.
clear
()
self
.
bwd_reg_to_offl_waiting
.
clear
()
self
.
iter_end_time
=
max
(
self
.
iter_end_time
=
max
(
self
.
last_comp
.
end_time
,
self
.
last_d2h
.
end_time
)
self
.
last_comp
.
end_time
,
self
.
last_d2h
.
end_time
)
def
_insert_h2d_exec
(
self
,
region
:
Region
,
is_fwd
:
bool
=
True
):
def
_insert_h2d_exec
(
self
,
region
:
Region
,
is_fwd
:
bool
=
True
):
"""
"""
...
@@ -258,10 +237,8 @@ class AsynTrainingSimulator(TrainingSimulator):
...
@@ -258,10 +237,8 @@ class AsynTrainingSimulator(TrainingSimulator):
"""
"""
pref_start_time
=
max
(
self
.
last_h2d
.
end_time
,
self
.
last_comp
.
end_time
)
pref_start_time
=
max
(
self
.
last_h2d
.
end_time
,
self
.
last_comp
.
end_time
)
pref_end_time
=
pref_start_time
+
\
pref_end_time
=
pref_start_time
+
2.0
*
self
.
_get_communication_overhead
(
"h2d"
,
region
.
param_size
)
2.0
*
self
.
_get_communication_overhead
(
'h2d'
,
region
.
param_size
)
pref_ep
=
ExecutionPeriod
(
start_time
=
pref_start_time
,
end_time
=
pref_end_time
)
pref_ep
=
ExecutionPeriod
(
start_time
=
pref_start_time
,
end_time
=
pref_end_time
)
if
is_fwd
:
if
is_fwd
:
self
.
fwd_reg_to_pref
[
region
.
r_id
]
=
pref_ep
self
.
fwd_reg_to_pref
[
region
.
r_id
]
=
pref_ep
else
:
else
:
...
@@ -276,18 +253,16 @@ class AsynTrainingSimulator(TrainingSimulator):
...
@@ -276,18 +253,16 @@ class AsynTrainingSimulator(TrainingSimulator):
if
is_fwd
:
if
is_fwd
:
reg_to_comp
=
self
.
fwd_reg_to_comp
reg_to_comp
=
self
.
fwd_reg_to_comp
reg_to_pref
=
self
.
fwd_reg_to_pref
reg_to_pref
=
self
.
fwd_reg_to_pref
flop_key
=
'
fwd_flop
'
flop_key
=
"
fwd_flop
"
else
:
else
:
reg_to_comp
=
self
.
bwd_reg_to_comp
reg_to_comp
=
self
.
bwd_reg_to_comp
reg_to_pref
=
self
.
bwd_reg_to_pref
reg_to_pref
=
self
.
bwd_reg_to_pref
flop_key
=
'bwd_flop'
flop_key
=
"bwd_flop"
comp_start_time
=
max
(
self
.
last_comp
.
end_time
,
reg_to_pref
.
get
(
comp_start_time
=
max
(
self
.
last_comp
.
end_time
,
reg_to_pref
.
get
(
region
.
r_id
,
ExecutionPeriod
(
0
,
0
)).
end_time
)
region
.
r_id
,
ExecutionPeriod
(
0
,
0
)).
end_time
)
comp_end_time
=
comp_start_time
+
sum
(
comp_end_time
=
comp_start_time
+
\
[
self
.
_get_computing_overhead
(
node
.
meta
.
get
(
flop_key
,
0
))
for
node
in
region
.
nodes
]
sum
([
self
.
_get_computing_overhead
(
node
.
meta
.
get
(
flop_key
,
0
))
)
for
node
in
region
.
nodes
])
comp_ep
=
ExecutionPeriod
(
start_time
=
comp_start_time
,
end_time
=
comp_end_time
)
comp_ep
=
ExecutionPeriod
(
start_time
=
comp_start_time
,
end_time
=
comp_end_time
)
reg_to_comp
[
region
.
r_id
]
=
comp_ep
reg_to_comp
[
region
.
r_id
]
=
comp_ep
self
.
last_comp
=
comp_ep
self
.
last_comp
=
comp_ep
...
@@ -297,10 +272,8 @@ class AsynTrainingSimulator(TrainingSimulator):
...
@@ -297,10 +272,8 @@ class AsynTrainingSimulator(TrainingSimulator):
"""
"""
offl_start_time
=
max
(
self
.
last_d2h
.
end_time
,
self
.
last_comp
.
end_time
)
offl_start_time
=
max
(
self
.
last_d2h
.
end_time
,
self
.
last_comp
.
end_time
)
offl_end_time
=
offl_start_time
+
\
offl_end_time
=
offl_start_time
+
self
.
_get_communication_overhead
(
"d2h"
,
region
.
param_size
)
self
.
_get_communication_overhead
(
'd2h'
,
region
.
param_size
)
offl_ep
=
ExecutionPeriod
(
start_time
=
offl_start_time
,
end_time
=
offl_end_time
)
offl_ep
=
ExecutionPeriod
(
start_time
=
offl_start_time
,
end_time
=
offl_end_time
)
self
.
bwd_reg_to_offl_waiting
[
region
.
r_id
]
=
offl_ep
self
.
bwd_reg_to_offl_waiting
[
region
.
r_id
]
=
offl_ep
self
.
last_d2h
=
offl_ep
self
.
last_d2h
=
offl_ep
...
@@ -332,20 +305,17 @@ class AsynTrainingSimulator(TrainingSimulator):
...
@@ -332,20 +305,17 @@ class AsynTrainingSimulator(TrainingSimulator):
self
.
fwd_reg_flow
[
region
.
r_id
,
region
.
r_id
]
=
True
self
.
fwd_reg_flow
[
region
.
r_id
,
region
.
r_id
]
=
True
else
:
else
:
self
.
fwd_reg_flow
[
region
.
r_id
]
=
self
.
fwd_reg_flow
[
region
.
r_id
-
1
]
self
.
fwd_reg_flow
[
region
.
r_id
]
=
self
.
fwd_reg_flow
[
region
.
r_id
-
1
]
self
.
fwd_reg_flow
[
region
.
r_id
,
self
.
fwd_reg_flow
[
region
.
r_id
,
self
.
reg_buffer_to_free
]
=
False
self
.
reg_buffer_to_free
]
=
False
self
.
reg_buffer_to_free
.
clear
()
self
.
reg_buffer_to_free
.
clear
()
# prefetch parameters of the next region
# prefetch parameters of the next region
fwd_prefetch_region
=
region
.
fwd_prefetch_region
fwd_prefetch_region
=
region
.
fwd_prefetch_region
if
fwd_prefetch_region
and
requires_upload_p_in_fwd
(
self
.
region_list
[
fwd_prefetch_region
.
shared_rid
]):
if
fwd_prefetch_region
and
requires_upload_p_in_fwd
(
self
.
region_list
[
fwd_prefetch_region
.
shared_rid
]):
self
.
runtime_mem
+=
fwd_prefetch_region
.
param_size
self
.
runtime_mem
+=
fwd_prefetch_region
.
param_size
self
.
fwd_reg_flow
[
region
.
r_id
,
self
.
fwd_reg_flow
[
region
.
r_id
,
fwd_prefetch_region
.
r_id
]
=
True
fwd_prefetch_region
.
r_id
]
=
True
for
node
in
region
.
nodes
:
for
node
in
region
.
nodes
:
self
.
runtime_mem
+=
calculate_fwd_tmp
(
node
)
+
\
self
.
runtime_mem
+=
calculate_fwd_tmp
(
node
)
+
calculate_fwd_out
(
node
)
calculate_fwd_out
(
node
)
self
.
peak_mem
=
max
(
self
.
runtime_mem
,
self
.
peak_mem
)
self
.
peak_mem
=
max
(
self
.
runtime_mem
,
self
.
peak_mem
)
self
.
total_mem_saving
+=
node
.
node_info
.
runtime_fwd_mem
-
self
.
runtime_mem
self
.
total_mem_saving
+=
node
.
node_info
.
runtime_fwd_mem
-
self
.
runtime_mem
...
@@ -354,8 +324,7 @@ class AsynTrainingSimulator(TrainingSimulator):
...
@@ -354,8 +324,7 @@ class AsynTrainingSimulator(TrainingSimulator):
if
region
.
need_offload
:
if
region
.
need_offload
:
self
.
runtime_mem
-=
region
.
param_size
self
.
runtime_mem
-=
region
.
param_size
assert
len
(
assert
len
(
self
.
reg_buffer_to_free
)
<=
1
,
f
"
{
len
(
self
.
reg_buffer_to_free
)
}
"
self
.
reg_buffer_to_free
)
<=
1
,
f
'
{
len
(
self
.
reg_buffer_to_free
)
}
'
self
.
reg_buffer_to_free
.
append
(
region
.
r_id
)
self
.
reg_buffer_to_free
.
append
(
region
.
r_id
)
def
_eval_bwd_cost_per_region
(
self
,
region
:
Region
):
def
_eval_bwd_cost_per_region
(
self
,
region
:
Region
):
...
@@ -398,8 +367,7 @@ class AsynTrainingSimulator(TrainingSimulator):
...
@@ -398,8 +367,7 @@ class AsynTrainingSimulator(TrainingSimulator):
self
.
bwd_reg_flow
[
region
.
r_id
]
=
self
.
bwd_reg_flow
[
region
.
r_id
+
1
]
self
.
bwd_reg_flow
[
region
.
r_id
]
=
self
.
bwd_reg_flow
[
region
.
r_id
+
1
]
else
:
else
:
self
.
bwd_reg_flow
[
region
.
r_id
]
=
self
.
fwd_reg_flow
[
-
1
]
self
.
bwd_reg_flow
[
region
.
r_id
]
=
self
.
fwd_reg_flow
[
-
1
]
self
.
bwd_reg_flow
[
region
.
r_id
,
self
.
bwd_reg_flow
[
region
.
r_id
,
self
.
reg_buffer_to_free
]
=
False
self
.
reg_buffer_to_free
]
=
False
# free gradients in the buffer
# free gradients in the buffer
while
len
(
self
.
reg_buffer_to_free
):
while
len
(
self
.
reg_buffer_to_free
):
...
@@ -415,8 +383,7 @@ class AsynTrainingSimulator(TrainingSimulator):
...
@@ -415,8 +383,7 @@ class AsynTrainingSimulator(TrainingSimulator):
bwd_prefetch_region
=
region
.
bwd_prefetch_region
bwd_prefetch_region
=
region
.
bwd_prefetch_region
if
bwd_prefetch_region
:
if
bwd_prefetch_region
:
self
.
runtime_mem
+=
bwd_prefetch_region
.
param_size
self
.
runtime_mem
+=
bwd_prefetch_region
.
param_size
self
.
bwd_reg_flow
[
region
.
r_id
,
self
.
bwd_reg_flow
[
region
.
r_id
,
bwd_prefetch_region
.
r_id
]
=
True
bwd_prefetch_region
.
r_id
]
=
True
# add the gradient of the parameter
# add the gradient of the parameter
if
region
.
r_id
<
region
.
shared_rid
:
if
region
.
r_id
<
region
.
shared_rid
:
...
@@ -426,10 +393,8 @@ class AsynTrainingSimulator(TrainingSimulator):
...
@@ -426,10 +393,8 @@ class AsynTrainingSimulator(TrainingSimulator):
self
.
runtime_mem
+=
region
.
param_size
self
.
runtime_mem
+=
region
.
param_size
for
node
in
region
.
nodes
.
__reversed__
():
for
node
in
region
.
nodes
.
__reversed__
():
self
.
runtime_mem
-=
calculate_fwd_out
(
node
)
self
.
runtime_mem
-=
calculate_fwd_out
(
node
)
self
.
runtime_mem
+=
node
.
meta
[
'bwd_mem_tmp'
]
+
\
self
.
runtime_mem
+=
node
.
meta
[
"bwd_mem_tmp"
]
+
node
.
meta
[
"bwd_mem_out"
]
node
.
meta
[
'bwd_mem_out'
]
self
.
peak_mem
=
max
(
self
.
runtime_mem
,
self
.
peak_mem
)
self
.
peak_mem
=
max
(
self
.
runtime_mem
,
self
.
peak_mem
)
# The memory savings of a node may be negative due to parameter prefetch.
# The memory savings of a node may be negative due to parameter prefetch.
...
@@ -437,8 +402,7 @@ class AsynTrainingSimulator(TrainingSimulator):
...
@@ -437,8 +402,7 @@ class AsynTrainingSimulator(TrainingSimulator):
self
.
bwd_node_mem
[
node
]
=
self
.
runtime_mem
self
.
bwd_node_mem
[
node
]
=
self
.
runtime_mem
self
.
runtime_mem
-=
(
node
.
meta
[
'bwd_mem_tmp'
]
+
self
.
runtime_mem
-=
node
.
meta
[
"bwd_mem_tmp"
]
+
calculate_fwd_tmp
(
node
)
calculate_fwd_tmp
(
node
))
# free bwd_mem_out
# free bwd_mem_out
self
.
bwd_node_deps
[
node
]
=
len
(
node
.
all_input_nodes
)
self
.
bwd_node_deps
[
node
]
=
len
(
node
.
all_input_nodes
)
...
@@ -446,12 +410,14 @@ class AsynTrainingSimulator(TrainingSimulator):
...
@@ -446,12 +410,14 @@ class AsynTrainingSimulator(TrainingSimulator):
if
user_node
in
self
.
bwd_node_deps
:
if
user_node
in
self
.
bwd_node_deps
:
self
.
bwd_node_deps
[
user_node
]
-=
1
self
.
bwd_node_deps
[
user_node
]
-=
1
if
self
.
bwd_node_deps
[
user_node
]
<=
0
:
if
self
.
bwd_node_deps
[
user_node
]
<=
0
:
self
.
runtime_mem
-=
user_node
.
meta
[
'
bwd_mem_out
'
]
self
.
runtime_mem
-=
user_node
.
meta
[
"
bwd_mem_out
"
]
if
self
.
runtime_mem
<
0
:
if
self
.
runtime_mem
<
0
:
raise
ValueError
(
f
"region id:
{
region
.
r_id
}
, node name:
{
node
.
name
}
, "
raise
ValueError
(
f
"runtime_mem:
{
self
.
runtime_mem
/
1024
**
2
:.
3
f
}
MB ---"
f
"region id:
{
region
.
r_id
}
, node name:
{
node
.
name
}
, "
f
"runtime memory computed less than 0, which is miscalculated!"
)
f
"runtime_mem:
{
self
.
runtime_mem
/
1024
**
2
:.
3
f
}
MB ---"
f
"runtime memory computed less than 0, which is miscalculated!"
)
# release parameters of the region
# release parameters of the region
if
requires_release_p_in_bwd
(
self
.
region_list
[
region
.
shared_rid
]):
if
requires_release_p_in_bwd
(
self
.
region_list
[
region
.
shared_rid
]):
...
...
colossalai/auto_parallel/offload/util.py
View file @
9e768b59
...
@@ -35,7 +35,6 @@ class NvDevicePower:
...
@@ -35,7 +35,6 @@ class NvDevicePower:
class
GlobalRuntimeInfo
(
metaclass
=
SingletonMeta
):
class
GlobalRuntimeInfo
(
metaclass
=
SingletonMeta
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
h2d_stream
=
torch
.
cuda
.
Stream
()
self
.
h2d_stream
=
torch
.
cuda
.
Stream
()
self
.
d2h_stream
=
torch
.
cuda
.
Stream
()
self
.
d2h_stream
=
torch
.
cuda
.
Stream
()
...
@@ -50,21 +49,18 @@ def compute_act_peak_mem(region_list: List[Region]) -> float:
...
@@ -50,21 +49,18 @@ def compute_act_peak_mem(region_list: List[Region]) -> float:
# forward
# forward
for
region
in
region_list
:
for
region
in
region_list
:
for
node
in
region
.
nodes
:
for
node
in
region
.
nodes
:
runtime_mem
=
runtime_mem
+
\
runtime_mem
=
runtime_mem
+
calculate_fwd_tmp
(
node
)
+
calculate_fwd_out
(
node
)
calculate_fwd_tmp
(
node
)
+
calculate_fwd_out
(
node
)
act_peak_mem
=
max
(
runtime_mem
,
act_peak_mem
)
act_peak_mem
=
max
(
runtime_mem
,
act_peak_mem
)
# backward
# backward
bwd_deps
=
{}
bwd_deps
=
{}
for
region
in
region_list
.
__reversed__
():
for
region
in
region_list
.
__reversed__
():
for
node
in
region
.
nodes
.
__reversed__
():
for
node
in
region
.
nodes
.
__reversed__
():
runtime_mem
-=
calculate_fwd_out
(
node
)
runtime_mem
-=
calculate_fwd_out
(
node
)
runtime_mem
=
runtime_mem
+
\
runtime_mem
=
runtime_mem
+
node
.
meta
[
"bwd_mem_tmp"
]
+
node
.
meta
[
"bwd_mem_out"
]
node
.
meta
[
'bwd_mem_tmp'
]
+
node
.
meta
[
'bwd_mem_out'
]
act_peak_mem
=
max
(
runtime_mem
,
act_peak_mem
)
act_peak_mem
=
max
(
runtime_mem
,
act_peak_mem
)
runtime_mem
=
runtime_mem
-
\
runtime_mem
=
runtime_mem
-
node
.
meta
[
"bwd_mem_tmp"
]
-
calculate_fwd_tmp
(
node
)
node
.
meta
[
'bwd_mem_tmp'
]
-
calculate_fwd_tmp
(
node
)
# free bwd_mem_out
# free bwd_mem_out
bwd_deps
[
node
]
=
len
(
node
.
all_input_nodes
)
bwd_deps
[
node
]
=
len
(
node
.
all_input_nodes
)
...
@@ -72,7 +68,7 @@ def compute_act_peak_mem(region_list: List[Region]) -> float:
...
@@ -72,7 +68,7 @@ def compute_act_peak_mem(region_list: List[Region]) -> float:
if
user_node
in
bwd_deps
:
if
user_node
in
bwd_deps
:
bwd_deps
[
user_node
]
-=
1
bwd_deps
[
user_node
]
-=
1
if
bwd_deps
[
user_node
]
<=
0
:
if
bwd_deps
[
user_node
]
<=
0
:
runtime_mem
-=
user_node
.
meta
[
'
bwd_mem_out
'
]
runtime_mem
-=
user_node
.
meta
[
"
bwd_mem_out
"
]
return
act_peak_mem
return
act_peak_mem
...
@@ -86,13 +82,15 @@ def compute_total_param_mem(region_list: List[Region]) -> float:
...
@@ -86,13 +82,15 @@ def compute_total_param_mem(region_list: List[Region]) -> float:
def
requires_upload_p_in_fwd
(
shared_reg
:
Region
):
def
requires_upload_p_in_fwd
(
shared_reg
:
Region
):
return
(
shared_reg
.
r_id
>=
shared_reg
.
shared_rid
)
or
(
shared_reg
.
r_id
<
shared_reg
.
shared_rid
return
(
shared_reg
.
r_id
>=
shared_reg
.
shared_rid
)
or
(
and
shared_reg
.
need_offload
)
shared_reg
.
r_id
<
shared_reg
.
shared_rid
and
shared_reg
.
need_offload
)
def
requires_release_p_in_bwd
(
shared_reg
:
Region
):
def
requires_release_p_in_bwd
(
shared_reg
:
Region
):
return
(
shared_reg
.
r_id
>=
shared_reg
.
shared_rid
)
or
(
shared_reg
.
r_id
<
shared_reg
.
shared_rid
return
(
shared_reg
.
r_id
>=
shared_reg
.
shared_rid
)
or
(
and
shared_reg
.
need_offload
)
shared_reg
.
r_id
<
shared_reg
.
shared_rid
and
shared_reg
.
need_offload
)
def
requires_offload_g_in_bwd
(
region
:
Region
):
def
requires_offload_g_in_bwd
(
region
:
Region
):
...
...
colossalai/auto_parallel/passes/comm_metainfo_pass.py
View file @
9e768b59
...
@@ -14,18 +14,20 @@ from colossalai.tensor.sharding_spec import ShardingSpec
...
@@ -14,18 +14,20 @@ from colossalai.tensor.sharding_spec import ShardingSpec
shape_consistency_manager
=
ShapeConsistencyManager
()
shape_consistency_manager
=
ShapeConsistencyManager
()
def
_construct_shard_meta_info
(
node
:
Node
,
origin_sharding_spec
:
ShardingSpec
,
def
_construct_shard_meta_info
(
target_sharding_spec
:
ShardingSpec
)
->
ShardMetaInfo
:
node
:
Node
,
origin_sharding_spec
:
ShardingSpec
,
target_sharding_spec
:
ShardingSpec
)
->
ShardMetaInfo
:
# get comm_action_sequence and total_cost from shape_consistency_manager
# get comm_action_sequence and total_cost from shape_consistency_manager
_
,
comm_action_sequence
,
total_cost
=
shape_consistency_manager
.
shape_consistency
(
_
,
comm_action_sequence
,
total_cost
=
shape_consistency_manager
.
shape_consistency
(
origin_sharding_spec
,
target_sharding_spec
)
origin_sharding_spec
,
target_sharding_spec
)
meta_info
=
ShardMetaInfo
()
meta_info
=
ShardMetaInfo
()
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
# NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel
# get mem cost for ShardMetaInfo
# get mem cost for ShardMetaInfo
mem_cost
=
shape_consistency_manager
.
mem_cost
(
comm_action_sequence
)
mem_cost
=
shape_consistency_manager
.
mem_cost
(
comm_action_sequence
)
# extract user that has _meta_data and extract element length
# extract user that has _meta_data and extract element length
input_node
=
next
(
n
for
n
in
node
.
_input_nodes
if
hasattr
(
n
,
'
_meta_data
'
))
input_node
=
next
(
n
for
n
in
node
.
_input_nodes
if
hasattr
(
n
,
"
_meta_data
"
))
element_length
=
input_node
.
_meta_data
.
element_size
()
element_length
=
input_node
.
_meta_data
.
element_size
()
mem_cost
.
fwd
.
activation
*=
element_length
mem_cost
.
fwd
.
activation
*=
element_length
...
@@ -37,9 +39,11 @@ def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
...
@@ -37,9 +39,11 @@ def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
meta_info
.
memory_cost
=
mem_cost
meta_info
.
memory_cost
=
mem_cost
# get computation cost for ShardMetaInfo
# get computation cost for ShardMetaInfo
meta_info
.
compute_cost
=
TrainCycleItem
(
total_cost
[
'forward'
]
*
element_length
,
meta_info
.
compute_cost
=
TrainCycleItem
(
total_cost
[
'backward'
]
*
element_length
,
total_cost
[
"forward"
]
*
element_length
,
total_cost
[
'total'
]
*
element_length
)
total_cost
[
"backward"
]
*
element_length
,
total_cost
[
"total"
]
*
element_length
,
)
# get tensor shape for ShardMetaInfo
# get tensor shape for ShardMetaInfo
origin_sharding_spec
:
ShardingSpec
origin_sharding_spec
:
ShardingSpec
...
@@ -47,9 +51,9 @@ def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
...
@@ -47,9 +51,9 @@ def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec,
input_shape
=
origin_sharding_spec
.
get_sharded_shape_per_device
()
input_shape
=
origin_sharding_spec
.
get_sharded_shape_per_device
()
output_shape
=
target_sharding_spec
.
get_sharded_shape_per_device
()
output_shape
=
target_sharding_spec
.
get_sharded_shape_per_device
()
meta_info
.
fwd_in
=
[
torch
.
rand
(
input_shape
,
device
=
'
meta
'
)]
meta_info
.
fwd_in
=
[
torch
.
rand
(
input_shape
,
device
=
"
meta
"
)]
meta_info
.
fwd_buffer
=
[]
meta_info
.
fwd_buffer
=
[]
meta_info
.
fwd_out
=
[
torch
.
rand
(
output_shape
,
device
=
'
meta
'
)]
meta_info
.
fwd_out
=
[
torch
.
rand
(
output_shape
,
device
=
"
meta
"
)]
return
meta_info
return
meta_info
...
@@ -62,8 +66,10 @@ def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -
...
@@ -62,8 +66,10 @@ def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -
# extract node index and user node index
# extract node index and user node index
args
=
node
.
args
args
=
node
.
args
node_index
,
user_node_index
=
args
[
3
],
args
[
4
]
node_index
,
user_node_index
=
args
[
3
],
args
[
4
]
origin_sharding_spec
,
target_sharding_spec
=
origin_spec_dict
[
node_index
],
sharding_spec_dict
[
node_index
][
origin_sharding_spec
,
target_sharding_spec
=
(
user_node_index
]
origin_spec_dict
[
node_index
],
sharding_spec_dict
[
node_index
][
user_node_index
],
)
return
_construct_shard_meta_info
(
node
,
origin_sharding_spec
,
target_sharding_spec
)
return
_construct_shard_meta_info
(
node
,
origin_sharding_spec
,
target_sharding_spec
)
...
@@ -77,37 +83,42 @@ def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> S
...
@@ -77,37 +83,42 @@ def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> S
# this case is for all_reduce, there will be no memory cost
# this case is for all_reduce, there will be no memory cost
meta_info
=
ShardMetaInfo
()
meta_info
=
ShardMetaInfo
()
meta_info
.
memory_cost
=
TrainCycleItem
(
MemoryCost
(),
MemoryCost
(),
MemoryCost
)
meta_info
.
memory_cost
=
TrainCycleItem
(
MemoryCost
(),
MemoryCost
(),
MemoryCost
)
output_node
=
next
(
n
for
n
in
node
.
users
if
hasattr
(
n
,
'
_meta_data
'
))
output_node
=
next
(
n
for
n
in
node
.
users
if
hasattr
(
n
,
"
_meta_data
"
))
element_length
=
output_node
.
_meta_data
.
element_size
()
element_length
=
output_node
.
_meta_data
.
element_size
()
total_cost
=
comm_action
.
comm_spec
.
get_comm_cost
()
total_cost
=
comm_action
.
comm_spec
.
get_comm_cost
()
meta_info
.
compute_cost
=
TrainCycleItem
(
total_cost
[
'forward'
]
*
element_length
,
meta_info
.
compute_cost
=
TrainCycleItem
(
total_cost
[
'backward'
]
*
element_length
,
total_cost
[
"forward"
]
*
element_length
,
total_cost
[
'total'
]
*
element_length
)
total_cost
[
"backward"
]
*
element_length
,
total_cost
[
"total"
]
*
element_length
,
)
input_shape
=
output_shape
=
comm_action
.
comm_spec
.
sharding_spec
.
get_sharded_shape_per_device
()
input_shape
=
output_shape
=
comm_action
.
comm_spec
.
sharding_spec
.
get_sharded_shape_per_device
()
meta_info
.
fwd_in
=
[
torch
.
rand
(
input_shape
,
device
=
'
meta
'
)]
meta_info
.
fwd_in
=
[
torch
.
rand
(
input_shape
,
device
=
"
meta
"
)]
meta_info
.
fwd_buffer
=
[]
meta_info
.
fwd_buffer
=
[]
meta_info
.
fwd_out
=
[
torch
.
rand
(
output_shape
,
device
=
'
meta
'
)]
meta_info
.
fwd_out
=
[
torch
.
rand
(
output_shape
,
device
=
"
meta
"
)]
else
:
else
:
# this case will be handled by shape consistency manager
# this case will be handled by shape consistency manager
origin_sharding_spec
,
target_sharding_spec
=
comm_action
.
comm_spec
[
'src_spec'
],
comm_action
.
comm_spec
[
origin_sharding_spec
,
target_sharding_spec
=
(
'tgt_spec'
]
comm_action
.
comm_spec
[
"src_spec"
],
comm_action
.
comm_spec
[
"tgt_spec"
],
)
meta_info
=
_construct_shard_meta_info
(
node
,
origin_sharding_spec
,
target_sharding_spec
)
meta_info
=
_construct_shard_meta_info
(
node
,
origin_sharding_spec
,
target_sharding_spec
)
return
meta_info
return
meta_info
def
comm_metainfo_pass
(
gm
:
GraphModule
,
sharding_spec_dict
:
Dict
,
origin_spec_dict
:
Dict
,
def
comm_metainfo_pass
(
comm_actions_dict
:
Dict
)
->
GraphModule
:
gm
:
GraphModule
,
sharding_spec_dict
:
Dict
,
origin_spec_dict
:
Dict
,
comm_actions_dict
:
Dict
)
->
GraphModule
:
"""
"""
The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph.
The method manages all the metainfo of the communication node (run_time_apply, runtime_comm_spec_apply) in the graph.
"""
"""
for
node
in
gm
.
graph
.
nodes
:
for
node
in
gm
.
graph
.
nodes
:
if
node
.
target
==
runtime_apply
:
if
node
.
target
==
runtime_apply
:
setattr
(
node
,
'
best_strategy_info
'
,
_runtime_apply_meta_info
(
node
,
origin_spec_dict
,
sharding_spec_dict
))
setattr
(
node
,
"
best_strategy_info
"
,
_runtime_apply_meta_info
(
node
,
origin_spec_dict
,
sharding_spec_dict
))
elif
node
.
target
==
runtime_comm_spec_apply
:
elif
node
.
target
==
runtime_comm_spec_apply
:
setattr
(
node
,
'
best_strategy_info
'
,
_runtime_comm_spec_apply_meta_info
(
node
,
comm_actions_dict
))
setattr
(
node
,
"
best_strategy_info
"
,
_runtime_comm_spec_apply_meta_info
(
node
,
comm_actions_dict
))
else
:
else
:
pass
pass
return
gm
return
gm
colossalai/auto_parallel/passes/meta_info_prop.py
View file @
9e768b59
...
@@ -21,16 +21,15 @@ def _normalize_tuple(x):
...
@@ -21,16 +21,15 @@ def _normalize_tuple(x):
@
compatibility
(
is_backward_compatible
=
False
)
@
compatibility
(
is_backward_compatible
=
False
)
class
MetaInfoProp
:
class
MetaInfoProp
:
def
__init__
(
self
,
module
:
GraphModule
)
->
None
:
def
__init__
(
self
,
module
:
GraphModule
)
->
None
:
self
.
module
=
module
self
.
module
=
module
self
.
func_dict
=
{
self
.
func_dict
=
{
'
placeholder
'
:
self
.
placeholder_handler
,
"
placeholder
"
:
self
.
placeholder_handler
,
'
get_attr
'
:
self
.
get_attr_handler
,
"
get_attr
"
:
self
.
get_attr_handler
,
'
output
'
:
self
.
output_handler
,
"
output
"
:
self
.
output_handler
,
'
call_function
'
:
self
.
node_handler
,
"
call_function
"
:
self
.
node_handler
,
'
call_module
'
:
self
.
node_handler
,
"
call_module
"
:
self
.
node_handler
,
'
call_method
'
:
self
.
node_handler
,
"
call_method
"
:
self
.
node_handler
,
}
}
def
_set_data_ptr
(
self
,
x
):
def
_set_data_ptr
(
self
,
x
):
...
@@ -46,7 +45,7 @@ class MetaInfoProp:
...
@@ -46,7 +45,7 @@ class MetaInfoProp:
"""
"""
Check if the node is inplace operation.
Check if the node is inplace operation.
"""
"""
if
node
.
op
==
'
call_module
'
:
if
node
.
op
==
"
call_module
"
:
return
node
.
graph
.
owning_module
.
get_submodule
(
node
.
target
).
__class__
in
OUTPUT_SAVED_MOD
return
node
.
graph
.
owning_module
.
get_submodule
(
node
.
target
).
__class__
in
OUTPUT_SAVED_MOD
elif
node
.
op
==
"call_function"
:
elif
node
.
op
==
"call_function"
:
return
node
.
target
in
OUTPUT_SAVED_OPS
return
node
.
target
in
OUTPUT_SAVED_OPS
...
@@ -66,7 +65,7 @@ class MetaInfoProp:
...
@@ -66,7 +65,7 @@ class MetaInfoProp:
Handle the placeholder node.
Handle the placeholder node.
"""
"""
graph_info
=
GraphInfo
()
graph_info
=
GraphInfo
()
out
=
_normalize_tuple
(
getattr
(
node
,
'
_meta_data
'
,
None
))
out
=
_normalize_tuple
(
getattr
(
node
,
"
_meta_data
"
,
None
))
graph_info
.
fwd_out
=
list
(
out
)
if
out
[
0
]
is
not
None
else
[]
graph_info
.
fwd_out
=
list
(
out
)
if
out
[
0
]
is
not
None
else
[]
node
.
meta
=
{
**
asdict
(
graph_info
)}
node
.
meta
=
{
**
asdict
(
graph_info
)}
...
@@ -96,7 +95,7 @@ class MetaInfoProp:
...
@@ -96,7 +95,7 @@ class MetaInfoProp:
"""
"""
Handle other kind of nodes
Handle other kind of nodes
"""
"""
assert
hasattr
(
node
,
'
best_strategy_info
'
),
f
"Cannot find best_strategy_info in node
{
node
}
,
{
node
.
op
}
"
assert
hasattr
(
node
,
"
best_strategy_info
"
),
f
"Cannot find best_strategy_info in node
{
node
}
,
{
node
.
op
}
"
graph_info
=
GraphInfo
()
graph_info
=
GraphInfo
()
meta_info
=
node
.
best_strategy_info
meta_info
=
node
.
best_strategy_info
meta_info
:
ShardMetaInfo
meta_info
:
ShardMetaInfo
...
@@ -126,7 +125,8 @@ class MetaInfoProp:
...
@@ -126,7 +125,8 @@ class MetaInfoProp:
for
tensor
in
par
.
meta
.
get
(
"fwd_out"
,
[]):
for
tensor
in
par
.
meta
.
get
(
"fwd_out"
,
[]):
tensor
:
torch
.
Tensor
tensor
:
torch
.
Tensor
target_input_tensor
=
next
(
target_input_tensor
=
next
(
(
x
for
x
in
input_tensors
if
not
x
.
data_ptr
()
and
x
.
shape
==
tensor
.
shape
),
None
)
(
x
for
x
in
input_tensors
if
not
x
.
data_ptr
()
and
x
.
shape
==
tensor
.
shape
),
None
)
if
target_input_tensor
is
not
None
:
if
target_input_tensor
is
not
None
:
target_input_tensor
.
data_ptr
=
tensor
.
data_ptr
target_input_tensor
.
data_ptr
=
tensor
.
data_ptr
...
@@ -148,7 +148,7 @@ class MetaInfoProp:
...
@@ -148,7 +148,7 @@ class MetaInfoProp:
graph_info
.
fwd_tmp
=
buffer_tensors
graph_info
.
fwd_tmp
=
buffer_tensors
graph_info
.
fwd_out
=
output_tensors
graph_info
.
fwd_out
=
output_tensors
# fetch other memory information
s
# fetch other memory information
memory_cost
=
meta_info
.
memory_cost
memory_cost
=
meta_info
.
memory_cost
graph_info
.
fwd_mem_tmp
=
memory_cost
.
fwd
.
temp
graph_info
.
fwd_mem_tmp
=
memory_cost
.
fwd
.
temp
graph_info
.
fwd_mem_out
=
memory_cost
.
fwd
.
activation
graph_info
.
fwd_mem_out
=
memory_cost
.
fwd
.
activation
...
...
colossalai/auto_parallel/passes/runtime_apply_pass.py
View file @
9e768b59
from
copy
import
deepcopy
from
typing
import
Dict
,
List
from
typing
import
Dict
,
List
import
torch
import
torch
from
torch.fx.node
import
Node
from
torch.fx.node
import
Node
from
colossalai._analyzer.fx.node_util
import
MetaInfo
from
colossalai._analyzer.fx.node_util
import
MetaInfo
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
CommType
,
OperationDataType
CommAction
,
CommType
,
OperationData
,
OperationDataType
,
TrainCycleItem
,
)
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.comm_spec
import
CommSpec
from
colossalai.tensor.comm_spec
import
CommSpec
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
colossalai.tensor.sharding_spec
import
ShardingSpec
...
@@ -30,19 +22,22 @@ def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: i
...
@@ -30,19 +22,22 @@ def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: i
return
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
node
,
origin_sharding_spec
,
target_sharding_spec
)
return
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
node
,
origin_sharding_spec
,
target_sharding_spec
)
def
runtime_apply_for_iterable_object
(
node
:
Node
,
origin_dict
:
Dict
,
input_dict
:
Dict
,
node_index
:
int
,
def
runtime_apply_for_iterable_object
(
user_node_index
:
int
):
node
:
Node
,
origin_dict
:
Dict
,
input_dict
:
Dict
,
node_index
:
int
,
user_node_index
:
int
):
"""
"""
This method will be invoked during runtime to do the shape consistency, which makes sure the activations in type of tuple or list
This method will be invoked during runtime to do the shape consistency, which makes sure the activations in type of tuple or list
is converted into the user node expected form.
is converted into the user node expected form.
"""
"""
rst
=
[]
rst
=
[]
for
index
,
(
origin_sharding_spec
,
for
index
,
(
origin_sharding_spec
,
target_sharding_spec
)
in
enumerate
(
target_sharding_spec
)
in
enumerate
(
zip
(
origin_dict
[
node_index
],
zip
(
origin_dict
[
node_index
],
input_dict
[
node_index
][
user_node_index
])
input_dict
[
node_index
][
user_node_index
])
):
):
rst
.
append
(
rst
.
append
(
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
node
[
index
],
origin_sharding_spec
,
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
target_sharding_spec
))
node
[
index
],
origin_sharding_spec
,
target_sharding_spec
)
)
rst
=
type
(
node
)(
rst
)
rst
=
type
(
node
)(
rst
)
return
rst
return
rst
...
@@ -55,8 +50,8 @@ def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_
...
@@ -55,8 +50,8 @@ def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_
if
isinstance
(
comm_action
.
comm_spec
,
CommSpec
):
if
isinstance
(
comm_action
.
comm_spec
,
CommSpec
):
rst
=
comm_action
.
comm_spec
.
covert_spec_to_action
(
tensor
)
rst
=
comm_action
.
comm_spec
.
covert_spec_to_action
(
tensor
)
else
:
else
:
origin_sharding_spec
=
comm_action
.
comm_spec
[
'
src_spec
'
]
origin_sharding_spec
=
comm_action
.
comm_spec
[
"
src_spec
"
]
tgt_sharding_spec
=
comm_action
.
comm_spec
[
'
tgt_spec
'
]
tgt_sharding_spec
=
comm_action
.
comm_spec
[
"
tgt_spec
"
]
rst
=
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
tensor
,
origin_sharding_spec
,
tgt_sharding_spec
)
rst
=
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
tensor
,
origin_sharding_spec
,
tgt_sharding_spec
)
return
rst
return
rst
...
@@ -70,16 +65,16 @@ def _preprocess_graph(nodes: List[Node]):
...
@@ -70,16 +65,16 @@ def _preprocess_graph(nodes: List[Node]):
node_to_index_dict
=
{}
node_to_index_dict
=
{}
index
=
0
index
=
0
for
node
in
nodes
:
for
node
in
nodes
:
if
node
.
target
==
'
sharding_spec_convert_dict
'
:
if
node
.
target
==
"
sharding_spec_convert_dict
"
:
input_dict_node
=
node
input_dict_node
=
node
continue
continue
if
node
.
target
==
'
origin_node_sharding_spec_dict
'
:
if
node
.
target
==
"
origin_node_sharding_spec_dict
"
:
origin_dict_node
=
node
origin_dict_node
=
node
continue
continue
if
node
.
target
==
'
comm_actions_dict
'
:
if
node
.
target
==
"
comm_actions_dict
"
:
comm_actions_dict_node
=
node
comm_actions_dict_node
=
node
continue
continue
if
not
hasattr
(
node
,
'
best_strategy
'
):
if
not
hasattr
(
node
,
"
best_strategy
"
):
continue
continue
node_to_index_dict
[
node
]
=
index
node_to_index_dict
[
node
]
=
index
index
+=
1
index
+=
1
...
@@ -97,41 +92,46 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
...
@@ -97,41 +92,46 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
input_dict_node
,
origin_dict_node
,
_
,
node_to_index_dict
=
_preprocess_graph
(
nodes
)
input_dict_node
,
origin_dict_node
,
_
,
node_to_index_dict
=
_preprocess_graph
(
nodes
)
for
node
in
nodes
:
for
node
in
nodes
:
if
not
hasattr
(
node
,
'
best_strategy
'
)
or
node
.
op
==
'
output
'
:
if
not
hasattr
(
node
,
"
best_strategy
"
)
or
node
.
op
==
"
output
"
:
continue
continue
for
user_node_index
,
user_node
in
enumerate
(
node
.
strategies_vector
.
successor_nodes
):
for
user_node_index
,
user_node
in
enumerate
(
node
.
strategies_vector
.
successor_nodes
):
if
isinstance
(
node
.
sharding_spec
,
(
list
,
tuple
)):
if
isinstance
(
node
.
sharding_spec
,
(
list
,
tuple
)):
assert
isinstance
(
assert
isinstance
(
node
.
target_sharding_specs
,
node
.
target_sharding_specs
,
(
list
,
tuple
)
(
list
,
),
"target sharding specs should be tuple or list when node.sharding_spec is tuple or list"
tuple
)),
'target sharding specs should be tuple or list when node.sharding_spec is tuple or list'
total_difference
=
0
total_difference
=
0
for
sharding_spec
,
target_sharding_spec
in
zip
(
node
.
sharding_spec
,
for
sharding_spec
,
target_sharding_spec
in
zip
(
node
.
target_sharding_specs
[
user_node_index
]):
node
.
sharding_spec
,
node
.
target_sharding_specs
[
user_node_index
]
):
total_difference
+=
sharding_spec
.
sharding_sequence_difference
(
target_sharding_spec
)
total_difference
+=
sharding_spec
.
sharding_sequence_difference
(
target_sharding_spec
)
if
total_difference
==
0
:
if
total_difference
==
0
:
continue
continue
with
mod_graph
.
inserting_before
(
user_node
):
with
mod_graph
.
inserting_before
(
user_node
):
shape_consistency_node
=
mod_graph
.
create_node
(
'call_function'
,
shape_consistency_node
=
mod_graph
.
create_node
(
runtime_apply_for_iterable_object
,
"call_function"
,
args
=
(
node
,
origin_dict_node
,
input_dict_node
,
runtime_apply_for_iterable_object
,
node_to_index_dict
[
node
],
user_node_index
))
args
=
(
node
,
origin_dict_node
,
input_dict_node
,
node_to_index_dict
[
node
],
user_node_index
),
)
else
:
else
:
assert
isinstance
(
node
.
sharding_spec
,
assert
isinstance
(
ShardingSpec
),
'node.sharding_spec should be type of ShardingSpec, tuple or list.'
node
.
sharding_spec
,
ShardingSpec
),
"node.sharding_spec should be type of ShardingSpec, tuple or list."
if
node
.
sharding_spec
.
sharding_sequence_difference
(
node
.
target_sharding_specs
[
user_node_index
])
==
0
:
if
node
.
sharding_spec
.
sharding_sequence_difference
(
node
.
target_sharding_specs
[
user_node_index
])
==
0
:
continue
continue
with
mod_graph
.
inserting_before
(
user_node
):
with
mod_graph
.
inserting_before
(
user_node
):
shape_consistency_node
=
mod_graph
.
create_node
(
'call_function'
,
shape_consistency_node
=
mod_graph
.
create_node
(
runtime_apply
,
"call_function"
,
args
=
(
node
,
origin_dict_node
,
input_dict_node
,
runtime_apply
,
node_to_index_dict
[
node
],
user_node_index
))
args
=
(
node
,
origin_dict_node
,
input_dict_node
,
node_to_index_dict
[
node
],
user_node_index
),
if
hasattr
(
user_node
.
meta
[
'info'
],
'activation_checkpoint'
):
)
MetaInfo
(
shape_consistency_node
,
if
hasattr
(
user_node
.
meta
[
"info"
],
"activation_checkpoint"
):
mod_dir
=
user_node
.
meta
[
'info'
].
mod_dir
,
MetaInfo
(
activation_checkpoint
=
tuple
(
user_node
.
meta
[
'info'
].
activation_checkpoint
))
shape_consistency_node
,
mod_dir
=
user_node
.
meta
[
"info"
].
mod_dir
,
activation_checkpoint
=
tuple
(
user_node
.
meta
[
"info"
].
activation_checkpoint
),
)
new_args
=
list
(
user_node
.
args
)
new_args
=
list
(
user_node
.
args
)
new_kwargs
=
dict
(
user_node
.
kwargs
)
new_kwargs
=
dict
(
user_node
.
kwargs
)
# the origin node may be a positional argument or key word argument of user node
# the origin node may be a positional argument or key word argument of user node
...
@@ -158,12 +158,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
...
@@ -158,12 +158,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
_
,
_
,
comm_actions_dict_node
,
node_to_index_dict
=
_preprocess_graph
(
nodes
)
_
,
_
,
comm_actions_dict_node
,
node_to_index_dict
=
_preprocess_graph
(
nodes
)
for
node
in
nodes
:
for
node
in
nodes
:
if
not
hasattr
(
node
,
'
best_strategy
'
)
or
node
.
op
==
'
output
'
:
if
not
hasattr
(
node
,
"
best_strategy
"
)
or
node
.
op
==
"
output
"
:
continue
continue
comm_actions
=
node
.
best_strategy
.
communication_actions
comm_actions
=
node
.
best_strategy
.
communication_actions
for
op_data
,
comm_action
in
comm_actions
.
items
():
for
op_data
,
comm_action
in
comm_actions
.
items
():
if
comm_action
.
comm_type
==
CommType
.
HOOK
:
if
comm_action
.
comm_type
==
CommType
.
HOOK
:
continue
continue
if
comm_action
.
comm_type
==
CommType
.
BEFORE
:
if
comm_action
.
comm_type
==
CommType
.
BEFORE
:
...
@@ -174,10 +173,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
...
@@ -174,10 +173,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
else
:
else
:
comm_object
=
node
.
args
[
comm_action
.
arg_index
]
comm_object
=
node
.
args
[
comm_action
.
arg_index
]
with
mod_graph
.
inserting_before
(
node
):
with
mod_graph
.
inserting_before
(
node
):
comm_spec_apply_node
=
mod_graph
.
create_node
(
'call_function'
,
comm_spec_apply_node
=
mod_graph
.
create_node
(
runtime_comm_spec_apply
,
"call_function"
,
args
=
(
comm_object
,
comm_actions_dict_node
,
runtime_comm_spec_apply
,
node_to_index_dict
[
node
],
op_data
.
name
))
args
=
(
comm_object
,
comm_actions_dict_node
,
node_to_index_dict
[
node
],
op_data
.
name
),
)
# the origin node may be a positional argument or key word argument of user node
# the origin node may be a positional argument or key word argument of user node
if
comm_action
.
key_for_kwarg
is
not
None
:
if
comm_action
.
key_for_kwarg
is
not
None
:
# substitute the origin node with comm_spec_apply_node
# substitute the origin node with comm_spec_apply_node
...
@@ -192,10 +192,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
...
@@ -192,10 +192,11 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
elif
comm_action
.
comm_type
==
CommType
.
AFTER
:
elif
comm_action
.
comm_type
==
CommType
.
AFTER
:
with
mod_graph
.
inserting_after
(
node
):
with
mod_graph
.
inserting_after
(
node
):
comm_spec_apply_node
=
mod_graph
.
create_node
(
'call_function'
,
comm_spec_apply_node
=
mod_graph
.
create_node
(
runtime_comm_spec_apply
,
"call_function"
,
args
=
(
node
,
comm_actions_dict_node
,
runtime_comm_spec_apply
,
node_to_index_dict
[
node
],
op_data
.
name
))
args
=
(
node
,
comm_actions_dict_node
,
node_to_index_dict
[
node
],
op_data
.
name
),
)
user_list
=
list
(
node
.
users
.
keys
())
user_list
=
list
(
node
.
users
.
keys
())
for
user
in
user_list
:
for
user
in
user_list
:
if
user
==
comm_spec_apply_node
:
if
user
==
comm_spec_apply_node
:
...
@@ -211,15 +212,17 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
...
@@ -211,15 +212,17 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
# substitute the origin node with comm_spec_apply_node
# substitute the origin node with comm_spec_apply_node
new_kwargs
[
str
(
node
)]
=
comm_spec_apply_node
new_kwargs
[
str
(
node
)]
=
comm_spec_apply_node
user
.
kwargs
=
new_kwargs
user
.
kwargs
=
new_kwargs
if
hasattr
(
node
.
meta
[
'info'
],
'activation_checkpoint'
):
if
hasattr
(
node
.
meta
[
"info"
],
"activation_checkpoint"
):
MetaInfo
(
comm_spec_apply_node
,
MetaInfo
(
mod_dir
=
node
.
meta
[
'info'
].
mod_dir
,
comm_spec_apply_node
,
activation_checkpoint
=
tuple
(
node
.
meta
[
'info'
].
activation_checkpoint
))
mod_dir
=
node
.
meta
[
"info"
].
mod_dir
,
activation_checkpoint
=
tuple
(
node
.
meta
[
"info"
].
activation_checkpoint
),
)
return
gm
return
gm
def
_act_annotat
a
ion_pass
(
gm
:
torch
.
fx
.
GraphModule
):
def
_act_annotation_pass
(
gm
:
torch
.
fx
.
GraphModule
):
"""
"""
This pass is used to add the act annotation to the new inserted nodes.
This pass is used to add the act annotation to the new inserted nodes.
"""
"""
...
@@ -227,21 +230,21 @@ def _act_annotataion_pass(gm: torch.fx.GraphModule):
...
@@ -227,21 +230,21 @@ def _act_annotataion_pass(gm: torch.fx.GraphModule):
nodes
=
tuple
(
mod_graph
.
nodes
)
nodes
=
tuple
(
mod_graph
.
nodes
)
for
node
in
nodes
:
for
node
in
nodes
:
if
not
hasattr
(
node
.
meta
,
'
activation_checkpoint
'
):
if
not
hasattr
(
node
.
meta
,
"
activation_checkpoint
"
):
from
.runtime_preparation_pass
import
size_processing
pass
user_act_annotation
=
-
1
user_act_annotation
=
-
1
input_act_annotation
=
-
1
input_act_annotation
=
-
1
for
user_node
in
node
.
users
.
keys
():
for
user_node
in
node
.
users
.
keys
():
if
'
activation_checkpoint
'
in
user_node
.
meta
:
if
"
activation_checkpoint
"
in
user_node
.
meta
:
user_act_annotation
=
user_node
.
meta
[
'
activation_checkpoint
'
]
user_act_annotation
=
user_node
.
meta
[
"
activation_checkpoint
"
]
break
break
for
input_node
in
node
.
_input_nodes
.
keys
():
for
input_node
in
node
.
_input_nodes
.
keys
():
if
'
activation_checkpoint
'
in
input_node
.
meta
:
if
"
activation_checkpoint
"
in
input_node
.
meta
:
input_act_annotation
=
input_node
.
meta
[
'
activation_checkpoint
'
]
input_act_annotation
=
input_node
.
meta
[
"
activation_checkpoint
"
]
break
break
if
user_act_annotation
==
input_act_annotation
and
user_act_annotation
!=
-
1
:
if
user_act_annotation
==
input_act_annotation
and
user_act_annotation
!=
-
1
:
node
.
meta
[
'
activation_checkpoint
'
]
=
user_act_annotation
node
.
meta
[
"
activation_checkpoint
"
]
=
user_act_annotation
return
gm
return
gm
...
...
colossalai/auto_parallel/passes/runtime_preparation_pass.py
View file @
9e768b59
import
operator
import
operator
from
copy
import
deepcopy
from
typing
import
Dict
,
List
,
Union
from
typing
import
Dict
,
List
,
Union
import
torch
import
torch
from
torch.fx
import
symbolic_trace
from
torch.fx.node
import
Node
from
torch.fx.node
import
Node
from
colossalai._analyzer.fx.node_util
import
MetaInfo
from
colossalai._analyzer.fx.node_util
import
MetaInfo
from
colossalai.auto_parallel.tensor_shard.constants
import
RESHAPE_FUNC_OP
from
colossalai.auto_parallel.tensor_shard.constants
import
RESHAPE_FUNC_OP
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
CommType
,
OperationDataType
CommAction
,
CommType
,
OperationDataType
,
ShardingStrategy
,
)
from
colossalai.auto_parallel.tensor_shard.solver.strategies_constructor
import
StrategiesConstructor
from
colossalai.auto_parallel.tensor_shard.solver.strategies_constructor
import
StrategiesConstructor
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.comm_spec
import
_all_reduce
from
colossalai.tensor.comm_spec
import
_all_reduce
...
@@ -25,11 +18,13 @@ from .constants import SHAPE_ARGUMENT_OPS
...
@@ -25,11 +18,13 @@ from .constants import SHAPE_ARGUMENT_OPS
shape_consistency_manager
=
ShapeConsistencyManager
()
shape_consistency_manager
=
ShapeConsistencyManager
()
def
size_processing
(
size
:
Union
[
int
,
torch
.
Size
],
def
size_processing
(
dim_partition_dict
:
Dict
[
int
,
List
[
int
]],
size
:
Union
[
int
,
torch
.
Size
],
device_mesh_info
:
Dict
[
int
,
int
],
dim_partition_dict
:
Dict
[
int
,
List
[
int
]],
target_dim
:
int
=
None
,
device_mesh_info
:
Dict
[
int
,
int
],
node_name
:
str
=
None
):
target_dim
:
int
=
None
,
node_name
:
str
=
None
,
):
"""
"""
This method will be invoked during runtime to convert size node value depending on distributed information.
This method will be invoked during runtime to convert size node value depending on distributed information.
"""
"""
...
@@ -54,8 +49,9 @@ def size_processing(size: Union[int, torch.Size],
...
@@ -54,8 +49,9 @@ def size_processing(size: Union[int, torch.Size],
return
size
return
size
def
solution_annotatation_pass
(
gm
:
torch
.
fx
.
GraphModule
,
solution
:
List
[
int
],
def
solution_annotation_pass
(
strategies_constructor
:
StrategiesConstructor
):
gm
:
torch
.
fx
.
GraphModule
,
solution
:
List
[
int
],
strategies_constructor
:
StrategiesConstructor
):
"""
"""
This method is used to stick the solution strategy to the nodes and add the information
This method is used to stick the solution strategy to the nodes and add the information
required in runtime into graph as placeholder nodes.
required in runtime into graph as placeholder nodes.
...
@@ -70,14 +66,15 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
...
@@ -70,14 +66,15 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
for
node_index
,
(
node
,
strategy_index
)
in
enumerate
(
zip
(
nodes
,
solution
)):
for
node_index
,
(
node
,
strategy_index
)
in
enumerate
(
zip
(
nodes
,
solution
)):
strategies_vector
=
node
.
strategies_vector
strategies_vector
=
node
.
strategies_vector
# stick the solution strategy to the corresponding node
# stick the solution strategy to the corresponding node
setattr
(
node
,
'
best_strategy
'
,
strategies_vector
[
strategy_index
])
setattr
(
node
,
"
best_strategy
"
,
strategies_vector
[
strategy_index
])
setattr
(
node
,
'
sharding_spec
'
,
strategies_vector
[
strategy_index
].
get_sharding_spec_by_name
(
str
(
node
)))
setattr
(
node
,
"
sharding_spec
"
,
strategies_vector
[
strategy_index
].
get_sharding_spec_by_name
(
str
(
node
)))
origin_node_sharding_spec_dict
[
node_index
]
=
strategies_vector
[
strategy_index
].
get_sharding_spec_by_name
(
origin_node_sharding_spec_dict
[
node_index
]
=
strategies_vector
[
strategy_index
].
get_sharding_spec_by_name
(
str
(
node
))
str
(
node
)
)
# attach the corresponding metainfo if node has the attribute `strategies_info`
# attach the corresponding metainfo if node has the attribute `strategies_info`
if
hasattr
(
node
,
'
strategies_info
'
):
if
hasattr
(
node
,
"
strategies_info
"
):
setattr
(
node
,
'
best_strategy_info
'
,
node
.
strategies_info
[
strategy_index
])
setattr
(
node
,
"
best_strategy_info
"
,
node
.
strategies_info
[
strategy_index
])
# the dict to get input sharding specs of user node
# the dict to get input sharding specs of user node
sharding_spec_convert_dict
=
{}
sharding_spec_convert_dict
=
{}
...
@@ -92,15 +89,15 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
...
@@ -92,15 +89,15 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
target_sharding_spec
=
user_node
.
best_strategy
.
get_sharding_spec_by_name
(
str
(
node
.
name
))
target_sharding_spec
=
user_node
.
best_strategy
.
get_sharding_spec_by_name
(
str
(
node
.
name
))
target_sharding_specs
.
append
(
target_sharding_spec
)
target_sharding_specs
.
append
(
target_sharding_spec
)
sharding_spec_convert_dict
[
index
]
=
target_sharding_specs
sharding_spec_convert_dict
[
index
]
=
target_sharding_specs
setattr
(
node
,
'
target_sharding_specs
'
,
target_sharding_specs
)
setattr
(
node
,
"
target_sharding_specs
"
,
target_sharding_specs
)
# the get_attr node strategy is kind of pending strategy, which means we will change it
# the get_attr node strategy is kind of pending strategy, which means we will change it
# to the same strategy of the user node.
# to the same strategy of the user node.
if
node
.
op
==
'
get_attr
'
:
if
node
.
op
==
"
get_attr
"
:
assert
len
(
target_sharding_specs
)
==
1
,
f
'
sharing weight is not supported in current version.
'
assert
len
(
target_sharding_specs
)
==
1
,
f
"
sharing weight is not supported in current version.
"
target_node
=
node
.
strategies_vector
.
successor_nodes
[
0
]
target_node
=
node
.
strategies_vector
.
successor_nodes
[
0
]
node_name
=
str
(
node
)
node_name
=
str
(
node
)
if
target_node
.
op
==
'
call_function
'
and
target_node
.
target
in
RESHAPE_FUNC_OP
:
if
target_node
.
op
==
"
call_function
"
and
target_node
.
target
in
RESHAPE_FUNC_OP
:
node_name
=
str
(
target_node
)
node_name
=
str
(
target_node
)
target_node
=
target_node
.
strategies_vector
.
successor_nodes
[
0
]
target_node
=
target_node
.
strategies_vector
.
successor_nodes
[
0
]
user_strategy
=
target_node
.
best_strategy
user_strategy
=
target_node
.
best_strategy
...
@@ -122,11 +119,11 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
...
@@ -122,11 +119,11 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
# add above dicts into graph
# add above dicts into graph
for
node
in
nodes
:
for
node
in
nodes
:
if
node
.
op
!=
'
placeholder
'
:
if
node
.
op
!=
"
placeholder
"
:
with
mod_graph
.
inserting_before
(
node
):
with
mod_graph
.
inserting_before
(
node
):
input_specs_node
=
mod_graph
.
create_node
(
'
placeholder
'
,
target
=
'
sharding_spec_convert_dict
'
)
input_specs_node
=
mod_graph
.
create_node
(
"
placeholder
"
,
target
=
"
sharding_spec_convert_dict
"
)
origin_specs_node
=
mod_graph
.
create_node
(
'
placeholder
'
,
target
=
'
origin_node_sharding_spec_dict
'
)
origin_specs_node
=
mod_graph
.
create_node
(
"
placeholder
"
,
target
=
"
origin_node_sharding_spec_dict
"
)
comm_actions_dict_node
=
mod_graph
.
create_node
(
'
placeholder
'
,
target
=
'
comm_actions_dict
'
)
comm_actions_dict_node
=
mod_graph
.
create_node
(
"
placeholder
"
,
target
=
"
comm_actions_dict
"
)
break
break
return
gm
,
sharding_spec_convert_dict
,
origin_node_sharding_spec_dict
,
comm_actions_dict
return
gm
,
sharding_spec_convert_dict
,
origin_node_sharding_spec_dict
,
comm_actions_dict
...
@@ -144,11 +141,11 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
...
@@ -144,11 +141,11 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
# DeviceMesh information instructs the scaling of the size value
# DeviceMesh information instructs the scaling of the size value
device_mesh_info
=
{}
device_mesh_info
=
{}
for
dim
,
dim_size
in
enumerate
(
device_mesh
.
mesh_
shape
):
for
dim
,
dim_size
in
enumerate
(
device_mesh
.
shape
):
device_mesh_info
[
dim
]
=
dim_size
device_mesh_info
[
dim
]
=
dim_size
def
_extract_target_dim
(
node
):
def
_extract_target_dim
(
node
):
'''
"""
A helper function to extract the target dimension from size node.
A helper function to extract the target dimension from size node.
There are two usages of torch.Tensor.size:
There are two usages of torch.Tensor.size:
1. tensor.size()
1. tensor.size()
...
@@ -156,7 +153,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
...
@@ -156,7 +153,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
If a target_dim is assigned, then the output will be in type of int, instead of torch.Size.
If a target_dim is assigned, then the output will be in type of int, instead of torch.Size.
Otherwise, the output will be in type of torch.Size and this function will return None.
Otherwise, the output will be in type of torch.Size and this function will return None.
'''
"""
target_dim
=
None
target_dim
=
None
if
len
(
node
.
args
)
>
1
:
if
len
(
node
.
args
)
>
1
:
target_dim
=
node
.
args
[
1
]
target_dim
=
node
.
args
[
1
]
...
@@ -165,19 +162,21 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
...
@@ -165,19 +162,21 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
return
target_dim
return
target_dim
def
_post_processing
(
node
,
size_processing_node
):
def
_post_processing
(
node
,
size_processing_node
):
'''
"""
This function is used to process the dependency between the size node and its users after
This function is used to process the dependency between the size node and its users after
inserting the size_process_node.
inserting the size_process_node.
'''
"""
# store original node and processing node pair in node_pairs dictio
a
nry
# store original node and processing node pair in node_pairs diction
a
ry
# It will be used to replace the original node with processing node in slice object
# It will be used to replace the original node with processing node in slice object
node_pairs
[
node
]
=
size_processing_node
node_pairs
[
node
]
=
size_processing_node
size_processing_node
.
_meta_data
=
node
.
_meta_data
size_processing_node
.
_meta_data
=
node
.
_meta_data
if
hasattr
(
node
.
meta
[
'info'
],
'activation_checkpoint'
):
if
hasattr
(
node
.
meta
[
"info"
],
"activation_checkpoint"
):
MetaInfo
(
size_processing_node
,
MetaInfo
(
mod_dir
=
node
.
meta
[
'info'
].
mod_dir
,
size_processing_node
,
activation_checkpoint
=
tuple
(
node
.
meta
[
'info'
].
activation_checkpoint
))
mod_dir
=
node
.
meta
[
"info"
].
mod_dir
,
activation_checkpoint
=
tuple
(
node
.
meta
[
"info"
].
activation_checkpoint
),
)
user_list
=
list
(
node
.
users
.
keys
())
user_list
=
list
(
node
.
users
.
keys
())
for
user
in
user_list
:
for
user
in
user_list
:
...
@@ -196,10 +195,10 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
...
@@ -196,10 +195,10 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
user
.
kwargs
=
new_kwargs
user
.
kwargs
=
new_kwargs
def
_update_slice_object_args
(
slice_object
):
def
_update_slice_object_args
(
slice_object
):
'''
"""
This function is used to update the slice object argument list.
This function is used to update the slice object argument list.
If the slice object contains the Node argument, then the size node will be replaced with
If the slice object contains the Node argument, then the size node will be replaced with
'''
"""
if
isinstance
(
slice_object
,
slice
):
if
isinstance
(
slice_object
,
slice
):
start
=
slice_object
.
start
start
=
slice_object
.
start
stop
=
slice_object
.
stop
stop
=
slice_object
.
stop
...
@@ -220,8 +219,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
...
@@ -220,8 +219,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
raise
RuntimeError
(
f
"Unsupported slice object type:
{
type
(
slice_object
)
}
"
)
raise
RuntimeError
(
f
"Unsupported slice object type:
{
type
(
slice_object
)
}
"
)
for
node
in
nodes
:
for
node
in
nodes
:
if
node
.
op
==
"call_method"
and
node
.
target
==
"size"
:
if
node
.
op
==
'call_method'
and
node
.
target
==
'size'
:
# extract useful information from size node
# extract useful information from size node
# dim_partition_dict will instruct the size value on which
# dim_partition_dict will instruct the size value on which
# dimension should be enlarged.
# dimension should be enlarged.
...
@@ -232,14 +230,14 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
...
@@ -232,14 +230,14 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
# insert size_processing node
# insert size_processing node
with
mod_graph
.
inserting_after
(
node
):
with
mod_graph
.
inserting_after
(
node
):
size_processing_node
=
mod_graph
.
create_node
(
'call_function'
,
size_processing_node
=
mod_graph
.
create_node
(
size_processing
,
"call_function"
,
args
=
(
node
,
dim_partition_dict
,
device_mesh_info
,
size_processing
,
target_dim
,
node
.
name
))
args
=
(
node
,
dim_partition_dict
,
device_mesh_info
,
target_dim
,
node
.
name
),
)
_post_processing
(
node
,
size_processing_node
)
_post_processing
(
node
,
size_processing_node
)
if
node
.
op
==
'call_function'
and
node
.
target
==
operator
.
getitem
:
if
node
.
op
==
"call_function"
and
node
.
target
==
operator
.
getitem
:
getitem_index
=
node
.
args
[
1
]
getitem_index
=
node
.
args
[
1
]
# slice object is quite special in torch.fx graph,
# slice object is quite special in torch.fx graph,
# On one side, we treat slice object same as type of int,
# On one side, we treat slice object same as type of int,
...
@@ -287,18 +285,19 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
...
@@ -287,18 +285,19 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
nodes
=
tuple
(
mod_graph
.
nodes
)
nodes
=
tuple
(
mod_graph
.
nodes
)
def
_extract_info_from_sharding_spec
(
sharding_spec
):
def
_extract_info_from_sharding_spec
(
sharding_spec
):
'''
"""
This function is used to extract the dim_partition_dict and device_mesh from
This function is used to extract the dim_partition_dict and device_mesh from
sharding spec instance or a list of sharding spec.
sharding spec instance or a list of sharding spec.
'''
"""
if
isinstance
(
sharding_spec
,
ShardingSpec
):
if
isinstance
(
sharding_spec
,
ShardingSpec
):
dim_partition_dict
=
sharding_spec
.
dim_partition_dict
dim_partition_dict
=
sharding_spec
.
dim_partition_dict
device_mesh
=
sharding_spec
.
device_mesh
device_mesh
=
sharding_spec
.
device_mesh
return
dim_partition_dict
,
device_mesh
return
dim_partition_dict
,
device_mesh
if
sharding_spec
is
None
:
if
sharding_spec
is
None
:
return
None
,
None
return
None
,
None
assert
isinstance
(
sharding_spec
,
assert
isinstance
(
(
tuple
,
list
)),
'sharding_spec should be type of ShardingSpec, tuple, list or None'
sharding_spec
,
(
tuple
,
list
)
),
"sharding_spec should be type of ShardingSpec, tuple, list or None"
device_mesh
=
sharding_spec
[
0
].
device_mesh
device_mesh
=
sharding_spec
[
0
].
device_mesh
dim_partition_dict
=
[]
dim_partition_dict
=
[]
...
@@ -322,8 +321,9 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
...
@@ -322,8 +321,9 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
else
:
else
:
new_args
.
append
(
arg
)
new_args
.
append
(
arg
)
else
:
else
:
assert
isinstance
(
arg
,
assert
isinstance
(
(
int
,
tuple
,
list
)),
'The argument in view node should be either type of Node or int.'
arg
,
(
int
,
tuple
,
list
)
),
"The argument in view node should be either type of Node or int."
if
isinstance
(
arg
,
(
tuple
,
list
)):
if
isinstance
(
arg
,
(
tuple
,
list
)):
new_args
.
extend
(
arg
)
new_args
.
extend
(
arg
)
else
:
else
:
...
@@ -332,7 +332,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
...
@@ -332,7 +332,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
def
_scale_args_adapt_sharding_spec
(
dim_partition_dict
,
device_mesh
,
node
):
def
_scale_args_adapt_sharding_spec
(
dim_partition_dict
,
device_mesh
,
node
):
new_args
=
_process_node_arguments
(
node
)
new_args
=
_process_node_arguments
(
node
)
if
node
.
op
==
'
call_method
'
:
if
node
.
op
==
"
call_method
"
:
args_to_process
=
list
(
new_args
[
1
:])
args_to_process
=
list
(
new_args
[
1
:])
else
:
else
:
args_to_process
=
list
(
new_args
)
args_to_process
=
list
(
new_args
)
...
@@ -350,7 +350,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
...
@@ -350,7 +350,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
args_to_process
=
tuple
(
args_to_process
)
args_to_process
=
tuple
(
args_to_process
)
if
node
.
op
==
'
call_method
'
:
if
node
.
op
==
"
call_method
"
:
new_args
=
(
new_args
[
0
],)
+
args_to_process
new_args
=
(
new_args
[
0
],)
+
args_to_process
else
:
else
:
new_args
=
args_to_process
new_args
=
args_to_process
...
@@ -358,9 +358,9 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
...
@@ -358,9 +358,9 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
node
.
args
=
new_args
node
.
args
=
new_args
def
_filter_node_with_shape_args
(
node
):
def
_filter_node_with_shape_args
(
node
):
if
node
.
op
==
'
call_method
'
:
if
node
.
op
==
"
call_method
"
:
target
=
getattr
(
node
.
args
[
0
].
_meta_data
.
__class__
,
node
.
target
)
target
=
getattr
(
node
.
args
[
0
].
_meta_data
.
__class__
,
node
.
target
)
elif
node
.
op
==
'
call_function
'
:
elif
node
.
op
==
"
call_function
"
:
target
=
node
.
target
target
=
node
.
target
else
:
else
:
target
=
None
target
=
None
...
@@ -371,7 +371,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
...
@@ -371,7 +371,7 @@ def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh)
for
node
in
nodes
:
for
node
in
nodes
:
# skip the placeholder node added in _solution_annotation pass
# skip the placeholder node added in _solution_annotation pass
if
not
hasattr
(
node
,
'
sharding_spec
'
):
if
not
hasattr
(
node
,
"
sharding_spec
"
):
continue
continue
output_dim_partition_dict
,
device_mesh
=
_extract_info_from_sharding_spec
(
node
.
sharding_spec
)
output_dim_partition_dict
,
device_mesh
=
_extract_info_from_sharding_spec
(
node
.
sharding_spec
)
...
@@ -388,19 +388,25 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
...
@@ -388,19 +388,25 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
"""
"""
mod_graph
=
gm
.
graph
mod_graph
=
gm
.
graph
nodes
=
tuple
(
mod_graph
.
nodes
)
nodes
=
tuple
(
mod_graph
.
nodes
)
# This stream is created for overlaping the communication and computation.
# This stream is created for overlap
p
ing the communication and computation.
reduction_stream
=
torch
.
cuda
.
Stream
()
reduction_stream
=
torch
.
cuda
.
Stream
()
def
_add_hook_for_grad_communication
(
node
,
param
,
name
=
None
):
def
_add_hook_for_grad_communication
(
node
,
param
,
name
=
None
):
comm_actions
=
node
.
best_strategy
.
communication_actions
comm_actions
=
node
.
best_strategy
.
communication_actions
def
_filter_param_to_hook
(
node
,
op_data
,
comm_action
,
name
):
def
_filter_param_to_hook
(
node
,
op_data
,
comm_action
,
name
):
if
(
if
node
.
op
==
'call_module'
and
op_data
.
type
==
OperationDataType
.
PARAM
and
op_data
.
name
==
name
and
comm_action
.
comm_type
==
CommType
.
HOOK
:
node
.
op
==
"call_module"
and
op_data
.
type
==
OperationDataType
.
PARAM
and
op_data
.
name
==
name
and
comm_action
.
comm_type
==
CommType
.
HOOK
):
return
True
return
True
if
node
.
op
==
'get_attr'
and
isinstance
(
if
(
node
.
_meta_data
,
torch
.
nn
.
parameter
.
Parameter
)
and
comm_action
.
comm_type
==
CommType
.
HOOK
:
node
.
op
==
"get_attr"
and
isinstance
(
node
.
_meta_data
,
torch
.
nn
.
parameter
.
Parameter
)
and
comm_action
.
comm_type
==
CommType
.
HOOK
):
return
True
return
True
return
False
return
False
...
@@ -410,7 +416,6 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
...
@@ -410,7 +416,6 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
if
_filter_param_to_hook
(
node
,
operation_data
,
comm_action
,
name
=
name
):
if
_filter_param_to_hook
(
node
,
operation_data
,
comm_action
,
name
=
name
):
def
wrapper
(
param
,
comm_spec
,
stream
,
overlap
):
def
wrapper
(
param
,
comm_spec
,
stream
,
overlap
):
def
hook_fn
(
grad
):
def
hook_fn
(
grad
):
if
overlap
:
if
overlap
:
with
torch
.
cuda
.
stream
(
stream
):
with
torch
.
cuda
.
stream
(
stream
):
...
@@ -426,22 +431,26 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
...
@@ -426,22 +431,26 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
# apply the sharding spec of parameters
# apply the sharding spec of parameters
if
target_sharding_spec
.
dim_partition_dict
!=
{}:
if
target_sharding_spec
.
dim_partition_dict
!=
{}:
origin_sharding_spec
=
ShardingSpec
(
device_mesh
,
param
.
shape
,
{})
origin_sharding_spec
=
ShardingSpec
(
device_mesh
,
param
.
shape
,
{})
setattr
(
param
,
'
sharding_spec
'
,
origin_sharding_spec
)
setattr
(
param
,
"
sharding_spec
"
,
origin_sharding_spec
)
# TODO: build a ColoParameter class to manager the distributed parameters
# TODO: build a ColoParameter class to manager the distributed parameters
# we could use .data here, because all the operations just happen before the real training
# we could use .data here, because all the operations just happen before the real training
# loop, so we don't need to track these operations in the autograd graph.
# loop, so we don't need to track these operations in the autograd graph.
param
=
torch
.
nn
.
Parameter
(
param
=
torch
.
nn
.
Parameter
(
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
param
.
data
,
param
.
sharding_spec
,
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
target_sharding_spec
).
detach
().
clone
())
param
.
data
,
param
.
sharding_spec
,
target_sharding_spec
)
.
detach
()
.
clone
()
)
return
param
return
param
for
node
in
nodes
:
for
node
in
nodes
:
if
node
.
op
==
'
call_module
'
:
if
node
.
op
==
"
call_module
"
:
target_module
=
node
.
graph
.
owning_module
.
get_submodule
(
node
.
target
)
target_module
=
node
.
graph
.
owning_module
.
get_submodule
(
node
.
target
)
# TODO: we need to do more actions to take care of the shared parameters.
# TODO: we need to do more actions to take care of the shared parameters.
if
hasattr
(
target_module
,
'
processed
'
)
and
target_module
.
processed
:
if
hasattr
(
target_module
,
"
processed
"
)
and
target_module
.
processed
:
continue
continue
setattr
(
target_module
,
'
processed
'
,
True
)
setattr
(
target_module
,
"
processed
"
,
True
)
for
name
,
param
in
target_module
.
named_parameters
():
for
name
,
param
in
target_module
.
named_parameters
():
target_sharding_spec
=
node
.
best_strategy
.
get_sharding_spec_by_name
(
name
)
target_sharding_spec
=
node
.
best_strategy
.
get_sharding_spec_by_name
(
name
)
param
=
_shard_param
(
param
,
target_sharding_spec
)
param
=
_shard_param
(
param
,
target_sharding_spec
)
...
@@ -453,7 +462,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
...
@@ -453,7 +462,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
# apply the sharding spec of buffers
# apply the sharding spec of buffers
for
name
,
buffer
in
target_module
.
named_buffers
():
for
name
,
buffer
in
target_module
.
named_buffers
():
origin_sharding_spec
=
ShardingSpec
(
device_mesh
,
buffer
.
shape
,
{})
origin_sharding_spec
=
ShardingSpec
(
device_mesh
,
buffer
.
shape
,
{})
setattr
(
buffer
,
'
sharding_spec
'
,
origin_sharding_spec
)
setattr
(
buffer
,
"
sharding_spec
"
,
origin_sharding_spec
)
target_sharding_spec
=
node
.
best_strategy
.
get_sharding_spec_by_name
(
name
)
target_sharding_spec
=
node
.
best_strategy
.
get_sharding_spec_by_name
(
name
)
buffer_sharded
=
shape_consistency_manager
.
apply
(
buffer
,
target_sharding_spec
)
buffer_sharded
=
shape_consistency_manager
.
apply
(
buffer
,
target_sharding_spec
)
sharded_buffer_dict
[
name
]
=
buffer_sharded
sharded_buffer_dict
[
name
]
=
buffer_sharded
...
@@ -461,7 +470,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
...
@@ -461,7 +470,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
for
name
,
buffer_sharded
in
sharded_buffer_dict
.
items
():
for
name
,
buffer_sharded
in
sharded_buffer_dict
.
items
():
setattr
(
target_module
,
name
,
buffer_sharded
.
detach
().
clone
())
setattr
(
target_module
,
name
,
buffer_sharded
.
detach
().
clone
())
if
node
.
op
==
'
get_attr
'
:
if
node
.
op
==
"
get_attr
"
:
root
=
node
.
graph
.
owning_module
root
=
node
.
graph
.
owning_module
atoms
=
node
.
target
.
split
(
"."
)
atoms
=
node
.
target
.
split
(
"."
)
attr_len
=
len
(
atoms
)
attr_len
=
len
(
atoms
)
...
@@ -488,16 +497,18 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
...
@@ -488,16 +497,18 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
"""
"""
replace the origin kernel into kernel with implicit communication inside.
replace the origin kernel into kernel with implicit communication inside.
"""
"""
pass
def
runtime_preparation_pass
(
gm
:
torch
.
fx
.
GraphModule
,
def
runtime_preparation_pass
(
solution
:
List
[
int
],
gm
:
torch
.
fx
.
GraphModule
,
device_mesh
:
DeviceMesh
,
solution
:
List
[
int
],
strategies_constructor
:
StrategiesConstructor
,
device_mesh
:
DeviceMesh
,
overlap
=
False
):
strategies_constructor
:
StrategiesConstructor
,
gm
,
sharding_spec_convert_dict
,
origin_node_sharding_spec_dict
,
comm_actions_dict
=
solution_annotatation_pass
(
overlap
=
False
,
gm
,
solution
,
strategies_constructor
)
):
gm
,
sharding_spec_convert_dict
,
origin_node_sharding_spec_dict
,
comm_actions_dict
=
solution_annotation_pass
(
gm
,
solution
,
strategies_constructor
)
gm
=
size_value_converting_pass
(
gm
,
device_mesh
)
gm
=
size_value_converting_pass
(
gm
,
device_mesh
)
gm
=
node_args_converting_pass
(
gm
,
device_mesh
)
gm
=
node_args_converting_pass
(
gm
,
device_mesh
)
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
...
...
Prev
1
…
12
13
14
15
16
17
18
19
20
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment