"docs/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "17aab8128a43c624695478b777ae50744d6b18d6"
Unverified Commit e8054701 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

addressing post-merge comments (#2455)

parent 0018e90c
...@@ -541,9 +541,9 @@ DGL_DLL int DGLStreamStreamSynchronize(int device_type, ...@@ -541,9 +541,9 @@ DGL_DLL int DGLStreamStreamSynchronize(int device_type,
DGLStreamHandle dst); DGLStreamHandle dst);
/*! /*!
* \brief Sets the path to the tensoradapter library * \brief Load tensor adapter.
*/ */
DGL_DLL void DGLSetTAPath(const char *path_cstr); DGL_DLL void DGLLoadTensorAdapter(const char *path);
/*! /*!
* \brief Bug report macro. * \brief Bug report macro.
......
/*!
* Copyright (c) 2017 by Contributors
* \file dgl/runtime/env.h
* \brief Structure for holding DGL global environment variables
*/
#ifndef DGL_RUNTIME_ENV_H_
#define DGL_RUNTIME_ENV_H_
#include <string>
/*!
* \brief Global environment variables.
*/
struct Env {
static Env* Global() {
static Env inst;
return &inst;
}
/*! \brief the path to the tensoradapter library */
std::string ta_path;
};
#endif // DGL_RUNTIME_ENV_H_
...@@ -48,6 +48,8 @@ namespace runtime { ...@@ -48,6 +48,8 @@ namespace runtime {
/*! /*!
* \brief Dispatcher that delegates the function calls to framework-specific C++ APIs. * \brief Dispatcher that delegates the function calls to framework-specific C++ APIs.
*
* This class is not thread-safe.
*/ */
class TensorDispatcher { class TensorDispatcher {
public: public:
...@@ -62,6 +64,9 @@ class TensorDispatcher { ...@@ -62,6 +64,9 @@ class TensorDispatcher {
return available_; return available_;
} }
/*! \brief Load symbols from the given tensor adapter library path */
void Load(const char *path_cstr);
/*! /*!
* \brief Allocate an empty tensor. * \brief Allocate an empty tensor.
* *
...@@ -75,7 +80,7 @@ class TensorDispatcher { ...@@ -75,7 +80,7 @@ class TensorDispatcher {
private: private:
/*! \brief ctor */ /*! \brief ctor */
TensorDispatcher(); TensorDispatcher() = default;
/*! \brief dtor */ /*! \brief dtor */
~TensorDispatcher(); ~TensorDispatcher();
...@@ -111,4 +116,6 @@ class TensorDispatcher { ...@@ -111,4 +116,6 @@ class TensorDispatcher {
}; // namespace runtime }; // namespace runtime
}; // namespace dgl }; // namespace dgl
#undef FUNCCAST
#endif // DGL_RUNTIME_TENSORDISPATCH_H_ #endif // DGL_RUNTIME_TENSORDISPATCH_H_
...@@ -113,8 +113,8 @@ def decorate(func, fwrapped): ...@@ -113,8 +113,8 @@ def decorate(func, fwrapped):
return decorator.decorate(func, fwrapped) return decorator.decorate(func, fwrapped)
def set_ta_path(backend, version): def load_tensor_adapter(backend, version):
"""Tell DGL which tensoradapter library to look for symbols. """Tell DGL to load a tensoradapter library for given backend and version.
Parameters Parameters
---------- ----------
...@@ -133,4 +133,4 @@ def set_ta_path(backend, version): ...@@ -133,4 +133,4 @@ def set_ta_path(backend, version):
else: else:
raise NotImplementedError('Unsupported system: %s' % sys.platform) raise NotImplementedError('Unsupported system: %s' % sys.platform)
path = os.path.join(_DIR_NAME, 'tensoradapter', backend, basename) path = os.path.join(_DIR_NAME, 'tensoradapter', backend, basename)
_LIB.DGLSetTAPath(path.encode('utf-8')) _LIB.DGLLoadTensorAdapter(path.encode('utf-8'))
...@@ -38,9 +38,9 @@ def load_backend(mod_name): ...@@ -38,9 +38,9 @@ def load_backend(mod_name):
else: else:
raise NotImplementedError('Unsupported backend: %s' % mod_name) raise NotImplementedError('Unsupported backend: %s' % mod_name)
from .._ffi.base import set_ta_path # imports DGL C library from .._ffi.base import load_tensor_adapter # imports DGL C library
version = mod.__version__ version = mod.__version__
set_ta_path(mod_name, version) load_tensor_adapter(mod_name, version)
print('Using backend: %s' % mod_name, file=sys.stderr) print('Using backend: %s' % mod_name, file=sys.stderr)
mod = importlib.import_module('.%s' % mod_name, __name__) mod = importlib.import_module('.%s' % mod_name, __name__)
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include <dgl/runtime/module.h> #include <dgl/runtime/module.h>
#include <dgl/runtime/registry.h> #include <dgl/runtime/registry.h>
#include <dgl/runtime/device_api.h> #include <dgl/runtime/device_api.h>
#include <dgl/runtime/env.h> #include <dgl/runtime/tensordispatch.h>
#include <array> #include <array>
#include <algorithm> #include <algorithm>
#include <string> #include <string>
...@@ -379,8 +379,8 @@ int DGLCbArgToReturn(DGLValue* value, int code) { ...@@ -379,8 +379,8 @@ int DGLCbArgToReturn(DGLValue* value, int code) {
API_END(); API_END();
} }
void DGLSetTAPath(const char *path_cstr) { void DGLLoadTensorAdapter(const char *path) {
Env::Global()->ta_path = std::string(path_cstr); TensorDispatcher::Global()->Load(path);
} }
// set device api // set device api
......
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
#include <dgl/runtime/tensordispatch.h> #include <dgl/runtime/tensordispatch.h>
#include <dgl/runtime/registry.h> #include <dgl/runtime/registry.h>
#include <dgl/runtime/env.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#if defined(WIN32) || defined(_WIN32) #if defined(WIN32) || defined(_WIN32)
#include <windows.h> #include <windows.h>
...@@ -20,26 +19,33 @@ namespace runtime { ...@@ -20,26 +19,33 @@ namespace runtime {
constexpr const char *TensorDispatcher::names_[]; constexpr const char *TensorDispatcher::names_[];
TensorDispatcher::TensorDispatcher() { void TensorDispatcher::Load(const char *path) {
const std::string& path = Env::Global()->ta_path; CHECK(!available_) << "The tensor adapter can only load once.";
if (path == "")
if (path == nullptr || strlen(path) == 0)
// does not have dispatcher library; all operators fall back to DGL's implementation // does not have dispatcher library; all operators fall back to DGL's implementation
return; return;
#if defined(WIN32) || defined(_WIN32) #if defined(WIN32) || defined(_WIN32)
handle_ = LoadLibrary(path.c_str()); handle_ = LoadLibrary(path);
if (!handle_) if (!handle_)
return; return;
for (int i = 0; i < num_entries_; ++i) for (int i = 0; i < num_entries_; ++i) {
entrypoints_[i] = reinterpret_cast<void*>(GetProcAddress(handle_, names_[i])); entrypoints_[i] = reinterpret_cast<void*>(GetProcAddress(handle_, names_[i]));
CHECK(entrypoints_[i]) << "cannot locate symbol " << names_[i];
}
#else // !WIN32 #else // !WIN32
handle_ = dlopen(path.c_str(), RTLD_LAZY); handle_ = dlopen(path, RTLD_LAZY);
if (!handle_) if (!handle_)
return; return;
for (int i = 0; i < num_entries_; ++i)
for (int i = 0; i < num_entries_; ++i) {
entrypoints_[i] = dlsym(handle_, names_[i]); entrypoints_[i] = dlsym(handle_, names_[i]);
CHECK(entrypoints_[i]) << "cannot locate symbol " << names_[i];
}
#endif // WIN32 #endif // WIN32
available_ = true; available_ = true;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment