rt_mod_hip.cc 3.82 KB
Newer Older
1
2
#if defined(__linux__)
#include <sys/stat.h>
3
#include <tvm/ffi/reflection/registry.h>
4
5
6
7
8
9
#endif

#include <hip/hip_runtime.h>
#include <hip/hiprtc.h>

#include "codegen_hip.h"
10
#include "runtime/rocm/rocm_module.h"
11
12
13
14
15
#include <tvm/ffi/function.h>

#ifndef kTVMGridConstant
#define kTVMGridConstant 130
#endif
16
17
18
19

namespace tvm {
namespace codegen {

20
21
static std::unordered_map<std::string, runtime::FunctionInfo>
ExtractFuncInfo(const IRModule &mod) {
22
23
24
  std::unordered_map<std::string, runtime::FunctionInfo> fmap;

  for (auto kv : mod->functions) {
25
26
    ICHECK(kv.second->IsInstance<tir::PrimFuncNode>())
        << "Can only lower IR Module with PrimFuncs";
27
28
29
30
31
32
33
34
35
36
37
    auto f = Downcast<tir::PrimFunc>(kv.second);

    runtime::FunctionInfo info;
    for (size_t i = 0; i < f->params.size(); ++i) {
      if (f->params[i]->dtype.is_handle()) {
        auto ptr = f->params[i]->type_annotation.as<PointerTypeNode>();
        if (ptr && ptr->storage_scope == "grid_constant") {
          info.arg_types.push_back(DataType(kTVMGridConstant, 64, 1));
          continue;
        }
      }
38
39
40
41
42
      DataType dtype = f->params[i].dtype();
      // Device runtime cannot directly take bool arguments, map to int32.
      if (dtype.is_bool())
        dtype = DataType::Int(32);
      info.arg_types.push_back(dtype);
43
    }
44
45
    if (auto opt = f->GetAttr<ffi::Array<ffi::String>>(
            tir::attr::kKernelLaunchParams)) {
46
      for (const auto &tag : opt.value()) {
47
48
49
        info.launch_param_tags.push_back(tag);
      }
    }
50
    auto global_symbol = f->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol);
51
52
53
54
55
    fmap[static_cast<std::string>(global_symbol.value())] = info;
  }
  return fmap;
}

56
ffi::Module BuildTileLangHIP(IRModule mod, Target target) {
57
58
59
60
61
  bool output_ssa = false;
  CodeGenTileLangHIP cg;
  cg.Init(output_ssa);

  for (auto kv : mod->functions) {
62
63
    ICHECK(kv.second->IsInstance<PrimFuncNode>())
        << "CodeGenTileLangHIP: Can only take PrimFunc";
64
65
66
67
68
69
70
    auto f = Downcast<PrimFunc>(kv.second);
    auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
    ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch);
    cg.AddFunction(f);
  }

  std::string code = cg.Finish();
71
72
73
74
75

  // Use the new FFI API to get registered functions
  using ffi::Function;
  if (auto f = Function::GetGlobal("tilelang_callback_hip_postproc")) {
    code = (*f)(code, target).cast<std::string>();
76
  }
77

78
79
  std::string fmt = "ptx";
  std::string ptx;
80
81
82

  if (auto f = Function::GetGlobal("tilelang_callback_hip_compile")) {
    ptx = (*f)(code, target).cast<std::string>();
83
84
    if (ptx[0] != '/')
      fmt = "hsaco";
85
  } else {
86
    ICHECK(false) << "tilelang_callback_hip_compile is not set";
87
  }
88

89
90
91
  return ROCMModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code, std::string());
}

92
ffi::Module BuildTileLangHIPWithoutCompile(IRModule mod, Target target) {
93
94
95
96
97
98
99
100
101
102
103
104
105
106
  bool output_ssa = false;
  CodeGenTileLangHIP cg;
  cg.Init(output_ssa);

  for (auto kv : mod->functions) {
    ICHECK(kv.second->IsInstance<PrimFuncNode>())
        << "CodeGenTileLangHIP: Can only take PrimFunc";
    auto f = Downcast<PrimFunc>(kv.second);
    auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
    ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch);
    cg.AddFunction(f);
  }

  std::string code = cg.Finish();
107
108
109
110
111

  // Use the new FFI API to get registered functions
  using ffi::Function;
  if (auto f = Function::GetGlobal("tilelang_callback_hip_postproc")) {
    code = (*f)(code, target).cast<std::string>();
112
  }
113

114
115
  return ROCMModuleCreate("ptx", "fmt", ExtractFuncInfo(mod), code,
                          std::string());
116
}
117

118
TVM_FFI_STATIC_INIT_BLOCK() {
119
120
121
122
123
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef()
      .def("target.build.tilelang_hip", BuildTileLangHIP)
      .def("target.build.tilelang_hip_without_compile",
           BuildTileLangHIPWithoutCompile);
124
}
125

126
} // namespace codegen
127
} // namespace tvm