Unverified Commit 1552090a authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

set default backend. (#1104)

parent adea4ba1
...@@ -8,7 +8,7 @@ import logging ...@@ -8,7 +8,7 @@ import logging
import time import time
import pickle import pickle
backend = os.environ.get('DGLBACKEND') backend = os.environ.get('DGLBACKEND', 'pytorch')
if backend.lower() == 'mxnet': if backend.lower() == 'mxnet':
from train_mxnet import load_model_from_checkpoint from train_mxnet import load_model_from_checkpoint
from train_mxnet import test from train_mxnet import test
......
...@@ -2,7 +2,7 @@ import os ...@@ -2,7 +2,7 @@ import os
import numpy as np import numpy as np
import dgl.backend as F import dgl.backend as F
backend = os.environ.get('DGLBACKEND') backend = os.environ.get('DGLBACKEND', 'pytorch')
if backend.lower() == 'mxnet': if backend.lower() == 'mxnet':
from .mxnet.tensor_models import logsigmoid from .mxnet.tensor_models import logsigmoid
from .mxnet.tensor_models import get_device from .mxnet.tensor_models import get_device
......
...@@ -5,7 +5,7 @@ import numpy as np ...@@ -5,7 +5,7 @@ import numpy as np
import dgl.backend as F import dgl.backend as F
import dgl import dgl
backend = os.environ.get('DGLBACKEND') backend = os.environ.get('DGLBACKEND', 'pytorch')
if backend.lower() == 'mxnet': if backend.lower() == 'mxnet':
import mxnet as mx import mxnet as mx
mx.random.seed(42) mx.random.seed(42)
......
...@@ -6,7 +6,7 @@ import os ...@@ -6,7 +6,7 @@ import os
import logging import logging
import time import time
backend = os.environ.get('DGLBACKEND') backend = os.environ.get('DGLBACKEND', 'pytorch')
if backend.lower() == 'mxnet': if backend.lower() == 'mxnet':
import multiprocessing as mp import multiprocessing as mp
from train_mxnet import load_model from train_mxnet import load_model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment