lower.py 7.5 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The compiler for TL programs."""

import os
import os.path as osp
7
from typing import Union, Optional, Callable
8
9
from tilelang import tvm as tvm
from tvm import tir, relay
10
from tvm.ir import CallingConv
11
12
from tvm.target import Target
from tilelang.contrib import hipcc, nvcc
13
from tilelang.utils.target import determine_target
Lei Wang's avatar
Lei Wang committed
14
15
16
17
from tilelang.engine.phase import (
    LowerAndLegalize,
    OptimizeForTarget,
)
18
19


20
21
22
23
24
25
26
27
28
29
30
def is_cpu_device_backend(target: Target):
    return target.kind.name == "c"


def has_device_kernel_launch(attrs) -> bool:
    """Check if the attributes indicate a device kernel launch."""
    return bool(attrs and "calling_conv" in attrs and
                attrs["calling_conv"] == CallingConv.DEVICE_KERNEL_LAUNCH)


def is_device_call_c_device(func: tir.PrimFunc):
31
32
    attrs = func.attrs

33
34
35
36
37
    # Check if it's a C target
    if "target" in attrs and attrs["target"].kind.name == "c":
        return True

    return has_device_kernel_launch(attrs)
38

39

40
41
42
43
44
45
def is_device_call(func: tir.PrimFunc):
    return has_device_kernel_launch(func.attrs)


def get_device_call(is_device_c: bool = False) -> Callable[[tir.PrimFunc], bool]:
    return is_device_call_c_device if is_device_c else is_device_call
46

47
48
49

def get_host_call(is_device_c: bool = False) -> Callable[[tir.PrimFunc], bool]:
    return lambda func: not get_device_call(is_device_c)(func)
50
51


52
53
@tvm.register_func("tilelang_callback_cuda_compile", override=True)
def tilelang_callback_cuda_compile(code, target):
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    project_root = osp.join(osp.dirname(__file__), "../..")
    if "TL_TEMPLATE_PATH" in os.environ:
        tl_template_path = os.environ["TL_TEMPLATE_PATH"]
    else:
        tl_template_path = osp.abspath(osp.join(project_root, "src"))
    # TODO(lei): this indeed should be renamed into
    # TL_CUTLASS_INCLUDE_PATH in the future
    if "TL_CUTLASS_PATH" in os.environ:
        cutlass_path = os.environ["TL_CUTLASS_PATH"]
    else:
        cutlass_path = osp.abspath(osp.join(project_root, "3rdparty/cutlass/include"))
    compute_version = "".join(nvcc.get_target_compute_version(target).split("."))

    # special handle for Hopper
    if compute_version == "90":
        arch = ["-arch=sm_90a"]
        format = "cubin"
    else:
        arch = [f"-arch=sm_{compute_version}"]
        format = "cubin"

    # printing out number of registers
    debug_option = "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage"
    ptx = nvcc.compile_cuda(
        code,
        format,
        arch,
        options=[
            "-std=c++17",
            debug_option,
            "--use_fast_math",
            "-I" + tl_template_path,
            "-I" + cutlass_path,
        ],
        verbose=False,
    )

    return ptx


94
95
@tvm.register_func("tilelang_callback_hip_compile", override=True)
def tilelang_callback_hip_compile(code, target):
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    project_root = osp.join(osp.dirname(__file__), "../..")
    tl_template_path = osp.abspath(osp.join(project_root, "src"))

    # TODO(lei): actually this indeed should be renamed into
    # TL_COMPOSABLE_KERNEL_INCLUDE_PATH in the future
    if "TL_COMPOSABLE_KERNEL_PATH" in os.environ:
        ck_path = os.environ["TL_COMPOSABLE_KERNEL_PATH"]
    else:
        ck_path = osp.abspath(osp.join(project_root, "3rdparty/composable_kernel/include"))

    hsaco = hipcc.compile_hip(
        code,
        target_format="hsaco",
        options=[
            "-std=c++17",
            "-I" + tl_template_path,
            "-I" + ck_path,
        ],
        verbose=False,
    )

    return hsaco


def extrac_params(func: tir.PrimFunc):
    buffers = [func.buffer_map[var] for var in func.params]
    tensor_types = [relay.TensorType(buffer.shape, buffer.dtype) for buffer in buffers]
    return tensor_types


126
127
128
129
130
131
132
133
def canon_target_host(target: Union[str, Target], target_host: Optional[Union[str, Target]]):

    if not target_host:
        target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm"

    return target_host


134
135
def lower(
    func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
136
137
    target: Union[str, Target] = "auto",
    target_host: Optional[Union[str, Target]] = None,
138
139
    runtime_only=False,
):
140

141
142
143
144
145
146
147
148
149
    mod = func_or_mod
    if isinstance(func_or_mod, tir.PrimFunc):
        func = func_or_mod
        params = extrac_params(func) if not runtime_only else None
        mod = tvm.IRModule({func.attrs["global_symbol"]: func})

    if isinstance(target, str):
        target = determine_target(target)

150
151
    target_host = canon_target_host(target, target_host)

152
153
154
    target_host = tvm.target.Target.canon_target(target_host)
    target = tvm.target.Target(target, target_host)

155
156
157
    _is_host_call = get_host_call(is_device_c=is_cpu_device_backend(target))
    _is_device_call = get_device_call(is_device_c=is_cpu_device_backend(target))

Lei Wang's avatar
Lei Wang committed
158
159
160
161
162
163
    # Phase 1: Lower and legalize the IR
    mod = LowerAndLegalize(mod, target)

    # Phase 2: Optimize the IR for the target
    mod = OptimizeForTarget(mod, target)

164
    host_mod = tir.transform.Filter(_is_host_call)(mod)
165
166
167
168
169
170
171
172
173
174
175
    host_mod = tir.transform.BindTarget(target_host)(host_mod)
    host_mod = tir.transform.FP8StorageLegalize()(host_mod)
    host_mod = tir.transform.BF16StorageLegalize()(host_mod)
    host_mod = tir.transform.LowerTVMBuiltin()(host_mod)
    host_mod = tir.transform.LowerCustomDatatypes()(host_mod)
    host_mod = tir.transform.LowerIntrin()(host_mod)
    host_mod = tir.transform.LowerDeviceStorageAccessInfo()(host_mod)
    host_mod = tir.transform.CombineContextCall()(host_mod)

    if target_host.kind.name == "llvm":
        host_mod = tvm._ffi.get_global_func("target.build.llvm")(host_mod, target_host)
176
    elif target_host.kind.name == "c":
177
178
179
180
        if is_cpu_device_backend(target):
            host_mod = tvm._ffi.get_global_func("target.build.tilelang_cpp")(host_mod, target_host)
        else:
            host_mod = tvm._ffi.get_global_func("target.build.c")(host_mod, target_host)
181
    else:
182
        raise ValueError(f"Target host {target_host.kind.name} is not supported")
183

184
    device_mod = tir.transform.Filter(_is_device_call)(mod)
185
186
187
188
189
190
191
192
193
194
    device_mod = tir.transform.LowerDeviceStorageAccessInfo()(device_mod)
    device_mod = tir.transform.LowerIntrin()(device_mod)
    device_mod = tir.transform.Simplify()(device_mod)

    if target.kind.name == "cuda":
        # Debug comments to get the code
        # code = tvm._ffi.get_global_func("target.build.tl_debug_codegen")(device_mod, target)
        device_mod = tvm._ffi.get_global_func("target.build.tilelang_cuda")(device_mod, target)
    elif target.kind.name == "hip":
        device_mod = tvm._ffi.get_global_func("target.build.tilelang_hip")(device_mod, target)
195
196
197
198
    elif target.kind.name == "c":
        device_mod = tvm._ffi.get_global_func("target.build.tilelang_cpp")(device_mod, target)
    elif target.kind.name == "llvm":
        device_mod = tvm._ffi.get_global_func("target.build.llvm")(device_mod, target)
199
200
    elif target.kind.name == "webgpu":
        device_mod = tvm._ffi.get_global_func("target.build.tilelang_webgpu")(device_mod, target)
201
    else:
202
        raise ValueError(f"Target {target.kind.name} is not supported")
203
204
205

    host_mod.import_module(device_mod)

206
207
208
209
210
211
212
213
    if target_host.kind.name == "c":
        # cpu host should be recompiled
        # TODO(lei): this is a hack to make the C host backend work
        temp_dir = tvm.contrib.utils.tempdir()
        tmp_lib_path = temp_dir.relpath("tmp.so")
        host_mod.export_library(tmp_lib_path)
        host_mod = tvm.runtime.load_module(tmp_lib_path)

214
215
216
217
    if runtime_only is True:
        return host_mod
    else:
        return host_mod, params