Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
7ab2db20
Commit
7ab2db20
authored
Jan 10, 2023
by
oahzxl
Browse files
adapt new fx
parent
e532679c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
12 additions
and
14 deletions
+12
-14
colossalai/autochunk/autochunk_codegen.py
colossalai/autochunk/autochunk_codegen.py
+3
-3
colossalai/autochunk/estimate_memory.py
colossalai/autochunk/estimate_memory.py
+1
-6
tests/test_autochunk/test_autochunk_codegen.py
tests/test_autochunk/test_autochunk_codegen.py
+4
-2
tests/test_autochunk/test_autochunk_search.py
tests/test_autochunk/test_autochunk_search.py
+4
-3
No files found.
colossalai/autochunk/autochunk_codegen.py
View file @
7ab2db20
...
@@ -585,9 +585,9 @@ if CODEGEN_AVAILABLE:
...
@@ -585,9 +585,9 @@ if CODEGEN_AVAILABLE:
code
=
""
.
join
(
body
)
code
=
""
.
join
(
body
)
code
=
"
\n
"
.
join
(
" "
+
line
for
line
in
code
.
split
(
"
\n
"
))
code
=
"
\n
"
.
join
(
" "
+
line
for
line
in
code
.
split
(
"
\n
"
))
fn_code
=
f
"""
fn_code
=
f
"""
{
wrap_stmts
}
{
wrap_stmts
}
{
prologue
}
{
prologue
}
{
code
}
"""
{
code
}
"""
# print(fn_code)
# print(fn_code)
return
PythonCode
(
fn_code
,
globals_
)
return
PythonCode
(
fn_code
,
globals_
)
colossalai/autochunk/estimate_memory.py
View file @
7ab2db20
...
@@ -28,12 +28,7 @@ class EstimateMemory(object):
...
@@ -28,12 +28,7 @@ class EstimateMemory(object):
return
x
return
x
def
_get_output_node
(
self
,
n
):
def
_get_output_node
(
self
,
n
):
fwd_out
=
{
out_size
=
activation_size
(
n
.
meta
[
"fwd_out"
])
x
.
uuid
:
x
for
x
in
n
.
meta
[
"fwd_out"
]
if
isinstance
(
x
,
torch
.
Tensor
)
and
hasattr
(
x
,
"uuid"
)
}
out_size
=
activation_size
(
fwd_out
)
out_node
=
[
n
.
name
]
if
out_size
>
0
else
[]
out_node
=
[
n
.
name
]
if
out_size
>
0
else
[]
return
out_size
,
out_node
return
out_size
,
out_node
...
...
tests/test_autochunk/test_autochunk_codegen.py
View file @
7ab2db20
...
@@ -8,6 +8,7 @@ import torch.multiprocessing as mp
...
@@ -8,6 +8,7 @@ import torch.multiprocessing as mp
import
colossalai
import
colossalai
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.fx
import
ColoTracer
from
colossalai.fx
import
ColoTracer
from
colossalai.fx._compatibility
import
is_compatible_with_meta
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
CODEGEN_AVAILABLE
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
CODEGEN_AVAILABLE
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
...
@@ -15,8 +16,9 @@ from colossalai.fx.profiler import MetaTensor
...
@@ -15,8 +16,9 @@ from colossalai.fx.profiler import MetaTensor
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
tests.test_autochunk.evoformer.evoformer
import
evoformer_base
from
tests.test_autochunk.evoformer.evoformer
import
evoformer_base
if
CODEGEN_AVAILABLE
:
if
CODEGEN_AVAILABLE
and
is_compatible_with_meta
()
:
from
colossalai.autochunk.autochunk_codegen
import
AutoChunkCodeGen
from
colossalai.autochunk.autochunk_codegen
import
AutoChunkCodeGen
from
colossalai.fx.profiler
import
MetaTensor
def
_test_fwd
(
model
:
torch
.
nn
.
Module
,
gm
:
ColoGraphModule
,
node
,
pair
):
def
_test_fwd
(
model
:
torch
.
nn
.
Module
,
gm
:
ColoGraphModule
,
node
,
pair
):
...
@@ -102,7 +104,7 @@ def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory):
...
@@ -102,7 +104,7 @@ def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory):
gpc
.
destroy
()
gpc
.
destroy
()
@
pytest
.
mark
.
skipif
(
not
CODEGEN_AVAILABLE
,
reason
=
'torch version is lower than 1.12.0'
)
@
pytest
.
mark
.
skipif
(
not
(
CODEGEN_AVAILABLE
and
is_compatible_with_meta
())
,
reason
=
'torch version is lower than 1.12.0'
)
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
20
,
25
,
30
])
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
20
,
25
,
30
])
@
pytest
.
mark
.
parametrize
(
"msa_len"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"msa_len"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"pair_len"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"pair_len"
,
[
64
])
...
...
tests/test_autochunk/test_autochunk_search.py
View file @
7ab2db20
...
@@ -7,14 +7,15 @@ import torch.multiprocessing as mp
...
@@ -7,14 +7,15 @@ import torch.multiprocessing as mp
import
colossalai
import
colossalai
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.fx._compatibility
import
is_compatible_with_meta
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
CODEGEN_AVAILABLE
from
colossalai.fx.codegen.activation_checkpoint_codegen
import
CODEGEN_AVAILABLE
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.fx.profiler
import
MetaTensor
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
tests.test_autochunk.evoformer.evoformer
import
evoformer_base
from
tests.test_autochunk.evoformer.evoformer
import
evoformer_base
if
CODEGEN_AVAILABLE
:
if
CODEGEN_AVAILABLE
and
is_compatible_with_meta
()
:
from
colossalai.autochunk.autochunk_codegen
import
AutoChunkCodeGen
from
colossalai.autochunk.autochunk_codegen
import
AutoChunkCodeGen
from
colossalai.fx.profiler
import
MetaTensor
def
assert_chunk_infos
(
chunk_infos
,
max_memory
,
msa_len
,
pair_len
):
def
assert_chunk_infos
(
chunk_infos
,
max_memory
,
msa_len
,
pair_len
):
...
@@ -89,7 +90,7 @@ def _test_autochunk_search(rank, msa_len, pair_len, max_memory):
...
@@ -89,7 +90,7 @@ def _test_autochunk_search(rank, msa_len, pair_len, max_memory):
gpc
.
destroy
()
gpc
.
destroy
()
@
pytest
.
mark
.
skipif
(
not
CODEGEN_AVAILABLE
,
reason
=
"torch version is lower than 1.12.0"
)
@
pytest
.
mark
.
skipif
(
not
(
CODEGEN_AVAILABLE
and
is_compatible_with_meta
())
,
reason
=
"torch version is lower than 1.12.0"
)
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
20
,
25
,
30
])
@
pytest
.
mark
.
parametrize
(
"max_memory"
,
[
None
,
20
,
25
,
30
])
@
pytest
.
mark
.
parametrize
(
"msa_len"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"msa_len"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"pair_len"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"pair_len"
,
[
64
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment