Unverified Commit e8054701 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

addressing post-merge comments (#2455)

parent 0018e90c
......@@ -541,9 +541,9 @@ DGL_DLL int DGLStreamStreamSynchronize(int device_type,
DGLStreamHandle dst);
/*!
* \brief Sets the path to the tensoradapter library
* \brief Load tensor adapter.
*/
DGL_DLL void DGLSetTAPath(const char *path_cstr);
DGL_DLL void DGLLoadTensorAdapter(const char *path);
/*!
* \brief Bug report macro.
......
/*!
* Copyright (c) 2017 by Contributors
* \file dgl/runtime/env.h
* \brief Structure for holding DGL global environment variables
*/
#ifndef DGL_RUNTIME_ENV_H_
#define DGL_RUNTIME_ENV_H_
#include <string>
/*!
* \brief Global environment variables.
*/
struct Env {
static Env* Global() {
static Env inst;
return &inst;
}
/*! \brief the path to the tensoradapter library */
std::string ta_path;
};
#endif // DGL_RUNTIME_ENV_H_
......@@ -48,6 +48,8 @@ namespace runtime {
/*!
* \brief Dispatcher that delegates the function calls to framework-specific C++ APIs.
*
* This class is not thread-safe.
*/
class TensorDispatcher {
public:
......@@ -62,6 +64,9 @@ class TensorDispatcher {
return available_;
}
/*! \brief Load symbols from the given tensor adapter library path */
void Load(const char *path_cstr);
/*!
* \brief Allocate an empty tensor.
*
......@@ -75,7 +80,7 @@ class TensorDispatcher {
private:
/*! \brief ctor */
TensorDispatcher();
TensorDispatcher() = default;
/*! \brief dtor */
~TensorDispatcher();
......@@ -111,4 +116,6 @@ class TensorDispatcher {
}; // namespace runtime
}; // namespace dgl
#undef FUNCCAST
#endif // DGL_RUNTIME_TENSORDISPATCH_H_
......@@ -113,8 +113,8 @@ def decorate(func, fwrapped):
return decorator.decorate(func, fwrapped)
def set_ta_path(backend, version):
"""Tell DGL which tensoradapter library to look for symbols.
def load_tensor_adapter(backend, version):
"""Tell DGL to load a tensoradapter library for given backend and version.
Parameters
----------
......@@ -133,4 +133,4 @@ def set_ta_path(backend, version):
else:
raise NotImplementedError('Unsupported system: %s' % sys.platform)
path = os.path.join(_DIR_NAME, 'tensoradapter', backend, basename)
_LIB.DGLSetTAPath(path.encode('utf-8'))
_LIB.DGLLoadTensorAdapter(path.encode('utf-8'))
......@@ -38,9 +38,9 @@ def load_backend(mod_name):
else:
raise NotImplementedError('Unsupported backend: %s' % mod_name)
from .._ffi.base import set_ta_path # imports DGL C library
from .._ffi.base import load_tensor_adapter # imports DGL C library
version = mod.__version__
set_ta_path(mod_name, version)
load_tensor_adapter(mod_name, version)
print('Using backend: %s' % mod_name, file=sys.stderr)
mod = importlib.import_module('.%s' % mod_name, __name__)
......
......@@ -10,7 +10,7 @@
#include <dgl/runtime/module.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/env.h>
#include <dgl/runtime/tensordispatch.h>
#include <array>
#include <algorithm>
#include <string>
......@@ -379,8 +379,8 @@ int DGLCbArgToReturn(DGLValue* value, int code) {
API_END();
}
void DGLSetTAPath(const char *path_cstr) {
Env::Global()->ta_path = std::string(path_cstr);
void DGLLoadTensorAdapter(const char *path) {
TensorDispatcher::Global()->Load(path);
}
// set device api
......
......@@ -6,7 +6,6 @@
#include <dgl/runtime/tensordispatch.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/env.h>
#include <dgl/packed_func_ext.h>
#if defined(WIN32) || defined(_WIN32)
#include <windows.h>
......@@ -20,26 +19,33 @@ namespace runtime {
constexpr const char *TensorDispatcher::names_[];
TensorDispatcher::TensorDispatcher() {
const std::string& path = Env::Global()->ta_path;
if (path == "")
void TensorDispatcher::Load(const char *path) {
CHECK(!available_) << "The tensor adapter can only load once.";
if (path == nullptr || strlen(path) == 0)
// does not have dispatcher library; all operators fall back to DGL's implementation
return;
#if defined(WIN32) || defined(_WIN32)
handle_ = LoadLibrary(path.c_str());
handle_ = LoadLibrary(path);
if (!handle_)
return;
for (int i = 0; i < num_entries_; ++i)
for (int i = 0; i < num_entries_; ++i) {
entrypoints_[i] = reinterpret_cast<void*>(GetProcAddress(handle_, names_[i]));
CHECK(entrypoints_[i]) << "cannot locate symbol " << names_[i];
}
#else // !WIN32
handle_ = dlopen(path.c_str(), RTLD_LAZY);
handle_ = dlopen(path, RTLD_LAZY);
if (!handle_)
return;
for (int i = 0; i < num_entries_; ++i)
for (int i = 0; i < num_entries_; ++i) {
entrypoints_[i] = dlsym(handle_, names_[i]);
CHECK(entrypoints_[i]) << "cannot locate symbol " << names_[i];
}
#endif // WIN32
available_ = true;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment