"tests/python/pytorch/nn/test_nn.py" did not exist on "14bffe97286030a9efd1cc1a0832c7fc21413fbe"
featgraph.cc 2.2 KB
Newer Older
1
/**
Zhi Lin's avatar
Zhi Lin committed
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file featgraph/src/featgraph.cc
 * @brief FeatGraph kernels.
Zhi Lin's avatar
Zhi Lin committed
5
 */
6
7
#include <dmlc/logging.h>
#include <featgraph.h>
Zhi Lin's avatar
Zhi Lin committed
8
9
10
11
12
13
14
#include <tvm/runtime/module.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>

namespace dgl {
namespace featgraph {

15
/* @brief Singleton that loads the featgraph module. */
Zhi Lin's avatar
Zhi Lin committed
16
class FeatGraphModule {
17
 public:
Zhi Lin's avatar
Zhi Lin committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
  static FeatGraphModule* Global() {
    static FeatGraphModule inst;
    return &inst;
  }

  void Load(const std::string& path) {
    mod = tvm::runtime::Module::LoadFromFile(path);
  }

  inline tvm::runtime::ModuleNode* Get() {
    auto ret = mod.operator->();
    if (!ret) {
      LOG(FATAL) << "FeatGraph module have not been loaded. "
                 << "Please set path of featgraph shared library.";
    }
    return ret;
  }
35
36

 private:
Zhi Lin's avatar
Zhi Lin committed
37
38
39
40
  tvm::runtime::Module mod;
  FeatGraphModule() {}
};

41
/* @brief Load Featgraph module from given path. */
Zhi Lin's avatar
Zhi Lin committed
42
43
44
45
void LoadFeatGraphModule(const std::string& path) {
  FeatGraphModule::Global()->Load(path);
}

46
/* @brief Convert DLDataType to string. */
Zhi Lin's avatar
Zhi Lin committed
47
inline std::string DTypeAsStr(const DLDataType& t) {
48
49
50
51
52
53
54
55
56
57
58
  switch (t.code) {
    case 0U:
      return "int" + std::to_string(t.bits);
    case 1U:
      return "uint" + std::to_string(t.bits);
    case 2U:
      return "float" + std::to_string(t.bits);
    case 3U:
      return "bfloat" + std::to_string(t.bits);
    default:
      LOG(FATAL) << "Type code " << t.code << " not recognized";
Zhi Lin's avatar
Zhi Lin committed
59
60
61
  }
}

62
/* @brief Get operator filename. */
Zhi Lin's avatar
Zhi Lin committed
63
inline std::string GetOperatorName(
64
    const std::string& base_name, const DLDataType& dtype,
Zhi Lin's avatar
Zhi Lin committed
65
66
67
68
    const DLDataType& idtype) {
  return base_name + "_" + DTypeAsStr(dtype) + "_" + DTypeAsStr(idtype);
}

69
/* @brief Call FeatGraph's SDDMM kernel. */
70
71
72
void SDDMMTreeReduction(
    DLManagedTensor* row, DLManagedTensor* col, DLManagedTensor* lhs,
    DLManagedTensor* rhs, DLManagedTensor* out) {
Zhi Lin's avatar
Zhi Lin committed
73
  tvm::runtime::ModuleNode* mod = FeatGraphModule::Global()->Get();
74
75
  std::string f_name = GetOperatorName(
      "SDDMMTreeReduction", (row->dl_tensor).dtype, (lhs->dl_tensor).dtype);
Zhi Lin's avatar
Zhi Lin committed
76
  tvm::runtime::PackedFunc f = mod->GetFunction(f_name);
77
  if (f != nullptr) f(row, col, lhs, rhs, out);
Zhi Lin's avatar
Zhi Lin committed
78
79
80
81
}

}  // namespace featgraph
}  // namespace dgl