Unverified Commit a0721405 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[BUGFIX] don’t import dgl in the package. (#1382)



* fix dgl data.

* remove more.

* fix.

* fix.
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-16-150.us-west-2.compute.internal>
parent 10253a5c
from scipy import io from scipy import io
import numpy as np import numpy as np
from dgl import DGLGraph
import os import os
import datetime import datetime
from .utils import get_download_dir, download, extract_archive from .utils import get_download_dir, download, extract_archive
from ..graph import DGLGraph
class BitcoinOTC(object): class BitcoinOTC(object):
...@@ -63,4 +63,4 @@ class BitcoinOTC(object): ...@@ -63,4 +63,4 @@ class BitcoinOTC(object):
@property @property
def is_temporal(self): def is_temporal(self):
return True return True
\ No newline at end of file
...@@ -10,10 +10,10 @@ import pickle as pkl ...@@ -10,10 +10,10 @@ import pickle as pkl
import networkx as nx import networkx as nx
import scipy.sparse as sp import scipy.sparse as sp
import os, sys import os, sys
from dgl import DGLGraph
import dgl
from .utils import download, extract_archive, get_download_dir, _get_dgl_url from .utils import download, extract_archive, get_download_dir, _get_dgl_url
from ..graph import DGLGraph
from ..graph import batch as graph_batch
_urls = { _urls = {
'cora' : 'dataset/cora_raw.zip', 'cora' : 'dataset/cora_raw.zip',
...@@ -314,13 +314,13 @@ class CoraBinary(object): ...@@ -314,13 +314,13 @@ class CoraBinary(object):
for line in f.readlines(): for line in f.readlines():
if line.startswith('graph'): if line.startswith('graph'):
if len(elist) != 0: if len(elist) != 0:
self.graphs.append(dgl.DGLGraph(elist)) self.graphs.append(DGLGraph(elist))
elist = [] elist = []
else: else:
u, v = line.strip().split(' ') u, v = line.strip().split(' ')
elist.append((int(u), int(v))) elist.append((int(u), int(v)))
if len(elist) != 0: if len(elist) != 0:
self.graphs.append(dgl.DGLGraph(elist)) self.graphs.append(DGLGraph(elist))
with open("{}/pmpds.pkl".format(root), 'rb') as f: with open("{}/pmpds.pkl".format(root), 'rb') as f:
self.pmpds = _pickle_load(f) self.pmpds = _pickle_load(f)
self.labels = [] self.labels = []
...@@ -348,7 +348,7 @@ class CoraBinary(object): ...@@ -348,7 +348,7 @@ class CoraBinary(object):
@staticmethod @staticmethod
def collate_fn(batch): def collate_fn(batch):
graphs, pmpds, labels = zip(*batch) graphs, pmpds, labels = zip(*batch)
batched_graphs = dgl.batch(graphs) batched_graphs = graph_batch(graphs)
batched_pmpds = sp.block_diag(pmpds) batched_pmpds = sp.block_diag(pmpds)
batched_labels = np.concatenate(labels, axis=0) batched_labels = np.concatenate(labels, axis=0)
return batched_graphs, batched_pmpds, batched_labels return batched_graphs, batched_pmpds, batched_labels
......
from scipy import io from scipy import io
import numpy as np import numpy as np
from dgl import DGLGraph
import os import os
import datetime import datetime
from .utils import get_download_dir, download, extract_archive, loadtxt from .utils import get_download_dir, download, extract_archive, loadtxt
from ..graph import DGLGraph
class GDELT(object): class GDELT(object):
......
import scipy.sparse as sp import scipy.sparse as sp
import numpy as np import numpy as np
from dgl import graph_index, DGLGraph, transform
import os import os
from .utils import download, extract_archive, get_download_dir, _get_dgl_url from .utils import download, extract_archive, get_download_dir, _get_dgl_url
from ..graph import DGLGraph
__all__=["AmazonCoBuy", "Coauthor", 'CoraFull'] __all__=["AmazonCoBuy", "Coauthor", 'CoraFull']
......
from scipy import io from scipy import io
import numpy as np import numpy as np
from dgl import DGLGraph
import os import os
import datetime import datetime
import warnings import warnings
from .utils import get_download_dir, download, extract_archive, loadtxt from .utils import get_download_dir, download, extract_archive, loadtxt
from ..graph import DGLGraph
class ICEWS18(object): class ICEWS18(object):
...@@ -88,4 +88,4 @@ class ICEWS18(object): ...@@ -88,4 +88,4 @@ class ICEWS18(object):
@property @property
def is_temporal(self): def is_temporal(self):
return True return True
\ No newline at end of file
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
""" """
import numpy as np import numpy as np
import networkx as nx import networkx as nx
from dgl import DGLGraph from ..graph import DGLGraph
class KarateClub(object): class KarateClub(object):
......
from scipy import io from scipy import io
import numpy as np import numpy as np
from dgl import DGLGraph
import os import os
from .utils import get_download_dir, download from .utils import get_download_dir, download
from ..graph import DGLGraph
class QM7b(object): class QM7b(object):
""" """
......
...@@ -2,9 +2,9 @@ from __future__ import absolute_import ...@@ -2,9 +2,9 @@ from __future__ import absolute_import
import scipy.sparse as sp import scipy.sparse as sp
import numpy as np import numpy as np
import dgl
import os, sys import os, sys
from .utils import download, extract_archive, get_download_dir, _get_dgl_url from .utils import download, extract_archive, get_download_dir, _get_dgl_url
from ..graph import DGLGraph
class RedditDataset(object): class RedditDataset(object):
...@@ -19,7 +19,7 @@ class RedditDataset(object): ...@@ -19,7 +19,7 @@ class RedditDataset(object):
extract_archive(zip_file_path, extract_dir) extract_archive(zip_file_path, extract_dir)
# graph # graph
coo_adj = sp.load_npz(os.path.join(extract_dir, "reddit{}_graph.npz".format(self_loop_str))) coo_adj = sp.load_npz(os.path.join(extract_dir, "reddit{}_graph.npz".format(self_loop_str)))
self.graph = dgl.DGLGraph(coo_adj, readonly=True) self.graph = DGLGraph(coo_adj, readonly=True)
# features and labels # features and labels
reddit_data = np.load(os.path.join(extract_dir, "reddit_data.npz")) reddit_data = np.load(os.path.join(extract_dir, "reddit_data.npz"))
self.features = reddit_data["feature"] self.features = reddit_data["feature"]
......
...@@ -10,9 +10,10 @@ import networkx as nx ...@@ -10,9 +10,10 @@ import networkx as nx
import numpy as np import numpy as np
import os import os
import dgl
import dgl.backend as F from .. import backend as F
from dgl.data.utils import download, extract_archive, get_download_dir, _get_dgl_url from ..graph import DGLGraph
from .utils import download, extract_archive, get_download_dir, _get_dgl_url
__all__ = ['SSTBatch', 'SST'] __all__ = ['SSTBatch', 'SST']
...@@ -116,7 +117,7 @@ class SST(object): ...@@ -116,7 +117,7 @@ class SST(object):
# add root # add root
g.add_node(0, x=SST.PAD_WORD, y=int(root.label()), mask=0) g.add_node(0, x=SST.PAD_WORD, y=int(root.label()), mask=0)
_rec_build(0, root) _rec_build(0, root)
ret = dgl.DGLGraph() ret = DGLGraph()
ret.from_networkx(g, node_attrs=['x', 'y', 'mask']) ret.from_networkx(g, node_attrs=['x', 'y', 'mask'])
return ret return ret
......
from __future__ import absolute_import from __future__ import absolute_import
import numpy as np import numpy as np
import dgl
import os import os
import random import random
from dgl.data.utils import download, extract_archive, get_download_dir, loadtxt from .utils import download, extract_archive, get_download_dir, loadtxt
from ..graph import DGLGraph
class LegacyTUDataset(object): class LegacyTUDataset(object):
""" """
...@@ -46,7 +45,7 @@ class LegacyTUDataset(object): ...@@ -46,7 +45,7 @@ class LegacyTUDataset(object):
DS_graph_labels = self._idx_from_zero( DS_graph_labels = self._idx_from_zero(
np.genfromtxt(self._file_path("graph_labels"), dtype=int)) np.genfromtxt(self._file_path("graph_labels"), dtype=int))
g = dgl.DGLGraph() g = DGLGraph()
g.add_nodes(int(DS_edge_list.max()) + 1) g.add_nodes(int(DS_edge_list.max()) + 1)
g.add_edges(DS_edge_list[:, 0], DS_edge_list[:, 1]) g.add_edges(DS_edge_list[:, 0], DS_edge_list[:, 1])
...@@ -180,7 +179,7 @@ class TUDataset(object): ...@@ -180,7 +179,7 @@ class TUDataset(object):
DS_graph_labels = self._idx_from_zero( DS_graph_labels = self._idx_from_zero(
loadtxt(self._file_path("graph_labels"), delimiter=",").astype(int)) loadtxt(self._file_path("graph_labels"), delimiter=",").astype(int))
g = dgl.DGLGraph() g = DGLGraph()
g.add_nodes(int(DS_edge_list.max()) + 1) g.add_nodes(int(DS_edge_list.max()) + 1)
g.add_edges(DS_edge_list[:, 0], DS_edge_list[:, 1]) g.add_edges(DS_edge_list[:, 0], DS_edge_list[:, 1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment