test_autochunk_codegen.py 3.02 KB
Newer Older
oahzxl's avatar
init  
oahzxl committed
1
import pytest
oahzxl's avatar
oahzxl committed
2
import torch
oahzxl's avatar
oahzxl committed
3
import torch.fx
oahzxl's avatar
init  
oahzxl committed
4
import torch.multiprocessing as mp
oahzxl's avatar
oahzxl committed
5

oahzxl's avatar
init  
oahzxl committed
6
import colossalai
oahzxl's avatar
oahzxl committed
7
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
oahzxl's avatar
init  
oahzxl committed
8
from colossalai.core import global_context as gpc
oahzxl's avatar
oahzxl committed
9
from colossalai.fx import ColoTracer
oahzxl's avatar
init  
oahzxl committed
10
from colossalai.fx.graph_module import ColoGraphModule
oahzxl's avatar
oahzxl committed
11
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
oahzxl's avatar
oahzxl committed
12
from colossalai.fx.profiler import MetaTensor
oahzxl's avatar
oahzxl committed
13
from colossalai.utils import free_port
oahzxl's avatar
oahzxl committed
14
from tests.test_autochunk.evoformer.evoformer import evoformer_base
oahzxl's avatar
oahzxl committed
15
16
17


def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
oahzxl's avatar
oahzxl committed
18
    torch.cuda.reset_peak_memory_stats()
oahzxl's avatar
oahzxl committed
19
20
    now_mem = torch.cuda.memory_allocated() / 1024**2
    with torch.no_grad():
oahzxl's avatar
oahzxl committed
21
        gm(node.clone(), pair.clone())
oahzxl's avatar
oahzxl committed
22
23
    new_now_mem = torch.cuda.memory_allocated() / 1024**2
    new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
oahzxl's avatar
oahzxl committed
24
25
26
27
28
    print(
        "autochunk now mem:%.2f max mem:%.2f"
        % (new_now_mem - now_mem, new_max_mem - now_mem)
    )

oahzxl's avatar
init  
oahzxl committed
29
    # test forward
30
31
32
    with torch.no_grad():
        non_fx_out = model(node, pair)
        fx_out = gm(node, pair)
oahzxl's avatar
oahzxl committed
33

oahzxl's avatar
oahzxl committed
34
35
36
37
38
39
40
41
42
43
    assert torch.allclose(
        non_fx_out[0], fx_out[0], atol=1e-4
    ), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
        torch.abs(non_fx_out[0] - fx_out[0])
    )
    assert torch.allclose(
        non_fx_out[1], fx_out[1], atol=1e-4
    ), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
        torch.abs(non_fx_out[1] - fx_out[1])
    )
oahzxl's avatar
init  
oahzxl committed
44
45
46
47


def _run_offload_codegen(rank):
    # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
oahzxl's avatar
oahzxl committed
48
49
50
51
52
53
54
55
    colossalai.launch(
        config={},
        rank=rank,
        world_size=1,
        host="localhost",
        port=free_port(),
        backend="nccl",
    )
oahzxl's avatar
init  
oahzxl committed
56
57

    # build model and input
oahzxl's avatar
oahzxl committed
58
    model = evoformer_base().cuda()
oahzxl's avatar
oahzxl committed
59
60
    node = torch.randn(1, 100, 300, 256).cuda()
    pair = torch.randn(1, 300, 300, 128).cuda()
oahzxl's avatar
init  
oahzxl committed
61
62

    # trace the module and replace codegen
oahzxl's avatar
oahzxl committed
63
64
65
66
67
68
69
70
71
72
73
74
    graph = ColoTracer().trace(
        model,
        meta_args={
            "node": node.to(torch.device("meta")),
            "pair": pair.to(torch.device("meta")),
        },
    )
    gm_prop = torch.fx.symbolic_trace(model)  # must use symbolic_trace
    interp = MetaInfoProp(gm_prop)
    interp.propagate(
        MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")
    )
oahzxl's avatar
oahzxl committed
75
76
77
78

    # now run it twice to get meta info in graph module, not necessary
    gm = torch.fx.GraphModule(model, graph)
    interp = MetaInfoProp(gm)
oahzxl's avatar
oahzxl committed
79
80
81
    interp.propagate(
        MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")
    )
oahzxl's avatar
init  
oahzxl committed
82

oahzxl's avatar
oahzxl committed
83
    codegen = AutoChunkCodeGen(gm_prop)
oahzxl's avatar
oahzxl committed
84
    graph.set_codegen(codegen)
oahzxl's avatar
oahzxl committed
85
    gm = ColoGraphModule(model, graph)
oahzxl's avatar
init  
oahzxl committed
86
87
88
    gm.recompile()

    # assert we have all the components
oahzxl's avatar
oahzxl committed
89
90
    # code = graph.python_code("self").src
    # print(code)
oahzxl's avatar
init  
oahzxl committed
91

oahzxl's avatar
oahzxl committed
92
    _test_fwd(model, gm, node, pair)
oahzxl's avatar
init  
oahzxl committed
93
94
95
    gpc.destroy()


oahzxl's avatar
oahzxl committed
96
def test_autochunk():
oahzxl's avatar
init  
oahzxl committed
97
98
99
100
101
    mp.spawn(_run_offload_codegen, nprocs=1)


if __name__ == "__main__":
    _run_offload_codegen(0)