Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
e8054701
"docs/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "17aab8128a43c624695478b777ae50744d6b18d6"
Unverified
Commit
e8054701
authored
Dec 26, 2020
by
Quan (Andy) Gan
Committed by
GitHub
Dec 26, 2020
Browse files
addressing post-merge comments (#2455)
parent
0018e90c
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
32 additions
and
43 deletions
+32
-43
include/dgl/runtime/c_runtime_api.h
include/dgl/runtime/c_runtime_api.h
+2
-2
include/dgl/runtime/env.h
include/dgl/runtime/env.h
+0
-24
include/dgl/runtime/tensordispatch.h
include/dgl/runtime/tensordispatch.h
+8
-1
python/dgl/_ffi/base.py
python/dgl/_ffi/base.py
+3
-3
python/dgl/backend/__init__.py
python/dgl/backend/__init__.py
+2
-2
src/runtime/c_runtime_api.cc
src/runtime/c_runtime_api.cc
+3
-3
src/runtime/tensordispatch.cc
src/runtime/tensordispatch.cc
+14
-8
No files found.
include/dgl/runtime/c_runtime_api.h
View file @
e8054701
...
@@ -541,9 +541,9 @@ DGL_DLL int DGLStreamStreamSynchronize(int device_type,
...
@@ -541,9 +541,9 @@ DGL_DLL int DGLStreamStreamSynchronize(int device_type,
DGLStreamHandle
dst
);
DGLStreamHandle
dst
);
/*!
/*!
* \brief
Sets the path to the
tensoradapter
library
* \brief
Load
tensor
adapter
.
*/
*/
DGL_DLL
void
DGL
SetTAPath
(
const
char
*
path
_cstr
);
DGL_DLL
void
DGL
LoadTensorAdapter
(
const
char
*
path
);
/*!
/*!
* \brief Bug report macro.
* \brief Bug report macro.
...
...
include/dgl/runtime/env.h
deleted
100644 → 0
View file @
0018e90c
/*!
* Copyright (c) 2017 by Contributors
* \file dgl/runtime/env.h
* \brief Structure for holding DGL global environment variables
*/
#ifndef DGL_RUNTIME_ENV_H_
#define DGL_RUNTIME_ENV_H_
#include <string>
/*!
* \brief Global environment variables.
*/
struct
Env
{
static
Env
*
Global
()
{
static
Env
inst
;
return
&
inst
;
}
/*! \brief the path to the tensoradapter library */
std
::
string
ta_path
;
};
#endif // DGL_RUNTIME_ENV_H_
include/dgl/runtime/tensordispatch.h
View file @
e8054701
...
@@ -48,6 +48,8 @@ namespace runtime {
...
@@ -48,6 +48,8 @@ namespace runtime {
/*!
/*!
* \brief Dispatcher that delegates the function calls to framework-specific C++ APIs.
* \brief Dispatcher that delegates the function calls to framework-specific C++ APIs.
*
* This class is not thread-safe.
*/
*/
class
TensorDispatcher
{
class
TensorDispatcher
{
public:
public:
...
@@ -62,6 +64,9 @@ class TensorDispatcher {
...
@@ -62,6 +64,9 @@ class TensorDispatcher {
return
available_
;
return
available_
;
}
}
/*! \brief Load symbols from the given tensor adapter library path */
void
Load
(
const
char
*
path_cstr
);
/*!
/*!
* \brief Allocate an empty tensor.
* \brief Allocate an empty tensor.
*
*
...
@@ -75,7 +80,7 @@ class TensorDispatcher {
...
@@ -75,7 +80,7 @@ class TensorDispatcher {
private:
private:
/*! \brief ctor */
/*! \brief ctor */
TensorDispatcher
();
TensorDispatcher
()
=
default
;
/*! \brief dtor */
/*! \brief dtor */
~
TensorDispatcher
();
~
TensorDispatcher
();
...
@@ -111,4 +116,6 @@ class TensorDispatcher {
...
@@ -111,4 +116,6 @@ class TensorDispatcher {
};
// namespace runtime
};
// namespace runtime
};
// namespace dgl
};
// namespace dgl
#undef FUNCCAST
#endif // DGL_RUNTIME_TENSORDISPATCH_H_
#endif // DGL_RUNTIME_TENSORDISPATCH_H_
python/dgl/_ffi/base.py
View file @
e8054701
...
@@ -113,8 +113,8 @@ def decorate(func, fwrapped):
...
@@ -113,8 +113,8 @@ def decorate(func, fwrapped):
return
decorator
.
decorate
(
func
,
fwrapped
)
return
decorator
.
decorate
(
func
,
fwrapped
)
def
set_ta_path
(
backend
,
version
):
def
load_tensor_adapter
(
backend
,
version
):
"""Tell DGL
which
tensoradapter library
to look for symbols
.
"""Tell DGL
to load a
tensoradapter library
for given backend and version
.
Parameters
Parameters
----------
----------
...
@@ -133,4 +133,4 @@ def set_ta_path(backend, version):
...
@@ -133,4 +133,4 @@ def set_ta_path(backend, version):
else
:
else
:
raise
NotImplementedError
(
'Unsupported system: %s'
%
sys
.
platform
)
raise
NotImplementedError
(
'Unsupported system: %s'
%
sys
.
platform
)
path
=
os
.
path
.
join
(
_DIR_NAME
,
'tensoradapter'
,
backend
,
basename
)
path
=
os
.
path
.
join
(
_DIR_NAME
,
'tensoradapter'
,
backend
,
basename
)
_LIB
.
DGL
SetTAPath
(
path
.
encode
(
'utf-8'
))
_LIB
.
DGL
LoadTensorAdapter
(
path
.
encode
(
'utf-8'
))
python/dgl/backend/__init__.py
View file @
e8054701
...
@@ -38,9 +38,9 @@ def load_backend(mod_name):
...
@@ -38,9 +38,9 @@ def load_backend(mod_name):
else
:
else
:
raise
NotImplementedError
(
'Unsupported backend: %s'
%
mod_name
)
raise
NotImplementedError
(
'Unsupported backend: %s'
%
mod_name
)
from
.._ffi.base
import
set_ta_path
# imports DGL C library
from
.._ffi.base
import
load_tensor_adapter
# imports DGL C library
version
=
mod
.
__version__
version
=
mod
.
__version__
set_ta_path
(
mod_name
,
version
)
load_tensor_adapter
(
mod_name
,
version
)
print
(
'Using backend: %s'
%
mod_name
,
file
=
sys
.
stderr
)
print
(
'Using backend: %s'
%
mod_name
,
file
=
sys
.
stderr
)
mod
=
importlib
.
import_module
(
'.%s'
%
mod_name
,
__name__
)
mod
=
importlib
.
import_module
(
'.%s'
%
mod_name
,
__name__
)
...
...
src/runtime/c_runtime_api.cc
View file @
e8054701
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
#include <dgl/runtime/module.h>
#include <dgl/runtime/module.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/en
v
.h>
#include <dgl/runtime/
t
en
sordispatch
.h>
#include <array>
#include <array>
#include <algorithm>
#include <algorithm>
#include <string>
#include <string>
...
@@ -379,8 +379,8 @@ int DGLCbArgToReturn(DGLValue* value, int code) {
...
@@ -379,8 +379,8 @@ int DGLCbArgToReturn(DGLValue* value, int code) {
API_END
();
API_END
();
}
}
void
DGL
SetTAPath
(
const
char
*
path
_cstr
)
{
void
DGL
LoadTensorAdapter
(
const
char
*
path
)
{
Env
::
Global
()
->
ta_path
=
std
::
string
(
path_cstr
);
TensorDispatcher
::
Global
()
->
Load
(
path
);
}
}
// set device api
// set device api
...
...
src/runtime/tensordispatch.cc
View file @
e8054701
...
@@ -6,7 +6,6 @@
...
@@ -6,7 +6,6 @@
#include <dgl/runtime/tensordispatch.h>
#include <dgl/runtime/tensordispatch.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/env.h>
#include <dgl/packed_func_ext.h>
#include <dgl/packed_func_ext.h>
#if defined(WIN32) || defined(_WIN32)
#if defined(WIN32) || defined(_WIN32)
#include <windows.h>
#include <windows.h>
...
@@ -20,26 +19,33 @@ namespace runtime {
...
@@ -20,26 +19,33 @@ namespace runtime {
constexpr
const
char
*
TensorDispatcher
::
names_
[];
constexpr
const
char
*
TensorDispatcher
::
names_
[];
TensorDispatcher
::
TensorDispatcher
()
{
void
TensorDispatcher
::
Load
(
const
char
*
path
)
{
const
std
::
string
&
path
=
Env
::
Global
()
->
ta_path
;
CHECK
(
!
available_
)
<<
"The tensor adapter can only load once."
;
if
(
path
==
""
)
if
(
path
==
nullptr
||
strlen
(
path
)
==
0
)
// does not have dispatcher library; all operators fall back to DGL's implementation
// does not have dispatcher library; all operators fall back to DGL's implementation
return
;
return
;
#if defined(WIN32) || defined(_WIN32)
#if defined(WIN32) || defined(_WIN32)
handle_
=
LoadLibrary
(
path
.
c_str
()
);
handle_
=
LoadLibrary
(
path
);
if
(
!
handle_
)
if
(
!
handle_
)
return
;
return
;
for
(
int
i
=
0
;
i
<
num_entries_
;
++
i
)
for
(
int
i
=
0
;
i
<
num_entries_
;
++
i
)
{
entrypoints_
[
i
]
=
reinterpret_cast
<
void
*>
(
GetProcAddress
(
handle_
,
names_
[
i
]));
entrypoints_
[
i
]
=
reinterpret_cast
<
void
*>
(
GetProcAddress
(
handle_
,
names_
[
i
]));
CHECK
(
entrypoints_
[
i
])
<<
"cannot locate symbol "
<<
names_
[
i
];
}
#else // !WIN32
#else // !WIN32
handle_
=
dlopen
(
path
.
c_str
(),
RTLD_LAZY
);
handle_
=
dlopen
(
path
,
RTLD_LAZY
);
if
(
!
handle_
)
if
(
!
handle_
)
return
;
return
;
for
(
int
i
=
0
;
i
<
num_entries_
;
++
i
)
for
(
int
i
=
0
;
i
<
num_entries_
;
++
i
)
{
entrypoints_
[
i
]
=
dlsym
(
handle_
,
names_
[
i
]);
entrypoints_
[
i
]
=
dlsym
(
handle_
,
names_
[
i
]);
CHECK
(
entrypoints_
[
i
])
<<
"cannot locate symbol "
<<
names_
[
i
];
}
#endif // WIN32
#endif // WIN32
available_
=
true
;
available_
=
true
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment