rt_mod_cuda.cc 3.81 KB
Newer Older
1
#include "codegen_cuda.h"
2
#include "runtime/cuda/cuda_module.h"
3
4
#include "runtime/pack_args.h"
#include <tvm/ffi/reflection/registry.h>
5
#include <tvm/ir/transform.h>
6
7
8
9

namespace tvm {
namespace codegen {

10
11
static std::unordered_map<std::string, runtime::FunctionInfo>
ExtractFuncInfo(const IRModule &mod) {
12
13
14
  std::unordered_map<std::string, runtime::FunctionInfo> fmap;

  for (auto kv : mod->functions) {
15
16
    ICHECK(kv.second->IsInstance<tir::PrimFuncNode>())
        << "Can only lower IR Module with PrimFuncs";
17
18
19
20
21
22
23
    auto f = Downcast<tir::PrimFunc>(kv.second);

    runtime::FunctionInfo info;
    for (size_t i = 0; i < f->params.size(); ++i) {
      if (f->params[i]->dtype.is_handle()) {
        auto ptr = f->params[i]->type_annotation.as<PointerTypeNode>();
        if (ptr && ptr->storage_scope == "grid_constant") {
24
          info.arg_types.push_back(DataType(runtime::kDLGridConstant, 64, 1));
25
26
27
          continue;
        }
      }
28
29
30
31
32
      DataType dtype = f->params[i].dtype();
      // Device runtime cannot directly take bool arguments, map to int32.
      if (dtype.is_bool())
        dtype = DataType::Int(32);
      info.arg_types.push_back(dtype);
33
    }
34
35
    if (auto opt = f->GetAttr<ffi::Array<ffi::String>>(
            tir::attr::kKernelLaunchParams)) {
36
      for (const auto &tag : opt.value()) {
37
38
39
        info.launch_param_tags.push_back(tag);
      }
    }
40
    auto global_symbol = f->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol);
41
42
43
44
45
    fmap[static_cast<std::string>(global_symbol.value())] = info;
  }
  return fmap;
}

46
ffi::Module BuildTileLangCUDA(IRModule mod, Target target) {
47
48
49
50
51
  bool output_ssa = false;
  CodeGenTileLangCUDA cg;
  cg.Init(output_ssa);

  for (auto kv : mod->functions) {
52
53
    ICHECK(kv.second->IsInstance<PrimFuncNode>())
        << "CodeGenTileLangCUDA: Can only take PrimFunc";
54
    auto gvar = Downcast<GlobalVar>(kv.first);
55
56
57
    auto f = Downcast<PrimFunc>(kv.second);
    auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
    ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch);
58
    cg.AddFunction(gvar, f);
59
60
61
  }

  std::string code = cg.Finish();
62
63
64
  if (const auto f =
          ffi::Function::GetGlobal("tilelang_callback_cuda_postproc")) {
    code = (*f)(code, target).cast<std::string>();
65
66
67
  }
  std::string fmt = "ptx";
  std::string ptx;
68
69
  if (const auto f =
          ffi::Function::GetGlobal("tilelang_callback_cuda_compile")) {
70
71
72
73
    // Fetch current pass context config and pass into the compile callback
    tvm::transform::PassContext pass_ctx =
        tvm::transform::PassContext::Current();
    ptx = (*f)(code, target, pass_ctx->config).cast<std::string>();
74
75
    if (ptx[0] != '/')
      fmt = "cubin";
76
77
78
79
80
81
  } else {
    ICHECK(0);
  }
  return runtime::CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code);
}

82
ffi::Module BuildTileLangCUDAWithoutCompile(IRModule mod, Target target) {
83
84
85
86
87
  bool output_ssa = false;
  CodeGenTileLangCUDA cg;
  cg.Init(output_ssa);

  for (auto kv : mod->functions) {
88
89
    ICHECK(kv.second->IsInstance<PrimFuncNode>())
        << "CodeGenTileLangCUDA: Can only take PrimFunc";
90
    auto gvar = Downcast<GlobalVar>(kv.first);
91
92
93
    auto f = Downcast<PrimFunc>(kv.second);
    auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
    ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch);
94
    cg.AddFunction(gvar, f);
95
96
97
  }

  std::string code = cg.Finish();
98
99
100
  if (const auto f =
          ffi::Function::GetGlobal("tilelang_callback_cuda_postproc")) {
    code = (*f)(code, target).cast<std::string>();
101
  }
102
  return runtime::CUDAModuleCreate("ptx", "ptx", ExtractFuncInfo(mod), code);
103
104
}

105
TVM_FFI_STATIC_INIT_BLOCK() {
106
107
108
109
110
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef()
      .def("target.build.tilelang_cuda", BuildTileLangCUDA)
      .def("target.build.tilelang_cuda_without_compile",
           BuildTileLangCUDAWithoutCompile);
111
}
112

113
114
} // namespace codegen
} // namespace tvm