Unverified Commit 43d9f43a authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

backend selection (#1424)

parent 91f7ff8e
...@@ -20,6 +20,7 @@ def _gen_missing_api(api, mod_name): ...@@ -20,6 +20,7 @@ def _gen_missing_api(api, mod_name):
def load_backend(mod_name): def load_backend(mod_name):
print('Using backend: %s' % mod_name, file=sys.stderr)
mod = importlib.import_module('.%s' % mod_name, __name__) mod = importlib.import_module('.%s' % mod_name, __name__)
thismod = sys.modules[__name__] thismod = sys.modules[__name__]
for api in backend.__dict__.keys(): for api in backend.__dict__.keys():
...@@ -62,13 +63,12 @@ def get_preferred_backend(): ...@@ -62,13 +63,12 @@ def get_preferred_backend():
backend_name = config_dict.get('backend', '').lower() backend_name = config_dict.get('backend', '').lower()
if (backend_name in ['tensorflow', 'mxnet', 'pytorch']): if (backend_name in ['tensorflow', 'mxnet', 'pytorch']):
return backend_name
else:
while not(backend_name in ['tensorflow', 'mxnet', 'pytorch']):
print("DGL does not detect a valid backend option. Which backend would you like to work with?")
backend_name = input("Backend choice (pytorch, mxnet or tensorflow): ").lower()
set_default_backend(backend_name)
return backend_name return backend_name
else:
print("DGL backend not selected or invalid. "
"Assuming PyTorch for now.", file=sys.stderr)
set_default_backend('pytorch')
return 'pytorch'
load_backend(get_preferred_backend()) load_backend(get_preferred_backend())
......
...@@ -9,8 +9,9 @@ def set_default_backend(backend_name): ...@@ -9,8 +9,9 @@ def set_default_backend(backend_name):
config_path = os.path.join(default_dir, 'config.json') config_path = os.path.join(default_dir, 'config.json')
with open(config_path, "w") as config_file: with open(config_path, "w") as config_file:
json.dump({'backend': backend_name.lower()}, config_file) json.dump({'backend': backend_name.lower()}, config_file)
print('Set the default backend to "{}". You can change it in the ' print('Setting the default backend to "{}". You can change it in the '
'~/.dgl/config.json file or export the DGLBACKEND environment variable.'.format( '~/.dgl/config.json file or export the DGLBACKEND environment variable. '
'Valid options are: pytorch, mxnet, tensorflow (all lowercase)'.format(
backend_name)) backend_name))
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment