tensordispatch.cc 1.52 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
/*!
 *  Copyright (c) 2019 by Contributors
 * \file runtime/tensordispatch.cc
 * \brief Adapter library caller
 */

#include <dgl/runtime/tensordispatch.h>
#include <dgl/runtime/registry.h>
#include <dgl/packed_func_ext.h>
#if defined(WIN32) || defined(_WIN32)
#include <windows.h>
#else   // !WIN32
#include <dlfcn.h>
#endif  // WIN32
#include <cstring>

namespace dgl {
namespace runtime {

constexpr const char *TensorDispatcher::names_[];

22
bool TensorDispatcher::Load(const char *path) {
23
24
25
  CHECK(!available_) << "The tensor adapter can only load once.";

  if (path == nullptr || strlen(path) == 0)
26
    // does not have dispatcher library; all operators fall back to DGL's implementation
27
    return false;
28
29

#if defined(WIN32) || defined(_WIN32)
30
  handle_ = LoadLibrary(path);
31
32

  if (!handle_)
33
    return false;
34

35
  for (int i = 0; i < num_entries_; ++i) {
36
    entrypoints_[i] = reinterpret_cast<void*>(GetProcAddress(handle_, names_[i]));
37
38
    CHECK(entrypoints_[i]) << "cannot locate symbol " << names_[i];
  }
39
#else   // !WIN32
40
41
  handle_ = dlopen(path, RTLD_LAZY);

42
  if (!handle_)
43
    return false;
44
45

  for (int i = 0; i < num_entries_; ++i) {
46
    entrypoints_[i] = dlsym(handle_, names_[i]);
47
48
    CHECK(entrypoints_[i]) << "cannot locate symbol " << names_[i];
  }
49
50
51
#endif  // WIN32

  available_ = true;
52
  return true;
53
54
55
56
57
58
59
60
61
62
63
64
65
66
}

TensorDispatcher::~TensorDispatcher() {
  if (handle_) {
#if defined(WIN32) || defined(_WIN32)
    FreeLibrary(handle_);
#else   // !WIN32
    dlclose(handle_);
#endif  // WIN32
  }
}

};  // namespace runtime
};  // namespace dgl