Unverified Commit 79305862 authored by Konstantinos Vandikas's avatar Konstantinos Vandikas Committed by GitHub
Browse files

allow for configuring default_dir (#3277)



* allow for configuring default_dir

* allow for using DGLDEFAULTDIR environment variable

* Update env_var.rst
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 9cc85aed
Environment Variables
=====================
Global Configurations
---------------------
* ``DGLDEFAULTDIR``:
* Values: String (default=``"${HOME}/.dgl"``)
* The directory to save the DGL configuration files.
Backend Options
---------------
* ``DGLBACKEND``:
......@@ -20,7 +26,7 @@ Data Repository
* 'https://data.dgl.ai/': DGL repo for Global Region.
* 'https://dgl-data.s3.cn-north-1.amazonaws.com.cn/': DGL repo for Mainland China
* ``DGL_DOWNLOAD_DIR``:
* Values: String (default="${HOME}/.dgl")
* Values: String (default=``"${HOME}/.dgl"``)
* The local directory to cache the downloaded data.
Intel CPU Performance Options
......
......@@ -74,7 +74,12 @@ def load_backend(mod_name):
setattr(thismod, api, _gen_missing_api(api, mod_name))
def get_preferred_backend():
config_path = os.path.join(os.path.expanduser('~'), '.dgl', 'config.json')
default_dir = None
if "DGLDEFAULTDIR" in os.environ:
default_dir = os.getenv('DGLDEFAULTDIR')
else:
default_dir = os.path.join(os.path.expanduser('~'), '.dgl')
config_path = os.path.join(default_dir, 'config.json')
backend_name = None
if "DGLBACKEND" in os.environ:
backend_name = os.getenv('DGLBACKEND')
......@@ -88,7 +93,7 @@ def get_preferred_backend():
else:
print("DGL backend not selected or invalid. "
"Assuming PyTorch for now.", file=sys.stderr)
set_default_backend('pytorch')
set_default_backend(default_dir, 'pytorch')
return 'pytorch'
......
......@@ -2,8 +2,7 @@ import argparse
import os
import json
def set_default_backend(backend_name):
default_dir = os.path.join(os.path.expanduser('~'), '.dgl')
def set_default_backend(default_dir, backend_name):
if not os.path.exists(default_dir):
os.makedirs(default_dir)
config_path = os.path.join(default_dir, 'config.json')
......@@ -16,7 +15,8 @@ def set_default_backend(backend_name):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("default_dir", type=str, default=os.path.join(os.path.expanduser('~'), '.dgl'))
parser.add_argument("backend", nargs=1, type=str, choices=[
'pytorch', 'tensorflow', 'mxnet'], help="Set default backend")
args = parser.parse_args()
set_default_backend(args.backend[0])
set_default_backend(args.default_dir, args.backend[0])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment