test_diffuser_utils.py 3.52 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from typing import Any, Dict, List

import torch
import torch.fx

import colossalai
from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port

if AUTOCHUNK_AVAILABLE:
    from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
    from colossalai.fx.profiler import MetaTensor
    from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace


def assert_codegen_run(
    model: Any,
    meta_args: List,
    concrete_args: List = None,
    max_memory: int = None,
    print_mem: bool = False,
    print_progress: bool = False,
    print_code: bool = False,
) -> List[Dict]:
    if concrete_args is None:
        concrete_args = []
    model = model()

    # trace the meta graph and setup codegen
    meta_graph = symbolic_trace(
        model,
        meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
        concrete_args={k: v for k, v in concrete_args},
    )
    interp = MetaInfoProp(meta_graph)
    meta_tensors = [MetaTensor(i[1], fake_device="cuda:0") for i in meta_args] + [i[1] for i in concrete_args]
    interp.propagate(*meta_tensors)
    codegen = AutoChunkCodeGen(
        meta_graph,
        max_memory=max_memory,
        print_mem=print_mem,
        print_progress=print_progress,
    )
    chunks = codegen.chunk_infos

    # trace and recompile
    # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
    graph = ColoTracer().trace(
        model.cuda(),
        meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
        concrete_args={k: v for k, v in concrete_args},
    )
    graph.set_codegen(codegen)
    gm = ColoGraphModule(model, graph, ckpt_codegen=False)
    gm.recompile()

    # assert chunk in code
    code = graph.python_code("self").src
    if print_code:
        print(code)
    assert "chunk_result = None;  chunk_size = None;" in code

    # assert result
    inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
    model.cuda().eval()
    gm.eval()
    with torch.no_grad():
        out_gm = gm(*inputs)
        out_model = model(*inputs)
    assert torch.allclose(out_gm["sample"], out_model["sample"],
                          atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
                              torch.abs(out_gm["sample"] - out_model["sample"]))

    return chunks


def run_test(
    rank: int,
    model: Any,
    data: tuple,
    max_memory: int,
    print_code: bool,
    print_mem: bool,
    print_progress: bool,
    get_chunk_target: Any = None,
) -> None:
    # launch colossalai
    colossalai.launch(
        config={},
        rank=rank,
        world_size=1,
        host="localhost",
        port=free_port(),
        backend="nccl",
    )

    # build model and input
    meta_args, concrete_args = data
    chunks = assert_codegen_run(
        model,
        meta_args=meta_args,
        concrete_args=concrete_args,
        max_memory=max_memory,
        print_code=print_code,
        print_mem=print_mem,
        print_progress=print_progress,
    )

    if get_chunk_target is not None:
        chunk_found = [i["region"] for i in chunks]
        chunk_target = get_chunk_target()[max_memory]
        assert (chunk_found == chunk_target), "found regions %s doesn't equal target regions %s" % (
            str(chunk_found),
            str(chunk_target),
        )

    gpc.destroy()