Commit b24daa66 authored by Minjie Wang's avatar Minjie Wang
Browse files

change to rel import within dgl

parent 842d3768
from __future__ import absolute_import
import os import os
__backend__ = os.environ.get('DGLBACKEND', 'pytorch').lower() __backend__ = os.environ.get('DGLBACKEND', 'pytorch').lower()
if __backend__ == 'numpy': if __backend__ == 'numpy':
from dgl.backend.numpy import * from .numpy import *
elif __backend__ == 'pytorch': elif __backend__ == 'pytorch':
from dgl.backend.pytorch import * from .pytorch import *
else: else:
raise Exception("Unsupported backend %s" % __backend__) raise Exception("Unsupported backend %s" % __backend__)
...@@ -4,7 +4,7 @@ import torch as th ...@@ -4,7 +4,7 @@ import torch as th
from .._ffi.runtime_ctypes import TVMType, TVMContext, TVMArray from .._ffi.runtime_ctypes import TVMType, TVMContext, TVMArray
from .._ffi.runtime_ctypes import TypeCode, tvm_shape_index_t from .._ffi.runtime_ctypes import TypeCode, tvm_shape_index_t
from ..context as cpu, gpu from ..context import cpu, gpu
# Tensor types # Tensor types
Tensor = th.Tensor Tensor = th.Tensor
......
...@@ -3,9 +3,8 @@ from __future__ import absolute_import ...@@ -3,9 +3,8 @@ from __future__ import absolute_import
import numpy as np import numpy as np
from dgl.graph import DGLGraph from .graph import DGLGraph
import dgl.backend as F from . import backend as F
import dgl
class BatchedDGLGraph(DGLGraph): class BatchedDGLGraph(DGLGraph):
def __init__(self, graph_list, node_attrs=None, edge_attrs=None, **attr): def __init__(self, graph_list, node_attrs=None, edge_attrs=None, **attr):
......
...@@ -7,9 +7,9 @@ from __future__ import absolute_import ...@@ -7,9 +7,9 @@ from __future__ import absolute_import
import igraph import igraph
import dgl.backend as F from . import backend as F
from dgl.backend import Tensor from .backend import Tensor
import dgl.utils as utils from . import utils
class CachedGraph: class CachedGraph:
def __init__(self): def __init__(self):
......
from __future__ import absolute_import from __future__ import absolute_import
from ._ffi.function import _init_api from ._ffi.function import _init_api
import .backend as F from . import backend as F
class DGLGraph(object): class DGLGraph(object):
def __init__(self): def __init__(self):
......
...@@ -11,7 +11,7 @@ import networkx as nx ...@@ -11,7 +11,7 @@ import networkx as nx
import scipy.sparse as sp import scipy.sparse as sp
import os, sys import os, sys
from dgl.data.utils import download, extract_archive, get_download_dir from .utils import download, extract_archive, get_download_dir
_urls = { _urls = {
'cora' : 'https://www.dropbox.com/s/3ggdpkj7ou8svoc/cora.zip?dl=1', 'cora' : 'https://www.dropbox.com/s/3ggdpkj7ou8svoc/cora.zip?dl=1',
......
...@@ -10,9 +10,9 @@ from nltk.tree import Tree ...@@ -10,9 +10,9 @@ from nltk.tree import Tree
from nltk.corpus.reader import BracketParseCorpusReader from nltk.corpus.reader import BracketParseCorpusReader
import networkx as nx import networkx as nx
import dgl from .. import backend as F
import dgl.backend as F from ..graph import DGLGraph
from dgl.data.utils import download, extract_archive, get_download_dir from .utils import download, extract_archive, get_download_dir
_urls = { _urls = {
'sst' : 'https://www.dropbox.com/s/dw8kr2vuq7k4dqi/sst.zip?dl=1', 'sst' : 'https://www.dropbox.com/s/dw8kr2vuq7k4dqi/sst.zip?dl=1',
......
...@@ -4,9 +4,9 @@ from __future__ import absolute_import ...@@ -4,9 +4,9 @@ from __future__ import absolute_import
from collections import MutableMapping from collections import MutableMapping
import numpy as np import numpy as np
import dgl.backend as F from . import backend as F
from dgl.backend import Tensor from .backend import Tensor
import dgl.utils as utils from . import utils
class Frame(MutableMapping): class Frame(MutableMapping):
def __init__(self, data=None): def __init__(self, data=None):
......
from .message import * """DGL builtin functors"""
from __future__ import absolute_import
from .message import *
from .reducer import * from .reducer import *
"""Built-in reducer function.""" """Built-in reducer function."""
from __future__ import absolute_import from __future__ import absolute_import
import dgl.backend as F from .. import backend as F
__all__ = ["ReduceFunction", "sum", "max"] __all__ = ["ReduceFunction", "sum", "max"]
......
"""Package for graph generators"""
from __future__ import absolute_import
from .line import * from .line import *
...@@ -4,9 +4,9 @@ from __future__ import absolute_import ...@@ -4,9 +4,9 @@ from __future__ import absolute_import
import networkx as nx import networkx as nx
import numpy as np import numpy as np
import dgl.backend as F from .. import backend as F
from dgl.graph import DGLGraph from ..graph import DGLGraph
from dgl.frame import FrameRef from ..frame import FrameRef
def line_graph(G, no_backtracking=False): def line_graph(G, no_backtracking=False):
"""Create the line graph that shares the underlying features. """Create the line graph that shares the underlying features.
......
...@@ -6,15 +6,15 @@ import networkx as nx ...@@ -6,15 +6,15 @@ import networkx as nx
from networkx.classes.digraph import DiGraph from networkx.classes.digraph import DiGraph
import dgl import dgl
from dgl.base import ALL, is_all, __MSG__, __REPR__ from .base import ALL, is_all, __MSG__, __REPR__
import dgl.backend as F from . import backend as F
from dgl.backend import Tensor from .backend import Tensor
from dgl.cached_graph import CachedGraph, create_cached_graph from .cached_graph import CachedGraph, create_cached_graph
import dgl.context as context from . import context
from dgl.frame import FrameRef, merge_frames from .frame import FrameRef, merge_frames
from dgl.nx_adapt import nx_init from .nx_adapt import nx_init
import dgl.scheduler as scheduler from . import scheduler
import dgl.utils as utils from . import utils
class DGLGraph(DiGraph): class DGLGraph(DiGraph):
"""Base graph class specialized for neural networks on graphs. """Base graph class specialized for neural networks on graphs.
......
"""Package nn modules"""
from __future__ import absolute_import
import os import os
__backend__ = os.environ.get('DGLBACKEND', 'pytorch').lower() __backend__ = os.environ.get('DGLBACKEND', 'pytorch').lower()
if __backend__ == 'numpy': if __backend__ == 'numpy':
......
...@@ -7,9 +7,8 @@ GCN with SPMV specialization. ...@@ -7,9 +7,8 @@ GCN with SPMV specialization.
""" """
import torch.nn as nn import torch.nn as nn
import dgl from ... import function as fn
import dgl.function as fn from ...base import ALL, is_all
from dgl.base import ALL, is_all
class NodeUpdateModule(nn.Module): class NodeUpdateModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None): def __init__(self, in_feats, out_feats, activation=None):
......
...@@ -3,10 +3,10 @@ from __future__ import absolute_import ...@@ -3,10 +3,10 @@ from __future__ import absolute_import
import numpy as np import numpy as np
import dgl.backend as F from . import backend as F
import dgl.function.message as fmsg from .function import message as fmsg
import dgl.function.reducer as fred from .function import reducer as fred
import dgl.utils as utils from . import utils
__all__ = ["degree_bucketing", "get_executor"] __all__ = ["degree_bucketing", "get_executor"]
......
...@@ -2,11 +2,12 @@ ...@@ -2,11 +2,12 @@
from __future__ import absolute_import from __future__ import absolute_import
import networkx as nx import networkx as nx
import dgl.backend as F
from dgl.frame import Frame, FrameRef from . import backend as F
from dgl.graph import DGLGraph from .frame import Frame, FrameRef
from dgl.nx_adapt import nx_init from .graph import DGLGraph
import dgl.utils as utils from .nx_adapt import nx_init
from . import utils
class DGLSubGraph(DGLGraph): class DGLSubGraph(DGLGraph):
# TODO(gaiyu): ReadOnlyGraph # TODO(gaiyu): ReadOnlyGraph
......
...@@ -5,8 +5,8 @@ from collections import Mapping ...@@ -5,8 +5,8 @@ from collections import Mapping
from functools import wraps from functools import wraps
import numpy as np import numpy as np
import dgl.backend as F from . import backend as F
from dgl.backend import Tensor, SparseTensor from .backend import Tensor, SparseTensor
def is_id_tensor(u): def is_id_tensor(u):
"""Return whether the input is a supported id tensor.""" """Return whether the input is a supported id tensor."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment