rt_mod_hip.cc 7.1 KB
Newer Older
1
2
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
3

4
5
6
7
8
9
10
11
#if defined(__linux__)
#include <sys/stat.h>
#endif

#include <hip/hip_runtime.h>
#include <hip/hiprtc.h>

#include "codegen_hip.h"
12
#include "runtime/rocm/rocm_module.h"
13
14
15
16

namespace tvm {
namespace codegen {

17
#define HIPRTC_CALL(x)                                                         \
18
  \  
19
  {                                                                            \
20
    \  
21
    hiprtcResult result = x;                                                   \
22
    \  
23
    if (result != HIPRTC_SUCCESS) {                                            \
24
      \  
25
26
27
      LOG(FATAL)                                                               \
          << "HiprtcError: " #x " failed with error: "                         \
          << hiprtcGetErrorString(result);                                     \
28
      \  
29

30
    }                                                                          \
31
    \  
32

33
34
35
36
37
38
39
40
41
  }

static std::string FindHIPIncludePath() {
#if defined(_WIN32)
  const std::string delimiter = "\\";
#else
  const std::string delimiter = "/";
#endif
  std::string hip_include_path;
42
  const char *hip_path_env = std::getenv("HIP_PATH");
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
  if (hip_path_env != nullptr) {
    hip_include_path += hip_path_env;
    hip_include_path += delimiter + "include";
    return hip_include_path;
  }

#if defined(__linux__)
  struct stat st;
  hip_include_path = "/opt/rocm/hip/include";
  if (stat(hip_include_path.c_str(), &st) == 0) {
    return hip_include_path;
  }

  if (stat("/usr/include/hip/hip_runtime.h", &st) == 0) {
    return "/usr/include/hip";
  }
#endif
  LOG(FATAL) << "Cannot find HIP include path."
61
62
             << "HIP_PATH is not set or ROCm is not installed in the default "
                "installation path."
63
64
65
66
             << "In other than linux, it is necessary to set HIP_PATH.";
  return hip_include_path;
}

67
68
static std::string HIPRTCCompile(const std::string &code,
                                 bool include_path = false) {
69
  std::vector<std::string> compile_params;
70
  std::vector<const char *> param_cstrings{};
71
  hiprtcProgram prog;
72
73
  std::string cc =
      "gfx900"; // Default target architecture (can be changed as needed)
74
  int major, minor;
75
76
77
78
  hipError_t e1 = hipDeviceGetAttribute(
      &major, hipDeviceAttributeComputeCapabilityMajor, 0);
  hipError_t e2 = hipDeviceGetAttribute(
      &minor, hipDeviceAttributeComputeCapabilityMinor, 0);
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

  if (e1 == hipSuccess && e2 == hipSuccess) {
    cc = "gfx" + std::to_string(major * 100 + minor * 10);
  } else {
    LOG(WARNING) << "cannot detect compute capability from your device, "
                 << "fall back to gfx900.";
  }

  compile_params.push_back("--gpu-architecture=" + cc);

  if (include_path) {
    std::string include_option = "--include-path=" + FindHIPIncludePath();
    compile_params.push_back(include_option);
  }

94
  for (const auto &string : compile_params) {
95
96
    param_cstrings.push_back(string.c_str());
  }
97
98
  HIPRTC_CALL(
      hiprtcCreateProgram(&prog, code.c_str(), nullptr, 0, nullptr, nullptr));
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
  hiprtcResult compile_res =
      hiprtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data());

  size_t log_size;
  HIPRTC_CALL(hiprtcGetProgramLogSize(prog, &log_size));
  std::string log;
  log.resize(log_size);
  HIPRTC_CALL(hiprtcGetProgramLog(prog, &log[0]));
  ICHECK_EQ(compile_res, HIPRTC_SUCCESS) << log;
  size_t code_size;
  HIPRTC_CALL(hiprtcGetCodeSize(prog, &code_size));

  std::string code_out;
  code_out.resize(code_size);
  HIPRTC_CALL(hiprtcGetCode(prog, &code_out[0]));
  HIPRTC_CALL(hiprtcDestroyProgram(&prog));

  return code_out;
}

119
120
static std::unordered_map<std::string, runtime::FunctionInfo>
ExtractFuncInfo(const IRModule &mod) {
121
122
123
  std::unordered_map<std::string, runtime::FunctionInfo> fmap;

  for (auto kv : mod->functions) {
124
125
    ICHECK(kv.second->IsInstance<tir::PrimFuncNode>())
        << "Can only lower IR Module with PrimFuncs";
126
127
128
129
130
131
132
133
134
135
136
137
138
139
    auto f = Downcast<tir::PrimFunc>(kv.second);

    runtime::FunctionInfo info;
    for (size_t i = 0; i < f->params.size(); ++i) {
      if (f->params[i]->dtype.is_handle()) {
        auto ptr = f->params[i]->type_annotation.as<PointerTypeNode>();
        if (ptr && ptr->storage_scope == "grid_constant") {
          info.arg_types.push_back(DataType(kTVMGridConstant, 64, 1));
          continue;
        }
      }
      info.arg_types.push_back(f->params[i].dtype());
    }
    if (auto opt = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams)) {
140
      for (const auto &tag : opt.value()) {
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        info.launch_param_tags.push_back(tag);
      }
    }
    auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
    fmap[static_cast<std::string>(global_symbol.value())] = info;
  }
  return fmap;
}

runtime::Module BuildTileLangHIP(IRModule mod, Target target) {
  using tvm::runtime::Registry;
  bool output_ssa = false;
  CodeGenTileLangHIP cg;
  cg.Init(output_ssa);

  for (auto kv : mod->functions) {
157
158
    ICHECK(kv.second->IsInstance<PrimFuncNode>())
        << "CodeGenTileLangHIP: Can only take PrimFunc";
159
160
161
162
163
164
165
    auto f = Downcast<PrimFunc>(kv.second);
    auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
    ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch);
    cg.AddFunction(f);
  }

  std::string code = cg.Finish();
166
  if (const auto *f = Registry::Get("tilelang_callback_hip_postproc")) {
167
168
169
170
    code = (*f)(code, target).operator std::string();
  }
  std::string fmt = "ptx";
  std::string ptx;
171
  if (const auto *f = Registry::Get("tilelang_callback_hip_compile")) {
172
    ptx = (*f)(code, target).operator std::string();
173
174
    if (ptx[0] != '/')
      fmt = "hsaco";
175
176
177
178
179
180
  } else {
    ptx = HIPRTCCompile(code, false);
  }
  return ROCMModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code, std::string());
}

181
182
TVM_REGISTER_GLOBAL("target.build.tilelang_hip")
    .set_body_typed(BuildTileLangHIP);
183

184
185
} // namespace codegen
} // namespace tvm