Commit 3abbaf8b authored by oahzxl's avatar oahzxl
Browse files

update codegen test

parent 74b81395
from functools import partial
import pytest import pytest
import torch import torch
import torch.fx import torch.fx
...@@ -46,7 +48,7 @@ def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair): ...@@ -46,7 +48,7 @@ def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
) )
def _test_autochunk_codegen(rank): def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory):
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
colossalai.launch( colossalai.launch(
config={}, config={},
...@@ -59,8 +61,6 @@ def _test_autochunk_codegen(rank): ...@@ -59,8 +61,6 @@ def _test_autochunk_codegen(rank):
# build model and input # build model and input
model = evoformer_base().cuda() model = evoformer_base().cuda()
msa_len = 32
pair_len = 64
node = torch.randn(1, msa_len, pair_len, 256).cuda() node = torch.randn(1, msa_len, pair_len, 256).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda() pair = torch.randn(1, pair_len, pair_len, 128).cuda()
...@@ -85,7 +85,7 @@ def _test_autochunk_codegen(rank): ...@@ -85,7 +85,7 @@ def _test_autochunk_codegen(rank):
MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0") MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")
) )
codegen = AutoChunkCodeGen(gm_prop) codegen = AutoChunkCodeGen(gm_prop, max_memory=max_memory)
graph.set_codegen(codegen) graph.set_codegen(codegen)
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
gm.recompile() gm.recompile()
...@@ -99,9 +99,18 @@ def _test_autochunk_codegen(rank): ...@@ -99,9 +99,18 @@ def _test_autochunk_codegen(rank):
gpc.destroy() gpc.destroy()
def test_autochunk_codegen(): @pytest.mark.parametrize("max_memory", [None, 20, 24, 28, 32])
mp.spawn(_test_autochunk_codegen, nprocs=1) @pytest.mark.parametrize("msa_len", [32])
@pytest.mark.parametrize("pair_len", [64])
def test_autochunk_codegen(msa_len, pair_len, max_memory):
run_func = partial(
_test_autochunk_codegen,
msa_len=msa_len,
pair_len=pair_len,
max_memory=max_memory,
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__": if __name__ == "__main__":
_test_autochunk_codegen(0) _test_autochunk_codegen(0, 32, 64, None)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment