Commit d106b271 authored by oahzxl's avatar oahzxl
Browse files

add chunk search test

parent a005965d
from functools import partial
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
import colossalai
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
from colossalai.core import global_context as gpc
from colossalai.fx import ColoTracer
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.profiler import MetaTensor
from colossalai.utils import free_port
from tests.test_autochunk.evoformer.evoformer import evoformer_base
def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
found_regions = [i["region"] for i in chunk_infos]
if msa_len == 32 and pair_len == 64:
if max_memory is None:
target_regions = [(142, 154), (366, 373), (233, 283), (301, 351), (127, 134), (204, 228), (167, 191), (161, 166), (198, 203), (6, 69)]
elif max_memory == 20:
target_regions = [(142, 154), (369, 373), (233, 269), (301, 351)]
elif max_memory == 25:
target_regions = [(144, 154), (369, 370)]
elif max_memory == 30:
target_regions = [(144, 154)]
else:
raise NotImplementedError()
else:
raise NotImplementedError()
assert len(found_regions) == len(target_regions), "len of found regions %s doesn't equal len of target regions %s" % (str(found_regions), str(target_regions))
for region in target_regions:
assert region in found_regions, "region:%s not in found regions for msa:%d, pair:%d, maxmem:%d" % (str(region), msa_len, pair_len, max_memory)
for region in found_regions:
assert region in target_regions, "region:%s should not be found for msa:%d, pair:%d, maxmem:%d" % (str(region), msa_len, pair_len, max_memory)
def _test_autochunk_search(rank, msa_len, pair_len, max_memory):
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
colossalai.launch(
config={},
rank=rank,
world_size=1,
host="localhost",
port=free_port(),
backend="nccl",
)
# build model and input
model = evoformer_base().cuda()
node = torch.randn(1, msa_len, pair_len, 256).cuda()
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace
interp = MetaInfoProp(gm_prop)
interp.propagate(
MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")
)
codegen = AutoChunkCodeGen(gm_prop, max_memory=max_memory)
chunk_infos = codegen.chunk_infos
assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len)
gpc.destroy()
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
@pytest.mark.parametrize("msa_len", [32])
@pytest.mark.parametrize("pair_len", [64])
def test_autochunk_search(msa_len, pair_len, max_memory):
run_func = partial(
_test_autochunk_search,
msa_len=msa_len,
pair_len=pair_len,
max_memory=max_memory,
)
mp.spawn(run_func, nprocs=1)
if __name__ == "__main__":
_test_autochunk_search(0, 32, 64, 20)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment