test_alphafold_utils.py 4.19 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from typing import Any, Dict, List

import torch
import torch.fx

import colossalai
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.autochunk.utils import flat_list
from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port

if AUTOCHUNK_AVAILABLE:
    from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
    from colossalai.fx.profiler import MetaTensor
    from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace


def assert_codegen_run(
    model: Any,
    meta_args: List,
    concrete_args: List = None,
    max_memory: int = None,
    print_mem: bool = False,
26
    print_est_mem: bool = False,
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
    print_progress: bool = False,
    print_code: bool = False,
) -> List[Dict]:
    if concrete_args is None:
        concrete_args = []

    # trace the meta graph and setup codegen
    meta_graph = symbolic_trace(
        model,
        meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
        concrete_args={k: v for k, v in concrete_args},
    )
    interp = MetaInfoProp(meta_graph)
    meta_tensors = [MetaTensor(i[1], fake_device="cuda:0") for i in meta_args] + [i[1] for i in concrete_args]
    interp.propagate(*meta_tensors)
    codegen = AutoChunkCodeGen(
        meta_graph,
        max_memory=max_memory,
45
        print_mem=print_est_mem,
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
        print_progress=print_progress,
    )
    chunks = codegen.chunk_infos

    # trace and recompile
    # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
    graph = ColoTracer().trace(
        model,
        meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
        concrete_args={k: v for k, v in concrete_args},
    )
    graph.set_codegen(codegen)
    gm = ColoGraphModule(model, graph, ckpt_codegen=False)
    gm.recompile()

    # assert chunk in code
    code = graph.python_code("self").src
    if print_code:
        print(code)
65
    assert "chunk_size = None;  " in code
66
67
68

    # assert result
    inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
69
    inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs]
70
71
    model.cuda()
    with torch.no_grad():
72
73
74
75
76
77
78
        if print_mem:
            torch.cuda.reset_peak_memory_stats()
            now_mem = torch.cuda.memory_allocated() / 1024**2
        out_gm = gm(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
        if print_mem:
            new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
            print("mem: %.2fMB" % (new_max_mem - now_mem))
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
        out_model = model(*inputs)
    out_gm = flat_list(out_gm)
    out_model = flat_list(out_model)
    for out_gm_i, out_model_i in zip(out_gm, out_model):
        assert torch.allclose(out_gm_i, out_model_i,
                              atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
                                  torch.abs(out_gm_i - out_model_i))

    return chunks


def run_test(
    rank: int,
    data_args: tuple,
    max_memory: int,
    get_model: Any,
    get_data: Any,
96
97
98
99
    print_code: bool = False,
    print_mem: bool = False,
    print_est_mem: bool = False,
    print_progress: bool = False,
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    get_chunk_target: Any = None,
) -> None:
    # launch colossalai
    colossalai.launch(
        config={},
        rank=rank,
        world_size=1,
        host="localhost",
        port=free_port(),
        backend="nccl",
    )

    # build model and input
    model = get_model()
    meta_args, concrete_args = get_data(*data_args)
    chunks = assert_codegen_run(
        model,
        meta_args=meta_args,
        concrete_args=concrete_args,
        max_memory=max_memory,
        print_code=print_code,
        print_mem=print_mem,
122
        print_est_mem=print_est_mem,
123
124
125
126
127
128
129
130
131
132
        print_progress=print_progress,
    )

    if get_chunk_target is not None:
        chunk_found = [i["region"] for i in chunks]
        chunk_target = get_chunk_target()[max_memory]
        assert chunk_found == chunk_target, "found regions %s doesn't equal target regions %s" % (
            str(chunk_found),
            str(chunk_target),
        )