module.cc 5.44 KB
Newer Older
1
/**
Minjie Wang's avatar
Minjie Wang committed
2
 *  Copyright (c) 2017 by Contributors
3
4
 * @file module.cc
 * @brief DGL module system
Minjie Wang's avatar
Minjie Wang committed
5
6
7
 */
#include <dgl/runtime/module.h>
#include <dgl/runtime/packed_func.h>
8
9
#include <dgl/runtime/registry.h>

Minjie Wang's avatar
Minjie Wang committed
10
#include <cstring>
11
#include <unordered_set>
Minjie Wang's avatar
Minjie Wang committed
12
13
14
15
#ifndef _LIBCPP_SGX_CONFIG
#include "file_util.h"
#endif

16
namespace dgl {
Minjie Wang's avatar
Minjie Wang committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
namespace runtime {

void Module::Import(Module other) {
  // specially handle rpc
  if (!std::strcmp((*this)->type_key(), "rpc")) {
    static const PackedFunc* fimport_ = nullptr;
    if (fimport_ == nullptr) {
      fimport_ = runtime::Registry::Get("rpc._ImportRemoteModule");
      CHECK(fimport_ != nullptr);
    }
    (*fimport_)(*this, other);
    return;
  }
  // cyclic detection.
  std::unordered_set<const ModuleNode*> visited{other.node_.get()};
  std::vector<const ModuleNode*> stack{other.node_.get()};
  while (!stack.empty()) {
    const ModuleNode* n = stack.back();
    stack.pop_back();
    for (const Module& m : n->imports_) {
      const ModuleNode* next = m.node_.get();
      if (visited.count(next)) continue;
      visited.insert(next);
      stack.push_back(next);
    }
  }
  CHECK(!visited.count(node_.get()))
      << "Cyclic dependency detected during import";
  node_->imports_.emplace_back(std::move(other));
}

48
49
Module Module::LoadFromFile(
    const std::string& file_name, const std::string& format) {
Minjie Wang's avatar
Minjie Wang committed
50
51
#ifndef _LIBCPP_SGX_CONFIG
  std::string fmt = GetFileFormat(file_name, format);
52
  CHECK(fmt.length() != 0) << "Cannot deduce format of file " << file_name;
Minjie Wang's avatar
Minjie Wang committed
53
54
55
56
57
  if (fmt == "dll" || fmt == "dylib" || fmt == "dso") {
    fmt = "so";
  }
  std::string load_f_name = "module.loadfile_" + fmt;
  const PackedFunc* f = Registry::Get(load_f_name);
58
59
  CHECK(f != nullptr) << "Loader of " << format << "(" << load_f_name
                      << ") is not presented.";
Minjie Wang's avatar
Minjie Wang committed
60
61
62
63
64
65
66
  Module m = (*f)(file_name, format);
  return m;
#else
  LOG(FATAL) << "SGX does not support LoadFromFile";
#endif
}

67
68
void ModuleNode::SaveToFile(
    const std::string& file_name, const std::string& format) {
Minjie Wang's avatar
Minjie Wang committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
  LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToFile";
}

void ModuleNode::SaveToBinary(dmlc::Stream* stream) {
  LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToBinary";
}

std::string ModuleNode::GetSource(const std::string& format) {
  LOG(FATAL) << "Module[" << type_key() << "] does not support GetSource";
  return "";
}

const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) {
  auto it = import_cache_.find(name);
  if (it != import_cache_.end()) return it->second.get();
  PackedFunc pf;
  for (Module& m : this->imports_) {
    pf = m.GetFunction(name, false);
    if (pf != nullptr) break;
  }
  if (pf == nullptr) {
    const PackedFunc* f = Registry::Get(name);
91
92
    CHECK(f != nullptr) << "Cannot find function " << name
                        << " in the imported modules or global registry";
Minjie Wang's avatar
Minjie Wang committed
93
94
95
96
97
98
99
100
101
102
103
104
105
    return f;
  } else {
    std::unique_ptr<PackedFunc> f(new PackedFunc(pf));
    import_cache_[name] = std::move(f);
    return import_cache_.at(name).get();
  }
}

bool RuntimeEnabled(const std::string& target) {
  std::string f_name;
  if (target == "cpu") {
    return true;
  } else if (target == "cuda" || target == "gpu") {
106
    f_name = "device_api.cuda";
Minjie Wang's avatar
Minjie Wang committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
  } else if (target == "cl" || target == "opencl" || target == "sdaccel") {
    f_name = "device_api.opencl";
  } else if (target == "gl" || target == "opengl") {
    f_name = "device_api.opengl";
  } else if (target == "mtl" || target == "metal") {
    f_name = "device_api.metal";
  } else if (target == "vulkan") {
    f_name = "device_api.vulkan";
  } else if (target == "stackvm") {
    f_name = "codegen.build_stackvm";
  } else if (target == "rpc") {
    f_name = "device_api.rpc";
  } else if (target == "vpi" || target == "verilog") {
    f_name = "device_api.vpi";
  } else if (target.length() >= 5 && target.substr(0, 5) == "nvptx") {
122
    f_name = "device_api.cuda";
Minjie Wang's avatar
Minjie Wang committed
123
124
125
  } else if (target.length() >= 4 && target.substr(0, 4) == "rocm") {
    f_name = "device_api.rocm";
  } else if (target.length() >= 4 && target.substr(0, 4) == "llvm") {
126
127
    const PackedFunc* pf =
        runtime::Registry::Get("codegen.llvm_target_enabled");
Minjie Wang's avatar
Minjie Wang committed
128
129
130
131
132
133
134
135
    if (pf == nullptr) return false;
    return (*pf)(target);
  } else {
    LOG(FATAL) << "Unknown optional runtime " << target;
  }
  return runtime::Registry::Get(f_name) != nullptr;
}

136
DGL_REGISTER_GLOBAL("module._Enabled")
137
138
    .set_body([](DGLArgs args, DGLRetValue* ret) {
      *ret = RuntimeEnabled(args[0]);
Minjie Wang's avatar
Minjie Wang committed
139
140
    });

141
DGL_REGISTER_GLOBAL("module._GetSource")
142
143
    .set_body([](DGLArgs args, DGLRetValue* ret) {
      *ret = args[0].operator Module()->GetSource(args[1]);
Minjie Wang's avatar
Minjie Wang committed
144
145
    });

146
DGL_REGISTER_GLOBAL("module._ImportsSize")
147
148
    .set_body([](DGLArgs args, DGLRetValue* ret) {
      *ret = static_cast<int64_t>(args[0].operator Module()->imports().size());
Minjie Wang's avatar
Minjie Wang committed
149
150
    });

151
DGL_REGISTER_GLOBAL("module._GetImport")
152
153
    .set_body([](DGLArgs args, DGLRetValue* ret) {
      *ret = args[0].operator Module()->imports().at(args[1].operator int());
Minjie Wang's avatar
Minjie Wang committed
154
155
    });

156
DGL_REGISTER_GLOBAL("module._GetTypeKey")
157
158
    .set_body([](DGLArgs args, DGLRetValue* ret) {
      *ret = std::string(args[0].operator Module()->type_key());
Minjie Wang's avatar
Minjie Wang committed
159
160
    });

161
DGL_REGISTER_GLOBAL("module._LoadFromFile")
162
163
    .set_body([](DGLArgs args, DGLRetValue* ret) {
      *ret = Module::LoadFromFile(args[0], args[1]);
Minjie Wang's avatar
Minjie Wang committed
164
165
    });

166
DGL_REGISTER_GLOBAL("module._SaveToFile")
167
168
    .set_body([](DGLArgs args, DGLRetValue* ret) {
      args[0].operator Module()->SaveToFile(args[1], args[2]);
Minjie Wang's avatar
Minjie Wang committed
169
170
    });
}  // namespace runtime
171
}  // namespace dgl