Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
43d9f43a
Unverified
Commit
43d9f43a
authored
Apr 06, 2020
by
Quan (Andy) Gan
Committed by
GitHub
Apr 06, 2020
Browse files
backend selection (#1424)
parent
91f7ff8e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
8 deletions
+9
-8
python/dgl/backend/__init__.py
python/dgl/backend/__init__.py
+6
-6
python/dgl/backend/set_default_backend.py
python/dgl/backend/set_default_backend.py
+3
-2
No files found.
python/dgl/backend/__init__.py
View file @
43d9f43a
...
@@ -20,6 +20,7 @@ def _gen_missing_api(api, mod_name):
...
@@ -20,6 +20,7 @@ def _gen_missing_api(api, mod_name):
def
load_backend
(
mod_name
):
def
load_backend
(
mod_name
):
print
(
'Using backend: %s'
%
mod_name
,
file
=
sys
.
stderr
)
mod
=
importlib
.
import_module
(
'.%s'
%
mod_name
,
__name__
)
mod
=
importlib
.
import_module
(
'.%s'
%
mod_name
,
__name__
)
thismod
=
sys
.
modules
[
__name__
]
thismod
=
sys
.
modules
[
__name__
]
for
api
in
backend
.
__dict__
.
keys
():
for
api
in
backend
.
__dict__
.
keys
():
...
@@ -62,13 +63,12 @@ def get_preferred_backend():
...
@@ -62,13 +63,12 @@ def get_preferred_backend():
backend_name
=
config_dict
.
get
(
'backend'
,
''
).
lower
()
backend_name
=
config_dict
.
get
(
'backend'
,
''
).
lower
()
if
(
backend_name
in
[
'tensorflow'
,
'mxnet'
,
'pytorch'
]):
if
(
backend_name
in
[
'tensorflow'
,
'mxnet'
,
'pytorch'
]):
return
backend_name
else
:
while
not
(
backend_name
in
[
'tensorflow'
,
'mxnet'
,
'pytorch'
]):
print
(
"DGL does not detect a valid backend option. Which backend would you like to work with?"
)
backend_name
=
input
(
"Backend choice (pytorch, mxnet or tensorflow): "
).
lower
()
set_default_backend
(
backend_name
)
return
backend_name
return
backend_name
else
:
print
(
"DGL backend not selected or invalid. "
"Assuming PyTorch for now."
,
file
=
sys
.
stderr
)
set_default_backend
(
'pytorch'
)
return
'pytorch'
load_backend
(
get_preferred_backend
())
load_backend
(
get_preferred_backend
())
...
...
python/dgl/backend/set_default_backend.py
View file @
43d9f43a
...
@@ -9,8 +9,9 @@ def set_default_backend(backend_name):
...
@@ -9,8 +9,9 @@ def set_default_backend(backend_name):
config_path
=
os
.
path
.
join
(
default_dir
,
'config.json'
)
config_path
=
os
.
path
.
join
(
default_dir
,
'config.json'
)
with
open
(
config_path
,
"w"
)
as
config_file
:
with
open
(
config_path
,
"w"
)
as
config_file
:
json
.
dump
({
'backend'
:
backend_name
.
lower
()},
config_file
)
json
.
dump
({
'backend'
:
backend_name
.
lower
()},
config_file
)
print
(
'Set the default backend to "{}". You can change it in the '
print
(
'Setting the default backend to "{}". You can change it in the '
'~/.dgl/config.json file or export the DGLBACKEND environment variable.'
.
format
(
'~/.dgl/config.json file or export the DGLBACKEND environment variable. '
'Valid options are: pytorch, mxnet, tensorflow (all lowercase)'
.
format
(
backend_name
))
backend_name
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment