Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
61fdd346
Commit
61fdd346
authored
Jan 10, 2023
by
oahzxl
Browse files
update doc
parent
36ab2cb7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
35 deletions
+21
-35
tests/test_autochunk/test_autochunk_codegen.py
tests/test_autochunk/test_autochunk_codegen.py
+10
-18
tests/test_autochunk/test_autochunk_search.py
tests/test_autochunk/test_autochunk_search.py
+11
-17
No files found.
tests/test_autochunk/test_autochunk_codegen.py
View file @
61fdd346
...
...
@@ -40,20 +40,16 @@ def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
non_fx_out
=
model
(
node
,
pair
)
fx_out
=
gm
(
node
,
pair
)
assert
torch
.
allclose
(
non_fx_out
[
0
],
fx_out
[
0
],
atol
=
1e-4
),
"fx_out doesn't comply with original output, diff is %.2e"
%
torch
.
mean
(
torch
.
abs
(
non_fx_out
[
0
]
-
fx_out
[
0
])
)
assert
torch
.
allclose
(
non_fx_out
[
1
],
fx_out
[
1
],
atol
=
1e-4
),
"fx_out doesn't comply with original output, diff is %.2e"
%
torch
.
mean
(
torch
.
abs
(
non_fx_out
[
1
]
-
fx_out
[
1
])
)
assert
torch
.
allclose
(
non_fx_out
[
0
],
fx_out
[
0
],
atol
=
1e-4
),
"fx_out doesn't comply with original output, diff is %.2e"
%
torch
.
mean
(
torch
.
abs
(
non_fx_out
[
0
]
-
fx_out
[
0
]))
assert
torch
.
allclose
(
non_fx_out
[
1
],
fx_out
[
1
],
atol
=
1e-4
),
"fx_out doesn't comply with original output, diff is %.2e"
%
torch
.
mean
(
torch
.
abs
(
non_fx_out
[
1
]
-
fx_out
[
1
]))
def
_test_autochunk_codegen
(
rank
,
msa_len
,
pair_len
,
max_memory
):
# launch colossalai
to make sure we could execute colossalai.utils.checkpoint currectly
# launch colossalai
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
...
...
@@ -76,18 +72,14 @@ def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory):
"pair"
:
pair
.
to
(
torch
.
device
(
"meta"
)),
},
)
gm_prop
=
torch
.
fx
.
symbolic_trace
(
model
)
# must use symbolic_trace
gm_prop
=
torch
.
fx
.
symbolic_trace
(
model
)
# must use symbolic_trace
interp
=
MetaInfoProp
(
gm_prop
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
)
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
))
# now run it twice to get meta info in graph module, not necessary
gm
=
torch
.
fx
.
GraphModule
(
model
,
graph
)
interp
=
MetaInfoProp
(
gm
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
)
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
))
codegen
=
AutoChunkCodeGen
(
gm_prop
,
max_memory
=
max_memory
)
graph
.
set_codegen
(
codegen
)
...
...
tests/test_autochunk/test_autochunk_search.py
View file @
61fdd346
...
...
@@ -23,7 +23,8 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
if
msa_len
==
32
and
pair_len
==
64
:
if
max_memory
is
None
:
target_regions
=
[(
142
,
154
),
(
366
,
373
),
(
233
,
283
),
(
301
,
351
),
(
127
,
134
),
(
204
,
228
),
(
167
,
191
),
(
161
,
166
),
(
198
,
203
),
(
6
,
69
)]
target_regions
=
[(
142
,
154
),
(
366
,
373
),
(
233
,
283
),
(
301
,
351
),
(
127
,
134
),
(
204
,
228
),
(
167
,
191
),
(
161
,
166
),
(
198
,
203
),
(
6
,
69
)]
elif
max_memory
==
20
:
target_regions
=
[(
142
,
154
),
(
369
,
373
),
(
233
,
269
),
(
301
,
351
)]
elif
max_memory
==
25
:
...
...
@@ -36,24 +37,19 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
raise
NotImplementedError
()
assert
len
(
found_regions
)
==
len
(
target_regions
),
"len of found regions %s doesn't equal len of target regions %s"
%
(
str
(
found_regions
),
str
(
target_regions
),
)
target_regions
),
"len of found regions %s doesn't equal len of target regions %s"
%
(
str
(
found_regions
),
str
(
target_regions
),
)
for
region
in
target_regions
:
assert
(
region
in
found_regions
),
"region:%s not in found regions for msa:%d, pair:%d, maxmem:%d"
%
(
assert
(
region
in
found_regions
),
"region:%s not in found regions for msa:%d, pair:%d, maxmem:%d"
%
(
str
(
region
),
msa_len
,
pair_len
,
max_memory
,
)
for
region
in
found_regions
:
assert
(
region
in
target_regions
),
"region:%s should not be found for msa:%d, pair:%d, maxmem:%d"
%
(
assert
(
region
in
target_regions
),
"region:%s should not be found for msa:%d, pair:%d, maxmem:%d"
%
(
str
(
region
),
msa_len
,
pair_len
,
...
...
@@ -62,7 +58,7 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
def
_test_autochunk_search
(
rank
,
msa_len
,
pair_len
,
max_memory
):
# launch colossalai
to make sure we could execute colossalai.utils.checkpoint currectly
# launch colossalai
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
...
...
@@ -77,11 +73,9 @@ def _test_autochunk_search(rank, msa_len, pair_len, max_memory):
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
gm_prop
=
torch
.
fx
.
symbolic_trace
(
model
)
# must use symbolic_trace
gm_prop
=
torch
.
fx
.
symbolic_trace
(
model
)
# must use symbolic_trace
interp
=
MetaInfoProp
(
gm_prop
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
)
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
))
codegen
=
AutoChunkCodeGen
(
gm_prop
,
max_memory
=
max_memory
)
chunk_infos
=
codegen
.
chunk_infos
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment