"vscode:/vscode.git/clone" did not exist on "468b1b70148e3f0a8c12fa399c380707cb33a716"
rt_mod_cuda.cc 3.27 KB
Newer Older
1
2
3
4
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

#include "codegen_cuda.h"
5
#include "runtime/cuda/cuda_module.h"
6
7
8
9

namespace tvm {
namespace codegen {

10
11
static std::unordered_map<std::string, runtime::FunctionInfo>
ExtractFuncInfo(const IRModule &mod) {
12
13
14
  std::unordered_map<std::string, runtime::FunctionInfo> fmap;

  for (auto kv : mod->functions) {
15
16
    ICHECK(kv.second->IsInstance<tir::PrimFuncNode>())
        << "Can only lower IR Module with PrimFuncs";
17
18
19
20
21
22
23
24
25
26
27
28
29
30
    auto f = Downcast<tir::PrimFunc>(kv.second);

    runtime::FunctionInfo info;
    for (size_t i = 0; i < f->params.size(); ++i) {
      if (f->params[i]->dtype.is_handle()) {
        auto ptr = f->params[i]->type_annotation.as<PointerTypeNode>();
        if (ptr && ptr->storage_scope == "grid_constant") {
          info.arg_types.push_back(DataType(kTVMGridConstant, 64, 1));
          continue;
        }
      }
      info.arg_types.push_back(f->params[i].dtype());
    }
    if (auto opt = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams)) {
31
      for (const auto &tag : opt.value()) {
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
        info.launch_param_tags.push_back(tag);
      }
    }
    auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
    fmap[static_cast<std::string>(global_symbol.value())] = info;
  }
  return fmap;
}

runtime::Module BuildTileLangCUDA(IRModule mod, Target target) {
  using tvm::runtime::Registry;
  bool output_ssa = false;
  CodeGenTileLangCUDA cg;
  cg.Init(output_ssa);

  for (auto kv : mod->functions) {
48
49
    ICHECK(kv.second->IsInstance<PrimFuncNode>())
        << "CodeGenTileLangCUDA: Can only take PrimFunc";
50
    auto gvar = Downcast<GlobalVar>(kv.first);
51
52
53
    auto f = Downcast<PrimFunc>(kv.second);
    auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
    ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch);
54
    cg.AddFunction(gvar, f);
55
56
57
  }

  std::string code = cg.Finish();
58
  if (const auto *f = Registry::Get("tilelang_callback_cuda_postproc")) {
59
60
61
62
    code = (*f)(code, target).operator std::string();
  }
  std::string fmt = "ptx";
  std::string ptx;
63
  if (const auto *f = Registry::Get("tilelang_callback_cuda_compile")) {
64
    ptx = (*f)(code, target).operator std::string();
65
66
    if (ptx[0] != '/')
      fmt = "cubin";
67
68
69
70
71
72
73
74
75
76
77
78
79
  } else {
    ICHECK(0);
  }
  return runtime::CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code);
}

String BuildTLDebug(IRModule mod, Target target) {
  using tvm::runtime::Registry;
  bool output_ssa = false;
  CodeGenTileLangCUDA cg;
  cg.Init(output_ssa);

  for (auto kv : mod->functions) {
80
81
    ICHECK(kv.second->IsInstance<PrimFuncNode>())
        << "CodeGenTileLangCUDA: Can only take PrimFunc";
82
    auto gvar = Downcast<GlobalVar>(kv.first);
83
84
85
    auto f = Downcast<PrimFunc>(kv.second);
    auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
    ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch);
86
    cg.AddFunction(gvar, f);
87
88
89
  }

  std::string code = cg.Finish();
90
  if (const auto *f = Registry::Get("tilelang_callback_cuda_postproc")) {
91
92
93
94
95
    code = (*f)(code, target).operator std::string();
  }
  return String(code);
}

96
97
98
99
TVM_REGISTER_GLOBAL("target.build.tilelang_cuda")
    .set_body_typed(BuildTileLangCUDA);
TVM_REGISTER_GLOBAL("target.build.tl_debug_codegen")
    .set_body_typed(BuildTLDebug);
100

101
102
} // namespace codegen
} // namespace tvm