Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c8c79102
Unverified
Commit
c8c79102
authored
Jan 02, 2023
by
Boyuan Yao
Committed by
GitHub
Jan 02, 2023
Browse files
[autoparallel] patch torch.flatten metainfo for autoparallel (#2247)
* [autoparallel] patch torch.flatten
parent
8897b8f7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
4 deletions
+6
-4
colossalai/auto_parallel/meta_profiler/meta_registry/activation.py
...i/auto_parallel/meta_profiler/meta_registry/activation.py
+2
-2
colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py
...alai/auto_parallel/meta_profiler/meta_registry/pooling.py
+4
-2
No files found.
colossalai/auto_parallel/meta_profiler/meta_registry/activation.py
View file @
c8c79102
...
@@ -30,7 +30,7 @@ def relu_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Lis
...
@@ -30,7 +30,7 @@ def relu_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Lis
input_tensor
=
args
[
0
].
data
input_tensor
=
args
[
0
].
data
output_tensor
=
next
(
filter
(
lambda
x
:
x
.
type
==
OperationDataType
.
OUTPUT
,
args
)).
data
output_tensor
=
next
(
filter
(
lambda
x
:
x
.
type
==
OperationDataType
.
OUTPUT
,
args
)).
data
inplace
=
kwargs
.
get
(
"inplace"
,
False
)
is_
inplace
=
kwargs
.
get
(
"inplace"
,
False
)
# construct input args for forward
# construct input args for forward
fwd_in_args
=
[
input_tensor
]
fwd_in_args
=
[
input_tensor
]
...
@@ -51,7 +51,7 @@ def relu_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Lis
...
@@ -51,7 +51,7 @@ def relu_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, Lis
# NOTE: the inplace ReLU don't have forward memory cost
# NOTE: the inplace ReLU don't have forward memory cost
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost
=
MemoryCost
(
fwd_memory_cost
=
MemoryCost
(
activation
=
activation_size
(
input_tensor
)
if
inplace
else
activation_size
([
output_tensor
,
input_tensor
]),
activation
=
activation_size
(
input_tensor
)
if
is_
inplace
else
activation_size
([
output_tensor
,
input_tensor
]),
parameter
=
0
,
parameter
=
0
,
temp
=
0
,
temp
=
0
,
buffer
=
0
)
buffer
=
0
)
...
...
colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py
View file @
c8c79102
...
@@ -14,6 +14,7 @@ __all__ = ["avgpool_meta_info", "maxpool_meta_info"]
...
@@ -14,6 +14,7 @@ __all__ = ["avgpool_meta_info", "maxpool_meta_info"]
@
meta_register
.
register
(
torch
.
nn
.
AdaptiveAvgPool1d
)
@
meta_register
.
register
(
torch
.
nn
.
AdaptiveAvgPool1d
)
@
meta_register
.
register
(
torch
.
nn
.
AdaptiveAvgPool2d
)
@
meta_register
.
register
(
torch
.
nn
.
AdaptiveAvgPool2d
)
@
meta_register
.
register
(
torch
.
nn
.
AdaptiveAvgPool3d
)
@
meta_register
.
register
(
torch
.
nn
.
AdaptiveAvgPool3d
)
@
meta_register
.
register
(
torch
.
flatten
)
def
avgpool_meta_info
(
*
args
,
**
kwargs
)
->
Tuple
[
TrainCycleItem
,
TrainCycleItem
,
List
[
torch
.
Tensor
]]:
def
avgpool_meta_info
(
*
args
,
**
kwargs
)
->
Tuple
[
TrainCycleItem
,
TrainCycleItem
,
List
[
torch
.
Tensor
]]:
"""Meta info for AdaptiveAvgPool
"""Meta info for AdaptiveAvgPool
The aten graph of AdaptiveAvgPool is
The aten graph of AdaptiveAvgPool is
...
@@ -32,6 +33,7 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
...
@@ -32,6 +33,7 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
input_tensor
=
args
[
0
].
data
input_tensor
=
args
[
0
].
data
output_tensor
=
next
(
filter
(
lambda
x
:
x
.
type
==
OperationDataType
.
OUTPUT
,
args
)).
data
output_tensor
=
next
(
filter
(
lambda
x
:
x
.
type
==
OperationDataType
.
OUTPUT
,
args
)).
data
is_inplace
=
kwargs
.
get
(
"inplace"
,
False
)
# construct forward args for flop mapping
# construct forward args for flop mapping
fwd_in_args
=
[
input_tensor
]
fwd_in_args
=
[
input_tensor
]
...
@@ -51,8 +53,8 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
...
@@ -51,8 +53,8 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem,
compute_cost
=
TrainCycleItem
(
fwd
=
fwd_compute_cost
,
bwd
=
bwd_compute_cost
,
total
=
fwd_compute_cost
+
bwd_compute_cost
)
compute_cost
=
TrainCycleItem
(
fwd
=
fwd_compute_cost
,
bwd
=
bwd_compute_cost
,
total
=
fwd_compute_cost
+
bwd_compute_cost
)
# calculate memory cost
# calculate memory cost
fwd_mem_cost
=
MemoryCost
(
activation
=
activation_size
(
output_tensor
))
fwd_mem_cost
=
MemoryCost
()
if
is_inplace
else
MemoryCost
(
activation
=
activation_size
(
output_tensor
))
bwd_mem_cost
=
MemoryCost
(
activation
=
activation_size
(
input_tensor
))
bwd_mem_cost
=
MemoryCost
()
if
is_inplace
else
MemoryCost
(
activation
=
activation_size
(
input_tensor
))
# total cost
# total cost
total_mem_cost
=
MemoryCost
(
activation
=
fwd_mem_cost
.
activation
+
bwd_mem_cost
.
activation
)
total_mem_cost
=
MemoryCost
(
activation
=
fwd_mem_cost
.
activation
+
bwd_mem_cost
.
activation
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment