Unverified Commit 1552090a authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

set default backend. (#1104)

parent adea4ba1
......@@ -8,7 +8,7 @@ import logging
import time
import pickle
backend = os.environ.get('DGLBACKEND')
backend = os.environ.get('DGLBACKEND', 'pytorch')
if backend.lower() == 'mxnet':
from train_mxnet import load_model_from_checkpoint
from train_mxnet import test
......
......@@ -2,7 +2,7 @@ import os
import numpy as np
import dgl.backend as F
backend = os.environ.get('DGLBACKEND')
backend = os.environ.get('DGLBACKEND', 'pytorch')
if backend.lower() == 'mxnet':
from .mxnet.tensor_models import logsigmoid
from .mxnet.tensor_models import get_device
......
......@@ -5,7 +5,7 @@ import numpy as np
import dgl.backend as F
import dgl
backend = os.environ.get('DGLBACKEND')
backend = os.environ.get('DGLBACKEND', 'pytorch')
if backend.lower() == 'mxnet':
import mxnet as mx
mx.random.seed(42)
......
......@@ -6,7 +6,7 @@ import os
import logging
import time
backend = os.environ.get('DGLBACKEND')
backend = os.environ.get('DGLBACKEND', 'pytorch')
if backend.lower() == 'mxnet':
import multiprocessing as mp
from train_mxnet import load_model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment