Unverified Commit 43d9f43a authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

backend selection (#1424)

parent 91f7ff8e
...@@ -20,6 +20,7 @@ def _gen_missing_api(api, mod_name): ...@@ -20,6 +20,7 @@ def _gen_missing_api(api, mod_name):
def load_backend(mod_name): def load_backend(mod_name):
print('Using backend: %s' % mod_name, file=sys.stderr)
mod = importlib.import_module('.%s' % mod_name, __name__) mod = importlib.import_module('.%s' % mod_name, __name__)
thismod = sys.modules[__name__] thismod = sys.modules[__name__]
for api in backend.__dict__.keys(): for api in backend.__dict__.keys():
...@@ -64,11 +65,10 @@ def get_preferred_backend(): ...@@ -64,11 +65,10 @@ def get_preferred_backend():
if (backend_name in ['tensorflow', 'mxnet', 'pytorch']): if (backend_name in ['tensorflow', 'mxnet', 'pytorch']):
return backend_name return backend_name
else: else:
while not(backend_name in ['tensorflow', 'mxnet', 'pytorch']): print("DGL backend not selected or invalid. "
print("DGL does not detect a valid backend option. Which backend would you like to work with?") "Assuming PyTorch for now.", file=sys.stderr)
backend_name = input("Backend choice (pytorch, mxnet or tensorflow): ").lower() set_default_backend('pytorch')
set_default_backend(backend_name) return 'pytorch'
return backend_name
load_backend(get_preferred_backend()) load_backend(get_preferred_backend())
......
...@@ -9,8 +9,9 @@ def set_default_backend(backend_name): ...@@ -9,8 +9,9 @@ def set_default_backend(backend_name):
config_path = os.path.join(default_dir, 'config.json') config_path = os.path.join(default_dir, 'config.json')
with open(config_path, "w") as config_file: with open(config_path, "w") as config_file:
json.dump({'backend': backend_name.lower()}, config_file) json.dump({'backend': backend_name.lower()}, config_file)
print('Set the default backend to "{}". You can change it in the ' print('Setting the default backend to "{}". You can change it in the '
'~/.dgl/config.json file or export the DGLBACKEND environment variable.'.format( '~/.dgl/config.json file or export the DGLBACKEND environment variable. '
'Valid options are: pytorch, mxnet, tensorflow (all lowercase)'.format(
backend_name)) backend_name))
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment