Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
e8054701
Unverified
Commit
e8054701
authored
Dec 26, 2020
by
Quan (Andy) Gan
Committed by
GitHub
Dec 26, 2020
Browse files
addressing post-merge comments (#2455)
parent
0018e90c
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
32 additions
and
43 deletions
+32
-43
include/dgl/runtime/c_runtime_api.h
include/dgl/runtime/c_runtime_api.h
+2
-2
include/dgl/runtime/env.h
include/dgl/runtime/env.h
+0
-24
include/dgl/runtime/tensordispatch.h
include/dgl/runtime/tensordispatch.h
+8
-1
python/dgl/_ffi/base.py
python/dgl/_ffi/base.py
+3
-3
python/dgl/backend/__init__.py
python/dgl/backend/__init__.py
+2
-2
src/runtime/c_runtime_api.cc
src/runtime/c_runtime_api.cc
+3
-3
src/runtime/tensordispatch.cc
src/runtime/tensordispatch.cc
+14
-8
No files found.
include/dgl/runtime/c_runtime_api.h
View file @
e8054701
...
@@ -541,9 +541,9 @@ DGL_DLL int DGLStreamStreamSynchronize(int device_type,
...
@@ -541,9 +541,9 @@ DGL_DLL int DGLStreamStreamSynchronize(int device_type,
DGLStreamHandle
dst
);
DGLStreamHandle
dst
);
/*!
/*!
* \brief
Sets the path to the
tensoradapter
library
* \brief
Load
tensor
adapter
.
*/
*/
DGL_DLL
void
DGL
SetTAPath
(
const
char
*
path
_cstr
);
DGL_DLL
void
DGL
LoadTensorAdapter
(
const
char
*
path
);
/*!
/*!
* \brief Bug report macro.
* \brief Bug report macro.
...
...
include/dgl/runtime/env.h
deleted
100644 → 0
View file @
0018e90c
/*!
* Copyright (c) 2017 by Contributors
* \file dgl/runtime/env.h
* \brief Structure for holding DGL global environment variables
*/
#ifndef DGL_RUNTIME_ENV_H_
#define DGL_RUNTIME_ENV_H_
#include <string>
/*!
* \brief Global environment variables.
*/
struct
Env
{
static
Env
*
Global
()
{
static
Env
inst
;
return
&
inst
;
}
/*! \brief the path to the tensoradapter library */
std
::
string
ta_path
;
};
#endif // DGL_RUNTIME_ENV_H_
include/dgl/runtime/tensordispatch.h
View file @
e8054701
...
@@ -48,6 +48,8 @@ namespace runtime {
...
@@ -48,6 +48,8 @@ namespace runtime {
/*!
/*!
* \brief Dispatcher that delegates the function calls to framework-specific C++ APIs.
* \brief Dispatcher that delegates the function calls to framework-specific C++ APIs.
*
* This class is not thread-safe.
*/
*/
class
TensorDispatcher
{
class
TensorDispatcher
{
public:
public:
...
@@ -62,6 +64,9 @@ class TensorDispatcher {
...
@@ -62,6 +64,9 @@ class TensorDispatcher {
return
available_
;
return
available_
;
}
}
/*! \brief Load symbols from the given tensor adapter library path */
void
Load
(
const
char
*
path_cstr
);
/*!
/*!
* \brief Allocate an empty tensor.
* \brief Allocate an empty tensor.
*
*
...
@@ -75,7 +80,7 @@ class TensorDispatcher {
...
@@ -75,7 +80,7 @@ class TensorDispatcher {
private:
private:
/*! \brief ctor */
/*! \brief ctor */
TensorDispatcher
();
TensorDispatcher
()
=
default
;
/*! \brief dtor */
/*! \brief dtor */
~
TensorDispatcher
();
~
TensorDispatcher
();
...
@@ -111,4 +116,6 @@ class TensorDispatcher {
...
@@ -111,4 +116,6 @@ class TensorDispatcher {
};
// namespace runtime
};
// namespace runtime
};
// namespace dgl
};
// namespace dgl
#undef FUNCCAST
#endif // DGL_RUNTIME_TENSORDISPATCH_H_
#endif // DGL_RUNTIME_TENSORDISPATCH_H_
python/dgl/_ffi/base.py
View file @
e8054701
...
@@ -113,8 +113,8 @@ def decorate(func, fwrapped):
...
@@ -113,8 +113,8 @@ def decorate(func, fwrapped):
return
decorator
.
decorate
(
func
,
fwrapped
)
return
decorator
.
decorate
(
func
,
fwrapped
)
def
set_ta_path
(
backend
,
version
):
def
load_tensor_adapter
(
backend
,
version
):
"""Tell DGL
which
tensoradapter library
to look for symbols
.
"""Tell DGL
to load a
tensoradapter library
for given backend and version
.
Parameters
Parameters
----------
----------
...
@@ -133,4 +133,4 @@ def set_ta_path(backend, version):
...
@@ -133,4 +133,4 @@ def set_ta_path(backend, version):
else
:
else
:
raise
NotImplementedError
(
'Unsupported system: %s'
%
sys
.
platform
)
raise
NotImplementedError
(
'Unsupported system: %s'
%
sys
.
platform
)
path
=
os
.
path
.
join
(
_DIR_NAME
,
'tensoradapter'
,
backend
,
basename
)
path
=
os
.
path
.
join
(
_DIR_NAME
,
'tensoradapter'
,
backend
,
basename
)
_LIB
.
DGL
SetTAPath
(
path
.
encode
(
'utf-8'
))
_LIB
.
DGL
LoadTensorAdapter
(
path
.
encode
(
'utf-8'
))
python/dgl/backend/__init__.py
View file @
e8054701
...
@@ -38,9 +38,9 @@ def load_backend(mod_name):
...
@@ -38,9 +38,9 @@ def load_backend(mod_name):
else
:
else
:
raise
NotImplementedError
(
'Unsupported backend: %s'
%
mod_name
)
raise
NotImplementedError
(
'Unsupported backend: %s'
%
mod_name
)
from
.._ffi.base
import
set_ta_path
# imports DGL C library
from
.._ffi.base
import
load_tensor_adapter
# imports DGL C library
version
=
mod
.
__version__
version
=
mod
.
__version__
set_ta_path
(
mod_name
,
version
)
load_tensor_adapter
(
mod_name
,
version
)
print
(
'Using backend: %s'
%
mod_name
,
file
=
sys
.
stderr
)
print
(
'Using backend: %s'
%
mod_name
,
file
=
sys
.
stderr
)
mod
=
importlib
.
import_module
(
'.%s'
%
mod_name
,
__name__
)
mod
=
importlib
.
import_module
(
'.%s'
%
mod_name
,
__name__
)
...
...
src/runtime/c_runtime_api.cc
View file @
e8054701
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
#include <dgl/runtime/module.h>
#include <dgl/runtime/module.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/en
v
.h>
#include <dgl/runtime/
t
en
sordispatch
.h>
#include <array>
#include <array>
#include <algorithm>
#include <algorithm>
#include <string>
#include <string>
...
@@ -379,8 +379,8 @@ int DGLCbArgToReturn(DGLValue* value, int code) {
...
@@ -379,8 +379,8 @@ int DGLCbArgToReturn(DGLValue* value, int code) {
API_END
();
API_END
();
}
}
void
DGL
SetTAPath
(
const
char
*
path
_cstr
)
{
void
DGL
LoadTensorAdapter
(
const
char
*
path
)
{
Env
::
Global
()
->
ta_path
=
std
::
string
(
path_cstr
);
TensorDispatcher
::
Global
()
->
Load
(
path
);
}
}
// set device api
// set device api
...
...
src/runtime/tensordispatch.cc
View file @
e8054701
...
@@ -6,7 +6,6 @@
...
@@ -6,7 +6,6 @@
#include <dgl/runtime/tensordispatch.h>
#include <dgl/runtime/tensordispatch.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/env.h>
#include <dgl/packed_func_ext.h>
#include <dgl/packed_func_ext.h>
#if defined(WIN32) || defined(_WIN32)
#if defined(WIN32) || defined(_WIN32)
#include <windows.h>
#include <windows.h>
...
@@ -20,26 +19,33 @@ namespace runtime {
...
@@ -20,26 +19,33 @@ namespace runtime {
constexpr
const
char
*
TensorDispatcher
::
names_
[];
constexpr
const
char
*
TensorDispatcher
::
names_
[];
TensorDispatcher
::
TensorDispatcher
()
{
void
TensorDispatcher
::
Load
(
const
char
*
path
)
{
const
std
::
string
&
path
=
Env
::
Global
()
->
ta_path
;
CHECK
(
!
available_
)
<<
"The tensor adapter can only load once."
;
if
(
path
==
""
)
if
(
path
==
nullptr
||
strlen
(
path
)
==
0
)
// does not have dispatcher library; all operators fall back to DGL's implementation
// does not have dispatcher library; all operators fall back to DGL's implementation
return
;
return
;
#if defined(WIN32) || defined(_WIN32)
#if defined(WIN32) || defined(_WIN32)
handle_
=
LoadLibrary
(
path
.
c_str
()
);
handle_
=
LoadLibrary
(
path
);
if
(
!
handle_
)
if
(
!
handle_
)
return
;
return
;
for
(
int
i
=
0
;
i
<
num_entries_
;
++
i
)
for
(
int
i
=
0
;
i
<
num_entries_
;
++
i
)
{
entrypoints_
[
i
]
=
reinterpret_cast
<
void
*>
(
GetProcAddress
(
handle_
,
names_
[
i
]));
entrypoints_
[
i
]
=
reinterpret_cast
<
void
*>
(
GetProcAddress
(
handle_
,
names_
[
i
]));
CHECK
(
entrypoints_
[
i
])
<<
"cannot locate symbol "
<<
names_
[
i
];
}
#else // !WIN32
#else // !WIN32
handle_
=
dlopen
(
path
.
c_str
(),
RTLD_LAZY
);
handle_
=
dlopen
(
path
,
RTLD_LAZY
);
if
(
!
handle_
)
if
(
!
handle_
)
return
;
return
;
for
(
int
i
=
0
;
i
<
num_entries_
;
++
i
)
for
(
int
i
=
0
;
i
<
num_entries_
;
++
i
)
{
entrypoints_
[
i
]
=
dlsym
(
handle_
,
names_
[
i
]);
entrypoints_
[
i
]
=
dlsym
(
handle_
,
names_
[
i
]);
CHECK
(
entrypoints_
[
i
])
<<
"cannot locate symbol "
<<
names_
[
i
];
}
#endif // WIN32
#endif // WIN32
available_
=
true
;
available_
=
true
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment